diff --git a/cmd/cloudflared/config/configuration.go b/cmd/cloudflared/config/configuration.go index 035f556d..82ce2255 100644 --- a/cmd/cloudflared/config/configuration.go +++ b/cmd/cloudflared/config/configuration.go @@ -263,8 +263,16 @@ func (c *configFileSettings) String(name string) (string, error) { func (c *configFileSettings) StringSlice(name string) ([]string, error) { if raw, ok := c.Settings[name]; ok { - if v, ok := raw.([]string); ok { - return v, nil + if slice, ok := raw.([]interface{}); ok { + strSlice := make([]string, len(slice)) + for i, v := range slice { + str, ok := v.(string) + if !ok { + return nil, fmt.Errorf("expected string, found %T for %v", i, v) + } + strSlice[i] = str + } + return strSlice, nil } return nil, fmt.Errorf("expected string slice found %T for %s", raw, name) } @@ -273,6 +281,17 @@ func (c *configFileSettings) StringSlice(name string) ([]string, error) { func (c *configFileSettings) IntSlice(name string) ([]int, error) { if raw, ok := c.Settings[name]; ok { + if slice, ok := raw.([]interface{}); ok { + intSlice := make([]int, len(slice)) + for i, v := range slice { + str, ok := v.(int) + if !ok { + return nil, fmt.Errorf("expected int, found %T for %v ", v, v) + } + intSlice[i] = str + } + return intSlice, nil + } if v, ok := raw.([]int); ok { return v, nil } @@ -322,7 +341,6 @@ func ReadConfigFile(c *cli.Context, log logger.Service) (*configFileSettings, er if err := yaml.NewDecoder(file).Decode(&configuration); err != nil { return nil, err } - configuration.sourceFile = configFile return &configuration, nil } diff --git a/cmd/cloudflared/config/configuration_test.go b/cmd/cloudflared/config/configuration_test.go new file mode 100644 index 00000000..37357539 --- /dev/null +++ b/cmd/cloudflared/config/configuration_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" +) + +func TestConfigFileSettings(t *testing.T) { + var ( + firstIngress = UnvalidatedIngressRule{ + Hostname: "tunnel1.example.com", + Path: "/id", + Service: "https://localhost:8000", + } + secondIngress = UnvalidatedIngressRule{ + Hostname: "*", + Path: "", + Service: "https://localhost:8001", + } + ) + rawYAML := ` +tunnel: config-file-test +ingress: + - hostname: tunnel1.example.com + path: /id + service: https://localhost:8000 + - hostname: "*" + service: https://localhost:8001 +retries: 5 +grace-period: 30s +percentage: 3.14 +hostname: example.com +tag: + - test + - central-1 +counters: + - 123 + - 456 +` + var config configFileSettings + err := yaml.Unmarshal([]byte(rawYAML), &config) + assert.NoError(t, err) + + assert.Equal(t, "config-file-test", config.TunnelID) + assert.Equal(t, firstIngress, config.Ingress[0]) + assert.Equal(t, secondIngress, config.Ingress[1]) + + retries, err := config.Int("retries") + assert.NoError(t, err) + assert.Equal(t, 5, retries) + + gracePeriod, err := config.Duration("grace-period") + assert.NoError(t, err) + assert.Equal(t, time.Second*30, gracePeriod) + + percentage, err := config.Float64("percentage") + assert.NoError(t, err) + assert.Equal(t, 3.14, percentage) + + hostname, err := config.String("hostname") + assert.NoError(t, err) + assert.Equal(t, "example.com", hostname) + + tags, err := config.StringSlice("tag") + assert.NoError(t, err) + assert.Equal(t, "test", tags[0]) + assert.Equal(t, "central-1", tags[1]) + + counters, err := config.IntSlice("counters") + assert.NoError(t, err) + assert.Equal(t, 123, counters[0]) + assert.Equal(t, 456, counters[1]) +} diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 71aa8660..0de66fcd 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -537,25 +537,24 @@ func forceSetFlag(c *cli.Context, name, value string) { } func SetFlagsFromConfigFile(c *cli.Context) error { + const exitCode = 1 log, err := createLogger(c, false, false) if err != nil { return cliutil.PrintLoggerSetupError("error setting up logger", err) } - inputSource, err := config.ReadConfigFile(c, log) if err != nil { if err == config.ErrNoConfigFile { return nil } - return err + return cli.Exit(err, exitCode) } targetFlags := c.Command.Flags if c.Command.Name == "" { targetFlags = c.App.Flags } if err := altsrc.ApplyInputSourceValues(c, inputSource, targetFlags); err != nil { - log.Errorf("Cannot load configuration from %s: %v", inputSource.Source(), err) - return err + return cli.Exit(err, exitCode) } return nil }