diff --git a/streamhandler/stream_handler.go b/streamhandler/stream_handler.go index 9f4a79f2..99c49657 100644 --- a/streamhandler/stream_handler.go +++ b/streamhandler/stream_handler.go @@ -86,7 +86,7 @@ func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (fail s.tunnelHostnameMapper.DeleteAll() for _, tunnelConfig := range newConfig { tunnelHostname := tunnelConfig.TunnelHostname - originSerice, err := tunnelConfig.OriginConfigUnmarshaler.OriginConfig.Service() + originSerice, err := tunnelConfig.OriginConfigJSONHandler.OriginConfig.Service() if err != nil { s.logger.WithField("tunnelHostname", tunnelHostname).WithError(err).Error("Invalid origin service config") failedConfigs = append(failedConfigs, &pogs.FailedConfig{ diff --git a/streamhandler/stream_handler_test.go b/streamhandler/stream_handler_test.go index acdcfda5..c8146127 100644 --- a/streamhandler/stream_handler_test.go +++ b/streamhandler/stream_handler_test.go @@ -49,7 +49,7 @@ func TestServeRequest(t *testing.T) { reverseProxyConfigs := []*pogs.ReverseProxyConfig{ { TunnelHostname: testTunnelHostname, - OriginConfigUnmarshaler: &pogs.OriginConfigUnmarshaler{ + OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{ OriginConfig: &pogs.HTTPOriginConfig{ URLString: httpServer.URL, }, @@ -99,7 +99,7 @@ func TestServeBadRequest(t *testing.T) { reverseProxyConfigs := []*pogs.ReverseProxyConfig{ { TunnelHostname: testTunnelHostname, - OriginConfigUnmarshaler: &pogs.OriginConfigUnmarshaler{ + OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{ OriginConfig: &pogs.HTTPOriginConfig{ URLString: "", }, diff --git a/tunnelrpc/pogs/config.go b/tunnelrpc/pogs/config.go index 644d40f2..76d079a0 100644 --- a/tunnelrpc/pogs/config.go +++ b/tunnelrpc/pogs/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/json" "fmt" "net" "net/http" @@ -53,6 +54,7 @@ func (v Version) String() string { // FallibleConfig is an interface implemented by configs that cloudflared might not be able to apply type FallibleConfig interface { FailReason(err error) string + jsonType() string } // SupervisorConfig specifies config of components managed by Supervisor other than ConnectionManager @@ -67,6 +69,16 @@ func (sc *SupervisorConfig) FailReason(err error) string { return fmt.Sprintf("Cannot apply SupervisorConfig, err: %v", err) } +func (sc *SupervisorConfig) MarshalJSON() ([]byte, error) { + marshaler := make(map[string]SupervisorConfig, 1) + marshaler[sc.jsonType()] = *sc + return json.Marshal(marshaler) +} + +func (sc *SupervisorConfig) jsonType() string { + return "supervisor_config" +} + // EdgeConnectionConfig specifies what parameters and how may connections should ConnectionManager establish with edge type EdgeConnectionConfig struct { NumHAConnections uint8 `json:"num_ha_connections"` @@ -81,6 +93,16 @@ func (cmc *EdgeConnectionConfig) FailReason(err error) string { return fmt.Sprintf("Cannot apply EdgeConnectionConfig, err: %v", err) } +func (cmc *EdgeConnectionConfig) MarshalJSON() ([]byte, error) { + marshaler := make(map[string]EdgeConnectionConfig, 1) + marshaler[cmc.jsonType()] = *cmc + return json.Marshal(marshaler) +} + +func (cmc *EdgeConnectionConfig) jsonType() string { + return "edge_connection_config" +} + // DoHProxyConfig is configuration for DNS over HTTPS service type DoHProxyConfig struct { ListenHost string `json:"listen_host"` @@ -93,10 +115,20 @@ func (dpc *DoHProxyConfig) FailReason(err error) string { return fmt.Sprintf("Cannot apply DoHProxyConfig, err: %v", err) } +func (dpc *DoHProxyConfig) MarshalJSON() ([]byte, error) { + marshaler := make(map[string]DoHProxyConfig, 1) + marshaler[dpc.jsonType()] = *dpc + return json.Marshal(marshaler) +} + +func (dpc *DoHProxyConfig) jsonType() string { + return "doh_proxy_config" +} + // ReverseProxyConfig how and for what hostnames can this cloudflared proxy type ReverseProxyConfig struct { TunnelHostname h2mux.TunnelHostname `json:"tunnel_hostname"` - OriginConfigUnmarshaler *OriginConfigUnmarshaler `json:"origin_config"` + OriginConfigJSONHandler *OriginConfigJSONHandler `json:"origin_config"` Retries uint64 `json:"retries"` ConnectionTimeout time.Duration `json:"connection_timeout"` CompressionQuality uint64 `json:"compression_quality"` @@ -114,7 +146,7 @@ func NewReverseProxyConfig( } return &ReverseProxyConfig{ TunnelHostname: h2mux.TunnelHostname(tunnelHostname), - OriginConfigUnmarshaler: &OriginConfigUnmarshaler{originConfig}, + OriginConfigJSONHandler: &OriginConfigJSONHandler{originConfig}, Retries: retries, ConnectionTimeout: connectionTimeout, CompressionQuality: compressionQuality, @@ -126,12 +158,22 @@ func (rpc *ReverseProxyConfig) FailReason(err error) string { return fmt.Sprintf("Cannot apply ReverseProxyConfig, err: %v", err) } +func (rpc *ReverseProxyConfig) MarshalJSON() ([]byte, error) { + marshaler := make(map[string]ReverseProxyConfig, 1) + marshaler[rpc.jsonType()] = *rpc + return json.Marshal(marshaler) +} + +func (rpc *ReverseProxyConfig) jsonType() string { + return "reverse_proxy_config" +} + //go-sumtype:decl OriginConfig type OriginConfig interface { // Service returns a OriginService used to proxy to the origin Service() (originservice.OriginService, error) // go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed - isOriginConfig() + jsonType() string } type originType int @@ -156,18 +198,18 @@ func (ot originType) String() string { } type HTTPOriginConfig struct { - URLString string `capnp:"urlString" mapstructure:"url_string"` - TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" mapstructure:"tcp_keep_alive"` - DialDualStack bool `mapstructure:"dial_dual_stack"` - TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" mapstructure:"tls_handshake_timeout"` - TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"` - OriginCAPool string `mapstructure:"origin_ca_pool"` - OriginServerName string `mapstructure:"origin_server_name"` - MaxIdleConnections uint64 `mapstructure:"max_idle_connections"` - IdleConnectionTimeout time.Duration `mapstructure:"idle_connection_timeout"` - ProxyConnectionTimeout time.Duration `mapstructure:"proxy_connection_timeout"` - ExpectContinueTimeout time.Duration `mapstructure:"expect_continue_timeout"` - ChunkedEncoding bool `mapstructure:"chunked_encoding"` + URLString string `capnp:"urlString" json:"url_string" mapstructure:"url_string"` + TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" json:"tcp_keep_alive" mapstructure:"tcp_keep_alive"` + DialDualStack bool `json:"dial_dual_stack" mapstructure:"dial_dual_stack"` + TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" json:"tls_handshake_timeout" mapstructure:"tls_handshake_timeout"` + TLSVerify bool `capnp:"tlsVerify" json:"tls_verify" mapstructure:"tls_verify"` + OriginCAPool string `json:"origin_ca_pool" mapstructure:"origin_ca_pool"` + OriginServerName string `json:"origin_server_name" mapstructure:"origin_server_name"` + MaxIdleConnections uint64 `json:"max_idle_connections" mapstructure:"max_idle_connections"` + IdleConnectionTimeout time.Duration `json:"idle_connection_timeout" mapstructure:"idle_connection_timeout"` + ProxyConnectionTimeout time.Duration `json:"proxy_connection_timeout" mapstructure:"proxy_connection_timeout"` + ExpectContinueTimeout time.Duration `json:"expect_continue_timeout" mapstructure:"expect_continue_timeout"` + ChunkedEncoding bool `json:"chunked_encoding" mapstructure:"chunked_encoding"` } func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) { @@ -206,13 +248,15 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) { return originservice.NewHTTPService(transport, url, hc.ChunkedEncoding), nil } -func (_ *HTTPOriginConfig) isOriginConfig() {} +func (_ *HTTPOriginConfig) jsonType() string { + return httpType.String() +} type WebSocketOriginConfig struct { - URLString string `capnp:"urlString" mapstructure:"url_string"` - TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"` - OriginCAPool string `mapstructure:"origin_ca_pool"` - OriginServerName string `mapstructure:"origin_server_name"` + URLString string `capnp:"urlString" json:"url_string" mapstructure:"url_string"` + TLSVerify bool `capnp:"tlsVerify" json:"tls_verify" mapstructure:"tls_verify"` + OriginCAPool string `json:"origin_ca_pool" mapstructure:"origin_ca_pool"` + OriginServerName string `json:"origin_server_name" mapstructure:"origin_server_name"` } func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) { @@ -233,7 +277,9 @@ func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) return originservice.NewWebSocketService(tlsConfig, url) } -func (_ *WebSocketOriginConfig) isOriginConfig() {} +func (_ *WebSocketOriginConfig) jsonType() string { + return wsType.String() +} type HelloWorldOriginConfig struct{} @@ -262,7 +308,9 @@ func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error) return originservice.NewHelloWorldService(transport) } -func (_ *HelloWorldOriginConfig) isOriginConfig() {} +func (_ *HelloWorldOriginConfig) jsonType() string { + return helloWorldType.String() +} /* * Boilerplate to convert between these structs and the primitive structs @@ -471,7 +519,7 @@ func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error { s.SetTunnelHostname(p.TunnelHostname.String()) - switch config := p.OriginConfigUnmarshaler.OriginConfig.(type) { + switch config := p.OriginConfigJSONHandler.OriginConfig.(type) { case *HTTPOriginConfig: ss, err := s.Origin().NewHttp() if err != nil { @@ -522,7 +570,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC if err != nil { return nil, err } - p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config} + p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config} case tunnelrpc.ReverseProxyConfig_origin_Which_websocket: ss, err := s.Origin().Websocket() if err != nil { @@ -532,7 +580,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC if err != nil { return nil, err } - p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config} + p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config} case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld: ss, err := s.Origin().HelloWorld() if err != nil { @@ -542,7 +590,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC if err != nil { return nil, err } - p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config} + p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config} } p.Retries = s.Retries() p.ConnectionTimeout = time.Duration(s.ConnectionTimeout()) @@ -642,13 +690,13 @@ func (i ClientService_PogsImpl) UseConfiguration(p tunnelrpc.ClientService_useCo } type UseConfigurationResult struct { - Success bool - FailedConfigs []*FailedConfig + Success bool `json:"success"` + FailedConfigs []*FailedConfig `json:"failed_configs"` } type FailedConfig struct { - Config FallibleConfig - Reason string + Config FallibleConfig `json:"config"` + Reason string `json:"reason"` } func MarshalFailedConfig(s tunnelrpc.FailedConfig, p *FailedConfig) error { @@ -663,7 +711,7 @@ func MarshalFailedConfig(s tunnelrpc.FailedConfig, p *FailedConfig) error { return err } case *EdgeConnectionConfig: - ss, err := s.Config().EdgeConnection() + ss, err := s.Config().NewEdgeConnection() if err != nil { return err } diff --git a/tunnelrpc/pogs/config_test.go b/tunnelrpc/pogs/config_test.go index 8ede50f1..d74fa62c 100644 --- a/tunnelrpc/pogs/config_test.go +++ b/tunnelrpc/pogs/config_test.go @@ -41,13 +41,13 @@ func TestClientConfig(t *testing.T) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) { }), sampleReverseProxyConfig(func(c *ReverseProxyConfig) { - c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()} + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()} }), sampleReverseProxyConfig(func(c *ReverseProxyConfig) { - c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()} + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfigUnixPath()} }), sampleReverseProxyConfig(func(c *ReverseProxyConfig) { - c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()} + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()} }), } } @@ -142,13 +142,13 @@ func TestReverseProxyConfig(t *testing.T) { testCases := []*ReverseProxyConfig{ sampleReverseProxyConfig(), sampleReverseProxyConfig(func(c *ReverseProxyConfig) { - c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()} + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()} }), sampleReverseProxyConfig(func(c *ReverseProxyConfig) { - c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()} + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfigUnixPath()} }), sampleReverseProxyConfig(func(c *ReverseProxyConfig) { - c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()} + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()} }), } for i, testCase := range testCases { @@ -246,18 +246,9 @@ func TestOriginConfigInvalidURL(t *testing.T) { // applies any number of overrides to it, and returns it. func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig { sample := &ClientConfig{ - Version: Version(1337), - SupervisorConfig: &SupervisorConfig{ - AutoUpdateFrequency: 21 * time.Hour, - MetricsUpdateFrequency: 11 * time.Minute, - GracePeriod: 31 * time.Second, - }, - EdgeConnectionConfig: &EdgeConnectionConfig{ - NumHAConnections: 49, - Timeout: 9 * time.Second, - HeartbeatInterval: 5 * time.Second, - MaxFailedHeartbeats: 9001, - }, + Version: Version(1337), + SupervisorConfig: sampleSupervisorConfig(), + EdgeConnectionConfig: sampleEdgeConnectionConfig(), } sample.ensureNoZeroFields() for _, f := range overrides { @@ -266,13 +257,35 @@ func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig { return sample } +func sampleSupervisorConfig() *SupervisorConfig { + sample := &SupervisorConfig{ + AutoUpdateFrequency: 21 * time.Hour, + MetricsUpdateFrequency: 11 * time.Minute, + GracePeriod: 31 * time.Second, + } + sample.ensureNoZeroFields() + return sample +} + +func sampleEdgeConnectionConfig() *EdgeConnectionConfig { + sample := &EdgeConnectionConfig{ + NumHAConnections: 49, + HeartbeatInterval: 5 * time.Second, + Timeout: 9 * time.Second, + MaxFailedHeartbeats: 9001, + UserCredentialPath: "/Users/example/.cloudflared/cert.pem", + } + sample.ensureNoZeroFields() + return sample +} + // sampleDoHProxyConfig initializes a new DoHProxyConfig struct, // applies any number of overrides to it, and returns it. func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig { sample := &DoHProxyConfig{ ListenHost: "127.0.0.1", ListenPort: 53, - Upstreams: []string{"https://1.example.com", "https://2.example.com"}, + Upstreams: []string{"1.1.1.1", "1.0.0.1"}, } sample.ensureNoZeroFields() for _, f := range overrides { @@ -285,11 +298,11 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig { // applies any number of overrides to it, and returns it. func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig { sample := &ReverseProxyConfig{ - TunnelHostname: "hijk.example.com", - OriginConfigUnmarshaler: &OriginConfigUnmarshaler{&HelloWorldOriginConfig{}}, + TunnelHostname: "mock-non-lb-tunnel.example.com", + OriginConfigJSONHandler: &OriginConfigJSONHandler{&HelloWorldOriginConfig{}}, Retries: 18, ConnectionTimeout: 5 * time.Second, - CompressionQuality: 4, + CompressionQuality: 3, } sample.ensureNoZeroFields() for _, f := range overrides { @@ -360,6 +373,14 @@ func (c *ClientConfig) ensureNoZeroFields() { ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{"DoHProxyConfigs", "ReverseProxyConfigs"}) } +func (c *SupervisorConfig) ensureNoZeroFields() { + ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{}) +} + +func (c *EdgeConnectionConfig) ensureNoZeroFields() { + ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{}) +} + func (c *DoHProxyConfig) ensureNoZeroFields() { ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{}) } diff --git a/tunnelrpc/pogs/unmarshal.go b/tunnelrpc/pogs/json.go similarity index 71% rename from tunnelrpc/pogs/unmarshal.go rename to tunnelrpc/pogs/json.go index 001c8c2a..497ee4d8 100644 --- a/tunnelrpc/pogs/unmarshal.go +++ b/tunnelrpc/pogs/json.go @@ -41,11 +41,19 @@ func (su *ScopeUnmarshaler) UnmarshalJSON(b []byte) error { return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'") } -type OriginConfigUnmarshaler struct { +// OriginConfigJSONHandler is a wrapper to serialize OriginConfig with type information, and deserialize JSON +// into an OriginConfig. +type OriginConfigJSONHandler struct { OriginConfig OriginConfig } -func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error { +func (ocjh *OriginConfigJSONHandler) MarshalJSON() ([]byte, error) { + marshaler := make(map[string]OriginConfig, 1) + marshaler[ocjh.OriginConfig.jsonType()] = ocjh.OriginConfig + return json.Marshal(marshaler) +} + +func (ocjh *OriginConfigJSONHandler) UnmarshalJSON(b []byte) error { var originJSON map[string]interface{} if err := json.Unmarshal(b, &originJSON); err != nil { return errors.Wrapf(err, "cannot unmarshal %s into originJSON", string(b)) @@ -56,7 +64,7 @@ func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error { if err := mapstructure.Decode(originConfig, httpOriginConfig); err != nil { return errors.Wrapf(err, "cannot decode %+v into HTTPOriginConfig", originConfig) } - ocu.OriginConfig = httpOriginConfig + ocjh.OriginConfig = httpOriginConfig return nil } @@ -65,7 +73,7 @@ func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error { if err := mapstructure.Decode(originConfig, wsOriginConfig); err != nil { return errors.Wrapf(err, "cannot decode %+v into WebSocketOriginConfig", originConfig) } - ocu.OriginConfig = wsOriginConfig + ocjh.OriginConfig = wsOriginConfig return nil } @@ -74,9 +82,20 @@ func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error { if err := mapstructure.Decode(originConfig, helloWorldOriginConfig); err != nil { return errors.Wrapf(err, "cannot decode %+v into HelloWorldOriginConfig", originConfig) } - ocu.OriginConfig = helloWorldOriginConfig + ocjh.OriginConfig = helloWorldOriginConfig return nil } return fmt.Errorf("cannot unmarshal %s into OriginConfig", string(b)) -} \ No newline at end of file +} + +// FallibleConfigMarshaler is a wrapper for FallibleConfig to implement custom marshal logic +type FallibleConfigMarshaler struct { + FallibleConfig FallibleConfig +} + +func (fcm *FallibleConfigMarshaler) MarshalJSON() ([]byte, error) { + marshaler := make(map[string]FallibleConfig, 1) + marshaler[fcm.FallibleConfig.jsonType()] = fcm.FallibleConfig + return json.Marshal(marshaler) +} diff --git a/tunnelrpc/pogs/unmarshal_test.go b/tunnelrpc/pogs/json_test.go similarity index 54% rename from tunnelrpc/pogs/unmarshal_test.go rename to tunnelrpc/pogs/json_test.go index 8e4a2981..70a896a4 100644 --- a/tunnelrpc/pogs/unmarshal_test.go +++ b/tunnelrpc/pogs/json_test.go @@ -70,50 +70,32 @@ func TestUnmarshalOrigin(t *testing.T) { { jsonLiteral: `{ "Http":{ - "url_string":"https://127.0.0.1:8080", - "tcp_keep_alive":30000000000, + "url_string":"https.example.com", + "tcp_keep_alive":7000000000, "dial_dual_stack":true, - "tls_handshake_timeout":10000000000, + "tls_handshake_timeout":11000000000, "tls_verify":true, - "origin_ca_pool":"", - "origin_server_name":"", - "max_idle_connections":100, - "idle_connection_timeout":90000000000, - "proxy_connection_timeout":90000000000, - "expect_continue_timeout":90000000000, + "origin_ca_pool":"/etc/cert.pem", + "origin_server_name":"secure.example.com", + "max_idle_connections":19, + "idle_connection_timeout":17000000000, + "proxy_connection_timeout":15000000000, + "expect_continue_timeout":21000000000, "chunked_encoding":true } }`, - exceptedOriginConfig: &HTTPOriginConfig{ - URLString: "https://127.0.0.1:8080", - TCPKeepAlive: time.Second * 30, - DialDualStack: true, - TLSHandshakeTimeout: time.Second * 10, - TLSVerify: true, - OriginCAPool: "", - OriginServerName: "", - MaxIdleConnections: 100, - IdleConnectionTimeout: time.Second * 90, - ProxyConnectionTimeout: time.Second * 90, - ExpectContinueTimeout: time.Second * 90, - ChunkedEncoding: true, - }, + exceptedOriginConfig: sampleHTTPOriginConfig(), }, { jsonLiteral: `{ "WebSocket":{ - "url_string":"https://127.0.0.1:9090", + "url_string":"ssh://example.com", "tls_verify":true, - "origin_ca_pool":"", - "origin_server_name":"ws.example.com" + "origin_ca_pool":"/etc/cert.pem", + "origin_server_name":"secure.example.com" } }`, - exceptedOriginConfig: &WebSocketOriginConfig{ - URLString: "https://127.0.0.1:9090", - TLSVerify: true, - OriginCAPool: "", - OriginServerName: "ws.example.com", - }, + exceptedOriginConfig: sampleWebSocketOriginConfig(), }, { jsonLiteral: `{ @@ -124,11 +106,11 @@ func TestUnmarshalOrigin(t *testing.T) { } for _, test := range tests { - originConfigJSON := strings.ReplaceAll(strings.ReplaceAll(test.jsonLiteral, "\n", ""), "\t", "") - var OriginConfigUnmarshaler OriginConfigUnmarshaler - err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigUnmarshaler) + originConfigJSON := prettyToValidJSON(test.jsonLiteral) + var OriginConfigJSONHandler OriginConfigJSONHandler + err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigJSONHandler) assert.NoError(t, err) - assert.Equal(t, test.exceptedOriginConfig, OriginConfigUnmarshaler.OriginConfig) + assert.Equal(t, test.exceptedOriginConfig, OriginConfigJSONHandler.OriginConfig) } } @@ -176,7 +158,7 @@ func TestUnmarshalClientConfig(t *testing.T) { }] }` // replace new line and tab - clientConfigJSON := strings.ReplaceAll(strings.ReplaceAll(prettyClientConfigJSON, "\n", ""), "\t", "") + clientConfigJSON := prettyToValidJSON(prettyClientConfigJSON) var clientConfig ClientConfig err := json.Unmarshal([]byte(clientConfigJSON), &clientConfig) @@ -211,7 +193,7 @@ func TestUnmarshalClientConfig(t *testing.T) { reverseProxyConfig := ReverseProxyConfig{ TunnelHostname: "sdfjadk33.cftunnel.com", - OriginConfigUnmarshaler: &OriginConfigUnmarshaler{ + OriginConfigJSONHandler: &OriginConfigJSONHandler{ OriginConfig: &HTTPOriginConfig{ URLString: "https://127.0.0.1:8080", TCPKeepAlive: time.Second * 30, @@ -236,6 +218,125 @@ func TestUnmarshalClientConfig(t *testing.T) { assert.Equal(t, reverseProxyConfig, *clientConfig.ReverseProxyConfigs[0]) } +func TestMarshalFallibleConfig(t *testing.T) { + tests := []struct { + fallibleConfig FallibleConfig + expctedJSONLiteral string + }{ + { + fallibleConfig: sampleSupervisorConfig(), + expctedJSONLiteral: `{ + "supervisor_config":{ + "auto_update_frequency":75600000000000, + "metrics_update_frequency":660000000000, + "grace_period":31000000000 + } + }`, + }, + { + fallibleConfig: sampleEdgeConnectionConfig(), + expctedJSONLiteral: `{ + "edge_connection_config":{ + "num_ha_connections":49, + "heartbeat_interval":5000000000, + "timeout":9000000000, + "max_failed_heartbeats":9001, + "user_credential_path":"/Users/example/.cloudflared/cert.pem" + } + }`, + }, + { + fallibleConfig: sampleDoHProxyConfig(), + expctedJSONLiteral: `{ + "doh_proxy_config":{ + "listen_host":"127.0.0.1", + "listen_port":53, + "upstreams":["1.1.1.1","1.0.0.1"] + } + }`, + }, + { + fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) { + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()} + }), + expctedJSONLiteral: `{ + "reverse_proxy_config":{ + "tunnel_hostname":"mock-non-lb-tunnel.example.com", + "origin_config":{ + "Http":{ + "url_string":"https.example.com", + "tcp_keep_alive":7000000000, + "dial_dual_stack":true, + "tls_handshake_timeout":11000000000, + "tls_verify":true, + "origin_ca_pool":"/etc/cert.pem", + "origin_server_name":"secure.example.com", + "max_idle_connections":19, + "idle_connection_timeout":17000000000, + "proxy_connection_timeout":15000000000, + "expect_continue_timeout":21000000000, + "chunked_encoding":true + } + }, + "retries":18, + "connection_timeout":5000000000, + "compression_quality":3 + } + }`, + }, + { + fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) { + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()} + }), + expctedJSONLiteral: `{ + "reverse_proxy_config":{ + "tunnel_hostname":"mock-non-lb-tunnel.example.com", + "origin_config":{ + "WebSocket":{ + "url_string":"ssh://example.com", + "tls_verify":true, + "origin_ca_pool":"/etc/cert.pem", + "origin_server_name":"secure.example.com" + } + }, + "retries":18, + "connection_timeout":5000000000, + "compression_quality":3 + } + }`, + }, + { + fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) { + c.OriginConfigJSONHandler = &OriginConfigJSONHandler{&HelloWorldOriginConfig{}} + }), + expctedJSONLiteral: `{ + "reverse_proxy_config":{ + "tunnel_hostname":"mock-non-lb-tunnel.example.com", + "origin_config":{ + "HelloWorld":{} + }, + "retries":18, + "connection_timeout":5000000000, + "compression_quality":3 + } + }`, + }, + } + + for _, test := range tests { + b, err := json.Marshal(test.fallibleConfig) + assert.NoError(t, err) + assert.Equal(t, prettyToValidJSON(test.expctedJSONLiteral), string(b)) + } + +} + +type prettyJSON string + +func prettyToValidJSON(prettyJSON string) string { + return strings.ReplaceAll(strings.ReplaceAll(prettyJSON, "\n", ""), "\t", "") +} + func eqScope(s1, s2 Scope) bool { return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType() }