TUN-8737: update metrics server port selection

## Summary
Update how metrics server binds to a listener by using a known set of ports whenever the default address is used with the fallback to a random port in case all address are already in use. The default address changes at compile time in order to bind to a different default address when the final deliverable is a docker image.

Refactor ReadyServer tests.

Closes TUN-8737
This commit is contained in:
Luis Neto 2024-11-22 07:23:46 -08:00
parent d779394748
commit e2c2b012f1
8 changed files with 194 additions and 93 deletions

View File

@ -5,7 +5,9 @@ FROM golang:1.22.5 as builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \ TARGET_GOOS=${TARGET_GOOS} \
TARGET_GOARCH=${TARGET_GOARCH} TARGET_GOARCH=${TARGET_GOARCH} \
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/ WORKDIR /go/src/github.com/cloudflare/cloudflared/

View File

@ -30,6 +30,10 @@ ifdef PACKAGE_MANAGER
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)" VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
endif endif
ifdef CONTAINER_BUILD
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
endif
LINK_FLAGS := LINK_FLAGS :=
ifeq ($(FIPS), true) ifeq ($(FIPS), true)
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS) LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)

View File

@ -39,6 +39,7 @@ import (
"github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelstate"
"github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/validation"
) )
@ -448,16 +449,19 @@ func StartServer(
return err return err
} }
metricsListener, err := listeners.Listen("tcp", c.String("metrics")) metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics"))
if err != nil { if err != nil {
log.Err(err).Msg("Error opening metrics server listener") log.Err(err).Msg("Error opening metrics server listener")
return errors.Wrap(err, "Error opening metrics server listener") return errors.Wrap(err, "Error opening metrics server listener")
} }
defer metricsListener.Close() defer metricsListener.Close()
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
readinessServer := metrics.NewReadyServer(log, clientID) readinessServer := metrics.NewReadyServer(clientID,
tunnelstate.NewConnTracker(log))
observer.RegisterSink(readinessServer) observer.RegisterSink(readinessServer)
metricsConfig := metrics.Config{ metricsConfig := metrics.Config{
ReadyServer: readinessServer, ReadyServer: readinessServer,
@ -857,9 +861,15 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "metrics", Name: "metrics",
Value: "localhost:", Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
Usage: "Listen address for metrics reporting.", Usage: fmt.Sprintf(
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
If all are unavailable, a random port will be used. Note that when running cloudflared from an virtual
environment the default address binds to all interfaces, hence, it is important to isolate the host
and virtualized host network stacks from each other`,
metrics.GetMetricsKnownAddresses(metrics.Runtime),
),
EnvVars: []string{"TUNNEL_METRICS"}, EnvVars: []string{"TUNNEL_METRICS"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),

View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/facebookgo/grace/gracenet"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -21,6 +22,34 @@ const (
defaultShutdownTimeout = time.Second * 15 defaultShutdownTimeout = time.Second * 15
) )
// This variable is set at compile time to allow the default local address to change.
var Runtime = "host"
func GetMetricsDefaultAddress(runtimeType string) string {
// When issuing the diagnostic command we may have to reach a server that is
// running in a virtual enviroment and in that case we must bind to 0.0.0.0
// otherwise the server won't be reachable.
switch runtimeType {
case "virtual":
return "0.0.0.0:0"
default:
return "localhost:0"
}
}
// GetMetricsKnownAddresses returns the addresses used by the metrics server to bind at
// startup time to allow a semi-deterministic approach to know where the server is listening at.
// The ports were selected because at the time we are in 2024 and they do not collide with any
// know/registered port according https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers.
func GetMetricsKnownAddresses(runtimeType string) [5]string {
switch Runtime {
case "virtual":
return [5]string{"0.0.0.0:20241", "0.0.0.0:20242", "0.0.0.0:20243", "0.0.0.0:20244", "0.0.0.0:20245"}
default:
return [5]string{"localhost:20241", "localhost:20242", "localhost:20243", "localhost:20244", "localhost:20245"}
}
}
type Config struct { type Config struct {
ReadyServer *ReadyServer ReadyServer *ReadyServer
QuickTunnelHostname string QuickTunnelHostname string
@ -65,6 +94,42 @@ func newMetricsHandler(
return router return router
} }
// CreateMetricsListener will create a new [net.Listener] by using an
// known set of ports when the default address is passed with the fallback
// of choosing a random port when none is available.
//
// In case the provided address is not the default one then it will be used
// as is.
func CreateMetricsListener(listeners *gracenet.Net, laddr string) (net.Listener, error) {
if laddr == GetMetricsDefaultAddress(Runtime) {
// On the presence of the default address select
// a port from the known set of addresses iteratively.
addresses := GetMetricsKnownAddresses(Runtime)
for _, address := range addresses {
listener, err := listeners.Listen("tcp", address)
if err == nil {
return listener, nil
}
}
// When no port is available then bind to a random one
listener, err := listeners.Listen("tcp", laddr)
if err != nil {
return nil, fmt.Errorf("failed to listen to default metrics address: %w", err)
}
return listener, nil
}
// Explicitly got a local address then bind to it
listener, err := listeners.Listen("tcp", laddr)
if err != nil {
return nil, fmt.Errorf("failed to bind to address (%s): %w", laddr, err)
}
return listener, nil
}
func ServeMetrics( func ServeMetrics(
l net.Listener, l net.Listener,
ctx context.Context, ctx context.Context,

52
metrics/metrics_test.go Normal file
View File

@ -0,0 +1,52 @@
package metrics_test
import (
"testing"
"github.com/facebookgo/grace/gracenet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/metrics"
)
func TestMetricsListenerCreation(t *testing.T) {
t.Parallel()
listeners := gracenet.Net{}
listener1, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20241", listener1.Addr().String())
require.NoError(t, err)
listener2, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20242", listener2.Addr().String())
require.NoError(t, err)
listener3, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20243", listener3.Addr().String())
require.NoError(t, err)
listener4, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20244", listener4.Addr().String())
require.NoError(t, err)
listener5, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20245", listener5.Addr().String())
require.NoError(t, err)
listener6, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
addresses := [5]string{"127.0.0.1:20241", "127.0.0.1:20242", "127.0.0.1:20243", "127.0.0.1:20244", "127.0.0.1:20245"}
assert.NotContains(t, addresses, listener6.Addr().String())
require.NoError(t, err)
listener7, err := metrics.CreateMetricsListener(&listeners, "localhost:12345")
assert.Equal(t, "127.0.0.1:12345", listener7.Addr().String())
require.NoError(t, err)
err = listener1.Close()
require.NoError(t, err)
err = listener2.Close()
require.NoError(t, err)
err = listener3.Close()
require.NoError(t, err)
err = listener4.Close()
require.NoError(t, err)
err = listener5.Close()
require.NoError(t, err)
err = listener6.Close()
require.NoError(t, err)
err = listener7.Close()
require.NoError(t, err)
}

View File

@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog"
conn "github.com/cloudflare/cloudflared/connection" conn "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/tunnelstate" "github.com/cloudflare/cloudflared/tunnelstate"
@ -19,10 +18,13 @@ type ReadyServer struct {
} }
// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events. // NewReadyServer initializes a ReadyServer and starts listening for dis/connection events.
func NewReadyServer(log *zerolog.Logger, clientID uuid.UUID) *ReadyServer { func NewReadyServer(
clientID uuid.UUID,
tracker *tunnelstate.ConnTracker,
) *ReadyServer {
return &ReadyServer{ return &ReadyServer{
clientID: clientID, clientID,
tracker: tunnelstate.NewConnTracker(log), tracker,
} }
} }

View File

@ -1,136 +1,106 @@
package metrics package metrics_test
import ( import (
"encoding/json"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/tunnelstate" "github.com/cloudflare/cloudflared/tunnelstate"
) )
func TestReadyServer_makeResponse(t *testing.T) { func mockRequest(t *testing.T, readyServer *metrics.ReadyServer) (int, uint) {
type fields struct { t.Helper()
isConnected map[uint8]tunnelstate.ConnectionInfo
} var readyreadyConnections struct {
tests := []struct { Status int `json:"status"`
name string ReadyConnections uint `json:"readyConnections"`
fields fields ConnectorID uuid.UUID `json:"connectorId"`
wantOK bool
wantReadyConnections uint
}{
{
name: "One connection online => HTTP 200",
fields: fields{
isConnected: map[uint8]tunnelstate.ConnectionInfo{
0: {IsConnected: false},
1: {IsConnected: false},
2: {IsConnected: true},
3: {IsConnected: false},
},
},
wantOK: true,
wantReadyConnections: 1,
},
{
name: "No connections online => no HTTP 200",
fields: fields{
isConnected: map[uint8]tunnelstate.ConnectionInfo{
0: {IsConnected: false},
1: {IsConnected: false},
2: {IsConnected: false},
3: {IsConnected: false},
},
},
wantReadyConnections: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := &ReadyServer{
tracker: tunnelstate.MockedConnTracker(tt.fields.isConnected),
}
gotStatusCode, gotReadyConnections := rs.makeResponse()
if tt.wantOK && gotStatusCode != http.StatusOK {
t.Errorf("ReadyServer.makeResponse() gotStatusCode = %v, want ok = %v", gotStatusCode, tt.wantOK)
}
if gotReadyConnections != tt.wantReadyConnections {
t.Errorf("ReadyServer.makeResponse() gotReadyConnections = %v, want %v", gotReadyConnections, tt.wantReadyConnections)
}
})
} }
rec := httptest.NewRecorder()
readyServer.ServeHTTP(rec, nil)
decoder := json.NewDecoder(rec.Body)
err := decoder.Decode(&readyreadyConnections)
require.NoError(t, err)
return rec.Code, readyreadyConnections.ReadyConnections
} }
func TestReadinessEventHandling(t *testing.T) { func TestReadinessEventHandling(t *testing.T) {
nopLogger := zerolog.Nop() nopLogger := zerolog.Nop()
rs := NewReadyServer(&nopLogger, uuid.Nil) tracker := tunnelstate.NewConnTracker(&nopLogger)
rs := metrics.NewReadyServer(uuid.Nil, tracker)
// start not ok // start not ok
code, ready := rs.makeResponse() code, readyConnections := mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code) assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready) assert.Zero(t, readyConnections)
// one connected => ok // one connected => ok
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 1, Index: 1,
EventType: connection.Connected, EventType: connection.Connected,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code) assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready) assert.EqualValues(t, 1, readyConnections)
// another connected => still ok // another connected => still ok
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 2, Index: 2,
EventType: connection.Connected, EventType: connection.Connected,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code) assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 2, ready) assert.EqualValues(t, 2, readyConnections)
// one reconnecting => still ok // one reconnecting => still ok
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 2, Index: 2,
EventType: connection.Reconnecting, EventType: connection.Reconnecting,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code) assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready) assert.EqualValues(t, 1, readyConnections)
// Regression test for TUN-3777 // Regression test for TUN-3777
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 1, Index: 1,
EventType: connection.RegisteringTunnel, EventType: connection.RegisteringTunnel,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code) assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready) assert.Zero(t, readyConnections)
// other connected then unregistered => not ok // other connected then unregistered => not ok
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 1, Index: 1,
EventType: connection.Connected, EventType: connection.Connected,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code) assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready) assert.EqualValues(t, 1, readyConnections)
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 1, Index: 1,
EventType: connection.Unregistering, EventType: connection.Unregistering,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code) assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready) assert.Zero(t, readyConnections)
// other disconnected => not ok // other disconnected => not ok
rs.OnTunnelEvent(connection.Event{ rs.OnTunnelEvent(connection.Event{
Index: 1, Index: 1,
EventType: connection.Disconnected, EventType: connection.Disconnected,
}) })
code, ready = rs.makeResponse() code, readyConnections = mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code) assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready) assert.Zero(t, readyConnections)
} }

View File

@ -9,7 +9,7 @@ import (
) )
type ConnTracker struct { type ConnTracker struct {
sync.RWMutex mutex sync.RWMutex
// int is the connection Index // int is the connection Index
connectionInfo map[uint8]ConnectionInfo connectionInfo map[uint8]ConnectionInfo
log *zerolog.Logger log *zerolog.Logger
@ -20,43 +20,39 @@ type ConnectionInfo struct {
Protocol connection.Protocol Protocol connection.Protocol
} }
func NewConnTracker(log *zerolog.Logger) *ConnTracker { func NewConnTracker(
log *zerolog.Logger,
) *ConnTracker {
return &ConnTracker{ return &ConnTracker{
connectionInfo: make(map[uint8]ConnectionInfo, 0), connectionInfo: make(map[uint8]ConnectionInfo, 0),
log: log, log: log,
} }
} }
func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker {
return &ConnTracker{
connectionInfo: mocked,
}
}
func (ct *ConnTracker) OnTunnelEvent(c connection.Event) { func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
switch c.EventType { switch c.EventType {
case connection.Connected: case connection.Connected:
ct.Lock() ct.mutex.Lock()
ci := ConnectionInfo{ ci := ConnectionInfo{
IsConnected: true, IsConnected: true,
Protocol: c.Protocol, Protocol: c.Protocol,
} }
ct.connectionInfo[c.Index] = ci ct.connectionInfo[c.Index] = ci
ct.Unlock() ct.mutex.Unlock()
case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering: case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
ct.Lock() ct.mutex.Lock()
ci := ct.connectionInfo[c.Index] ci := ct.connectionInfo[c.Index]
ci.IsConnected = false ci.IsConnected = false
ct.connectionInfo[c.Index] = ci ct.connectionInfo[c.Index] = ci
ct.Unlock() ct.mutex.Unlock()
default: default:
ct.log.Error().Msgf("Unknown connection event case %v", c) ct.log.Error().Msgf("Unknown connection event case %v", c)
} }
} }
func (ct *ConnTracker) CountActiveConns() uint { func (ct *ConnTracker) CountActiveConns() uint {
ct.RLock() ct.mutex.RLock()
defer ct.RUnlock() defer ct.mutex.RUnlock()
active := uint(0) active := uint(0)
for _, ci := range ct.connectionInfo { for _, ci := range ct.connectionInfo {
if ci.IsConnected { if ci.IsConnected {
@ -69,8 +65,8 @@ func (ct *ConnTracker) CountActiveConns() uint {
// HasConnectedWith checks if we've ever had a successful connection to the edge // HasConnectedWith checks if we've ever had a successful connection to the edge
// with said protocol. // with said protocol.
func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool { func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
ct.RLock() ct.mutex.RLock()
defer ct.RUnlock() defer ct.mutex.RUnlock()
for _, ci := range ct.connectionInfo { for _, ci := range ct.connectionInfo {
if ci.Protocol == protocol { if ci.Protocol == protocol {
return true return true