TUN-8786: calculate cli flags once for the diagnostic procedure

## Summary

The flags were always being computed when their value is static.

 Closes TUN-8786
This commit is contained in:
Luis Neto 2024-12-11 01:29:20 -08:00
parent 77b99cf5fe
commit ba9f28ef43
4 changed files with 75 additions and 109 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
"path/filepath"
"runtime/trace" "runtime/trace"
"strings" "strings"
"sync" "sync"
@ -560,6 +561,7 @@ func StartServer(
} }
readinessServer := metrics.NewReadyServer(clientID, tracker) readinessServer := metrics.NewReadyServer(clientID, tracker)
cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList)
diagnosticHandler := diagnostic.NewDiagnosticHandler( diagnosticHandler := diagnostic.NewDiagnosticHandler(
log, log,
0, 0,
@ -567,8 +569,7 @@ func StartServer(
tunnelConfig.NamedTunnel.Credentials.TunnelID, tunnelConfig.NamedTunnel.Credentials.TunnelID,
clientID, clientID,
tracker, tracker,
c, cliFlags,
nonSecretFlagsList,
sources, sources,
) )
metricsConfig := metrics.Config{ metricsConfig := metrics.Config{
@ -1309,3 +1310,46 @@ reconnect [delay]
} }
} }
} }
func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string {
flagsNames := cli.FlagNames()
flags := make(map[string]string, len(flagsNames))
for _, flag := range flagsNames {
value := cli.String(flag)
if value == "" {
continue
}
isIncluded := isFlagIncluded(flagInclusionList, flag)
if !isIncluded {
continue
}
switch flag {
case logger.LogDirectoryFlag, logger.LogFileFlag:
{
absolute, err := filepath.Abs(value)
if err != nil {
log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
} else {
flags[flag] = absolute
}
}
default:
flags[flag] = value
}
}
return flags
}
func isFlagIncluded(flagInclusionList []string, flag string) bool {
for _, include := range flagInclusionList {
if include == flag {
return true
}
}
return false
}

View File

