TUN-8730: implement diag/configuration

Implements the endpoint that retrieves the configuration of a running instance.

The configuration consists in a map of cli flag to the provided value along with the uid that of the user that started the process
This commit is contained in:
Luis Neto 2024-11-25 11:24:51 -08:00
parent 4b0b6dc8c6
commit f85c0f1cc0
6 changed files with 277 additions and 24 deletions

View File

@ -127,6 +127,92 @@ var (
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+ "most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
"any existing DNS records for this hostname.", overwriteDNSFlag) "any existing DNS records for this hostname.", overwriteDNSFlag)
deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)") deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
nonSecretFlagsList = []string{
"config",
"autoupdate-freq",
"no-autoupdate",
"metrics",
"pidfile",
"url",
"hello-world",
"socks5",
"proxy-connect-timeout",
"proxy-tls-timeout",
"proxy-tcp-keepalive",
"proxy-no-happy-eyeballs",
"proxy-keepalive-connections",
"proxy-keepalive-timeout",
"proxy-connection-timeout",
"proxy-expect-continue-timeout",
"http-host-header",
"origin-server-name",
"unix-socket",
"origin-ca-pool",
"no-tls-verify",
"no-chunked-encoding",
"http2-origin",
"management-hostname",
"service-op-ip",
"local-ssh-port",
"ssh-idle-timeout",
"ssh-max-timeout",
"bucket-name",
"region-name",
"s3-url-host",
"host-key-path",
"ssh-server",
"bastion",
"proxy-address",
"proxy-port",
"loglevel",
"transport-loglevel",
"logfile",
"log-directory",
"trace-output",
"proxy-dns",
"proxy-dns-port",
"proxy-dns-address",
"proxy-dns-upstream",
"proxy-dns-max-upstream-conns",
"proxy-dns-bootstrap",
"is-autoupdated",
"edge",
"region",
"edge-ip-version",
"edge-bind-address",
"cacert",
"hostname",
"id",
"lb-pool",
"api-url",
"metrics-update-freq",
"tag",
"heartbeat-interval",
"heartbeat-count",
"max-edge-addr-retries",
"retries",
"ha-connections",
"rpc-timeout",
"write-stream-timeout",
"quic-disable-pmtu-discovery",
"quic-connection-level-flow-control-limit",
"quic-stream-level-flow-control-limit",
"label",
"grace-period",
"compression-quality",
"use-reconnect-token",
"dial-edge-timeout",
"stdin-control",
"name",
"ui",
"quick-service",
"max-fetch-size",
"post-quantum",
"management-diagnostics",
"protocol",
"overwrite-dns",
"help",
}
) )
func Flags() []cli.Flag { func Flags() []cli.Flag {
@ -465,7 +551,16 @@ func StartServer(
observer.RegisterSink(tracker) observer.RegisterSink(tracker)
readinessServer := metrics.NewReadyServer(clientID, tracker) readinessServer := metrics.NewReadyServer(clientID, tracker)
diagnosticHandler := diagnostic.NewDiagnosticHandler(log, 0, diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion), tunnelConfig.NamedTunnel.Credentials.TunnelID, clientID, tracker) diagnosticHandler := diagnostic.NewDiagnosticHandler(
log,
0,
diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion),
tunnelConfig.NamedTunnel.Credentials.TunnelID,
clientID,
tracker,
c,
nonSecretFlagsList,
)
metricsConfig := metrics.Config{ metricsConfig := metrics.Config{
ReadyServer: readinessServer, ReadyServer: readinessServer,
DiagnosticHandler: diagnosticHandler, DiagnosticHandler: diagnosticHandler,

View File

@ -3,8 +3,10 @@ package diagnostic
import "time" import "time"
const ( const (
defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation. defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation.
collectorField = "collector" // used for logging purposes collectorField = "collector" // used for logging purposes
systemCollectorName = "system" // used for logging purposes systemCollectorName = "system" // used for logging purposes
tunnelStateCollectorName = "tunnelState" // used for logging purposes tunnelStateCollectorName = "tunnelState" // used for logging purposes
configurationCollectorName = "configuration" // used for logging purposes
configurationKeyUid = "uid"
) )

View File

@ -4,21 +4,28 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"os"
"path/filepath"
"strconv"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelstate" "github.com/cloudflare/cloudflared/tunnelstate"
) )
type Handler struct { type Handler struct {
log *zerolog.Logger log *zerolog.Logger
timeout time.Duration timeout time.Duration
systemCollector SystemCollector systemCollector SystemCollector
tunnelID uuid.UUID tunnelID uuid.UUID
connectorID uuid.UUID connectorID uuid.UUID
tracker *tunnelstate.ConnTracker tracker *tunnelstate.ConnTracker
cli *cli.Context
flagInclusionList []string
} }
func NewDiagnosticHandler( func NewDiagnosticHandler(
@ -28,6 +35,8 @@ func NewDiagnosticHandler(
tunnelID uuid.UUID, tunnelID uuid.UUID,
connectorID uuid.UUID, connectorID uuid.UUID,
tracker *tunnelstate.ConnTracker, tracker *tunnelstate.ConnTracker,
cli *cli.Context,
flagInclusionList []string,
) *Handler { ) *Handler {
logger := log.With().Logger() logger := log.With().Logger()
if timeout == 0 { if timeout == 0 {
@ -35,12 +44,14 @@ func NewDiagnosticHandler(
} }
return &Handler{ return &Handler{
log: &logger, log: &logger,
timeout: timeout, timeout: timeout,
systemCollector: systemCollector, systemCollector: systemCollector,
tunnelID: tunnelID, tunnelID: tunnelID,
connectorID: connectorID, connectorID: connectorID,
tracker: tracker, tracker: tracker,
cli: cli,
flagInclusionList: flagInclusionList,
} }
} }
@ -110,8 +121,77 @@ func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.R
} }
} }
func writeResponse(writer http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http.Request) {
bytesWritten, err := writer.Write(bytes) log := handler.log.With().Str(collectorField, configurationCollectorName).Logger()
log.Info().Msg("Collection started")
defer func() {
log.Info().Msg("Collection finished")
}()
flagsNames := handler.cli.FlagNames()
flags := make(map[string]string, len(flagsNames))
for _, flag := range flagsNames {
value := handler.cli.String(flag)
// empty values are not relevant
if value == "" {
continue
}
// exclude flags that are sensitive
isIncluded := handler.isFlagIncluded(flag)
if !isIncluded {
continue
}
switch flag {
case logger.LogDirectoryFlag:
case logger.LogFileFlag:
{
// the log directory may be relative to the instance thus it must be resolved
absolute, err := filepath.Abs(value)
if err != nil {
handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
} else {
flags[flag] = absolute
}
}
default:
flags[flag] = value
}
}
// The UID is included to help the
// diagnostic tool to understand
// if this instance is managed or not.
flags[configurationKeyUid] = strconv.Itoa(os.Getuid())
encoder := json.NewEncoder(writer)
err := encoder.Encode(flags)
if err != nil {
handler.log.Error().Err(err).Msgf("error occurred whilst serializing response")
writer.WriteHeader(http.StatusInternalServerError)
}
}
func (handler *Handler) isFlagIncluded(flag string) bool {
isIncluded := false
for _, include := range handler.flagInclusionList {
if include == flag {
isIncluded = true
break
}
}
return isIncluded
}
func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) {
bytesWritten, err := w.Write(bytes)
if err != nil { if err != nil {
logger.Error().Err(err).Msg("error occurred writing response") logger.Error().Err(err).Msg("error occurred writing response")
} else if bytesWritten != len(bytes) { } else if bytesWritten != len(bytes) {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -14,6 +15,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/diagnostic" "github.com/cloudflare/cloudflared/diagnostic"
@ -28,6 +30,21 @@ const (
errorKey = "errkey" errorKey = "errkey"
) )
func buildCliContext(t *testing.T, flags map[string]string) *cli.Context {
t.Helper()
flagSet := flag.NewFlagSet("", flag.PanicOnError)
ctx := cli.NewContext(cli.NewApp(), flagSet, nil)
for k, v := range flags {
flagSet.String(k, v, "")
err := ctx.Set(k, v)
require.NoError(t, err)
}
return ctx
}
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
t.Helper() t.Helper()
@ -45,6 +62,7 @@ func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnecti
return tracker return tracker
} }
func setCtxValuesForSystemCollector( func setCtxValuesForSystemCollector(
systemInfo *diagnostic.SystemInformation, systemInfo *diagnostic.SystemInformation,
rawInfo string, rawInfo string,
@ -104,7 +122,8 @@ func TestSystemHandler(t *testing.T) {
for _, tCase := range tests { for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
t.Parallel() t.Parallel()
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil)
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil)
@ -162,7 +181,7 @@ func TestTunnelStateHandler(t *testing.T) {
t.Run(tCase.name, func(t *testing.T) { t.Run(tCase.name, func(t *testing.T) {
t.Parallel() t.Parallel()
tracker := newTrackerFromConns(t, tCase.connections) tracker := newTrackerFromConns(t, tCase.connections)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker) handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker, nil, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.TunnelStateHandler(recorder, nil) handler.TunnelStateHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body) decoder := json.NewDecoder(recorder.Body)
@ -182,3 +201,59 @@ func TestTunnelStateHandler(t *testing.T) {
}) })
} }
} }
func TestConfigurationHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
flags map[string]string
expected map[string]string
}{
{
name: "empty cli",
flags: make(map[string]string),
expected: map[string]string{
"uid": "0",
},
},
{
name: "cli with flags",
flags: map[string]string{
"a": "a",
"b": "a",
"c": "a",
"d": "a",
},
expected: map[string]string{
"b": "a",
"c": "a",
"d": "a",
"uid": "0",
},
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
var response map[string]string
t.Parallel()
ctx := buildCliContext(t, tCase.flags)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"})
recorder := httptest.NewRecorder()
handler.ConfigurationHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body)
err := decoder.Decode(&response)
require.NoError(t, err)
_, ok := response["uid"]
assert.True(t, ok)
delete(tCase.expected, "uid")
delete(response, "uid")
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, tCase.expected, response)
})
}
}

View File

@ -76,10 +76,10 @@ func CreateConfig(
var file *FileConfig var file *FileConfig
var rolling *RollingConfig var rolling *RollingConfig
if rollingLogPath != "" { if nonRollingLogFilePath != "" {
rolling = createRollingConfig(rollingLogPath)
} else if nonRollingLogFilePath != "" {
file = createFileConfig(nonRollingLogFilePath) file = createFileConfig(nonRollingLogFilePath)
} else if rollingLogPath != "" {
rolling = createRollingConfig(rollingLogPath)
} }
if minLevel == "" { if minLevel == "" {

View File

@ -94,6 +94,7 @@ func newMetricsHandler(
}) })
} }
router.HandleFunc("/diag/configuration", config.DiagnosticHandler.ConfigurationHandler)
router.HandleFunc("/diag/tunnel", config.DiagnosticHandler.TunnelStateHandler) router.HandleFunc("/diag/tunnel", config.DiagnosticHandler.TunnelStateHandler)
router.HandleFunc("/diag/system", config.DiagnosticHandler.SystemHandler) router.HandleFunc("/diag/system", config.DiagnosticHandler.SystemHandler)