diff --git a/Dockerfile b/Dockerfile index 8dac4752..39307965 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,9 @@ FROM golang:1.22.5 as builder ENV GO111MODULE=on \ CGO_ENABLED=0 \ TARGET_GOOS=${TARGET_GOOS} \ - TARGET_GOARCH=${TARGET_GOARCH} + TARGET_GOARCH=${TARGET_GOARCH} \ + CONTAINER_BUILD=1 + WORKDIR /go/src/github.com/cloudflare/cloudflared/ diff --git a/Makefile b/Makefile index 116bcc8f..c572954a 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,10 @@ ifdef PACKAGE_MANAGER VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)" endif +ifdef CONTAINER_BUILD + VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual" +endif + LINK_FLAGS := ifeq ($(FIPS), true) LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index f7482154..c900edb6 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -39,6 +39,7 @@ import ( "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tunneldns" + "github.com/cloudflare/cloudflared/tunnelstate" "github.com/cloudflare/cloudflared/validation" ) @@ -448,16 +449,19 @@ func StartServer( return err } - metricsListener, err := listeners.Listen("tcp", c.String("metrics")) + metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics")) if err != nil { log.Err(err).Msg("Error opening metrics server listener") return errors.Wrap(err, "Error opening metrics server listener") } + defer metricsListener.Close() wg.Add(1) + go func() { defer wg.Done() - readinessServer := metrics.NewReadyServer(log, clientID) + readinessServer := metrics.NewReadyServer(clientID, + tunnelstate.NewConnTracker(log)) observer.RegisterSink(readinessServer) metricsConfig := metrics.Config{ ReadyServer: readinessServer, @@ -857,9 +861,15 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag { Hidden: shouldHide, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "metrics", - Value: "localhost:", - Usage: "Listen address for metrics reporting.", + Name: "metrics", + Value: metrics.GetMetricsDefaultAddress(metrics.Runtime), + Usage: fmt.Sprintf( + `Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v. +If all are unavailable, a random port will be used. Note that when running cloudflared from an virtual +environment the default address binds to all interfaces, hence, it is important to isolate the host +and virtualized host network stacks from each other`, + metrics.GetMetricsKnownAddresses(metrics.Runtime), + ), EnvVars: []string{"TUNNEL_METRICS"}, Hidden: shouldHide, }), diff --git a/metrics/metrics.go b/metrics/metrics.go index 1759451e..6aadffeb 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/facebookgo/grace/gracenet" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" @@ -21,6 +22,34 @@ const ( defaultShutdownTimeout = time.Second * 15 ) +// This variable is set at compile time to allow the default local address to change. +var Runtime = "host" + +func GetMetricsDefaultAddress(runtimeType string) string { + // When issuing the diagnostic command we may have to reach a server that is + // running in a virtual enviroment and in that case we must bind to 0.0.0.0 + // otherwise the server won't be reachable. + switch runtimeType { + case "virtual": + return "0.0.0.0:0" + default: + return "localhost:0" + } +} + +// GetMetricsKnownAddresses returns the addresses used by the metrics server to bind at +// startup time to allow a semi-deterministic approach to know where the server is listening at. +// The ports were selected because at the time we are in 2024 and they do not collide with any +// know/registered port according https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers. +func GetMetricsKnownAddresses(runtimeType string) [5]string { + switch Runtime { + case "virtual": + return [5]string{"0.0.0.0:20241", "0.0.0.0:20242", "0.0.0.0:20243", "0.0.0.0:20244", "0.0.0.0:20245"} + default: + return [5]string{"localhost:20241", "localhost:20242", "localhost:20243", "localhost:20244", "localhost:20245"} + } +} + type Config struct { ReadyServer *ReadyServer QuickTunnelHostname string @@ -65,6 +94,42 @@ func newMetricsHandler( return router } +// CreateMetricsListener will create a new [net.Listener] by using an +// known set of ports when the default address is passed with the fallback +// of choosing a random port when none is available. +// +// In case the provided address is not the default one then it will be used +// as is. +func CreateMetricsListener(listeners *gracenet.Net, laddr string) (net.Listener, error) { + if laddr == GetMetricsDefaultAddress(Runtime) { + // On the presence of the default address select + // a port from the known set of addresses iteratively. + addresses := GetMetricsKnownAddresses(Runtime) + for _, address := range addresses { + listener, err := listeners.Listen("tcp", address) + if err == nil { + return listener, nil + } + } + + // When no port is available then bind to a random one + listener, err := listeners.Listen("tcp", laddr) + if err != nil { + return nil, fmt.Errorf("failed to listen to default metrics address: %w", err) + } + + return listener, nil + } + + // Explicitly got a local address then bind to it + listener, err := listeners.Listen("tcp", laddr) + if err != nil { + return nil, fmt.Errorf("failed to bind to address (%s): %w", laddr, err) + } + + return listener, nil +} + func ServeMetrics( l net.Listener, ctx context.Context, diff --git a/metrics/metrics_test.go b/metrics/metrics_test.go new file mode 100644 index 00000000..849076d7 --- /dev/null +++ b/metrics/metrics_test.go @@ -0,0 +1,52 @@ +package metrics_test + +import ( + "testing" + + "github.com/facebookgo/grace/gracenet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/metrics" +) + +func TestMetricsListenerCreation(t *testing.T) { + t.Parallel() + listeners := gracenet.Net{} + listener1, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host")) + assert.Equal(t, "127.0.0.1:20241", listener1.Addr().String()) + require.NoError(t, err) + listener2, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host")) + assert.Equal(t, "127.0.0.1:20242", listener2.Addr().String()) + require.NoError(t, err) + listener3, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host")) + assert.Equal(t, "127.0.0.1:20243", listener3.Addr().String()) + require.NoError(t, err) + listener4, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host")) + assert.Equal(t, "127.0.0.1:20244", listener4.Addr().String()) + require.NoError(t, err) + listener5, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host")) + assert.Equal(t, "127.0.0.1:20245", listener5.Addr().String()) + require.NoError(t, err) + listener6, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host")) + addresses := [5]string{"127.0.0.1:20241", "127.0.0.1:20242", "127.0.0.1:20243", "127.0.0.1:20244", "127.0.0.1:20245"} + assert.NotContains(t, addresses, listener6.Addr().String()) + require.NoError(t, err) + listener7, err := metrics.CreateMetricsListener(&listeners, "localhost:12345") + assert.Equal(t, "127.0.0.1:12345", listener7.Addr().String()) + require.NoError(t, err) + err = listener1.Close() + require.NoError(t, err) + err = listener2.Close() + require.NoError(t, err) + err = listener3.Close() + require.NoError(t, err) + err = listener4.Close() + require.NoError(t, err) + err = listener5.Close() + require.NoError(t, err) + err = listener6.Close() + require.NoError(t, err) + err = listener7.Close() + require.NoError(t, err) +} diff --git a/metrics/readiness.go b/metrics/readiness.go index b4e2025d..e2de549a 100644 --- a/metrics/readiness.go +++ b/metrics/readiness.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/google/uuid" - "github.com/rs/zerolog" conn "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/tunnelstate" @@ -19,10 +18,13 @@ type ReadyServer struct { } // NewReadyServer initializes a ReadyServer and starts listening for dis/connection events. -func NewReadyServer(log *zerolog.Logger, clientID uuid.UUID) *ReadyServer { +func NewReadyServer( + clientID uuid.UUID, + tracker *tunnelstate.ConnTracker, +) *ReadyServer { return &ReadyServer{ - clientID: clientID, - tracker: tunnelstate.NewConnTracker(log), + clientID, + tracker, } } diff --git a/metrics/readiness_test.go b/metrics/readiness_test.go index 8e035f85..cd30bece 100644 --- a/metrics/readiness_test.go +++ b/metrics/readiness_test.go @@ -1,136 +1,106 @@ -package metrics +package metrics_test import ( + "encoding/json" "net/http" + "net/http/httptest" "testing" "github.com/google/uuid" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/tunnelstate" ) -func TestReadyServer_makeResponse(t *testing.T) { - type fields struct { - isConnected map[uint8]tunnelstate.ConnectionInfo - } - tests := []struct { - name string - fields fields - wantOK bool - wantReadyConnections uint - }{ - { - name: "One connection online => HTTP 200", - fields: fields{ - isConnected: map[uint8]tunnelstate.ConnectionInfo{ - 0: {IsConnected: false}, - 1: {IsConnected: false}, - 2: {IsConnected: true}, - 3: {IsConnected: false}, - }, - }, - wantOK: true, - wantReadyConnections: 1, - }, - { - name: "No connections online => no HTTP 200", - fields: fields{ - isConnected: map[uint8]tunnelstate.ConnectionInfo{ - 0: {IsConnected: false}, - 1: {IsConnected: false}, - 2: {IsConnected: false}, - 3: {IsConnected: false}, - }, - }, - wantReadyConnections: 0, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rs := &ReadyServer{ - tracker: tunnelstate.MockedConnTracker(tt.fields.isConnected), - } - gotStatusCode, gotReadyConnections := rs.makeResponse() - if tt.wantOK && gotStatusCode != http.StatusOK { - t.Errorf("ReadyServer.makeResponse() gotStatusCode = %v, want ok = %v", gotStatusCode, tt.wantOK) - } - if gotReadyConnections != tt.wantReadyConnections { - t.Errorf("ReadyServer.makeResponse() gotReadyConnections = %v, want %v", gotReadyConnections, tt.wantReadyConnections) - } - }) +func mockRequest(t *testing.T, readyServer *metrics.ReadyServer) (int, uint) { + t.Helper() + + var readyreadyConnections struct { + Status int `json:"status"` + ReadyConnections uint `json:"readyConnections"` + ConnectorID uuid.UUID `json:"connectorId"` } + rec := httptest.NewRecorder() + readyServer.ServeHTTP(rec, nil) + + decoder := json.NewDecoder(rec.Body) + err := decoder.Decode(&readyreadyConnections) + require.NoError(t, err) + return rec.Code, readyreadyConnections.ReadyConnections } func TestReadinessEventHandling(t *testing.T) { nopLogger := zerolog.Nop() - rs := NewReadyServer(&nopLogger, uuid.Nil) + tracker := tunnelstate.NewConnTracker(&nopLogger) + rs := metrics.NewReadyServer(uuid.Nil, tracker) // start not ok - code, ready := rs.makeResponse() + code, readyConnections := mockRequest(t, rs) assert.NotEqualValues(t, http.StatusOK, code) - assert.Zero(t, ready) + assert.Zero(t, readyConnections) // one connected => ok rs.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Connected, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.EqualValues(t, http.StatusOK, code) - assert.EqualValues(t, 1, ready) + assert.EqualValues(t, 1, readyConnections) // another connected => still ok rs.OnTunnelEvent(connection.Event{ Index: 2, EventType: connection.Connected, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.EqualValues(t, http.StatusOK, code) - assert.EqualValues(t, 2, ready) + assert.EqualValues(t, 2, readyConnections) // one reconnecting => still ok rs.OnTunnelEvent(connection.Event{ Index: 2, EventType: connection.Reconnecting, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.EqualValues(t, http.StatusOK, code) - assert.EqualValues(t, 1, ready) + assert.EqualValues(t, 1, readyConnections) // Regression test for TUN-3777 rs.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.RegisteringTunnel, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.NotEqualValues(t, http.StatusOK, code) - assert.Zero(t, ready) + assert.Zero(t, readyConnections) // other connected then unregistered => not ok rs.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Connected, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.EqualValues(t, http.StatusOK, code) - assert.EqualValues(t, 1, ready) + assert.EqualValues(t, 1, readyConnections) rs.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Unregistering, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.NotEqualValues(t, http.StatusOK, code) - assert.Zero(t, ready) + assert.Zero(t, readyConnections) // other disconnected => not ok rs.OnTunnelEvent(connection.Event{ Index: 1, EventType: connection.Disconnected, }) - code, ready = rs.makeResponse() + code, readyConnections = mockRequest(t, rs) assert.NotEqualValues(t, http.StatusOK, code) - assert.Zero(t, ready) + assert.Zero(t, readyConnections) } diff --git a/tunnelstate/conntracker.go b/tunnelstate/conntracker.go index 426ba483..cd58a292 100644 --- a/tunnelstate/conntracker.go +++ b/tunnelstate/conntracker.go @@ -9,7 +9,7 @@ import ( ) type ConnTracker struct { - sync.RWMutex + mutex sync.RWMutex // int is the connection Index connectionInfo map[uint8]ConnectionInfo log *zerolog.Logger @@ -20,43 +20,39 @@ type ConnectionInfo struct { Protocol connection.Protocol } -func NewConnTracker(log *zerolog.Logger) *ConnTracker { +func NewConnTracker( + log *zerolog.Logger, +) *ConnTracker { return &ConnTracker{ connectionInfo: make(map[uint8]ConnectionInfo, 0), log: log, } } -func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker { - return &ConnTracker{ - connectionInfo: mocked, - } -} - func (ct *ConnTracker) OnTunnelEvent(c connection.Event) { switch c.EventType { case connection.Connected: - ct.Lock() + ct.mutex.Lock() ci := ConnectionInfo{ IsConnected: true, Protocol: c.Protocol, } ct.connectionInfo[c.Index] = ci - ct.Unlock() + ct.mutex.Unlock() case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering: - ct.Lock() + ct.mutex.Lock() ci := ct.connectionInfo[c.Index] ci.IsConnected = false ct.connectionInfo[c.Index] = ci - ct.Unlock() + ct.mutex.Unlock() default: ct.log.Error().Msgf("Unknown connection event case %v", c) } } func (ct *ConnTracker) CountActiveConns() uint { - ct.RLock() - defer ct.RUnlock() + ct.mutex.RLock() + defer ct.mutex.RUnlock() active := uint(0) for _, ci := range ct.connectionInfo { if ci.IsConnected { @@ -69,8 +65,8 @@ func (ct *ConnTracker) CountActiveConns() uint { // HasConnectedWith checks if we've ever had a successful connection to the edge // with said protocol. func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool { - ct.RLock() - defer ct.RUnlock() + ct.mutex.RLock() + defer ct.mutex.RUnlock() for _, ci := range ct.connectionInfo { if ci.Protocol == protocol { return true