From 5352b3cf0489e6568e148f40cd6e89660a7233cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveirinha?= Date: Tue, 8 Mar 2022 16:10:24 +0000 Subject: [PATCH] TUN-5801: Add custom wrapper for OriginConfig for JSON serde --- config/configuration.go | 45 +++++++++++++++++++---- config/configuration_test.go | 58 ++++++++++++++++++++++++------ ingress/config.go | 16 ++++----- ingress/config_test.go | 24 ++++++------- orchestration/orchestrator_test.go | 10 +++--- 5 files changed, 111 insertions(+), 42 deletions(-) diff --git a/config/configuration.go b/config/configuration.go index 8b16d4fe..ffd3eb7c 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -1,12 +1,14 @@ package config import ( + "encoding/json" "fmt" "io" "net/url" "os" "path/filepath" "runtime" + "strconv" "time" homedir "github.com/mitchellh/go-homedir" @@ -49,7 +51,7 @@ func DefaultConfigDirectory() string { path := os.Getenv("CFDPATH") if path == "" { path = filepath.Join(os.Getenv("ProgramFiles(x86)"), "cloudflared") - if _, err := os.Stat(path); os.IsNotExist(err) { //doesn't exist, so return an empty failure string + if _, err := os.Stat(path); os.IsNotExist(err) { // doesn't exist, so return an empty failure string return "" } } @@ -138,7 +140,7 @@ func FindOrCreateConfigPath() string { defer file.Close() logDir := DefaultLogDirectory() - _ = os.MkdirAll(logDir, os.ModePerm) //try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs + _ = os.MkdirAll(logDir, os.ModePerm) // try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs c := Root{ LogDirectory: logDir, @@ -190,17 +192,17 @@ type UnvalidatedIngressRule struct { // - To specify a time.Duration in json, use int64 of the nanoseconds type OriginRequestConfig struct { // HTTP proxy timeout for establishing a new connection - ConnectTimeout *time.Duration `yaml:"connectTimeout" json:"connectTimeout"` + ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout"` // HTTP proxy timeout for completing a TLS handshake - TLSTimeout *time.Duration `yaml:"tlsTimeout" json:"tlsTimeout"` + TLSTimeout *CustomDuration `yaml:"tlsTimeout" json:"tlsTimeout"` // HTTP proxy TCP keepalive duration - TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive" json:"tcpKeepAlive"` + TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive"` // HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback NoHappyEyeballs *bool `yaml:"noHappyEyeballs" json:"noHappyEyeballs"` // HTTP proxy maximum keepalive connection pool size KeepAliveConnections *int `yaml:"keepAliveConnections" json:"keepAliveConnections"` // HTTP proxy timeout for closing an idle connection - KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout" json:"keepAliveTimeout"` + KeepAliveTimeout *CustomDuration `yaml:"keepAliveTimeout" json:"keepAliveTimeout"` // Sets the HTTP Host header for the local webserver. HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader"` // Hostname on the origin server certificate. @@ -399,3 +401,34 @@ func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSe return &configuration, warnings, nil } + +// A CustomDuration is a Duration that has custom serialization for JSON. +// JSON in Javascript assumes that int fields are 32 bits and Duration fields are deserialized assuming that numbers +// are in nanoseconds, which in 32bit integers limits to just 2 seconds. +// This type assumes that when serializing/deserializing from JSON, that the number is in seconds, while it maintains +// the YAML serde assumptions. +type CustomDuration struct { + time.Duration +} + +func (s *CustomDuration) MarshalJSON() ([]byte, error) { + return json.Marshal(s.Duration.Seconds()) +} + +func (s *CustomDuration) UnmarshalJSON(data []byte) error { + seconds, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return err + } + + s.Duration = time.Duration(seconds * int64(time.Second)) + return nil +} + +func (s *CustomDuration) MarshalYAML() (interface{}, error) { + return s.Duration.String(), nil +} + +func (s *CustomDuration) UnmarshalYAML(unmarshal func(interface{}) error) error { + return unmarshal(&s.Duration) +} diff --git a/config/configuration_test.go b/config/configuration_test.go index 11db35db..d870913d 100644 --- a/config/configuration_test.go +++ b/config/configuration_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" yaml "gopkg.in/yaml.v2" ) @@ -110,14 +111,13 @@ counters: } -func TestUnmarshalOriginRequestConfig(t *testing.T) { - raw := []byte(` +var rawConfig = []byte(` { - "connectTimeout": 10000000000, - "tlsTimeout": 30000000000, - "tcpKeepAlive": 30000000000, + "connectTimeout": 10, + "tlsTimeout": 30, + "tcpKeepAlive": 30, "noHappyEyeballs": true, - "keepAliveTimeout": 60000000000, + "keepAliveTimeout": 60, "keepAliveConnections": 10, "httpHostHeader": "app.tunnel.com", "originServerName": "app.tunnel.com", @@ -142,13 +142,41 @@ func TestUnmarshalOriginRequestConfig(t *testing.T) { ] } `) + +func TestMarshalUnmarshalOriginRequest(t *testing.T) { + testCases := []struct { + name string + marshalFunc func(in interface{}) (out []byte, err error) + unMarshalFunc func(in []byte, out interface{}) (err error) + baseUnit time.Duration + }{ + {"json", json.Marshal, json.Unmarshal, time.Second}, + {"yaml", yaml.Marshal, yaml.Unmarshal, time.Nanosecond}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assertConfig(t, tc.marshalFunc, tc.unMarshalFunc, tc.baseUnit) + }) + } +} + +func assertConfig( + t *testing.T, + marshalFunc func(in interface{}) (out []byte, err error), + unMarshalFunc func(in []byte, out interface{}) (err error), + baseUnit time.Duration, +) { var config OriginRequestConfig - assert.NoError(t, json.Unmarshal(raw, &config)) - assert.Equal(t, time.Second*10, *config.ConnectTimeout) - assert.Equal(t, time.Second*30, *config.TLSTimeout) - assert.Equal(t, time.Second*30, *config.TCPKeepAlive) + var config2 OriginRequestConfig + + assert.NoError(t, unMarshalFunc(rawConfig, &config)) + + assert.Equal(t, baseUnit*10, config.ConnectTimeout.Duration) + assert.Equal(t, baseUnit*30, config.TLSTimeout.Duration) + assert.Equal(t, baseUnit*30, config.TCPKeepAlive.Duration) assert.Equal(t, true, *config.NoHappyEyeballs) - assert.Equal(t, time.Second*60, *config.KeepAliveTimeout) + assert.Equal(t, baseUnit*60, config.KeepAliveTimeout.Duration) assert.Equal(t, 10, *config.KeepAliveConnections) assert.Equal(t, "app.tunnel.com", *config.HTTPHostHeader) assert.Equal(t, "app.tunnel.com", *config.OriginServerName) @@ -176,4 +204,12 @@ func TestUnmarshalOriginRequestConfig(t *testing.T) { }, } assert.Equal(t, ipRules, config.IPRules) + + // validate that serializing and deserializing again matches the deserialization from raw string + result, err := marshalFunc(config) + require.NoError(t, err) + err = unMarshalFunc(result, &config2) + require.NoError(t, err) + + require.Equal(t, config2, config) } diff --git a/ingress/config.go b/ingress/config.go index b389bc9b..f8284b43 100644 --- a/ingress/config.go +++ b/ingress/config.go @@ -158,13 +158,13 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig { ProxyAddress: defaultProxyAddress, } if c.ConnectTimeout != nil { - out.ConnectTimeout = *c.ConnectTimeout + out.ConnectTimeout = c.ConnectTimeout.Duration } if c.TLSTimeout != nil { - out.TLSTimeout = *c.TLSTimeout + out.TLSTimeout = c.TLSTimeout.Duration } if c.TCPKeepAlive != nil { - out.TCPKeepAlive = *c.TCPKeepAlive + out.TCPKeepAlive = c.TCPKeepAlive.Duration } if c.NoHappyEyeballs != nil { out.NoHappyEyeballs = *c.NoHappyEyeballs @@ -173,7 +173,7 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig { out.KeepAliveConnections = *c.KeepAliveConnections } if c.KeepAliveTimeout != nil { - out.KeepAliveTimeout = *c.KeepAliveTimeout + out.KeepAliveTimeout = c.KeepAliveTimeout.Duration } if c.HTTPHostHeader != nil { out.HTTPHostHeader = *c.HTTPHostHeader @@ -257,13 +257,13 @@ type OriginRequestConfig struct { func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) { if val := overrides.ConnectTimeout; val != nil { - defaults.ConnectTimeout = *val + defaults.ConnectTimeout = val.Duration } } func (defaults *OriginRequestConfig) setTLSTimeout(overrides config.OriginRequestConfig) { if val := overrides.TLSTimeout; val != nil { - defaults.TLSTimeout = *val + defaults.TLSTimeout = val.Duration } } @@ -281,13 +281,13 @@ func (defaults *OriginRequestConfig) setKeepAliveConnections(overrides config.Or func (defaults *OriginRequestConfig) setKeepAliveTimeout(overrides config.OriginRequestConfig) { if val := overrides.KeepAliveTimeout; val != nil { - defaults.KeepAliveTimeout = *val + defaults.KeepAliveTimeout = val.Duration } } func (defaults *OriginRequestConfig) setTCPKeepAlive(overrides config.OriginRequestConfig) { if val := overrides.TCPKeepAlive; val != nil { - defaults.TCPKeepAlive = *val + defaults.TCPKeepAlive = val.Duration } } diff --git a/ingress/config_test.go b/ingress/config_test.go index ff0a3c0a..e520c9bf 100644 --- a/ingress/config_test.go +++ b/ingress/config_test.go @@ -191,12 +191,12 @@ ingress: rawConfig := []byte(` { "originRequest": { - "connectTimeout": 60000000000, - "tlsTimeout": 1000000000, + "connectTimeout": 60, + "tlsTimeout": 1, "noHappyEyeballs": true, - "tcpKeepAlive": 1000000000, + "tcpKeepAlive": 1, "keepAliveConnections": 1, - "keepAliveTimeout": 1000000000, + "keepAliveTimeout": 1, "httpHostHeader": "abc", "originServerName": "a1", "caPool": "/tmp/path0", @@ -228,12 +228,12 @@ ingress: "hostname": "*", "service": "https://localhost:8001", "originRequest": { - "connectTimeout": 120000000000, - "tlsTimeout": 2000000000, + "connectTimeout": 120, + "tlsTimeout": 2, "noHappyEyeballs": false, - "tcpKeepAlive": 2000000000, + "tcpKeepAlive": 2, "keepAliveConnections": 2, - "keepAliveTimeout": 2000000000, + "keepAliveTimeout": 2, "httpHostHeader": "def", "originServerName": "b2", "caPool": "/tmp/path1", @@ -360,12 +360,12 @@ ingress: "hostname": "*", "service": "https://localhost:8001", "originRequest": { - "connectTimeout": 120000000000, - "tlsTimeout": 2000000000, + "connectTimeout": 120, + "tlsTimeout": 2, "noHappyEyeballs": false, - "tcpKeepAlive": 2000000000, + "tcpKeepAlive": 2, "keepAliveConnections": 2, - "keepAliveTimeout": 2000000000, + "keepAliveTimeout": 2, "httpHostHeader": "def", "originServerName": "b2", "caPool": "/tmp/path1", diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index 4fe6c7b1..78de169a 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -58,7 +58,7 @@ func TestUpdateConfiguration(t *testing.T) { { "unknown_field": "not_deserialized", "originRequest": { - "connectTimeout": 90000000000, + "connectTimeout": 90, "noHappyEyeballs": true }, "ingress": [ @@ -68,7 +68,7 @@ func TestUpdateConfiguration(t *testing.T) { "service": "http://192.16.19.1:443", "originRequest": { "noTLSVerify": true, - "connectTimeout": 10000000000 + "connectTimeout": 10 } }, { @@ -76,7 +76,7 @@ func TestUpdateConfiguration(t *testing.T) { "service": "http://172.32.20.6:80", "originRequest": { "noTLSVerify": true, - "connectTimeout": 30000000000 + "connectTimeout": 30 } }, { @@ -192,7 +192,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { configJSONV1 = []byte(fmt.Sprintf(` { "originRequest": { - "connectTimeout": 90000000000, + "connectTimeout": 90, "noHappyEyeballs": true }, "ingress": [ @@ -201,7 +201,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { "service": "%s", "originRequest": { "httpHostHeader": "%s", - "connectTimeout": 10000000000 + "connectTimeout": 10 } }, {