diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 651bcb8d..8f757ce9 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -461,10 +461,11 @@ func StartServer( go func() { defer wg.Done() - readinessServer := metrics.NewReadyServer(clientID, - tunnelstate.NewConnTracker(log)) - observer.RegisterSink(readinessServer) - diagnosticHandler := diagnostic.NewDiagnosticHandler(log, 0, diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion)) + tracker := tunnelstate.NewConnTracker(log) + observer.RegisterSink(tracker) + + readinessServer := metrics.NewReadyServer(clientID, tracker) + diagnosticHandler := diagnostic.NewDiagnosticHandler(log, 0, diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion), tunnelConfig.NamedTunnel.Credentials.TunnelID, clientID, tracker) metricsConfig := metrics.Config{ ReadyServer: readinessServer, DiagnosticHandler: diagnosticHandler, diff --git a/connection/control.go b/connection/control.go index 94e0d66b..2e5f1e35 100644 --- a/connection/control.go +++ b/connection/control.go @@ -102,7 +102,7 @@ func (c *controlStream) ServeControlStream( c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol) - c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location) + c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location, c.edgeAddress) c.connectedFuse.Connected() // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration diff --git a/connection/event.go b/connection/event.go index d10b92fc..f4078fe0 100644 --- a/connection/event.go +++ b/connection/event.go @@ -1,12 +1,15 @@ package connection +import "net" + // Event is something that happened to a connection, e.g. disconnection or registration. type Event struct { - Index uint8 - EventType Status - Location string - Protocol Protocol - URL string + Index uint8 + EventType Status + Location string + Protocol Protocol + URL string + EdgeAddress net.IP } // Status is the status of a connection. diff --git a/connection/observer.go b/connection/observer.go index c6cb895e..817e6d2e 100644 --- a/connection/observer.go +++ b/connection/observer.go @@ -47,7 +47,6 @@ func (o *Observer) RegisterSink(sink EventSink) { } func (o *Observer) logConnected(connectionID uuid.UUID, connIndex uint8, location string, address net.IP, protocol Protocol) { - o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location}) o.log.Info(). Int(management.EventTypeKey, int(management.Cloudflared)). Str(LogFieldConnectionID, connectionID.String()). @@ -63,8 +62,8 @@ func (o *Observer) sendRegisteringEvent(connIndex uint8) { o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel}) } -func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string) { - o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location}) +func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string, edgeAddress net.IP) { + o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location, EdgeAddress: edgeAddress}) } func (o *Observer) SendURL(url string) { diff --git a/diagnostic/consts.go b/diagnostic/consts.go index 8081cca8..07cd8a7e 100644 --- a/diagnostic/consts.go +++ b/diagnostic/consts.go @@ -3,7 +3,8 @@ package diagnostic import "time" const ( - defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation. - collectorField = "collector" // used for logging purposes - systemCollectorName = "system" // used for logging purposes + defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation. + collectorField = "collector" // used for logging purposes + systemCollectorName = "system" // used for logging purposes + tunnelStateCollectorName = "tunnelState" // used for logging purposes ) diff --git a/diagnostic/handlers.go b/diagnostic/handlers.go index c9865795..a3faef1d 100644 --- a/diagnostic/handlers.go +++ b/diagnostic/handlers.go @@ -6,28 +6,41 @@ import ( "net/http" "time" + "github.com/google/uuid" "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/tunnelstate" ) type Handler struct { log *zerolog.Logger timeout time.Duration systemCollector SystemCollector + tunnelID uuid.UUID + connectorID uuid.UUID + tracker *tunnelstate.ConnTracker } func NewDiagnosticHandler( log *zerolog.Logger, timeout time.Duration, systemCollector SystemCollector, + tunnelID uuid.UUID, + connectorID uuid.UUID, + tracker *tunnelstate.ConnTracker, ) *Handler { + logger := log.With().Logger() if timeout == 0 { timeout = defaultCollectorTimeout } return &Handler{ - log, - timeout, - systemCollector, + log: &logger, + timeout: timeout, + systemCollector: systemCollector, + tunnelID: tunnelID, + connectorID: connectorID, + tracker: tracker, } } @@ -35,9 +48,7 @@ func (handler *Handler) SystemHandler(writer http.ResponseWriter, request *http. logger := handler.log.With().Str(collectorField, systemCollectorName).Logger() logger.Info().Msg("Collection started") - defer func() { - logger.Info().Msg("Collection finished") - }() + defer logger.Info().Msg("Collection finished") ctx, cancel := context.WithTimeout(request.Context(), handler.timeout) @@ -73,6 +84,32 @@ func (handler *Handler) SystemHandler(writer http.ResponseWriter, request *http. } } +type tunnelStateResponse struct { + TunnelID uuid.UUID `json:"tunnelID,omitempty"` + ConnectorID uuid.UUID `json:"connectorID,omitempty"` + Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"` +} + +func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.Request) { + log := handler.log.With().Str(collectorField, tunnelStateCollectorName).Logger() + log.Info().Msg("Collection started") + + defer log.Info().Msg("Collection finished") + + body := tunnelStateResponse{ + handler.tunnelID, + handler.connectorID, + handler.tracker.GetActiveConnections(), + } + encoder := json.NewEncoder(writer) + + err := encoder.Encode(body) + if err != nil { + handler.log.Error().Err(err).Msgf("error occurred whilst serializing information") + writer.WriteHeader(http.StatusInternalServerError) + } +} + func writeResponse(writer http.ResponseWriter, bytes []byte, logger *zerolog.Logger) { bytesWritten, err := writer.Write(bytes) if err != nil { diff --git a/diagnostic/handlers_test.go b/diagnostic/handlers_test.go index 984501f3..04ec60db 100644 --- a/diagnostic/handlers_test.go +++ b/diagnostic/handlers_test.go @@ -5,15 +5,19 @@ import ( "encoding/json" "errors" "io" + "net" "net/http" "net/http/httptest" "testing" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/diagnostic" + "github.com/cloudflare/cloudflared/tunnelstate" ) type SystemCollectorMock struct{} @@ -24,6 +28,23 @@ const ( errorKey = "errkey" ) +func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { + t.Helper() + + log := zerolog.Nop() + tracker := tunnelstate.NewConnTracker(&log) + + for _, conn := range connections { + tracker.OnTunnelEvent(connection.Event{ + Index: conn.Index, + EventType: connection.Connected, + Protocol: conn.Protocol, + EdgeAddress: conn.EdgeAddress, + }) + } + + return tracker +} func setCtxValuesForSystemCollector( systemInfo *diagnostic.SystemInformation, rawInfo string, @@ -83,7 +104,7 @@ func TestSystemHandler(t *testing.T) { for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { t.Parallel() - handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}) + handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil) recorder := httptest.NewRecorder() ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err) request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil) @@ -106,3 +127,58 @@ func TestSystemHandler(t *testing.T) { }) } } + +func TestTunnelStateHandler(t *testing.T) { + t.Parallel() + + log := zerolog.Nop() + tests := []struct { + name string + tunnelID uuid.UUID + clientID uuid.UUID + connections []tunnelstate.IndexedConnectionInfo + }{ + { + name: "case1", + tunnelID: uuid.New(), + clientID: uuid.New(), + }, + { + name: "case2", + tunnelID: uuid.New(), + clientID: uuid.New(), + connections: []tunnelstate.IndexedConnectionInfo{{ + ConnectionInfo: tunnelstate.ConnectionInfo{ + IsConnected: true, + Protocol: connection.QUIC, + EdgeAddress: net.IPv4(100, 100, 100, 100), + }, + Index: 0, + }}, + }, + } + + for _, tCase := range tests { + t.Run(tCase.name, func(t *testing.T) { + t.Parallel() + tracker := newTrackerFromConns(t, tCase.connections) + handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker) + recorder := httptest.NewRecorder() + handler.TunnelStateHandler(recorder, nil) + decoder := json.NewDecoder(recorder.Body) + + var response struct { + TunnelID uuid.UUID `json:"tunnelID,omitempty"` + ConnectorID uuid.UUID `json:"connectorID,omitempty"` + Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"` + } + + err := decoder.Decode(&response) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, tCase.tunnelID, response.TunnelID) + assert.Equal(t, tCase.clientID, response.ConnectorID) + assert.Equal(t, tCase.connections, response.Connections) + }) + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index 77e5e9a6..5b8ae2ff 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -94,6 +94,7 @@ func newMetricsHandler( }) } + router.HandleFunc("/diag/tunnel", config.DiagnosticHandler.TunnelStateHandler) router.HandleFunc("/diag/system", config.DiagnosticHandler.SystemHandler) return router diff --git a/metrics/readiness.go b/metrics/readiness.go index e2de549a..0e5124f1 100644 --- a/metrics/readiness.go +++ b/metrics/readiness.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" - conn "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/tunnelstate" ) @@ -28,10 +27,6 @@ func NewReadyServer( } } -func (rs *ReadyServer) OnTunnelEvent(c conn.Event) { - rs.tracker.OnTunnelEvent(c) -} - type body struct { Status int `json:"status"` ReadyConnections uint `json:"readyConnections"` diff --git a/metrics/readiness_test.go b/metrics/readiness_test.go index cd30bece..240f171e 100644 --- a/metrics/readiness_test.go +++ b/metrics/readiness_test.go @@ -44,7 +44,7 @@ func TestReadinessEventHandling(t *testing.T) { assert.Zero(t, readyConnections) // one connected => ok - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Connected, }) @@ -53,7 +53,7 @@ func TestReadinessEventHandling(t *testing.T) { assert.EqualValues(t, 1, readyConnections) // another connected => still ok - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 2, EventType: connection.Connected, }) @@ -62,7 +62,7 @@ func TestReadinessEventHandling(t *testing.T) { assert.EqualValues(t, 2, readyConnections) // one reconnecting => still ok - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 2, EventType: connection.Reconnecting, }) @@ -71,7 +71,7 @@ func TestReadinessEventHandling(t *testing.T) { assert.EqualValues(t, 1, readyConnections) // Regression test for TUN-3777 - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.RegisteringTunnel, }) @@ -80,14 +80,14 @@ func TestReadinessEventHandling(t *testing.T) { assert.Zero(t, readyConnections) // other connected then unregistered => not ok - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Connected, }) code, readyConnections = mockRequest(t, rs) assert.EqualValues(t, http.StatusOK, code) assert.EqualValues(t, 1, readyConnections) - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Unregistering, }) @@ -96,7 +96,7 @@ func TestReadinessEventHandling(t *testing.T) { assert.Zero(t, readyConnections) // other disconnected => not ok - rs.OnTunnelEvent(connection.Event{ + tracker.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Disconnected, }) diff --git a/tunnelstate/conntracker.go b/tunnelstate/conntracker.go index cd58a292..d0119f10 100644 --- a/tunnelstate/conntracker.go +++ b/tunnelstate/conntracker.go @@ -1,6 +1,7 @@ package tunnelstate import ( + "net" "sync" "github.com/rs/zerolog" @@ -16,8 +17,15 @@ type ConnTracker struct { } type ConnectionInfo struct { - IsConnected bool - Protocol connection.Protocol + IsConnected bool `json:"isConnected,omitempty"` + Protocol connection.Protocol `json:"protocol,omitempty"` + EdgeAddress net.IP `json:"edgeAddress,omitempty"` +} + +// Convinience struct to extend the connection with its index. +type IndexedConnectionInfo struct { + ConnectionInfo + Index uint8 `json:"index,omitempty"` } func NewConnTracker( @@ -36,6 +44,7 @@ func (ct *ConnTracker) OnTunnelEvent(c connection.Event) { ci := ConnectionInfo{ IsConnected: true, Protocol: c.Protocol, + EdgeAddress: c.EdgeAddress, } ct.connectionInfo[c.Index] = ci ct.mutex.Unlock() @@ -74,3 +83,21 @@ func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool { } return false } + +// Returns the connection information iff it is connected this +// also leverages the [IndexedConnectionInfo] to also provide the connection index +func (ct *ConnTracker) GetActiveConnections() []IndexedConnectionInfo { + ct.mutex.RLock() + defer ct.mutex.RUnlock() + + connections := make([]IndexedConnectionInfo, 0) + + for key, value := range ct.connectionInfo { + if value.IsConnected { + info := IndexedConnectionInfo{value, key} + connections = append(connections, info) + } + } + + return connections +}