From ba9f28ef43fdae3b2361f2fdcd49cf690f772bf4 Mon Sep 17 00:00:00 2001 From: Luis Neto Date: Wed, 11 Dec 2024 01:29:20 -0800 Subject: [PATCH] TUN-8786: calculate cli flags once for the diagnostic procedure ## Summary The flags were always being computed when their value is static. Closes TUN-8786 --- cmd/cloudflared/tunnel/cmd.go | 48 ++++++++++++++- diagnostic/diagnostic_utils_test.go | 2 +- diagnostic/handlers.go | 96 ++++++----------------------- diagnostic/handlers_test.go | 38 +++--------- 4 files changed, 75 insertions(+), 109 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 52b75fc9..cdf08bdf 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "path/filepath" "runtime/trace" "strings" "sync" @@ -560,6 +561,7 @@ func StartServer( } readinessServer := metrics.NewReadyServer(clientID, tracker) + cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList) diagnosticHandler := diagnostic.NewDiagnosticHandler( log, 0, @@ -567,8 +569,7 @@ func StartServer( tunnelConfig.NamedTunnel.Credentials.TunnelID, clientID, tracker, - c, - nonSecretFlagsList, + cliFlags, sources, ) metricsConfig := metrics.Config{ @@ -1309,3 +1310,46 @@ reconnect [delay] } } } + +func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string { + flagsNames := cli.FlagNames() + flags := make(map[string]string, len(flagsNames)) + + for _, flag := range flagsNames { + value := cli.String(flag) + + if value == "" { + continue + } + + isIncluded := isFlagIncluded(flagInclusionList, flag) + if !isIncluded { + continue + } + + switch flag { + case logger.LogDirectoryFlag, logger.LogFileFlag: + { + absolute, err := filepath.Abs(value) + if err != nil { + log.Error().Err(err).Msgf("could not convert %s path to absolute", flag) + } else { + flags[flag] = absolute + } + } + default: + flags[flag] = value + } + } + return flags +} + +func isFlagIncluded(flagInclusionList []string, flag string) bool { + for _, include := range flagInclusionList { + if include == flag { + return true + } + } + + return false +} diff --git a/diagnostic/diagnostic_utils_test.go b/diagnostic/diagnostic_utils_test.go index ecd2001a..f0f5a6a3 100644 --- a/diagnostic/diagnostic_utils_test.go +++ b/diagnostic/diagnostic_utils_test.go @@ -25,7 +25,7 @@ func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUI require.NoError(t, err) log := zerolog.Nop() tracker := tunnelstate.NewConnTracker(&log) - handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, nil, []string{}, []string{}) + handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{}) router := http.NewServeMux() router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler) server := &http.Server{ diff --git a/diagnostic/handlers.go b/diagnostic/handlers.go index a49d4aa3..1d9ef4f6 100644 --- a/diagnostic/handlers.go +++ b/diagnostic/handlers.go @@ -5,28 +5,24 @@ import ( "encoding/json" "net/http" "os" - "path/filepath" "strconv" "time" "github.com/google/uuid" "github.com/rs/zerolog" - "github.com/urfave/cli/v2" - "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tunnelstate" ) type Handler struct { - log *zerolog.Logger - timeout time.Duration - systemCollector SystemCollector - tunnelID uuid.UUID - connectorID uuid.UUID - tracker *tunnelstate.ConnTracker - cli *cli.Context - flagInclusionList []string - icmpSources []string + log *zerolog.Logger + timeout time.Duration + systemCollector SystemCollector + tunnelID uuid.UUID + connectorID uuid.UUID + tracker *tunnelstate.ConnTracker + cliFlags map[string]string + icmpSources []string } func NewDiagnosticHandler( @@ -36,8 +32,7 @@ func NewDiagnosticHandler( tunnelID uuid.UUID, connectorID uuid.UUID, tracker *tunnelstate.ConnTracker, - cli *cli.Context, - flagInclusionList []string, + cliFlags map[string]string, icmpSources []string, ) *Handler { logger := log.With().Logger() @@ -45,16 +40,16 @@ func NewDiagnosticHandler( timeout = defaultCollectorTimeout } + cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid()) return &Handler{ - log: &logger, - timeout: timeout, - systemCollector: systemCollector, - tunnelID: tunnelID, - connectorID: connectorID, - tracker: tracker, - cli: cli, - flagInclusionList: flagInclusionList, - icmpSources: icmpSources, + log: &logger, + timeout: timeout, + systemCollector: systemCollector, + tunnelID: tunnelID, + connectorID: connectorID, + tracker: tracker, + cliFlags: cliFlags, + icmpSources: icmpSources, } } @@ -140,68 +135,15 @@ func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http log.Info().Msg("Collection finished") }() - flagsNames := handler.cli.FlagNames() - flags := make(map[string]string, len(flagsNames)) - - for _, flag := range flagsNames { - value := handler.cli.String(flag) - - // empty values are not relevant - if value == "" { - continue - } - - // exclude flags that are sensitive - isIncluded := handler.isFlagIncluded(flag) - if !isIncluded { - continue - } - - switch flag { - case logger.LogDirectoryFlag: - fallthrough - case logger.LogFileFlag: - { - // the log directory may be relative to the instance thus it must be resolved - absolute, err := filepath.Abs(value) - if err != nil { - handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag) - } else { - flags[flag] = absolute - } - } - default: - flags[flag] = value - } - } - - // The UID is included to help the - // diagnostic tool to understand - // if this instance is managed or not. - flags[configurationKeyUID] = strconv.Itoa(os.Getuid()) encoder := json.NewEncoder(writer) - err := encoder.Encode(flags) + err := encoder.Encode(handler.cliFlags) if err != nil { handler.log.Error().Err(err).Msgf("error occurred whilst serializing response") writer.WriteHeader(http.StatusInternalServerError) } } -func (handler *Handler) isFlagIncluded(flag string) bool { - isIncluded := false - - for _, include := range handler.flagInclusionList { - if include == flag { - isIncluded = true - - break - } - } - - return isIncluded -} - func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { bytesWritten, err := w.Write(bytes) if err != nil { diff --git a/diagnostic/handlers_test.go b/diagnostic/handlers_test.go index fd2b9c27..3a300bff 100644 --- a/diagnostic/handlers_test.go +++ b/diagnostic/handlers_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "flag" "io" "net" "net/http" @@ -15,7 +14,6 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/diagnostic" @@ -30,21 +28,6 @@ const ( errorKey = "errkey" ) -func buildCliContext(t *testing.T, flags map[string]string) *cli.Context { - t.Helper() - - flagSet := flag.NewFlagSet("", flag.PanicOnError) - ctx := cli.NewContext(cli.NewApp(), flagSet, nil) - - for k, v := range flags { - flagSet.String(k, v, "") - err := ctx.Set(k, v) - require.NoError(t, err) - } - - return ctx -} - func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { t.Helper() @@ -80,7 +63,6 @@ func (*SystemCollectorMock) Collect(ctx context.Context) (*diagnostic.SystemInfo si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation) ri, _ := ctx.Value(rawInformationKey).(string) err, _ := ctx.Value(errorKey).(error) - return si, ri, err } @@ -122,8 +104,7 @@ func TestSystemHandler(t *testing.T) { for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { t.Parallel() - - handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil, nil) + handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, map[string]string{}, nil) recorder := httptest.NewRecorder() ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) @@ -190,8 +171,7 @@ func TestTunnelStateHandler(t *testing.T) { tCase.tunnelID, tCase.clientID, tracker, - nil, - nil, + map[string]string{}, tCase.icmpSources, ) recorder := httptest.NewRecorder() @@ -230,10 +210,10 @@ func TestConfigurationHandler(t *testing.T) { { name: "cli with flags", flags: map[string]string{ - "a": "a", - "b": "a", - "c": "a", - "d": "a", + "b": "a", + "c": "a", + "d": "a", + "uid": "0", }, expected: map[string]string{ "b": "a", @@ -246,11 +226,11 @@ func TestConfigurationHandler(t *testing.T) { for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { + t.Parallel() + var response map[string]string - t.Parallel() - ctx := buildCliContext(t, tCase.flags) - handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"}, nil) + handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil) recorder := httptest.NewRecorder() handler.ConfigurationHandler(recorder, nil) decoder := json.NewDecoder(recorder.Body)