package diagnostic_test import ( "context" "encoding/json" "errors" "net" "net/http" "net/http/httptest" "runtime" "testing" "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/diagnostic" "github.com/cloudflare/cloudflared/tunnelstate" ) type SystemCollectorMock struct { systemInfo *diagnostic.SystemInformation err error } const ( systemInformationKey = "sikey" errorKey = "errkey" ) func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker { t.Helper() log := zerolog.Nop() tracker := tunnelstate.NewConnTracker(&log) for _, conn := range connections { tracker.OnTunnelEvent(connection.Event{ Index: conn.Index, EventType: connection.Connected, Protocol: conn.Protocol, EdgeAddress: conn.EdgeAddress, }) } return tracker } func (collector *SystemCollectorMock) Collect(context.Context) (*diagnostic.SystemInformation, error) { return collector.systemInfo, collector.err } func TestSystemHandler(t *testing.T) { t.Parallel() log := zerolog.Nop() tests := []struct { name string systemInfo *diagnostic.SystemInformation err error statusCode int }{ { name: "happy path", systemInfo: diagnostic.NewSystemInformation( 0, 0, 0, 0, "string", "string", "string", "string", "string", "string", runtime.Version(), runtime.GOARCH, nil, ), err: nil, statusCode: http.StatusOK, }, { name: "on error and no raw info", systemInfo: nil, err: errors.New("an error"), statusCode: http.StatusOK, }, } for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { t.Parallel() handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{ systemInfo: tCase.systemInfo, err: tCase.err, }, uuid.New(), uuid.New(), nil, map[string]string{}, nil) recorder := httptest.NewRecorder() ctx := context.Background() request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/system", nil) require.NoError(t, err) handler.SystemHandler(recorder, request) assert.Equal(t, tCase.statusCode, recorder.Code) if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil { var response diagnostic.SystemInformationResponse decoder := json.NewDecoder(recorder.Body) err := decoder.Decode(&response) require.NoError(t, err) assert.Equal(t, tCase.systemInfo, response.Info) } }) } } func TestTunnelStateHandler(t *testing.T) { t.Parallel() log := zerolog.Nop() tests := []struct { name string tunnelID uuid.UUID clientID uuid.UUID connections []tunnelstate.IndexedConnectionInfo icmpSources []string }{ { name: "case1", tunnelID: uuid.New(), clientID: uuid.New(), }, { name: "case2", tunnelID: uuid.New(), clientID: uuid.New(), icmpSources: []string{"172.17.0.3", "::1"}, connections: []tunnelstate.IndexedConnectionInfo{{ ConnectionInfo: tunnelstate.ConnectionInfo{ IsConnected: true, Protocol: connection.QUIC, EdgeAddress: net.IPv4(100, 100, 100, 100), }, Index: 0, }}, }, } for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { t.Parallel() tracker := newTrackerFromConns(t, tCase.connections) handler := diagnostic.NewDiagnosticHandler( &log, 0, nil, tCase.tunnelID, tCase.clientID, tracker, map[string]string{}, tCase.icmpSources, ) recorder := httptest.NewRecorder() handler.TunnelStateHandler(recorder, nil) decoder := json.NewDecoder(recorder.Body) var response diagnostic.TunnelState err := decoder.Decode(&response) require.NoError(t, err) assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, tCase.tunnelID, response.TunnelID) assert.Equal(t, tCase.clientID, response.ConnectorID) assert.Equal(t, tCase.connections, response.Connections) assert.Equal(t, tCase.icmpSources, response.ICMPSources) }) } } func TestConfigurationHandler(t *testing.T) { t.Parallel() log := zerolog.Nop() tests := []struct { name string flags map[string]string expected map[string]string }{ { name: "empty cli", flags: make(map[string]string), expected: map[string]string{ "uid": "0", }, }, { name: "cli with flags", flags: map[string]string{ "b": "a", "c": "a", "d": "a", "uid": "0", }, expected: map[string]string{ "b": "a", "c": "a", "d": "a", "uid": "0", }, }, } for _, tCase := range tests { t.Run(tCase.name, func(t *testing.T) { t.Parallel() var response map[string]string handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil) recorder := httptest.NewRecorder() handler.ConfigurationHandler(recorder, nil) decoder := json.NewDecoder(recorder.Body) err := decoder.Decode(&response) require.NoError(t, err) _, ok := response["uid"] assert.True(t, ok) delete(tCase.expected, "uid") delete(response, "uid") assert.Equal(t, http.StatusOK, recorder.Code) assert.Equal(t, tCase.expected, response) }) } }