cloudflared-mirror/diagnostic/handlers_test.go

260 lines
6.5 KiB
Go

package diagnostic_test
import (
"context"
"encoding/json"
"errors"
"flag"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/tunnelstate"
)
type SystemCollectorMock struct{}
const (
systemInformationKey = "sikey"
rawInformationKey = "rikey"
errorKey = "errkey"
)
func buildCliContext(t *testing.T, flags map[string]string) *cli.Context {
t.Helper()
flagSet := flag.NewFlagSet("", flag.PanicOnError)
ctx := cli.NewContext(cli.NewApp(), flagSet, nil)
for k, v := range flags {
flagSet.String(k, v, "")
err := ctx.Set(k, v)
require.NoError(t, err)
}
return ctx
}
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
t.Helper()
log := zerolog.Nop()
tracker := tunnelstate.NewConnTracker(&log)
for _, conn := range connections {
tracker.OnTunnelEvent(connection.Event{
Index: conn.Index,
EventType: connection.Connected,
Protocol: conn.Protocol,
EdgeAddress: conn.EdgeAddress,
})
}
return tracker
}
func setCtxValuesForSystemCollector(
systemInfo *diagnostic.SystemInformation,
rawInfo string,
err error,
) context.Context {
ctx := context.Background()
ctx = context.WithValue(ctx, systemInformationKey, systemInfo)
ctx = context.WithValue(ctx, rawInformationKey, rawInfo)
ctx = context.WithValue(ctx, errorKey, err)
return ctx
}
func (*SystemCollectorMock) Collect(ctx context.Context) (*diagnostic.SystemInformation, string, error) {
si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation)
ri, _ := ctx.Value(rawInformationKey).(string)
err, _ := ctx.Value(errorKey).(error)
return si, ri, err
}
func TestSystemHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
systemInfo *diagnostic.SystemInformation
rawInfo string
err error
statusCode int
}{
{
name: "happy path",
systemInfo: diagnostic.NewSystemInformation(
0, 0, 0, 0,
"string", "string", "string", "string",
"string", "string", nil,
),
rawInfo: "",
err: nil,
statusCode: http.StatusOK,
},
{
name: "on error and raw info", systemInfo: nil,
rawInfo: "raw info", err: errors.New("an error"), statusCode: http.StatusOK,
},
{
name: "on error and no raw info", systemInfo: nil,
rawInfo: "", err: errors.New("an error"), statusCode: http.StatusInternalServerError,
},
{
name: "malformed response", systemInfo: nil, rawInfo: "", err: nil, statusCode: http.StatusInternalServerError,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil)
recorder := httptest.NewRecorder()
ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err)
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil)
require.NoError(t, err)
handler.SystemHandler(recorder, request)
assert.Equal(t, tCase.statusCode, recorder.Code)
if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil {
var response diagnostic.SystemInformation
decoder := json.NewDecoder(recorder.Body)
err = decoder.Decode(&response)
require.NoError(t, err)
assert.Equal(t, tCase.systemInfo, &response)
} else if tCase.statusCode == http.StatusOK && tCase.rawInfo != "" {
rawBytes, err := io.ReadAll(recorder.Body)
require.NoError(t, err)
assert.Equal(t, tCase.rawInfo, string(rawBytes))
}
})
}
}
func TestTunnelStateHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
tunnelID uuid.UUID
clientID uuid.UUID
connections []tunnelstate.IndexedConnectionInfo
}{
{
name: "case1",
tunnelID: uuid.New(),
clientID: uuid.New(),
},
{
name: "case2",
tunnelID: uuid.New(),
clientID: uuid.New(),
connections: []tunnelstate.IndexedConnectionInfo{{
ConnectionInfo: tunnelstate.ConnectionInfo{
IsConnected: true,
Protocol: connection.QUIC,
EdgeAddress: net.IPv4(100, 100, 100, 100),
},
Index: 0,
}},
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
tracker := newTrackerFromConns(t, tCase.connections)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker, nil, nil)
recorder := httptest.NewRecorder()
handler.TunnelStateHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body)
var response struct {
TunnelID uuid.UUID `json:"tunnelID,omitempty"`
ConnectorID uuid.UUID `json:"connectorID,omitempty"`
Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"`
}
err := decoder.Decode(&response)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, tCase.tunnelID, response.TunnelID)
assert.Equal(t, tCase.clientID, response.ConnectorID)
assert.Equal(t, tCase.connections, response.Connections)
})
}
}
func TestConfigurationHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
flags map[string]string
expected map[string]string
}{
{
name: "empty cli",
flags: make(map[string]string),
expected: map[string]string{
"uid": "0",
},
},
{
name: "cli with flags",
flags: map[string]string{
"a": "a",
"b": "a",
"c": "a",
"d": "a",
},
expected: map[string]string{
"b": "a",
"c": "a",
"d": "a",
"uid": "0",
},
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
var response map[string]string
t.Parallel()
ctx := buildCliContext(t, tCase.flags)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"})
recorder := httptest.NewRecorder()
handler.ConfigurationHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body)
err := decoder.Decode(&response)
require.NoError(t, err)
_, ok := response["uid"]
assert.True(t, ok)
delete(tCase.expected, "uid")
delete(response, "uid")
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, tCase.expected, response)
})
}
}