diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 8f757ce9..508e71d0 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -127,6 +127,92 @@ var ( "most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+ "any existing DNS records for this hostname.", overwriteDNSFlag) deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)") + nonSecretFlagsList = []string{ + "config", + "autoupdate-freq", + "no-autoupdate", + "metrics", + "pidfile", + "url", + "hello-world", + "socks5", + "proxy-connect-timeout", + "proxy-tls-timeout", + "proxy-tcp-keepalive", + "proxy-no-happy-eyeballs", + "proxy-keepalive-connections", + "proxy-keepalive-timeout", + "proxy-connection-timeout", + "proxy-expect-continue-timeout", + "http-host-header", + "origin-server-name", + "unix-socket", + "origin-ca-pool", + "no-tls-verify", + "no-chunked-encoding", + "http2-origin", + "management-hostname", + "service-op-ip", + "local-ssh-port", + "ssh-idle-timeout", + "ssh-max-timeout", + "bucket-name", + "region-name", + "s3-url-host", + "host-key-path", + "ssh-server", + "bastion", + "proxy-address", + "proxy-port", + "loglevel", + "transport-loglevel", + "logfile", + "log-directory", + "trace-output", + "proxy-dns", + "proxy-dns-port", + "proxy-dns-address", + "proxy-dns-upstream", + "proxy-dns-max-upstream-conns", + "proxy-dns-bootstrap", + "is-autoupdated", + "edge", + "region", + "edge-ip-version", + "edge-bind-address", + "cacert", + "hostname", + "id", + "lb-pool", + "api-url", + "metrics-update-freq", + "tag", + "heartbeat-interval", + "heartbeat-count", + "max-edge-addr-retries", + "retries", + "ha-connections", + "rpc-timeout", + "write-stream-timeout", + "quic-disable-pmtu-discovery", + "quic-connection-level-flow-control-limit", + "quic-stream-level-flow-control-limit", + "label", + "grace-period", + "compression-quality", + "use-reconnect-token", + "dial-edge-timeout", + "stdin-control", + "name", + "ui", + "quick-service", + "max-fetch-size", + "post-quantum", + "management-diagnostics", + "protocol", + "overwrite-dns", + "help", + } ) func Flags() []cli.Flag { @@ -465,7 +551,16 @@ func StartServer( observer.RegisterSink(tracker) readinessServer := metrics.NewReadyServer(clientID, tracker) - diagnosticHandler := diagnostic.NewDiagnosticHandler(log, 0, diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion), tunnelConfig.NamedTunnel.Credentials.TunnelID, clientID, tracker) + diagnosticHandler := diagnostic.NewDiagnosticHandler( + log, + 0, + diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion), + tunnelConfig.NamedTunnel.Credentials.TunnelID, + clientID, + tracker, + c, + nonSecretFlagsList, + ) metricsConfig := metrics.Config{ ReadyServer: readinessServer, DiagnosticHandler: diagnosticHandler, diff --git a/diagnostic/consts.go b/diagnostic/consts.go index 07cd8a7e..0fd2b574 100644 --- a/diagnostic/consts.go +++ b/diagnostic/consts.go @@ -3,8 +3,10 @@ package diagnostic import "time" const ( - defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation. - collectorField = "collector" // used for logging purposes - systemCollectorName = "system" // used for logging purposes - tunnelStateCollectorName = "tunnelState" // used for logging purposes + defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation. + collectorField = "collector" // used for logging purposes + systemCollectorName = "system" // used for logging purposes + tunnelStateCollectorName = "tunnelState" // used for logging purposes + configurationCollectorName = "configuration" // used for logging purposes + configurationKeyUid = "uid" ) diff --git a/diagnostic/handlers.go b/diagnostic/handlers.go index a3faef1d..6eb2ed74 100644 --- a/diagnostic/handlers.go +++ b/diagnostic/handlers.go @@ -4,21 +4,28 @@ import ( "context" "encoding/json" "net/http" + "os" + "path/filepath" + "strconv" "time" "github.com/google/uuid" "github.com/rs/zerolog" + "github.com/urfave/cli/v2" + "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tunnelstate" ) type Handler struct { - log *zerolog.Logger - timeout time.Duration - systemCollector SystemCollector - tunnelID uuid.UUID - connectorID uuid.UUID - tracker *tunnelstate.ConnTracker + log *zerolog.Logger + timeout time.Duration + systemCollector SystemCollector + tunnelID uuid.UUID + connectorID uuid.UUID + tracker *tunnelstate.ConnTracker + cli *cli.Context + flagInclusionList []string } func NewDiagnosticHandler( @@ -28,6 +35,8 @@ func NewDiagnosticHandler( tunnelID uuid.UUID, connectorID uuid.UUID, tracker *tunnelstate.ConnTracker, + cli *cli.Context, + flagInclusionList []string, ) *Handler { logger := log.With().Logger() if timeout == 0 { @@ -35,12 +44,14 @@ func NewDiagnosticHandler( } return &Handler{ - log: &logger, - timeout: timeout, - systemCollector: systemCollector, - tunnelID: tunnelID, - connectorID: connectorID, - tracker: tracker, + log: &logger, + timeout: timeout, + systemCollector: systemCollector, + tunnelID: tunnelID, + connectorID: connectorID, + tracker: tracker, + cli: cli, + flagInclusionList: flagInclusionList, } } @@ -110,8 +121,77 @@ func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.R } } -func writeResponse(writer http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { - bytesWritten, err := writer.Write(bytes) +func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http.Request) { + log := handler.log.With().Str(collectorField, configurationCollectorName).Logger() + log.Info().Msg("Collection started") + + defer func() { + log.Info().Msg("Collection finished") + }() + + flagsNames := handler.cli.FlagNames() + flags := make(map[string]string, len(flagsNames)) + + for _, flag := range flagsNames { + value := handler.cli.String(flag) + + // empty values are not relevant + if value == "" { + continue + } + + // exclude flags that are sensitive + isIncluded := handler.isFlagIncluded(flag) + if !isIncluded { + continue + } + + switch flag { + case logger.LogDirectoryFlag: + case logger.LogFileFlag: + { + // the log directory may be relative to the instance thus it must be resolved + absolute, err := filepath.Abs(value) + if err != nil { + handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag) + } else { + flags[flag] = absolute + } + } + default: + flags[flag] = value + } + } + + // The UID is included to help the + // diagnostic tool to understand + // if this instance is managed or not. + flags[configurationKeyUid] = strconv.Itoa(os.Getuid()) + encoder := json.NewEncoder(writer) + + err := encoder.Encode(flags) + if err != nil { + handler.log.Error().Err(err).Msgf("error occurred whilst serializing response") + writer.WriteHeader(http.StatusInternalServerError) + } +} + +func (handler *Handler) isFlagIncluded(flag string) bool { + isIncluded := false + + for _, include := range handler.flagInclusionList { + if include == flag { + isIncluded = true + + break + } + } + + return isIncluded +} + +func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { + bytesWritten, err := w.Write(bytes) if err != nil { logger.Error().Err(err).Msg("error occurred writing response") } else if bytesWritten != len(bytes) { diff --git a/diagnostic/handlers_test.go b/diagnostic/handlers_test.go index 04ec60db..7e8222da 100644 --- a/diagnostic/handlers_test.go +++ b/diagnostic/handlers_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "flag" "io" "net" "net/http" @@ -14,6 +15,7 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/diagnostic" @@ -28,6 +30,21 @@ const ( errorKey = "errkey" ) +func buildCliContext(t *testing.T, flags map[string]string) *cli.Context { + t.Helper() + + flagSet := flag.NewFlagSet("", flag.PanicOnError) + ctx := cli.NewContext(cli.NewApp(), flagSet, nil) + + for k, v := range flags { + flagSet.String(k, v, "") + err := ctx.Set(k, v) + require.NoError(t, err) + } + + return ctx +} + func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { t.Helper() @@ -45,6 +62,7 @@ func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnecti return tracker } + func setCtxValuesForSystemCollector( systemInfo *diagnostic.SystemInformation, rawInfo string, @@ -104,7 +122,8 @@ func TestSystemHandler(t *testing.T) { for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { t.Parallel() - handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil) + + handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil) recorder := httptest.NewRecorder() ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) @@ -162,7 +181,7 @@ func TestTunnelStateHandler(t *testing.T) { t.Run(tCase.name, func(t *testing.T) { t.Parallel() tracker := newTrackerFromConns(t, tCase.connections) - handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker) + handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker, nil, nil) recorder := httptest.NewRecorder() handler.TunnelStateHandler(recorder, nil) decoder := json.NewDecoder(recorder.Body) @@ -182,3 +201,59 @@ func TestTunnelStateHandler(t *testing.T) { }) } } + +func TestConfigurationHandler(t *testing.T) { + t.Parallel() + + log := zerolog.Nop() + + tests := []struct { + name string + flags map[string]string + expected map[string]string + }{ + { + name: "empty cli", + flags: make(map[string]string), + expected: map[string]string{ + "uid": "0", + }, + }, + { + name: "cli with flags", + flags: map[string]string{ + "a": "a", + "b": "a", + "c": "a", + "d": "a", + }, + expected: map[string]string{ + "b": "a", + "c": "a", + "d": "a", + "uid": "0", + }, + }, + } + + for _, tCase := range tests { + t.Run(tCase.name, func(t *testing.T) { + var response map[string]string + + t.Parallel() + ctx := buildCliContext(t, tCase.flags) + handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"}) + recorder := httptest.NewRecorder() + handler.ConfigurationHandler(recorder, nil) + decoder := json.NewDecoder(recorder.Body) + err := decoder.Decode(&response) + require.NoError(t, err) + _, ok := response["uid"] + assert.True(t, ok) + delete(tCase.expected, "uid") + delete(response, "uid") + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, tCase.expected, response) + }) + } +} diff --git a/logger/configuration.go b/logger/configuration.go index 79dc6220..d406a220 100644 --- a/logger/configuration.go +++ b/logger/configuration.go @@ -76,10 +76,10 @@ func CreateConfig( var file *FileConfig var rolling *RollingConfig - if rollingLogPath != "" { - rolling = createRollingConfig(rollingLogPath) - } else if nonRollingLogFilePath != "" { + if nonRollingLogFilePath != "" { file = createFileConfig(nonRollingLogFilePath) + } else if rollingLogPath != "" { + rolling = createRollingConfig(rollingLogPath) } if minLevel == "" { diff --git a/metrics/metrics.go b/metrics/metrics.go index 5b8ae2ff..4fdbc2b2 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -94,6 +94,7 @@ func newMetricsHandler( }) } + router.HandleFunc("/diag/configuration", config.DiagnosticHandler.ConfigurationHandler) router.HandleFunc("/diag/tunnel", config.DiagnosticHandler.TunnelStateHandler) router.HandleFunc("/diag/system", config.DiagnosticHandler.SystemHandler)