TUN-2111: Implement custom serialization logic for FallibleConfig and OriginConfig

This commit is contained in:
Chung-Ting Huang 2019-08-05 10:14:58 -05:00
parent 993a9bc4b9
commit fd4ab314dc
6 changed files with 289 additions and 100 deletions

View File

@ -86,7 +86,7 @@ func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (fail
s.tunnelHostnameMapper.DeleteAll() s.tunnelHostnameMapper.DeleteAll()
for _, tunnelConfig := range newConfig { for _, tunnelConfig := range newConfig {
tunnelHostname := tunnelConfig.TunnelHostname tunnelHostname := tunnelConfig.TunnelHostname
originSerice, err := tunnelConfig.OriginConfigUnmarshaler.OriginConfig.Service() originSerice, err := tunnelConfig.OriginConfigJSONHandler.OriginConfig.Service()
if err != nil { if err != nil {
s.logger.WithField("tunnelHostname", tunnelHostname).WithError(err).Error("Invalid origin service config") s.logger.WithField("tunnelHostname", tunnelHostname).WithError(err).Error("Invalid origin service config")
failedConfigs = append(failedConfigs, &pogs.FailedConfig{ failedConfigs = append(failedConfigs, &pogs.FailedConfig{

View File

@ -49,7 +49,7 @@ func TestServeRequest(t *testing.T) {
reverseProxyConfigs := []*pogs.ReverseProxyConfig{ reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{ {
TunnelHostname: testTunnelHostname, TunnelHostname: testTunnelHostname,
OriginConfigUnmarshaler: &pogs.OriginConfigUnmarshaler{ OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{
OriginConfig: &pogs.HTTPOriginConfig{ OriginConfig: &pogs.HTTPOriginConfig{
URLString: httpServer.URL, URLString: httpServer.URL,
}, },
@ -99,7 +99,7 @@ func TestServeBadRequest(t *testing.T) {
reverseProxyConfigs := []*pogs.ReverseProxyConfig{ reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{ {
TunnelHostname: testTunnelHostname, TunnelHostname: testTunnelHostname,
OriginConfigUnmarshaler: &pogs.OriginConfigUnmarshaler{ OriginConfigJSONHandler: &pogs.OriginConfigJSONHandler{
OriginConfig: &pogs.HTTPOriginConfig{ OriginConfig: &pogs.HTTPOriginConfig{
URLString: "", URLString: "",
}, },

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -53,6 +54,7 @@ func (v Version) String() string {
// FallibleConfig is an interface implemented by configs that cloudflared might not be able to apply // FallibleConfig is an interface implemented by configs that cloudflared might not be able to apply
type FallibleConfig interface { type FallibleConfig interface {
FailReason(err error) string FailReason(err error) string
jsonType() string
} }
// SupervisorConfig specifies config of components managed by Supervisor other than ConnectionManager // SupervisorConfig specifies config of components managed by Supervisor other than ConnectionManager
@ -67,6 +69,16 @@ func (sc *SupervisorConfig) FailReason(err error) string {
return fmt.Sprintf("Cannot apply SupervisorConfig, err: %v", err) return fmt.Sprintf("Cannot apply SupervisorConfig, err: %v", err)
} }
func (sc *SupervisorConfig) MarshalJSON() ([]byte, error) {
marshaler := make(map[string]SupervisorConfig, 1)
marshaler[sc.jsonType()] = *sc
return json.Marshal(marshaler)
}
func (sc *SupervisorConfig) jsonType() string {
return "supervisor_config"
}
// EdgeConnectionConfig specifies what parameters and how may connections should ConnectionManager establish with edge // EdgeConnectionConfig specifies what parameters and how may connections should ConnectionManager establish with edge
type EdgeConnectionConfig struct { type EdgeConnectionConfig struct {
NumHAConnections uint8 `json:"num_ha_connections"` NumHAConnections uint8 `json:"num_ha_connections"`
@ -81,6 +93,16 @@ func (cmc *EdgeConnectionConfig) FailReason(err error) string {
return fmt.Sprintf("Cannot apply EdgeConnectionConfig, err: %v", err) return fmt.Sprintf("Cannot apply EdgeConnectionConfig, err: %v", err)
} }
func (cmc *EdgeConnectionConfig) MarshalJSON() ([]byte, error) {
marshaler := make(map[string]EdgeConnectionConfig, 1)
marshaler[cmc.jsonType()] = *cmc
return json.Marshal(marshaler)
}
func (cmc *EdgeConnectionConfig) jsonType() string {
return "edge_connection_config"
}
// DoHProxyConfig is configuration for DNS over HTTPS service // DoHProxyConfig is configuration for DNS over HTTPS service
type DoHProxyConfig struct { type DoHProxyConfig struct {
ListenHost string `json:"listen_host"` ListenHost string `json:"listen_host"`
@ -93,10 +115,20 @@ func (dpc *DoHProxyConfig) FailReason(err error) string {
return fmt.Sprintf("Cannot apply DoHProxyConfig, err: %v", err) return fmt.Sprintf("Cannot apply DoHProxyConfig, err: %v", err)
} }
func (dpc *DoHProxyConfig) MarshalJSON() ([]byte, error) {
marshaler := make(map[string]DoHProxyConfig, 1)
marshaler[dpc.jsonType()] = *dpc
return json.Marshal(marshaler)
}
func (dpc *DoHProxyConfig) jsonType() string {
return "doh_proxy_config"
}
// ReverseProxyConfig how and for what hostnames can this cloudflared proxy // ReverseProxyConfig how and for what hostnames can this cloudflared proxy
type ReverseProxyConfig struct { type ReverseProxyConfig struct {
TunnelHostname h2mux.TunnelHostname `json:"tunnel_hostname"` TunnelHostname h2mux.TunnelHostname `json:"tunnel_hostname"`
OriginConfigUnmarshaler *OriginConfigUnmarshaler `json:"origin_config"` OriginConfigJSONHandler *OriginConfigJSONHandler `json:"origin_config"`
Retries uint64 `json:"retries"` Retries uint64 `json:"retries"`
ConnectionTimeout time.Duration `json:"connection_timeout"` ConnectionTimeout time.Duration `json:"connection_timeout"`
CompressionQuality uint64 `json:"compression_quality"` CompressionQuality uint64 `json:"compression_quality"`
@ -114,7 +146,7 @@ func NewReverseProxyConfig(
} }
return &ReverseProxyConfig{ return &ReverseProxyConfig{
TunnelHostname: h2mux.TunnelHostname(tunnelHostname), TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{originConfig}, OriginConfigJSONHandler: &OriginConfigJSONHandler{originConfig},
Retries: retries, Retries: retries,
ConnectionTimeout: connectionTimeout, ConnectionTimeout: connectionTimeout,
CompressionQuality: compressionQuality, CompressionQuality: compressionQuality,
@ -126,12 +158,22 @@ func (rpc *ReverseProxyConfig) FailReason(err error) string {
return fmt.Sprintf("Cannot apply ReverseProxyConfig, err: %v", err) return fmt.Sprintf("Cannot apply ReverseProxyConfig, err: %v", err)
} }
func (rpc *ReverseProxyConfig) MarshalJSON() ([]byte, error) {
marshaler := make(map[string]ReverseProxyConfig, 1)
marshaler[rpc.jsonType()] = *rpc
return json.Marshal(marshaler)
}
func (rpc *ReverseProxyConfig) jsonType() string {
return "reverse_proxy_config"
}
//go-sumtype:decl OriginConfig //go-sumtype:decl OriginConfig
type OriginConfig interface { type OriginConfig interface {
// Service returns a OriginService used to proxy to the origin // Service returns a OriginService used to proxy to the origin
Service() (originservice.OriginService, error) Service() (originservice.OriginService, error)
// go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed // go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed
isOriginConfig() jsonType() string
} }
type originType int type originType int
@ -156,18 +198,18 @@ func (ot originType) String() string {
} }
type HTTPOriginConfig struct { type HTTPOriginConfig struct {
URLString string `capnp:"urlString" mapstructure:"url_string"` URLString string `capnp:"urlString" json:"url_string" mapstructure:"url_string"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" mapstructure:"tcp_keep_alive"` TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" json:"tcp_keep_alive" mapstructure:"tcp_keep_alive"`
DialDualStack bool `mapstructure:"dial_dual_stack"` DialDualStack bool `json:"dial_dual_stack" mapstructure:"dial_dual_stack"`
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" mapstructure:"tls_handshake_timeout"` TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" json:"tls_handshake_timeout" mapstructure:"tls_handshake_timeout"`
TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"` TLSVerify bool `capnp:"tlsVerify" json:"tls_verify" mapstructure:"tls_verify"`
OriginCAPool string `mapstructure:"origin_ca_pool"` OriginCAPool string `json:"origin_ca_pool" mapstructure:"origin_ca_pool"`
OriginServerName string `mapstructure:"origin_server_name"` OriginServerName string `json:"origin_server_name" mapstructure:"origin_server_name"`
MaxIdleConnections uint64 `mapstructure:"max_idle_connections"` MaxIdleConnections uint64 `json:"max_idle_connections" mapstructure:"max_idle_connections"`
IdleConnectionTimeout time.Duration `mapstructure:"idle_connection_timeout"` IdleConnectionTimeout time.Duration `json:"idle_connection_timeout" mapstructure:"idle_connection_timeout"`
ProxyConnectionTimeout time.Duration `mapstructure:"proxy_connection_timeout"` ProxyConnectionTimeout time.Duration `json:"proxy_connection_timeout" mapstructure:"proxy_connection_timeout"`
ExpectContinueTimeout time.Duration `mapstructure:"expect_continue_timeout"` ExpectContinueTimeout time.Duration `json:"expect_continue_timeout" mapstructure:"expect_continue_timeout"`
ChunkedEncoding bool `mapstructure:"chunked_encoding"` ChunkedEncoding bool `json:"chunked_encoding" mapstructure:"chunked_encoding"`
} }
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) { func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
@ -206,13 +248,15 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
return originservice.NewHTTPService(transport, url, hc.ChunkedEncoding), nil return originservice.NewHTTPService(transport, url, hc.ChunkedEncoding), nil
} }
func (_ *HTTPOriginConfig) isOriginConfig() {} func (_ *HTTPOriginConfig) jsonType() string {
return httpType.String()
}
type WebSocketOriginConfig struct { type WebSocketOriginConfig struct {
URLString string `capnp:"urlString" mapstructure:"url_string"` URLString string `capnp:"urlString" json:"url_string" mapstructure:"url_string"`
TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"` TLSVerify bool `capnp:"tlsVerify" json:"tls_verify" mapstructure:"tls_verify"`
OriginCAPool string `mapstructure:"origin_ca_pool"` OriginCAPool string `json:"origin_ca_pool" mapstructure:"origin_ca_pool"`
OriginServerName string `mapstructure:"origin_server_name"` OriginServerName string `json:"origin_server_name" mapstructure:"origin_server_name"`
} }
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) { func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
@ -233,7 +277,9 @@ func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error)
return originservice.NewWebSocketService(tlsConfig, url) return originservice.NewWebSocketService(tlsConfig, url)
} }
func (_ *WebSocketOriginConfig) isOriginConfig() {} func (_ *WebSocketOriginConfig) jsonType() string {
return wsType.String()
}
type HelloWorldOriginConfig struct{} type HelloWorldOriginConfig struct{}
@ -262,7 +308,9 @@ func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error)
return originservice.NewHelloWorldService(transport) return originservice.NewHelloWorldService(transport)
} }
func (_ *HelloWorldOriginConfig) isOriginConfig() {} func (_ *HelloWorldOriginConfig) jsonType() string {
return helloWorldType.String()
}
/* /*
* Boilerplate to convert between these structs and the primitive structs * Boilerplate to convert between these structs and the primitive structs
@ -471,7 +519,7 @@ func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error
func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error { func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error {
s.SetTunnelHostname(p.TunnelHostname.String()) s.SetTunnelHostname(p.TunnelHostname.String())
switch config := p.OriginConfigUnmarshaler.OriginConfig.(type) { switch config := p.OriginConfigJSONHandler.OriginConfig.(type) {
case *HTTPOriginConfig: case *HTTPOriginConfig:
ss, err := s.Origin().NewHttp() ss, err := s.Origin().NewHttp()
if err != nil { if err != nil {
@ -522,7 +570,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config} p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config}
case tunnelrpc.ReverseProxyConfig_origin_Which_websocket: case tunnelrpc.ReverseProxyConfig_origin_Which_websocket:
ss, err := s.Origin().Websocket() ss, err := s.Origin().Websocket()
if err != nil { if err != nil {
@ -532,7 +580,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config} p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config}
case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld: case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld:
ss, err := s.Origin().HelloWorld() ss, err := s.Origin().HelloWorld()
if err != nil { if err != nil {
@ -542,7 +590,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil { if err != nil {
return nil, err return nil, err
} }
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config} p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config}
} }
p.Retries = s.Retries() p.Retries = s.Retries()
p.ConnectionTimeout = time.Duration(s.ConnectionTimeout()) p.ConnectionTimeout = time.Duration(s.ConnectionTimeout())
@ -642,13 +690,13 @@ func (i ClientService_PogsImpl) UseConfiguration(p tunnelrpc.ClientService_useCo
} }
type UseConfigurationResult struct { type UseConfigurationResult struct {
Success bool Success bool `json:"success"`
FailedConfigs []*FailedConfig FailedConfigs []*FailedConfig `json:"failed_configs"`
} }
type FailedConfig struct { type FailedConfig struct {
Config FallibleConfig Config FallibleConfig `json:"config"`
Reason string Reason string `json:"reason"`
} }
func MarshalFailedConfig(s tunnelrpc.FailedConfig, p *FailedConfig) error { func MarshalFailedConfig(s tunnelrpc.FailedConfig, p *FailedConfig) error {
@ -663,7 +711,7 @@ func MarshalFailedConfig(s tunnelrpc.FailedConfig, p *FailedConfig) error {
return err return err
} }
case *EdgeConnectionConfig: case *EdgeConnectionConfig:
ss, err := s.Config().EdgeConnection() ss, err := s.Config().NewEdgeConnection()
if err != nil { if err != nil {
return err return err
} }

View File

@ -41,13 +41,13 @@ func TestClientConfig(t *testing.T) {
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
}), }),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()} c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()}
}), }),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()} c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfigUnixPath()}
}), }),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()} c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()}
}), }),
} }
} }
@ -142,13 +142,13 @@ func TestReverseProxyConfig(t *testing.T) {
testCases := []*ReverseProxyConfig{ testCases := []*ReverseProxyConfig{
sampleReverseProxyConfig(), sampleReverseProxyConfig(),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()} c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()}
}), }),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()} c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfigUnixPath()}
}), }),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) { sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()} c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()}
}), }),
} }
for i, testCase := range testCases { for i, testCase := range testCases {
@ -246,18 +246,9 @@ func TestOriginConfigInvalidURL(t *testing.T) {
// applies any number of overrides to it, and returns it. // applies any number of overrides to it, and returns it.
func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig { func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
sample := &ClientConfig{ sample := &ClientConfig{
Version: Version(1337), Version: Version(1337),
SupervisorConfig: &SupervisorConfig{ SupervisorConfig: sampleSupervisorConfig(),
AutoUpdateFrequency: 21 * time.Hour, EdgeConnectionConfig: sampleEdgeConnectionConfig(),
MetricsUpdateFrequency: 11 * time.Minute,
GracePeriod: 31 * time.Second,
},
EdgeConnectionConfig: &EdgeConnectionConfig{
NumHAConnections: 49,
Timeout: 9 * time.Second,
HeartbeatInterval: 5 * time.Second,
MaxFailedHeartbeats: 9001,
},
} }
sample.ensureNoZeroFields() sample.ensureNoZeroFields()
for _, f := range overrides { for _, f := range overrides {
@ -266,13 +257,35 @@ func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
return sample return sample
} }
func sampleSupervisorConfig() *SupervisorConfig {
sample := &SupervisorConfig{
AutoUpdateFrequency: 21 * time.Hour,
MetricsUpdateFrequency: 11 * time.Minute,
GracePeriod: 31 * time.Second,
}
sample.ensureNoZeroFields()
return sample
}
func sampleEdgeConnectionConfig() *EdgeConnectionConfig {
sample := &EdgeConnectionConfig{
NumHAConnections: 49,
HeartbeatInterval: 5 * time.Second,
Timeout: 9 * time.Second,
MaxFailedHeartbeats: 9001,
UserCredentialPath: "/Users/example/.cloudflared/cert.pem",
}
sample.ensureNoZeroFields()
return sample
}
// sampleDoHProxyConfig initializes a new DoHProxyConfig struct, // sampleDoHProxyConfig initializes a new DoHProxyConfig struct,
// applies any number of overrides to it, and returns it. // applies any number of overrides to it, and returns it.
func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig { func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
sample := &DoHProxyConfig{ sample := &DoHProxyConfig{
ListenHost: "127.0.0.1", ListenHost: "127.0.0.1",
ListenPort: 53, ListenPort: 53,
Upstreams: []string{"https://1.example.com", "https://2.example.com"}, Upstreams: []string{"1.1.1.1", "1.0.0.1"},
} }
sample.ensureNoZeroFields() sample.ensureNoZeroFields()
for _, f := range overrides { for _, f := range overrides {
@ -285,11 +298,11 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
// applies any number of overrides to it, and returns it. // applies any number of overrides to it, and returns it.
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig { func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
sample := &ReverseProxyConfig{ sample := &ReverseProxyConfig{
TunnelHostname: "hijk.example.com", TunnelHostname: "mock-non-lb-tunnel.example.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{&HelloWorldOriginConfig{}}, OriginConfigJSONHandler: &OriginConfigJSONHandler{&HelloWorldOriginConfig{}},
Retries: 18, Retries: 18,
ConnectionTimeout: 5 * time.Second, ConnectionTimeout: 5 * time.Second,
CompressionQuality: 4, CompressionQuality: 3,
} }
sample.ensureNoZeroFields() sample.ensureNoZeroFields()
for _, f := range overrides { for _, f := range overrides {
@ -360,6 +373,14 @@ func (c *ClientConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{"DoHProxyConfigs", "ReverseProxyConfigs"}) ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{"DoHProxyConfigs", "ReverseProxyConfigs"})
} }
func (c *SupervisorConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}
func (c *EdgeConnectionConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}
func (c *DoHProxyConfig) ensureNoZeroFields() { func (c *DoHProxyConfig) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{}) ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
} }

View File

@ -41,11 +41,19 @@ func (su *ScopeUnmarshaler) UnmarshalJSON(b []byte) error {
return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'") return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'")
} }
type OriginConfigUnmarshaler struct { // OriginConfigJSONHandler is a wrapper to serialize OriginConfig with type information, and deserialize JSON
// into an OriginConfig.
type OriginConfigJSONHandler struct {
OriginConfig OriginConfig OriginConfig OriginConfig
} }
func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error { func (ocjh *OriginConfigJSONHandler) MarshalJSON() ([]byte, error) {
marshaler := make(map[string]OriginConfig, 1)
marshaler[ocjh.OriginConfig.jsonType()] = ocjh.OriginConfig
return json.Marshal(marshaler)
}
func (ocjh *OriginConfigJSONHandler) UnmarshalJSON(b []byte) error {
var originJSON map[string]interface{} var originJSON map[string]interface{}
if err := json.Unmarshal(b, &originJSON); err != nil { if err := json.Unmarshal(b, &originJSON); err != nil {
return errors.Wrapf(err, "cannot unmarshal %s into originJSON", string(b)) return errors.Wrapf(err, "cannot unmarshal %s into originJSON", string(b))
@ -56,7 +64,7 @@ func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error {
if err := mapstructure.Decode(originConfig, httpOriginConfig); err != nil { if err := mapstructure.Decode(originConfig, httpOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into HTTPOriginConfig", originConfig) return errors.Wrapf(err, "cannot decode %+v into HTTPOriginConfig", originConfig)
} }
ocu.OriginConfig = httpOriginConfig ocjh.OriginConfig = httpOriginConfig
return nil return nil
} }
@ -65,7 +73,7 @@ func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error {
if err := mapstructure.Decode(originConfig, wsOriginConfig); err != nil { if err := mapstructure.Decode(originConfig, wsOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into WebSocketOriginConfig", originConfig) return errors.Wrapf(err, "cannot decode %+v into WebSocketOriginConfig", originConfig)
} }
ocu.OriginConfig = wsOriginConfig ocjh.OriginConfig = wsOriginConfig
return nil return nil
} }
@ -74,9 +82,20 @@ func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error {
if err := mapstructure.Decode(originConfig, helloWorldOriginConfig); err != nil { if err := mapstructure.Decode(originConfig, helloWorldOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into HelloWorldOriginConfig", originConfig) return errors.Wrapf(err, "cannot decode %+v into HelloWorldOriginConfig", originConfig)
} }
ocu.OriginConfig = helloWorldOriginConfig ocjh.OriginConfig = helloWorldOriginConfig
return nil return nil
} }
return fmt.Errorf("cannot unmarshal %s into OriginConfig", string(b)) return fmt.Errorf("cannot unmarshal %s into OriginConfig", string(b))
} }
// FallibleConfigMarshaler is a wrapper for FallibleConfig to implement custom marshal logic
type FallibleConfigMarshaler struct {
FallibleConfig FallibleConfig
}
func (fcm *FallibleConfigMarshaler) MarshalJSON() ([]byte, error) {
marshaler := make(map[string]FallibleConfig, 1)
marshaler[fcm.FallibleConfig.jsonType()] = fcm.FallibleConfig
return json.Marshal(marshaler)
}

View File

@ -70,50 +70,32 @@ func TestUnmarshalOrigin(t *testing.T) {
{ {
jsonLiteral: `{ jsonLiteral: `{
"Http":{ "Http":{
"url_string":"https://127.0.0.1:8080", "url_string":"https.example.com",
"tcp_keep_alive":30000000000, "tcp_keep_alive":7000000000,
"dial_dual_stack":true, "dial_dual_stack":true,
"tls_handshake_timeout":10000000000, "tls_handshake_timeout":11000000000,
"tls_verify":true, "tls_verify":true,
"origin_ca_pool":"", "origin_ca_pool":"/etc/cert.pem",
"origin_server_name":"", "origin_server_name":"secure.example.com",
"max_idle_connections":100, "max_idle_connections":19,
"idle_connection_timeout":90000000000, "idle_connection_timeout":17000000000,
"proxy_connection_timeout":90000000000, "proxy_connection_timeout":15000000000,
"expect_continue_timeout":90000000000, "expect_continue_timeout":21000000000,
"chunked_encoding":true "chunked_encoding":true
} }
}`, }`,
exceptedOriginConfig: &HTTPOriginConfig{ exceptedOriginConfig: sampleHTTPOriginConfig(),
URLString: "https://127.0.0.1:8080",
TCPKeepAlive: time.Second * 30,
DialDualStack: true,
TLSHandshakeTimeout: time.Second * 10,
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "",
MaxIdleConnections: 100,
IdleConnectionTimeout: time.Second * 90,
ProxyConnectionTimeout: time.Second * 90,
ExpectContinueTimeout: time.Second * 90,
ChunkedEncoding: true,
},
}, },
{ {
jsonLiteral: `{ jsonLiteral: `{
"WebSocket":{ "WebSocket":{
"url_string":"https://127.0.0.1:9090", "url_string":"ssh://example.com",
"tls_verify":true, "tls_verify":true,
"origin_ca_pool":"", "origin_ca_pool":"/etc/cert.pem",
"origin_server_name":"ws.example.com" "origin_server_name":"secure.example.com"
} }
}`, }`,
exceptedOriginConfig: &WebSocketOriginConfig{ exceptedOriginConfig: sampleWebSocketOriginConfig(),
URLString: "https://127.0.0.1:9090",
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "ws.example.com",
},
}, },
{ {
jsonLiteral: `{ jsonLiteral: `{
@ -124,11 +106,11 @@ func TestUnmarshalOrigin(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
originConfigJSON := strings.ReplaceAll(strings.ReplaceAll(test.jsonLiteral, "\n", ""), "\t", "") originConfigJSON := prettyToValidJSON(test.jsonLiteral)
var OriginConfigUnmarshaler OriginConfigUnmarshaler var OriginConfigJSONHandler OriginConfigJSONHandler
err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigUnmarshaler) err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigJSONHandler)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, test.exceptedOriginConfig, OriginConfigUnmarshaler.OriginConfig) assert.Equal(t, test.exceptedOriginConfig, OriginConfigJSONHandler.OriginConfig)
} }
} }
@ -176,7 +158,7 @@ func TestUnmarshalClientConfig(t *testing.T) {
}] }]
}` }`
// replace new line and tab // replace new line and tab
clientConfigJSON := strings.ReplaceAll(strings.ReplaceAll(prettyClientConfigJSON, "\n", ""), "\t", "") clientConfigJSON := prettyToValidJSON(prettyClientConfigJSON)
var clientConfig ClientConfig var clientConfig ClientConfig
err := json.Unmarshal([]byte(clientConfigJSON), &clientConfig) err := json.Unmarshal([]byte(clientConfigJSON), &clientConfig)
@ -211,7 +193,7 @@ func TestUnmarshalClientConfig(t *testing.T) {
reverseProxyConfig := ReverseProxyConfig{ reverseProxyConfig := ReverseProxyConfig{
TunnelHostname: "sdfjadk33.cftunnel.com", TunnelHostname: "sdfjadk33.cftunnel.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{ OriginConfigJSONHandler: &OriginConfigJSONHandler{
OriginConfig: &HTTPOriginConfig{ OriginConfig: &HTTPOriginConfig{
URLString: "https://127.0.0.1:8080", URLString: "https://127.0.0.1:8080",
TCPKeepAlive: time.Second * 30, TCPKeepAlive: time.Second * 30,
@ -236,6 +218,125 @@ func TestUnmarshalClientConfig(t *testing.T) {
assert.Equal(t, reverseProxyConfig, *clientConfig.ReverseProxyConfigs[0]) assert.Equal(t, reverseProxyConfig, *clientConfig.ReverseProxyConfigs[0])
} }
func TestMarshalFallibleConfig(t *testing.T) {
tests := []struct {
fallibleConfig FallibleConfig
expctedJSONLiteral string
}{
{
fallibleConfig: sampleSupervisorConfig(),
expctedJSONLiteral: `{
"supervisor_config":{
"auto_update_frequency":75600000000000,
"metrics_update_frequency":660000000000,
"grace_period":31000000000
}
}`,
},
{
fallibleConfig: sampleEdgeConnectionConfig(),
expctedJSONLiteral: `{
"edge_connection_config":{
"num_ha_connections":49,
"heartbeat_interval":5000000000,
"timeout":9000000000,
"max_failed_heartbeats":9001,
"user_credential_path":"/Users/example/.cloudflared/cert.pem"
}
}`,
},
{
fallibleConfig: sampleDoHProxyConfig(),
expctedJSONLiteral: `{
"doh_proxy_config":{
"listen_host":"127.0.0.1",
"listen_port":53,
"upstreams":["1.1.1.1","1.0.0.1"]
}
}`,
},
{
fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()}
}),
expctedJSONLiteral: `{
"reverse_proxy_config":{
"tunnel_hostname":"mock-non-lb-tunnel.example.com",
"origin_config":{
"Http":{
"url_string":"https.example.com",
"tcp_keep_alive":7000000000,
"dial_dual_stack":true,
"tls_handshake_timeout":11000000000,
"tls_verify":true,
"origin_ca_pool":"/etc/cert.pem",
"origin_server_name":"secure.example.com",
"max_idle_connections":19,
"idle_connection_timeout":17000000000,
"proxy_connection_timeout":15000000000,
"expect_continue_timeout":21000000000,
"chunked_encoding":true
}
},
"retries":18,
"connection_timeout":5000000000,
"compression_quality":3
}
}`,
},
{
fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()}
}),
expctedJSONLiteral: `{
"reverse_proxy_config":{
"tunnel_hostname":"mock-non-lb-tunnel.example.com",
"origin_config":{
"WebSocket":{
"url_string":"ssh://example.com",
"tls_verify":true,
"origin_ca_pool":"/etc/cert.pem",
"origin_server_name":"secure.example.com"
}
},
"retries":18,
"connection_timeout":5000000000,
"compression_quality":3
}
}`,
},
{
fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{&HelloWorldOriginConfig{}}
}),
expctedJSONLiteral: `{
"reverse_proxy_config":{
"tunnel_hostname":"mock-non-lb-tunnel.example.com",
"origin_config":{
"HelloWorld":{}
},
"retries":18,
"connection_timeout":5000000000,
"compression_quality":3
}
}`,
},
}
for _, test := range tests {
b, err := json.Marshal(test.fallibleConfig)
assert.NoError(t, err)
assert.Equal(t, prettyToValidJSON(test.expctedJSONLiteral), string(b))
}
}
type prettyJSON string
func prettyToValidJSON(prettyJSON string) string {
return strings.ReplaceAll(strings.ReplaceAll(prettyJSON, "\n", ""), "\t", "")
}
func eqScope(s1, s2 Scope) bool { func eqScope(s1, s2 Scope) bool {
return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType() return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType()
} }