TUN-8786: calculate cli flags once for the diagnostic procedure
## Summary The flags were always being computed when their value is static. Closes TUN-8786
This commit is contained in:
		
							parent
							
								
									77b99cf5fe
								
							
						
					
					
						commit
						ba9f28ef43
					
				|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
| 	"runtime/trace" | 	"runtime/trace" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
|  | @ -560,6 +561,7 @@ func StartServer( | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		readinessServer := metrics.NewReadyServer(clientID, tracker) | 		readinessServer := metrics.NewReadyServer(clientID, tracker) | ||||||
|  | 		cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList) | ||||||
| 		diagnosticHandler := diagnostic.NewDiagnosticHandler( | 		diagnosticHandler := diagnostic.NewDiagnosticHandler( | ||||||
| 			log, | 			log, | ||||||
| 			0, | 			0, | ||||||
|  | @ -567,8 +569,7 @@ func StartServer( | ||||||
| 			tunnelConfig.NamedTunnel.Credentials.TunnelID, | 			tunnelConfig.NamedTunnel.Credentials.TunnelID, | ||||||
| 			clientID, | 			clientID, | ||||||
| 			tracker, | 			tracker, | ||||||
| 			c, | 			cliFlags, | ||||||
| 			nonSecretFlagsList, |  | ||||||
| 			sources, | 			sources, | ||||||
| 		) | 		) | ||||||
| 		metricsConfig := metrics.Config{ | 		metricsConfig := metrics.Config{ | ||||||
|  | @ -1309,3 +1310,46 @@ reconnect [delay] | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string { | ||||||
|  | 	flagsNames := cli.FlagNames() | ||||||
|  | 	flags := make(map[string]string, len(flagsNames)) | ||||||
|  | 
 | ||||||
|  | 	for _, flag := range flagsNames { | ||||||
|  | 		value := cli.String(flag) | ||||||
|  | 
 | ||||||
|  | 		if value == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		isIncluded := isFlagIncluded(flagInclusionList, flag) | ||||||
|  | 		if !isIncluded { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		switch flag { | ||||||
|  | 		case logger.LogDirectoryFlag, logger.LogFileFlag: | ||||||
|  | 			{ | ||||||
|  | 				absolute, err := filepath.Abs(value) | ||||||
|  | 				if err != nil { | ||||||
|  | 					log.Error().Err(err).Msgf("could not convert %s path to absolute", flag) | ||||||
|  | 				} else { | ||||||
|  | 					flags[flag] = absolute | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		default: | ||||||
|  | 			flags[flag] = value | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return flags | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func isFlagIncluded(flagInclusionList []string, flag string) bool { | ||||||
|  | 	for _, include := range flagInclusionList { | ||||||
|  | 		if include == flag { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -25,7 +25,7 @@ func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUI | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 	log := zerolog.Nop() | 	log := zerolog.Nop() | ||||||
| 	tracker := tunnelstate.NewConnTracker(&log) | 	tracker := tunnelstate.NewConnTracker(&log) | ||||||
| 	handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, nil, []string{}, []string{}) | 	handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{}) | ||||||
| 	router := http.NewServeMux() | 	router := http.NewServeMux() | ||||||
| 	router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler) | 	router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler) | ||||||
| 	server := &http.Server{ | 	server := &http.Server{ | ||||||
|  |  | ||||||
|  | @ -5,28 +5,24 @@ import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" |  | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/rs/zerolog" | 	"github.com/rs/zerolog" | ||||||
| 	"github.com/urfave/cli/v2" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflared/logger" |  | ||||||
| 	"github.com/cloudflare/cloudflared/tunnelstate" | 	"github.com/cloudflare/cloudflared/tunnelstate" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Handler struct { | type Handler struct { | ||||||
| 	log               *zerolog.Logger | 	log             *zerolog.Logger | ||||||
| 	timeout           time.Duration | 	timeout         time.Duration | ||||||
| 	systemCollector   SystemCollector | 	systemCollector SystemCollector | ||||||
| 	tunnelID          uuid.UUID | 	tunnelID        uuid.UUID | ||||||
| 	connectorID       uuid.UUID | 	connectorID     uuid.UUID | ||||||
| 	tracker           *tunnelstate.ConnTracker | 	tracker         *tunnelstate.ConnTracker | ||||||
| 	cli               *cli.Context | 	cliFlags        map[string]string | ||||||
| 	flagInclusionList []string | 	icmpSources     []string | ||||||
| 	icmpSources       []string |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewDiagnosticHandler( | func NewDiagnosticHandler( | ||||||
|  | @ -36,8 +32,7 @@ func NewDiagnosticHandler( | ||||||
| 	tunnelID uuid.UUID, | 	tunnelID uuid.UUID, | ||||||
| 	connectorID uuid.UUID, | 	connectorID uuid.UUID, | ||||||
| 	tracker *tunnelstate.ConnTracker, | 	tracker *tunnelstate.ConnTracker, | ||||||
| 	cli *cli.Context, | 	cliFlags map[string]string, | ||||||
| 	flagInclusionList []string, |  | ||||||
| 	icmpSources []string, | 	icmpSources []string, | ||||||
| ) *Handler { | ) *Handler { | ||||||
| 	logger := log.With().Logger() | 	logger := log.With().Logger() | ||||||
|  | @ -45,16 +40,16 @@ func NewDiagnosticHandler( | ||||||
| 		timeout = defaultCollectorTimeout | 		timeout = defaultCollectorTimeout | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid()) | ||||||
| 	return &Handler{ | 	return &Handler{ | ||||||
| 		log:               &logger, | 		log:             &logger, | ||||||
| 		timeout:           timeout, | 		timeout:         timeout, | ||||||
| 		systemCollector:   systemCollector, | 		systemCollector: systemCollector, | ||||||
| 		tunnelID:          tunnelID, | 		tunnelID:        tunnelID, | ||||||
| 		connectorID:       connectorID, | 		connectorID:     connectorID, | ||||||
| 		tracker:           tracker, | 		tracker:         tracker, | ||||||
| 		cli:               cli, | 		cliFlags:        cliFlags, | ||||||
| 		flagInclusionList: flagInclusionList, | 		icmpSources:     icmpSources, | ||||||
| 		icmpSources:       icmpSources, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -140,68 +135,15 @@ func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http | ||||||
| 		log.Info().Msg("Collection finished") | 		log.Info().Msg("Collection finished") | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	flagsNames := handler.cli.FlagNames() |  | ||||||
| 	flags := make(map[string]string, len(flagsNames)) |  | ||||||
| 
 |  | ||||||
| 	for _, flag := range flagsNames { |  | ||||||
| 		value := handler.cli.String(flag) |  | ||||||
| 
 |  | ||||||
| 		// empty values are not relevant
 |  | ||||||
| 		if value == "" { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// exclude flags that are sensitive
 |  | ||||||
| 		isIncluded := handler.isFlagIncluded(flag) |  | ||||||
| 		if !isIncluded { |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		switch flag { |  | ||||||
| 		case logger.LogDirectoryFlag: |  | ||||||
| 			fallthrough |  | ||||||
| 		case logger.LogFileFlag: |  | ||||||
| 			{ |  | ||||||
| 				// the log directory may be relative to the instance thus it must be resolved
 |  | ||||||
| 				absolute, err := filepath.Abs(value) |  | ||||||
| 				if err != nil { |  | ||||||
| 					handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag) |  | ||||||
| 				} else { |  | ||||||
| 					flags[flag] = absolute |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		default: |  | ||||||
| 			flags[flag] = value |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// The UID is included to help the
 |  | ||||||
| 	// diagnostic tool to understand
 |  | ||||||
| 	// if this instance is managed or not.
 |  | ||||||
| 	flags[configurationKeyUID] = strconv.Itoa(os.Getuid()) |  | ||||||
| 	encoder := json.NewEncoder(writer) | 	encoder := json.NewEncoder(writer) | ||||||
| 
 | 
 | ||||||
| 	err := encoder.Encode(flags) | 	err := encoder.Encode(handler.cliFlags) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		handler.log.Error().Err(err).Msgf("error occurred whilst serializing response") | 		handler.log.Error().Err(err).Msgf("error occurred whilst serializing response") | ||||||
| 		writer.WriteHeader(http.StatusInternalServerError) | 		writer.WriteHeader(http.StatusInternalServerError) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (handler *Handler) isFlagIncluded(flag string) bool { |  | ||||||
| 	isIncluded := false |  | ||||||
| 
 |  | ||||||
| 	for _, include := range handler.flagInclusionList { |  | ||||||
| 		if include == flag { |  | ||||||
| 			isIncluded = true |  | ||||||
| 
 |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return isIncluded |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { | func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { | ||||||
| 	bytesWritten, err := w.Write(bytes) | 	bytesWritten, err := w.Write(bytes) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -4,7 +4,6 @@ import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"flag" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | @ -15,7 +14,6 @@ import ( | ||||||
| 	"github.com/rs/zerolog" | 	"github.com/rs/zerolog" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| 	"github.com/urfave/cli/v2" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflared/connection" | 	"github.com/cloudflare/cloudflared/connection" | ||||||
| 	"github.com/cloudflare/cloudflared/diagnostic" | 	"github.com/cloudflare/cloudflared/diagnostic" | ||||||
|  | @ -30,21 +28,6 @@ const ( | ||||||
| 	errorKey             = "errkey" | 	errorKey             = "errkey" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func buildCliContext(t *testing.T, flags map[string]string) *cli.Context { |  | ||||||
| 	t.Helper() |  | ||||||
| 
 |  | ||||||
| 	flagSet := flag.NewFlagSet("", flag.PanicOnError) |  | ||||||
| 	ctx := cli.NewContext(cli.NewApp(), flagSet, nil) |  | ||||||
| 
 |  | ||||||
| 	for k, v := range flags { |  | ||||||
| 		flagSet.String(k, v, "") |  | ||||||
| 		err := ctx.Set(k, v) |  | ||||||
| 		require.NoError(t, err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return ctx |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { | func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { | ||||||
| 	t.Helper() | 	t.Helper() | ||||||
| 
 | 
 | ||||||
|  | @ -80,7 +63,6 @@ func (*SystemCollectorMock) Collect(ctx context.Context) (*diagnostic.SystemInfo | ||||||
| 	si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation) | 	si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation) | ||||||
| 	ri, _ := ctx.Value(rawInformationKey).(string) | 	ri, _ := ctx.Value(rawInformationKey).(string) | ||||||
| 	err, _ := ctx.Value(errorKey).(error) | 	err, _ := ctx.Value(errorKey).(error) | ||||||
| 
 |  | ||||||
| 	return si, ri, err | 	return si, ri, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -122,8 +104,7 @@ func TestSystemHandler(t *testing.T) { | ||||||
| 	for _, tCase := range tests { | 	for _, tCase := range tests { | ||||||
| 		t.Run(tCase.name, func(t *testing.T) { | 		t.Run(tCase.name, func(t *testing.T) { | ||||||
| 			t.Parallel() | 			t.Parallel() | ||||||
| 
 | 			handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, map[string]string{}, nil) | ||||||
| 			handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil, nil) |  | ||||||
| 			recorder := httptest.NewRecorder() | 			recorder := httptest.NewRecorder() | ||||||
| 			ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) | 			ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) | ||||||
| 			request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) | 			request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) | ||||||
|  | @ -190,8 +171,7 @@ func TestTunnelStateHandler(t *testing.T) { | ||||||
| 				tCase.tunnelID, | 				tCase.tunnelID, | ||||||
| 				tCase.clientID, | 				tCase.clientID, | ||||||
| 				tracker, | 				tracker, | ||||||
| 				nil, | 				map[string]string{}, | ||||||
| 				nil, |  | ||||||
| 				tCase.icmpSources, | 				tCase.icmpSources, | ||||||
| 			) | 			) | ||||||
| 			recorder := httptest.NewRecorder() | 			recorder := httptest.NewRecorder() | ||||||
|  | @ -230,10 +210,10 @@ func TestConfigurationHandler(t *testing.T) { | ||||||
| 		{ | 		{ | ||||||
| 			name: "cli with flags", | 			name: "cli with flags", | ||||||
| 			flags: map[string]string{ | 			flags: map[string]string{ | ||||||
| 				"a": "a", | 				"b":   "a", | ||||||
| 				"b": "a", | 				"c":   "a", | ||||||
| 				"c": "a", | 				"d":   "a", | ||||||
| 				"d": "a", | 				"uid": "0", | ||||||
| 			}, | 			}, | ||||||
| 			expected: map[string]string{ | 			expected: map[string]string{ | ||||||
| 				"b":   "a", | 				"b":   "a", | ||||||
|  | @ -246,11 +226,11 @@ func TestConfigurationHandler(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	for _, tCase := range tests { | 	for _, tCase := range tests { | ||||||
| 		t.Run(tCase.name, func(t *testing.T) { | 		t.Run(tCase.name, func(t *testing.T) { | ||||||
|  | 			t.Parallel() | ||||||
|  | 
 | ||||||
| 			var response map[string]string | 			var response map[string]string | ||||||
| 
 | 
 | ||||||
| 			t.Parallel() | 			handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil) | ||||||
| 			ctx := buildCliContext(t, tCase.flags) |  | ||||||
| 			handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"}, nil) |  | ||||||
| 			recorder := httptest.NewRecorder() | 			recorder := httptest.NewRecorder() | ||||||
| 			handler.ConfigurationHandler(recorder, nil) | 			handler.ConfigurationHandler(recorder, nil) | ||||||
| 			decoder := json.NewDecoder(recorder.Body) | 			decoder := json.NewDecoder(recorder.Body) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue