TUN-8737: update metrics server port selection

## Summary
Update how metrics server binds to a listener by using a known set of ports whenever the default address is used with the fallback to a random port in case all address are already in use. The default address changes at compile time in order to bind to a different default address when the final deliverable is a docker image.

Refactor ReadyServer tests.

Closes TUN-8737
This commit is contained in:
Luis Neto 2024-11-22 07:23:46 -08:00
parent d779394748
commit e2c2b012f1
8 changed files with 194 additions and 93 deletions

View File

@ -5,7 +5,9 @@ FROM golang:1.22.5 as builder
ENV GO111MODULE=on \
CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \
TARGET_GOARCH=${TARGET_GOARCH}
TARGET_GOARCH=${TARGET_GOARCH} \
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/

View File

@ -30,6 +30,10 @@ ifdef PACKAGE_MANAGER
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
endif
ifdef CONTAINER_BUILD
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
endif
LINK_FLAGS :=
ifeq ($(FIPS), true)
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)

View File

@ -39,6 +39,7 @@ import (
"github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelstate"
"github.com/cloudflare/cloudflared/validation"
)
@ -448,16 +449,19 @@ func StartServer(
return err
}
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics"))
if err != nil {
log.Err(err).Msg("Error opening metrics server listener")
return errors.Wrap(err, "Error opening metrics server listener")
}
defer metricsListener.Close()
wg.Add(1)
go func() {
defer wg.Done()
readinessServer := metrics.NewReadyServer(log, clientID)
readinessServer := metrics.NewReadyServer(clientID,
tunnelstate.NewConnTracker(log))
observer.RegisterSink(readinessServer)
metricsConfig := metrics.Config{
ReadyServer: readinessServer,
@ -857,9 +861,15 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "metrics",
Value: "localhost:",
Usage: "Listen address for metrics reporting.",
Name: "metrics",
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
Usage: fmt.Sprintf(
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
If all are unavailable, a random port will be used. Note that when running cloudflared from an virtual
environment the default address binds to all interfaces, hence, it is important to isolate the host
and virtualized host network stacks from each other`,
metrics.GetMetricsKnownAddresses(metrics.Runtime),
),
EnvVars: []string{"TUNNEL_METRICS"},
Hidden: shouldHide,
}),

View File

@ -10,6 +10,7 @@ import (
"sync"
"time"
"github.com/facebookgo/grace/gracenet"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
@ -21,6 +22,34 @@ const (
defaultShutdownTimeout = time.Second * 15
)
// This variable is set at compile time to allow the default local address to change.
var Runtime = "host"
func GetMetricsDefaultAddress(runtimeType string) string {
// When issuing the diagnostic command we may have to reach a server that is
// running in a virtual enviroment and in that case we must bind to 0.0.0.0
// otherwise the server won't be reachable.
switch runtimeType {
case "virtual":
return "0.0.0.0:0"
default:
return "localhost:0"
}
}
// GetMetricsKnownAddresses returns the addresses used by the metrics server to bind at
// startup time to allow a semi-deterministic approach to know where the server is listening at.
// The ports were selected because at the time we are in 2024 and they do not collide with any
// know/registered port according https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers.
func GetMetricsKnownAddresses(runtimeType string) [5]string {
switch Runtime {
case "virtual":
return [5]string{"0.0.0.0:20241", "0.0.0.0:20242", "0.0.0.0:20243", "0.0.0.0:20244", "0.0.0.0:20245"}
default:
return [5]string{"localhost:20241", "localhost:20242", "localhost:20243", "localhost:20244", "localhost:20245"}
}
}
type Config struct {
ReadyServer *ReadyServer
QuickTunnelHostname string
@ -65,6 +94,42 @@ func newMetricsHandler(
return router
}
// CreateMetricsListener will create a new [net.Listener] by using an
// known set of ports when the default address is passed with the fallback
// of choosing a random port when none is available.
//
// In case the provided address is not the default one then it will be used
// as is.
func CreateMetricsListener(listeners *gracenet.Net, laddr string) (net.Listener, error) {
if laddr == GetMetricsDefaultAddress(Runtime) {
// On the presence of the default address select
// a port from the known set of addresses iteratively.
addresses := GetMetricsKnownAddresses(Runtime)
for _, address := range addresses {
listener, err := listeners.Listen("tcp", address)
if err == nil {
return listener, nil
}
}
// When no port is available then bind to a random one
listener, err := listeners.Listen("tcp", laddr)
if err != nil {
return nil, fmt.Errorf("failed to listen to default metrics address: %w", err)
}
return listener, nil
}
// Explicitly got a local address then bind to it
listener, err := listeners.Listen("tcp", laddr)
if err != nil {
return nil, fmt.Errorf("failed to bind to address (%s): %w", laddr, err)
}
return listener, nil
}
func ServeMetrics(
l net.Listener,
ctx context.Context,

52
metrics/metrics_test.go Normal file
View File

@ -0,0 +1,52 @@
package metrics_test
import (
"testing"
"github.com/facebookgo/grace/gracenet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/metrics"
)
func TestMetricsListenerCreation(t *testing.T) {
t.Parallel()
listeners := gracenet.Net{}
listener1, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20241", listener1.Addr().String())
require.NoError(t, err)
listener2, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20242", listener2.Addr().String())
require.NoError(t, err)
listener3, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20243", listener3.Addr().String())
require.NoError(t, err)
listener4, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20244", listener4.Addr().String())
require.NoError(t, err)
listener5, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
assert.Equal(t, "127.0.0.1:20245", listener5.Addr().String())
require.NoError(t, err)
listener6, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
addresses := [5]string{"127.0.0.1:20241", "127.0.0.1:20242", "127.0.0.1:20243", "127.0.0.1:20244", "127.0.0.1:20245"}
assert.NotContains(t, addresses, listener6.Addr().String())
require.NoError(t, err)
listener7, err := metrics.CreateMetricsListener(&listeners, "localhost:12345")
assert.Equal(t, "127.0.0.1:12345", listener7.Addr().String())
require.NoError(t, err)
err = listener1.Close()
require.NoError(t, err)
err = listener2.Close()
require.NoError(t, err)
err = listener3.Close()
require.NoError(t, err)
err = listener4.Close()
require.NoError(t, err)
err = listener5.Close()
require.NoError(t, err)
err = listener6.Close()
require.NoError(t, err)
err = listener7.Close()
require.NoError(t, err)
}

View File

@ -6,7 +6,6 @@ import (
"net/http"
"github.com/google/uuid"
"github.com/rs/zerolog"
conn "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/tunnelstate"
@ -19,10 +18,13 @@ type ReadyServer struct {
}
// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events.
func NewReadyServer(log *zerolog.Logger, clientID uuid.UUID) *ReadyServer {
func NewReadyServer(
clientID uuid.UUID,
tracker *tunnelstate.ConnTracker,
) *ReadyServer {
return &ReadyServer{
clientID: clientID,
tracker: tunnelstate.NewConnTracker(log),
clientID,
tracker,
}
}

View File

@ -1,136 +1,106 @@
package metrics
package metrics_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/tunnelstate"
)
func TestReadyServer_makeResponse(t *testing.T) {
type fields struct {
isConnected map[uint8]tunnelstate.ConnectionInfo
}
tests := []struct {
name string
fields fields
wantOK bool
wantReadyConnections uint
}{
{
name: "One connection online => HTTP 200",
fields: fields{
isConnected: map[uint8]tunnelstate.ConnectionInfo{
0: {IsConnected: false},
1: {IsConnected: false},
2: {IsConnected: true},
3: {IsConnected: false},
},
},
wantOK: true,
wantReadyConnections: 1,
},
{
name: "No connections online => no HTTP 200",
fields: fields{
isConnected: map[uint8]tunnelstate.ConnectionInfo{
0: {IsConnected: false},
1: {IsConnected: false},
2: {IsConnected: false},
3: {IsConnected: false},
},
},
wantReadyConnections: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := &ReadyServer{
tracker: tunnelstate.MockedConnTracker(tt.fields.isConnected),
}
gotStatusCode, gotReadyConnections := rs.makeResponse()
if tt.wantOK && gotStatusCode != http.StatusOK {
t.Errorf("ReadyServer.makeResponse() gotStatusCode = %v, want ok = %v", gotStatusCode, tt.wantOK)
}
if gotReadyConnections != tt.wantReadyConnections {
t.Errorf("ReadyServer.makeResponse() gotReadyConnections = %v, want %v", gotReadyConnections, tt.wantReadyConnections)
}
})
func mockRequest(t *testing.T, readyServer *metrics.ReadyServer) (int, uint) {
t.Helper()
var readyreadyConnections struct {
Status int `json:"status"`
ReadyConnections uint `json:"readyConnections"`
ConnectorID uuid.UUID `json:"connectorId"`
}
rec := httptest.NewRecorder()
readyServer.ServeHTTP(rec, nil)
decoder := json.NewDecoder(rec.Body)
err := decoder.Decode(&readyreadyConnections)
require.NoError(t, err)
return rec.Code, readyreadyConnections.ReadyConnections
}
func TestReadinessEventHandling(t *testing.T) {
nopLogger := zerolog.Nop()
rs := NewReadyServer(&nopLogger, uuid.Nil)
tracker := tunnelstate.NewConnTracker(&nopLogger)
rs := metrics.NewReadyServer(uuid.Nil, tracker)
// start not ok
code, ready := rs.makeResponse()
code, readyConnections := mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready)
assert.Zero(t, readyConnections)
// one connected => ok
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.Connected,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready)
assert.EqualValues(t, 1, readyConnections)
// another connected => still ok
rs.OnTunnelEvent(connection.Event{
Index: 2,
EventType: connection.Connected,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 2, ready)
assert.EqualValues(t, 2, readyConnections)
// one reconnecting => still ok
rs.OnTunnelEvent(connection.Event{
Index: 2,
EventType: connection.Reconnecting,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready)
assert.EqualValues(t, 1, readyConnections)
// Regression test for TUN-3777
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.RegisteringTunnel,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready)
assert.Zero(t, readyConnections)
// other connected then unregistered => not ok
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.Connected,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.EqualValues(t, http.StatusOK, code)
assert.EqualValues(t, 1, ready)
assert.EqualValues(t, 1, readyConnections)
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.Unregistering,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready)
assert.Zero(t, readyConnections)
// other disconnected => not ok
rs.OnTunnelEvent(connection.Event{
Index: 1,
EventType: connection.Disconnected,
})
code, ready = rs.makeResponse()
code, readyConnections = mockRequest(t, rs)
assert.NotEqualValues(t, http.StatusOK, code)
assert.Zero(t, ready)
assert.Zero(t, readyConnections)
}

View File

@ -9,7 +9,7 @@ import (
)
type ConnTracker struct {
sync.RWMutex
mutex sync.RWMutex
// int is the connection Index
connectionInfo map[uint8]ConnectionInfo
log *zerolog.Logger
@ -20,43 +20,39 @@ type ConnectionInfo struct {
Protocol connection.Protocol
}
func NewConnTracker(log *zerolog.Logger) *ConnTracker {
func NewConnTracker(
log *zerolog.Logger,
) *ConnTracker {
return &ConnTracker{
connectionInfo: make(map[uint8]ConnectionInfo, 0),
log: log,
}
}
func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker {
return &ConnTracker{
connectionInfo: mocked,
}
}
func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
switch c.EventType {
case connection.Connected:
ct.Lock()
ct.mutex.Lock()
ci := ConnectionInfo{
IsConnected: true,
Protocol: c.Protocol,
}
ct.connectionInfo[c.Index] = ci
ct.Unlock()
ct.mutex.Unlock()
case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
ct.Lock()
ct.mutex.Lock()
ci := ct.connectionInfo[c.Index]
ci.IsConnected = false
ct.connectionInfo[c.Index] = ci
ct.Unlock()
ct.mutex.Unlock()
default:
ct.log.Error().Msgf("Unknown connection event case %v", c)
}
}
func (ct *ConnTracker) CountActiveConns() uint {
ct.RLock()
defer ct.RUnlock()
ct.mutex.RLock()
defer ct.mutex.RUnlock()
active := uint(0)
for _, ci := range ct.connectionInfo {
if ci.IsConnected {
@ -69,8 +65,8 @@ func (ct *ConnTracker) CountActiveConns() uint {
// HasConnectedWith checks if we've ever had a successful connection to the edge
// with said protocol.
func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
ct.RLock()
defer ct.RUnlock()
ct.mutex.RLock()
defer ct.mutex.RUnlock()
for _, ci := range ct.connectionInfo {
if ci.Protocol == protocol {
return true