diff --git a/diagnostic/diagnostic_utils.go b/diagnostic/diagnostic_utils.go index a14f2efa..bc811eae 100644 --- a/diagnostic/diagnostic_utils.go +++ b/diagnostic/diagnostic_utils.go @@ -2,12 +2,17 @@ package diagnostic import ( "archive/zip" + "context" "fmt" "io" + "net/url" "os" "path/filepath" "strings" "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" ) // CreateDiagnosticZipFile create a zip file with the contents from the all @@ -67,3 +72,69 @@ func CreateDiagnosticZipFile(base string, paths []string) (zipFileName string, e zipFileName = archive.Name() return zipFileName, nil } + +type AddressableTunnelState struct { + *TunnelState + URL *url.URL +} + +func findMetricsServerPredicate(tunnelID, connectorID uuid.UUID) func(state *TunnelState) bool { + if tunnelID != uuid.Nil && connectorID != uuid.Nil { + return func(state *TunnelState) bool { + return state.ConnectorID == connectorID && state.TunnelID == tunnelID + } + } else if tunnelID == uuid.Nil && connectorID != uuid.Nil { + return func(state *TunnelState) bool { + return state.ConnectorID == connectorID + } + } else if tunnelID != uuid.Nil && connectorID == uuid.Nil { + return func(state *TunnelState) bool { + return state.TunnelID == tunnelID + } + } + + return func(*TunnelState) bool { + return true + } +} + +// The FindMetricsServer will try to find the metrics server url. +// There are two possible error scenarios: +// 1. No instance is found which will only return ErrMetricsServerNotFound +// 2. Multiple instances are found which will return an array of state and ErrMultipleMetricsServerFound +// In case of success, only the state for the instance is returned. +func FindMetricsServer( + log *zerolog.Logger, + client *httpClient, + addresses []string, +) (*AddressableTunnelState, []*AddressableTunnelState, error) { + instances := make([]*AddressableTunnelState, 0) + + for _, address := range addresses { + url, err := url.Parse("http://" + address) + if err != nil { + log.Debug().Err(err).Msgf("error parsing address %s", address) + + continue + } + + client.SetBaseURL(url) + + state, err := client.GetTunnelState(context.Background()) + if err == nil { + instances = append(instances, &AddressableTunnelState{state, url}) + } else { + log.Debug().Err(err).Msgf("error getting tunnel state from address %s", address) + } + } + + if len(instances) == 0 { + return nil, nil, ErrMetricsServerNotFound + } + + if len(instances) == 1 { + return instances[0], nil, nil + } + + return nil, instances, ErrMultipleMetricsServerFound +} diff --git a/diagnostic/diagnostic_utils_test.go b/diagnostic/diagnostic_utils_test.go new file mode 100644 index 00000000..b068abd1 --- /dev/null +++ b/diagnostic/diagnostic_utils_test.go @@ -0,0 +1,147 @@ +package diagnostic_test + +import ( + "context" + "net/http" + "net/url" + "sync" + "testing" + "time" + + "github.com/facebookgo/grace/gracenet" + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/diagnostic" + "github.com/cloudflare/cloudflared/metrics" + "github.com/cloudflare/cloudflared/tunnelstate" +) + +func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUID, connectorID uuid.UUID) func() { + t.Helper() + listener, err := metrics.CreateMetricsListener(listeners, "localhost:0") + require.NoError(t, err) + log := zerolog.Nop() + tracker := tunnelstate.NewConnTracker(&log) + handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, nil, []string{}) + router := http.NewServeMux() + router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler) + server := &http.Server{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + Handler: router, + } + + var wgroup sync.WaitGroup + + wgroup.Add(1) + + go func() { + defer wgroup.Done() + + _ = server.Serve(listener) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + + cleanUp := func() { + _ = server.Shutdown(ctx) + + cancel() + wgroup.Wait() + } + + return cleanUp +} + +func TestFindMetricsServer_WhenSingleServerIsRunning_ReturnState(t *testing.T) { + listeners := gracenet.Net{} + tid1 := uuid.New() + cid1 := uuid.New() + + cleanUp := helperCreateServer(t, &listeners, tid1, cid1) + defer cleanUp() + + log := zerolog.Nop() + client := diagnostic.NewHTTPClient() + addresses := metrics.GetMetricsKnownAddresses("host") + url1, err := url.Parse("http://localhost:20241") + require.NoError(t, err) + + tunnel1 := &diagnostic.AddressableTunnelState{ + TunnelState: &diagnostic.TunnelState{ + TunnelID: tid1, + ConnectorID: cid1, + Connections: nil, + }, + URL: url1, + } + + state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:]) + if err != nil { + require.ErrorIs(t, err, diagnostic.ErrMultipleMetricsServerFound) + } + + assert.Equal(t, tunnel1, state) + assert.Nil(t, tunnels) +} + +func TestFindMetricsServer_WhenMultipleServerAreRunning_ReturnError(t *testing.T) { + listeners := gracenet.Net{} + tid1 := uuid.New() + cid1 := uuid.New() + cid2 := uuid.New() + + cleanUp := helperCreateServer(t, &listeners, tid1, cid1) + defer cleanUp() + + cleanUp = helperCreateServer(t, &listeners, tid1, cid2) + defer cleanUp() + + log := zerolog.Nop() + client := diagnostic.NewHTTPClient() + addresses := metrics.GetMetricsKnownAddresses("host") + url1, err := url.Parse("http://localhost:20241") + require.NoError(t, err) + url2, err := url.Parse("http://localhost:20242") + require.NoError(t, err) + + tunnel1 := &diagnostic.AddressableTunnelState{ + TunnelState: &diagnostic.TunnelState{ + TunnelID: tid1, + ConnectorID: cid1, + Connections: nil, + }, + URL: url1, + } + tunnel2 := &diagnostic.AddressableTunnelState{ + TunnelState: &diagnostic.TunnelState{ + TunnelID: tid1, + ConnectorID: cid2, + Connections: nil, + }, + URL: url2, + } + + state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:]) + if err != nil { + require.ErrorIs(t, err, diagnostic.ErrMultipleMetricsServerFound) + } + + assert.Nil(t, state) + assert.Equal(t, []*diagnostic.AddressableTunnelState{tunnel1, tunnel2}, tunnels) +} + +func TestFindMetricsServer_WhenNoInstanceIsRuning_ReturnError(t *testing.T) { + log := zerolog.Nop() + client := diagnostic.NewHTTPClient() + addresses := metrics.GetMetricsKnownAddresses("host") + + state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:]) + require.ErrorIs(t, err, diagnostic.ErrMetricsServerNotFound) + + assert.Nil(t, state) + assert.Nil(t, tunnels) +} diff --git a/diagnostic/error.go b/diagnostic/error.go index cc601b7f..39aa55d6 100644 --- a/diagnostic/error.go +++ b/diagnostic/error.go @@ -19,4 +19,8 @@ var ( ErrNoVolumeFound = errors.New("no disk volume information found") // Error user when the base url of the diagnostic client is not provided. ErrNoBaseUrl = errors.New("no base url") + // Error used when no metrics server is found listening to the known addresses list (check [metrics.GetMetricsKnownAddresses]) + ErrMetricsServerNotFound = errors.New("metrics server not found") + // Error used when multiple metrics server are found listening to the known addresses list (check [metrics.GetMetricsKnownAddresses]) + ErrMultipleMetricsServerFound = errors.New("multiple metrics server found") )