TUN-8737: update metrics server port selection
## Summary Update how metrics server binds to a listener by using a known set of ports whenever the default address is used with the fallback to a random port in case all address are already in use. The default address changes at compile time in order to bind to a different default address when the final deliverable is a docker image. Refactor ReadyServer tests. Closes TUN-8737
This commit is contained in:
parent
d779394748
commit
e2c2b012f1
|
@ -5,7 +5,9 @@ FROM golang:1.22.5 as builder
|
||||||
ENV GO111MODULE=on \
|
ENV GO111MODULE=on \
|
||||||
CGO_ENABLED=0 \
|
CGO_ENABLED=0 \
|
||||||
TARGET_GOOS=${TARGET_GOOS} \
|
TARGET_GOOS=${TARGET_GOOS} \
|
||||||
TARGET_GOARCH=${TARGET_GOARCH}
|
TARGET_GOARCH=${TARGET_GOARCH} \
|
||||||
|
CONTAINER_BUILD=1
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
||||||
|
|
||||||
|
|
4
Makefile
4
Makefile
|
@ -30,6 +30,10 @@ ifdef PACKAGE_MANAGER
|
||||||
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
|
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
ifdef CONTAINER_BUILD
|
||||||
|
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
|
||||||
|
endif
|
||||||
|
|
||||||
LINK_FLAGS :=
|
LINK_FLAGS :=
|
||||||
ifeq ($(FIPS), true)
|
ifeq ($(FIPS), true)
|
||||||
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)
|
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)
|
||||||
|
|
|
@ -39,6 +39,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/supervisor"
|
"github.com/cloudflare/cloudflared/supervisor"
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
"github.com/cloudflare/cloudflared/tunneldns"
|
"github.com/cloudflare/cloudflared/tunneldns"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||||
"github.com/cloudflare/cloudflared/validation"
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -448,16 +449,19 @@ func StartServer(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msg("Error opening metrics server listener")
|
log.Err(err).Msg("Error opening metrics server listener")
|
||||||
return errors.Wrap(err, "Error opening metrics server listener")
|
return errors.Wrap(err, "Error opening metrics server listener")
|
||||||
}
|
}
|
||||||
|
|
||||||
defer metricsListener.Close()
|
defer metricsListener.Close()
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
readinessServer := metrics.NewReadyServer(log, clientID)
|
readinessServer := metrics.NewReadyServer(clientID,
|
||||||
|
tunnelstate.NewConnTracker(log))
|
||||||
observer.RegisterSink(readinessServer)
|
observer.RegisterSink(readinessServer)
|
||||||
metricsConfig := metrics.Config{
|
metricsConfig := metrics.Config{
|
||||||
ReadyServer: readinessServer,
|
ReadyServer: readinessServer,
|
||||||
|
@ -857,9 +861,15 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "metrics",
|
Name: "metrics",
|
||||||
Value: "localhost:",
|
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
|
||||||
Usage: "Listen address for metrics reporting.",
|
Usage: fmt.Sprintf(
|
||||||
|
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
|
||||||
|
If all are unavailable, a random port will be used. Note that when running cloudflared from an virtual
|
||||||
|
environment the default address binds to all interfaces, hence, it is important to isolate the host
|
||||||
|
and virtualized host network stacks from each other`,
|
||||||
|
metrics.GetMetricsKnownAddresses(metrics.Runtime),
|
||||||
|
),
|
||||||
EnvVars: []string{"TUNNEL_METRICS"},
|
EnvVars: []string{"TUNNEL_METRICS"},
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/facebookgo/grace/gracenet"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -21,6 +22,34 @@ const (
|
||||||
defaultShutdownTimeout = time.Second * 15
|
defaultShutdownTimeout = time.Second * 15
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// This variable is set at compile time to allow the default local address to change.
|
||||||
|
var Runtime = "host"
|
||||||
|
|
||||||
|
func GetMetricsDefaultAddress(runtimeType string) string {
|
||||||
|
// When issuing the diagnostic command we may have to reach a server that is
|
||||||
|
// running in a virtual enviroment and in that case we must bind to 0.0.0.0
|
||||||
|
// otherwise the server won't be reachable.
|
||||||
|
switch runtimeType {
|
||||||
|
case "virtual":
|
||||||
|
return "0.0.0.0:0"
|
||||||
|
default:
|
||||||
|
return "localhost:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMetricsKnownAddresses returns the addresses used by the metrics server to bind at
|
||||||
|
// startup time to allow a semi-deterministic approach to know where the server is listening at.
|
||||||
|
// The ports were selected because at the time we are in 2024 and they do not collide with any
|
||||||
|
// know/registered port according https://en.wikipedia.org/wiki/List_of_TCP_and_UDP_port_numbers.
|
||||||
|
func GetMetricsKnownAddresses(runtimeType string) [5]string {
|
||||||
|
switch Runtime {
|
||||||
|
case "virtual":
|
||||||
|
return [5]string{"0.0.0.0:20241", "0.0.0.0:20242", "0.0.0.0:20243", "0.0.0.0:20244", "0.0.0.0:20245"}
|
||||||
|
default:
|
||||||
|
return [5]string{"localhost:20241", "localhost:20242", "localhost:20243", "localhost:20244", "localhost:20245"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ReadyServer *ReadyServer
|
ReadyServer *ReadyServer
|
||||||
QuickTunnelHostname string
|
QuickTunnelHostname string
|
||||||
|
@ -65,6 +94,42 @@ func newMetricsHandler(
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateMetricsListener will create a new [net.Listener] by using an
|
||||||
|
// known set of ports when the default address is passed with the fallback
|
||||||
|
// of choosing a random port when none is available.
|
||||||
|
//
|
||||||
|
// In case the provided address is not the default one then it will be used
|
||||||
|
// as is.
|
||||||
|
func CreateMetricsListener(listeners *gracenet.Net, laddr string) (net.Listener, error) {
|
||||||
|
if laddr == GetMetricsDefaultAddress(Runtime) {
|
||||||
|
// On the presence of the default address select
|
||||||
|
// a port from the known set of addresses iteratively.
|
||||||
|
addresses := GetMetricsKnownAddresses(Runtime)
|
||||||
|
for _, address := range addresses {
|
||||||
|
listener, err := listeners.Listen("tcp", address)
|
||||||
|
if err == nil {
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// When no port is available then bind to a random one
|
||||||
|
listener, err := listeners.Listen("tcp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen to default metrics address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicitly got a local address then bind to it
|
||||||
|
listener, err := listeners.Listen("tcp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to bind to address (%s): %w", laddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ServeMetrics(
|
func ServeMetrics(
|
||||||
l net.Listener,
|
l net.Listener,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
package metrics_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/facebookgo/grace/gracenet"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMetricsListenerCreation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
listeners := gracenet.Net{}
|
||||||
|
listener1, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
|
||||||
|
assert.Equal(t, "127.0.0.1:20241", listener1.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
listener2, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
|
||||||
|
assert.Equal(t, "127.0.0.1:20242", listener2.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
listener3, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
|
||||||
|
assert.Equal(t, "127.0.0.1:20243", listener3.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
listener4, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
|
||||||
|
assert.Equal(t, "127.0.0.1:20244", listener4.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
listener5, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
|
||||||
|
assert.Equal(t, "127.0.0.1:20245", listener5.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
listener6, err := metrics.CreateMetricsListener(&listeners, metrics.GetMetricsDefaultAddress("host"))
|
||||||
|
addresses := [5]string{"127.0.0.1:20241", "127.0.0.1:20242", "127.0.0.1:20243", "127.0.0.1:20244", "127.0.0.1:20245"}
|
||||||
|
assert.NotContains(t, addresses, listener6.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
listener7, err := metrics.CreateMetricsListener(&listeners, "localhost:12345")
|
||||||
|
assert.Equal(t, "127.0.0.1:12345", listener7.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener1.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener2.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener3.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener4.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener5.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener6.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = listener7.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
|
|
||||||
conn "github.com/cloudflare/cloudflared/connection"
|
conn "github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||||
|
@ -19,10 +18,13 @@ type ReadyServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events.
|
// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events.
|
||||||
func NewReadyServer(log *zerolog.Logger, clientID uuid.UUID) *ReadyServer {
|
func NewReadyServer(
|
||||||
|
clientID uuid.UUID,
|
||||||
|
tracker *tunnelstate.ConnTracker,
|
||||||
|
) *ReadyServer {
|
||||||
return &ReadyServer{
|
return &ReadyServer{
|
||||||
clientID: clientID,
|
clientID,
|
||||||
tracker: tunnelstate.NewConnTracker(log),
|
tracker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,136 +1,106 @@
|
||||||
package metrics
|
package metrics_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/metrics"
|
||||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReadyServer_makeResponse(t *testing.T) {
|
func mockRequest(t *testing.T, readyServer *metrics.ReadyServer) (int, uint) {
|
||||||
type fields struct {
|
t.Helper()
|
||||||
isConnected map[uint8]tunnelstate.ConnectionInfo
|
|
||||||
}
|
var readyreadyConnections struct {
|
||||||
tests := []struct {
|
Status int `json:"status"`
|
||||||
name string
|
ReadyConnections uint `json:"readyConnections"`
|
||||||
fields fields
|
ConnectorID uuid.UUID `json:"connectorId"`
|
||||||
wantOK bool
|
|
||||||
wantReadyConnections uint
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "One connection online => HTTP 200",
|
|
||||||
fields: fields{
|
|
||||||
isConnected: map[uint8]tunnelstate.ConnectionInfo{
|
|
||||||
0: {IsConnected: false},
|
|
||||||
1: {IsConnected: false},
|
|
||||||
2: {IsConnected: true},
|
|
||||||
3: {IsConnected: false},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantOK: true,
|
|
||||||
wantReadyConnections: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No connections online => no HTTP 200",
|
|
||||||
fields: fields{
|
|
||||||
isConnected: map[uint8]tunnelstate.ConnectionInfo{
|
|
||||||
0: {IsConnected: false},
|
|
||||||
1: {IsConnected: false},
|
|
||||||
2: {IsConnected: false},
|
|
||||||
3: {IsConnected: false},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantReadyConnections: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
rs := &ReadyServer{
|
|
||||||
tracker: tunnelstate.MockedConnTracker(tt.fields.isConnected),
|
|
||||||
}
|
|
||||||
gotStatusCode, gotReadyConnections := rs.makeResponse()
|
|
||||||
if tt.wantOK && gotStatusCode != http.StatusOK {
|
|
||||||
t.Errorf("ReadyServer.makeResponse() gotStatusCode = %v, want ok = %v", gotStatusCode, tt.wantOK)
|
|
||||||
}
|
|
||||||
if gotReadyConnections != tt.wantReadyConnections {
|
|
||||||
t.Errorf("ReadyServer.makeResponse() gotReadyConnections = %v, want %v", gotReadyConnections, tt.wantReadyConnections)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
readyServer.ServeHTTP(rec, nil)
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(rec.Body)
|
||||||
|
err := decoder.Decode(&readyreadyConnections)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return rec.Code, readyreadyConnections.ReadyConnections
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadinessEventHandling(t *testing.T) {
|
func TestReadinessEventHandling(t *testing.T) {
|
||||||
nopLogger := zerolog.Nop()
|
nopLogger := zerolog.Nop()
|
||||||
rs := NewReadyServer(&nopLogger, uuid.Nil)
|
tracker := tunnelstate.NewConnTracker(&nopLogger)
|
||||||
|
rs := metrics.NewReadyServer(uuid.Nil, tracker)
|
||||||
|
|
||||||
// start not ok
|
// start not ok
|
||||||
code, ready := rs.makeResponse()
|
code, readyConnections := mockRequest(t, rs)
|
||||||
assert.NotEqualValues(t, http.StatusOK, code)
|
assert.NotEqualValues(t, http.StatusOK, code)
|
||||||
assert.Zero(t, ready)
|
assert.Zero(t, readyConnections)
|
||||||
|
|
||||||
// one connected => ok
|
// one connected => ok
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
EventType: connection.Connected,
|
EventType: connection.Connected,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.EqualValues(t, http.StatusOK, code)
|
assert.EqualValues(t, http.StatusOK, code)
|
||||||
assert.EqualValues(t, 1, ready)
|
assert.EqualValues(t, 1, readyConnections)
|
||||||
|
|
||||||
// another connected => still ok
|
// another connected => still ok
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 2,
|
Index: 2,
|
||||||
EventType: connection.Connected,
|
EventType: connection.Connected,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.EqualValues(t, http.StatusOK, code)
|
assert.EqualValues(t, http.StatusOK, code)
|
||||||
assert.EqualValues(t, 2, ready)
|
assert.EqualValues(t, 2, readyConnections)
|
||||||
|
|
||||||
// one reconnecting => still ok
|
// one reconnecting => still ok
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 2,
|
Index: 2,
|
||||||
EventType: connection.Reconnecting,
|
EventType: connection.Reconnecting,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.EqualValues(t, http.StatusOK, code)
|
assert.EqualValues(t, http.StatusOK, code)
|
||||||
assert.EqualValues(t, 1, ready)
|
assert.EqualValues(t, 1, readyConnections)
|
||||||
|
|
||||||
// Regression test for TUN-3777
|
// Regression test for TUN-3777
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
EventType: connection.RegisteringTunnel,
|
EventType: connection.RegisteringTunnel,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.NotEqualValues(t, http.StatusOK, code)
|
assert.NotEqualValues(t, http.StatusOK, code)
|
||||||
assert.Zero(t, ready)
|
assert.Zero(t, readyConnections)
|
||||||
|
|
||||||
// other connected then unregistered => not ok
|
// other connected then unregistered => not ok
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
EventType: connection.Connected,
|
EventType: connection.Connected,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.EqualValues(t, http.StatusOK, code)
|
assert.EqualValues(t, http.StatusOK, code)
|
||||||
assert.EqualValues(t, 1, ready)
|
assert.EqualValues(t, 1, readyConnections)
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
EventType: connection.Unregistering,
|
EventType: connection.Unregistering,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.NotEqualValues(t, http.StatusOK, code)
|
assert.NotEqualValues(t, http.StatusOK, code)
|
||||||
assert.Zero(t, ready)
|
assert.Zero(t, readyConnections)
|
||||||
|
|
||||||
// other disconnected => not ok
|
// other disconnected => not ok
|
||||||
rs.OnTunnelEvent(connection.Event{
|
rs.OnTunnelEvent(connection.Event{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
EventType: connection.Disconnected,
|
EventType: connection.Disconnected,
|
||||||
})
|
})
|
||||||
code, ready = rs.makeResponse()
|
code, readyConnections = mockRequest(t, rs)
|
||||||
assert.NotEqualValues(t, http.StatusOK, code)
|
assert.NotEqualValues(t, http.StatusOK, code)
|
||||||
assert.Zero(t, ready)
|
assert.Zero(t, readyConnections)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnTracker struct {
|
type ConnTracker struct {
|
||||||
sync.RWMutex
|
mutex sync.RWMutex
|
||||||
// int is the connection Index
|
// int is the connection Index
|
||||||
connectionInfo map[uint8]ConnectionInfo
|
connectionInfo map[uint8]ConnectionInfo
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
|
@ -20,43 +20,39 @@ type ConnectionInfo struct {
|
||||||
Protocol connection.Protocol
|
Protocol connection.Protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnTracker(log *zerolog.Logger) *ConnTracker {
|
func NewConnTracker(
|
||||||
|
log *zerolog.Logger,
|
||||||
|
) *ConnTracker {
|
||||||
return &ConnTracker{
|
return &ConnTracker{
|
||||||
connectionInfo: make(map[uint8]ConnectionInfo, 0),
|
connectionInfo: make(map[uint8]ConnectionInfo, 0),
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker {
|
|
||||||
return &ConnTracker{
|
|
||||||
connectionInfo: mocked,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
|
func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
|
||||||
switch c.EventType {
|
switch c.EventType {
|
||||||
case connection.Connected:
|
case connection.Connected:
|
||||||
ct.Lock()
|
ct.mutex.Lock()
|
||||||
ci := ConnectionInfo{
|
ci := ConnectionInfo{
|
||||||
IsConnected: true,
|
IsConnected: true,
|
||||||
Protocol: c.Protocol,
|
Protocol: c.Protocol,
|
||||||
}
|
}
|
||||||
ct.connectionInfo[c.Index] = ci
|
ct.connectionInfo[c.Index] = ci
|
||||||
ct.Unlock()
|
ct.mutex.Unlock()
|
||||||
case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
|
case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
|
||||||
ct.Lock()
|
ct.mutex.Lock()
|
||||||
ci := ct.connectionInfo[c.Index]
|
ci := ct.connectionInfo[c.Index]
|
||||||
ci.IsConnected = false
|
ci.IsConnected = false
|
||||||
ct.connectionInfo[c.Index] = ci
|
ct.connectionInfo[c.Index] = ci
|
||||||
ct.Unlock()
|
ct.mutex.Unlock()
|
||||||
default:
|
default:
|
||||||
ct.log.Error().Msgf("Unknown connection event case %v", c)
|
ct.log.Error().Msgf("Unknown connection event case %v", c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ct *ConnTracker) CountActiveConns() uint {
|
func (ct *ConnTracker) CountActiveConns() uint {
|
||||||
ct.RLock()
|
ct.mutex.RLock()
|
||||||
defer ct.RUnlock()
|
defer ct.mutex.RUnlock()
|
||||||
active := uint(0)
|
active := uint(0)
|
||||||
for _, ci := range ct.connectionInfo {
|
for _, ci := range ct.connectionInfo {
|
||||||
if ci.IsConnected {
|
if ci.IsConnected {
|
||||||
|
@ -69,8 +65,8 @@ func (ct *ConnTracker) CountActiveConns() uint {
|
||||||
// HasConnectedWith checks if we've ever had a successful connection to the edge
|
// HasConnectedWith checks if we've ever had a successful connection to the edge
|
||||||
// with said protocol.
|
// with said protocol.
|
||||||
func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
|
func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
|
||||||
ct.RLock()
|
ct.mutex.RLock()
|
||||||
defer ct.RUnlock()
|
defer ct.mutex.RUnlock()
|
||||||
for _, ci := range ct.connectionInfo {
|
for _, ci := range ct.connectionInfo {
|
||||||
if ci.Protocol == protocol {
|
if ci.Protocol == protocol {
|
||||||
return true
|
return true
|
||||||
|
|
Loading…
Reference in New Issue