cloudflared-mirror/tunnelrpc/pogs/config_test.go

415 lines
14 KiB
Go

package pogs
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/tunnelrpc"
capnp "zombiezen.com/go/capnproto2"
)
// Assert *HTTPOriginConfig implements OriginConfig
var _ OriginConfig = (*HTTPOriginConfig)(nil)
// Assert *WebSocketOriginConfig implements OriginConfig
var _ OriginConfig = (*WebSocketOriginConfig)(nil)
// Assert *HelloWorldOriginConfig implements OriginConfig
var _ OriginConfig = (*HelloWorldOriginConfig)(nil)
func TestVersion(t *testing.T) {
firstVersion := InitVersion()
secondVersion := Version(1)
assert.False(t, firstVersion.IsNewerOrEqual(secondVersion))
assert.True(t, secondVersion.IsNewerOrEqual(firstVersion))
assert.True(t, secondVersion.IsNewerOrEqual(secondVersion))
}
func TestClientConfig(t *testing.T) {
addDoHProxyConfigs := func(c *ClientConfig) {
c.DoHProxyConfigs = []*DoHProxyConfig{
sampleDoHProxyConfig(),
}
}
addReverseProxyConfigs := func(c *ClientConfig) {
c.ReverseProxyConfigs = []*ReverseProxyConfig{
sampleReverseProxyConfig(),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()}
}),
}
}
testCases := []*ClientConfig{
sampleClientConfig(),
sampleClientConfig(addDoHProxyConfigs),
sampleClientConfig(addReverseProxyConfigs),
sampleClientConfig(addDoHProxyConfigs, addReverseProxyConfigs),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewClientConfig(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalClientConfig(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalClientConfig(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestUseConfigurationResult(t *testing.T) {
testCases := []*UseConfigurationResult{
&UseConfigurationResult{
Success: true,
},
&UseConfigurationResult{
Success: false,
FailedConfigs: []*FailedConfig{
{
Config: sampleReverseProxyConfig(),
Reason: "Invalid certificate",
},
{
Config: sampleDoHProxyConfig(),
Reason: "Cannot listen on port 53",
},
},
},
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewUseConfigurationResult(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalUseConfigurationResult(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalUseConfigurationResult(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestDoHProxyConfig(t *testing.T) {
testCases := []*DoHProxyConfig{
sampleDoHProxyConfig(),
sampleDoHProxyConfig(func(c *DoHProxyConfig) {
c.Upstreams = nil
}),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewDoHProxyConfig(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalDoHProxyConfig(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalDoHProxyConfig(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestReverseProxyConfig(t *testing.T) {
testCases := []*ReverseProxyConfig{
sampleReverseProxyConfig(),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()}
}),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewReverseProxyConfig(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalReverseProxyConfig(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalReverseProxyConfig(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestHTTPOriginConfig(t *testing.T) {
testCases := []*HTTPOriginConfig{
sampleHTTPOriginConfig(),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewHTTPOriginConfig(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalHTTPOriginConfig(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalHTTPOriginConfig(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestWebSocketOriginConfig(t *testing.T) {
testCases := []*WebSocketOriginConfig{
sampleWebSocketOriginConfig(),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewWebSocketOriginConfig(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalWebSocketOriginConfig(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalWebSocketOriginConfig(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestOriginConfigInvalidURL(t *testing.T) {
invalidConfigs := []OriginConfig{
&HTTPOriginConfig{
// this url doesn't have a scheme
URLString: "127.0.0.1:36192",
},
&WebSocketOriginConfig{
URLString: "127.0.0.1:36192",
},
}
for _, config := range invalidConfigs {
service, err := config.Service()
assert.Error(t, err)
assert.Nil(t, service)
}
}
//////////////////////////////////////////////////////////////////////////////
// Functions to generate sample data for ease of testing
//
// There's one "sample" function per struct type. Each goes like this:
// 1. Initialize an instance of the relevant struct.
// 2. Ensure the instance has no zero-valued fields. (This catches the
// error-case where a field was added, but we forgot to add code to
// marshal/unmarshal this field in CapnProto.)
// 3. Apply one or more "override" functions (which accept a
// pointer-to-struct, so they can mutate the instance).
// sampleClientConfig initializes a new ClientConfig literal,
// applies any number of overrides to it, and returns it.
func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
sample := &ClientConfig{
Version: Version(1337),
SupervisorConfig: &SupervisorConfig{
AutoUpdateFrequency: 21 * time.Hour,
MetricsUpdateFrequency: 11 * time.Minute,
GracePeriod: 31 * time.Second,
},
EdgeConnectionConfig: &EdgeConnectionConfig{
NumHAConnections: 49,
Timeout: 9 * time.Second,
HeartbeatInterval: 5 * time.Second,
MaxFailedHeartbeats: 9001,
},
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
// sampleDoHProxyConfig initializes a new DoHProxyConfig struct,
// applies any number of overrides to it, and returns it.
func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
sample := &DoHProxyConfig{
ListenHost: "127.0.0.1",
ListenPort: 53,
Upstreams: []string{"https://1.example.com", "https://2.example.com"},
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
// sampleReverseProxyConfig initializes a new ReverseProxyConfig struct,
// applies any number of overrides to it, and returns it.
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
sample := &ReverseProxyConfig{
TunnelHostname: "hijk.example.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{&HelloWorldOriginConfig{}},
Retries: 18,
ConnectionTimeout: 5 * time.Second,
CompressionQuality: 4,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "https.example.com",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
TLSVerify: true,
OriginCAPool: "/etc/cert.pem",
OriginServerName: "secure.example.com",
MaxIdleConnections: 19,
IdleConnectionTimeout: 17 * time.Second,
ProxyConnectionTimeout: 15 * time.Second,
ExpectContinueTimeout: 21 * time.Second,
ChunkedEncoding: true,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
func sampleHTTPOriginConfigUnixPath(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "unix:/var/lib/file.sock",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
TLSVerify: true,
OriginCAPool: "/etc/cert.pem",
OriginServerName: "secure.example.com",
MaxIdleConnections: 19,
IdleConnectionTimeout: 17 * time.Second,
ProxyConnectionTimeout: 15 * time.Second,
ExpectContinueTimeout: 21 * time.Second,
ChunkedEncoding: true,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
func sampleWebSocketOriginConfig(overrides ...func(*WebSocketOriginConfig)) *WebSocketOriginConfig {
sample := &WebSocketOriginConfig{
URLString: "ssh://example.com",
TLSVerify: true,
OriginCAPool: "/etc/cert.pem",
OriginServerName: "secure.example.com",
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
func (c *ClientConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{"DoHProxyConfigs", "ReverseProxyConfigs"})
}
func (c *DoHProxyConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}
func (c *ReverseProxyConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}
func (c *HTTPOriginConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}
func (c *WebSocketOriginConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}
// ensureNoZeroFieldsInSample checks that all fields in the sample struct,
// except those listed in `allowedZeroFieldNames`, are initialized to nonzero
// values. Note that the value has to be a pointer for reflection to work
// correctly:
// e := &ExampleStruct{ ... }
// ensureNoZeroFieldsInSample(reflect.ValueOf(e), []string{})
//
// Context:
// Our tests work by taking a sample struct and marshalling/unmarshalling it.
// This makes them easy to write, but introduces some risk: if we don't
// include a field in the sample value, it won't be covered under tests.
// This check reduces that risk by requiring fields to be either initialized
// or explicitly uninitialized.
func ensureNoZeroFieldsInSample(ptrToSampleValue reflect.Value, allowedZeroFieldNames []string) {
sampleValue := ptrToSampleValue.Elem()
structType := ptrToSampleValue.Type().Elem()
allowedZeroFieldSet := make(map[string]bool)
for _, name := range allowedZeroFieldNames {
if _, ok := structType.FieldByName(name); !ok {
panic(fmt.Sprintf("struct %v has no field %v", structType.Name(), name))
}
allowedZeroFieldSet[name] = true
}
for i := 0; i < structType.NumField(); i++ {
if allowedZeroFieldSet[structType.Field(i).Name] {
continue
}
zeroValue := reflect.Zero(structType.Field(i).Type)
if reflect.DeepEqual(zeroValue.Interface(), sampleValue.Field(i).Interface()) {
panic(fmt.Sprintf("In the sample value for struct %v, field %v was not initialized", structType.Name(), structType.Field(i).Name))
}
}
}