@ -25,7 +25,7 @@ func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUI
require.NoError(t, err) require.NoError(t, err)
log := zerolog.Nop() log := zerolog.Nop()
tracker := tunnelstate.NewConnTracker(&log) tracker := tunnelstate.NewConnTracker(&log)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, nil, []string{}, []string{}) handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{})
router := http.NewServeMux() router := http.NewServeMux()
router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler) router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler)
server := &http.Server{ server := &http.Server{

View File

@ -5,28 +5,24 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strconv" "strconv"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelstate" "github.com/cloudflare/cloudflared/tunnelstate"
) )
type Handler struct { type Handler struct {
log *zerolog.Logger log *zerolog.Logger
timeout time.Duration timeout time.Duration
systemCollector SystemCollector systemCollector SystemCollector
tunnelID uuid.UUID tunnelID uuid.UUID
connectorID uuid.UUID connectorID uuid.UUID
tracker *tunnelstate.ConnTracker tracker *tunnelstate.ConnTracker
cli *cli.Context cliFlags map[string]string
flagInclusionList []string icmpSources []string
icmpSources []string
} }
func NewDiagnosticHandler( func NewDiagnosticHandler(
@ -36,8 +32,7 @@ func NewDiagnosticHandler(
tunnelID uuid.UUID, tunnelID uuid.UUID,
connectorID uuid.UUID, connectorID uuid.UUID,
tracker *tunnelstate.ConnTracker, tracker *tunnelstate.ConnTracker,
cli *cli.Context, cliFlags map[string]string,
flagInclusionList []string,
icmpSources []string, icmpSources []string,
) *Handler { ) *Handler {
logger := log.With().Logger() logger := log.With().Logger()
@ -45,16 +40,16 @@ func NewDiagnosticHandler(
timeout = defaultCollectorTimeout timeout = defaultCollectorTimeout
} }
cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid())
return &Handler{ return &Handler{
log: &logger, log: &logger,
timeout: timeout, timeout: timeout,
systemCollector: systemCollector, systemCollector: systemCollector,
tunnelID: tunnelID, tunnelID: tunnelID,
connectorID: connectorID, connectorID: connectorID,
tracker: tracker, tracker: tracker,
cli: cli, cliFlags: cliFlags,
flagInclusionList: flagInclusionList, icmpSources: icmpSources,
icmpSources: icmpSources,
} }
} }
@ -140,68 +135,15 @@ func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http
log.Info().Msg("Collection finished") log.Info().Msg("Collection finished")
}() }()
flagsNames := handler.cli.FlagNames()
flags := make(map[string]string, len(flagsNames))
for _, flag := range flagsNames {
value := handler.cli.String(flag)
// empty values are not relevant
if value == "" {
continue
}
// exclude flags that are sensitive
isIncluded := handler.isFlagIncluded(flag)
if !isIncluded {
continue
}
switch flag {
case logger.LogDirectoryFlag:
fallthrough
case logger.LogFileFlag:
{
// the log directory may be relative to the instance thus it must be resolved
absolute, err := filepath.Abs(value)
if err != nil {
handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
} else {
flags[flag] = absolute
}
}
default:
flags[flag] = value
}
}
// The UID is included to help the
// diagnostic tool to understand
// if this instance is managed or not.
flags[configurationKeyUID] = strconv.Itoa(os.Getuid())
encoder := json.NewEncoder(writer) encoder := json.NewEncoder(writer)
err := encoder.Encode(flags) err := encoder.Encode(handler.cliFlags)
if err != nil { if err != nil {
handler.log.Error().Err(err).Msgf("error occurred whilst serializing response") handler.log.Error().Err(err).Msgf("error occurred whilst serializing response")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
} }
} }
func (handler *Handler) isFlagIncluded(flag string) bool {
isIncluded := false
for _, include := range handler.flagInclusionList {
if include == flag {
isIncluded = true
break
}
}
return isIncluded
}
func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) {
bytesWritten, err := w.Write(bytes) bytesWritten, err := w.Write(bytes)
if err != nil { if err != nil {

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -15,7 +14,6 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/diagnostic" "github.com/cloudflare/cloudflared/diagnostic"
@ -30,21 +28,6 @@ const (
errorKey = "errkey" errorKey = "errkey"
) )
func buildCliContext(t *testing.T, flags map[string]string) *cli.Context {
t.Helper()
flagSet := flag.NewFlagSet("", flag.PanicOnError)
ctx := cli.NewContext(cli.NewApp(), flagSet, nil)
for k, v := range flags {
flagSet.String(k, v, "")
err := ctx.Set(k, v)
require.NoError(t, err)
}
return ctx
}
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
t.Helper() t.Helper()
@ -80,7 +63,6 @@ func (*SystemCollectorMock) Collect(ctx context.Context) (*diagnostic.SystemInfo
si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation) si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation)
ri, _ := ctx.Value(rawInformationKey).(string) ri, _ := ctx.Value(rawInformationKey).(string)
err, _ := ctx.Value(errorKey).(error) err, _ := ctx.Value(errorKey).(error)
return si, ri, err return si, ri, err
} }
@ -122,8 +104,7 @@ func TestSystemHandler(t *testing.T) {
for _, tCase := range tests { for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
t.Parallel() t.Parallel()
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil)
@ -190,8 +171,7 @@ func TestTunnelStateHandler(t *testing.T) {
tCase.tunnelID, tCase.tunnelID,
tCase.clientID, tCase.clientID,
tracker, tracker,
nil, map[string]string{},
nil,
tCase.icmpSources, tCase.icmpSources,
) )
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@ -230,10 +210,10 @@ func TestConfigurationHandler(t *testing.T) {
{ {
name: "cli with flags", name: "cli with flags",
flags: map[string]string{ flags: map[string]string{
"a": "a", "b": "a",
"b": "a", "c": "a",
"c": "a", "d": "a",
"d": "a", "uid": "0",
}, },
expected: map[string]string{ expected: map[string]string{
"b": "a", "b": "a",
@ -246,11 +226,11 @@ func TestConfigurationHandler(t *testing.T) {
for _, tCase := range tests { for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
var response map[string]string var response map[string]string
t.Parallel() handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
ctx := buildCliContext(t, tCase.flags)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"}, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ConfigurationHandler(recorder, nil) handler.ConfigurationHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body) decoder := json.NewDecoder(recorder.Body)