diff --git a/connection/control.go b/connection/control.go index a7d49772..a7fe1ac9 100644 --- a/connection/control.go +++ b/connection/control.go @@ -2,6 +2,7 @@ package connection import ( "context" + "fmt" "io" "net" "time" @@ -21,6 +22,7 @@ type controlStream struct { namedTunnelProperties *NamedTunnelProperties connIndex uint8 edgeAddress net.IP + protocol Protocol newRPCClientFunc RPCClientFunc @@ -51,6 +53,7 @@ func NewControlStream( newRPCClientFunc RPCClientFunc, gracefulShutdownC <-chan struct{}, gracePeriod time.Duration, + protocol Protocol, ) ControlStreamHandler { if newRPCClientFunc == nil { newRPCClientFunc = newRegistrationRPCClient @@ -64,6 +67,7 @@ func NewControlStream( edgeAddress: edgeAddress, gracefulShutdownC: gracefulShutdownC, gracePeriod: gracePeriod, + protocol: protocol, } } @@ -80,6 +84,9 @@ func (c *controlStream) ServeControlStream( rpcClient.Close() return err } + + c.observer.logServerInfo(c.connIndex, registrationDetails.Location, c.edgeAddress, fmt.Sprintf("Connection %s registered", registrationDetails.UUID)) + c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location) c.connectedFuse.Connected() // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration diff --git a/connection/event.go b/connection/event.go index ab6d0d33..d10b92fc 100644 --- a/connection/event.go +++ b/connection/event.go @@ -5,6 +5,7 @@ type Event struct { Index uint8 EventType Status Location string + Protocol Protocol URL string } diff --git a/connection/http2_test.go b/connection/http2_test.go index e962353f..b7f1c49c 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -43,6 +43,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { nil, nil, 1*time.Second, + HTTP2, ) return NewHTTP2Connection( cfdConn, @@ -366,6 +367,7 @@ func TestServeControlStream(t *testing.T) { rpcClientFactory.newMockRPCClient, nil, 1*time.Second, + HTTP2, ) http2Conn.controlStreamHandler = controlStream @@ -417,6 +419,7 @@ func TestFailRegistration(t *testing.T) { rpcClientFactory.newMockRPCClient, nil, 1*time.Second, + HTTP2, ) http2Conn.controlStreamHandler = controlStream @@ -464,6 +467,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { rpcClientFactory.newMockRPCClient, shutdownC, 1*time.Second, + HTTP2, ) http2Conn.controlStreamHandler = controlStream diff --git a/connection/observer.go b/connection/observer.go index 3a855878..7429b32c 100644 --- a/connection/observer.go +++ b/connection/observer.go @@ -55,8 +55,8 @@ func (o *Observer) sendRegisteringEvent(connIndex uint8) { o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel}) } -func (o *Observer) sendConnectedEvent(connIndex uint8, location string) { - o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location}) +func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string) { + o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location}) } func (o *Observer) SendURL(url string) { diff --git a/connection/quic_test.go b/connection/quic_test.go index c35c7d51..d82947c2 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -81,63 +81,63 @@ func TestQUICServer(t *testing.T) { }, expectedResponse: []byte("OK"), }, - //{ - // desc: "test http body request streaming", - // dest: "/slow_echo_body", - // connectionType: quicpogs.ConnectionTypeHTTP, - // metadata: []quicpogs.Metadata{ - // { - // Key: "HttpHeader:Cf-Ray", - // Val: "123123123", - // }, - // { - // Key: "HttpHost", - // Val: "cf.host", - // }, - // { - // Key: "HttpMethod", - // Val: "POST", - // }, - // { - // Key: "HttpHeader:Content-Length", - // Val: "24", - // }, - // }, - // message: []byte("This is the message body"), - // expectedResponse: []byte("This is the message body"), - //}, - //{ - // desc: "test ws proxy", - // dest: "/ws/echo", - // connectionType: quicpogs.ConnectionTypeWebsocket, - // metadata: []quicpogs.Metadata{ - // { - // Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade", - // Val: "Websocket", - // }, - // { - // Key: "HttpHeader:Another-Header", - // Val: "Misc", - // }, - // { - // Key: "HttpHost", - // Val: "cf.host", - // }, - // { - // Key: "HttpMethod", - // Val: "get", - // }, - // }, - // message: wsBuf.Bytes(), - // expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, - //}, - //{ - // desc: "test tcp proxy", - // connectionType: quicpogs.ConnectionTypeTCP, - // metadata: []quicpogs.Metadata{}, - // message: []byte("Here is some tcp data"), - // expectedResponse: []byte("Here is some tcp data"), - //}, + { + desc: "test http body request streaming", + dest: "/slow_echo_body", + connectionType: quicpogs.ConnectionTypeHTTP, + metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Cf-Ray", + Val: "123123123", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "POST", + }, + { + Key: "HttpHeader:Content-Length", + Val: "24", + }, + }, + message: []byte("This is the message body"), + expectedResponse: []byte("This is the message body"), + }, + { + desc: "test ws proxy", + dest: "/ws/echo", + connectionType: quicpogs.ConnectionTypeWebsocket, + metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade", + Val: "Websocket", + }, + { + Key: "HttpHeader:Another-Header", + Val: "Misc", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "get", + }, + }, + message: wsBuf.Bytes(), + expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, + }, + { + desc: "test tcp proxy", + connectionType: quicpogs.ConnectionTypeTCP, + metadata: []quicpogs.Metadata{}, + message: []byte("Here is some tcp data"), + expectedResponse: []byte("Here is some tcp data"), + }, } for _, test := range tests { diff --git a/connection/rpc.go b/connection/rpc.go index b288a0f8..f602290b 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -2,7 +2,6 @@ package connection import ( "context" - "fmt" "io" "net" "time" @@ -117,9 +116,6 @@ func (rsc *registrationServerClient) RegisterConnection( observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() - observer.logServerInfo(connIndex, conn.Location, edgeAddress, fmt.Sprintf("Connection %s registered", conn.UUID)) - observer.sendConnectedEvent(connIndex, conn.Location) - return conn, nil } diff --git a/metrics/readiness_test.go b/metrics/readiness_test.go index 62e42f63..8e035f85 100644 --- a/metrics/readiness_test.go +++ b/metrics/readiness_test.go @@ -14,7 +14,7 @@ import ( func TestReadyServer_makeResponse(t *testing.T) { type fields struct { - isConnected map[int]bool + isConnected map[uint8]tunnelstate.ConnectionInfo } tests := []struct { name string @@ -25,11 +25,11 @@ func TestReadyServer_makeResponse(t *testing.T) { { name: "One connection online => HTTP 200", fields: fields{ - isConnected: map[int]bool{ - 0: false, - 1: false, - 2: true, - 3: false, + isConnected: map[uint8]tunnelstate.ConnectionInfo{ + 0: {IsConnected: false}, + 1: {IsConnected: false}, + 2: {IsConnected: true}, + 3: {IsConnected: false}, }, }, wantOK: true, @@ -38,11 +38,11 @@ func TestReadyServer_makeResponse(t *testing.T) { { name: "No connections online => no HTTP 200", fields: fields{ - isConnected: map[int]bool{ - 0: false, - 1: false, - 2: false, - 3: false, + isConnected: map[uint8]tunnelstate.ConnectionInfo{ + 0: {IsConnected: false}, + 1: {IsConnected: false}, + 2: {IsConnected: false}, + 3: {IsConnected: false}, }, }, wantReadyConnections: 0, diff --git a/supervisor/conn_aware_logger.go b/supervisor/conn_aware_logger.go index 6e717588..1311f20b 100644 --- a/supervisor/conn_aware_logger.go +++ b/supervisor/conn_aware_logger.go @@ -12,9 +12,9 @@ type ConnAwareLogger struct { logger *zerolog.Logger } -func NewConnAwareLogger(logger *zerolog.Logger, observer *connection.Observer) *ConnAwareLogger { +func NewConnAwareLogger(logger *zerolog.Logger, tracker *tunnelstate.ConnTracker, observer *connection.Observer) *ConnAwareLogger { connAwareLogger := &ConnAwareLogger{ - tracker: tunnelstate.NewConnTracker(logger), + tracker: tracker, logger: logger, } diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index c3f08844..7a100a59 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -19,6 +19,7 @@ import ( "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/tunnelstate" ) const ( @@ -88,7 +89,9 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato } reconnectCredentialManager := newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections) - log := NewConnAwareLogger(config.Log, config.Observer) + + tracker := tunnelstate.NewConnTracker(config.Log) + log := NewConnAwareLogger(config.Log, tracker, config.Observer) var edgeAddrHandler EdgeAddrHandler if isStaticEdge { // static edge addresses @@ -106,6 +109,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato credentialManager: reconnectCredentialManager, edgeAddrs: edgeIPs, edgeAddrHandler: edgeAddrHandler, + tracker: tracker, reconnectCh: reconnectCh, gracefulShutdownC: gracefulShutdownC, connAwareLogger: log, diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 2abf09f5..8ff22ca1 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -26,6 +26,7 @@ import ( "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/tunnelstate" ) const ( @@ -190,6 +191,7 @@ type EdgeTunnelServer struct { edgeAddrs *edgediscovery.Edge reconnectCh chan ReconnectSignal gracefulShutdownC <-chan struct{} + tracker *tunnelstate.ConnTracker connAwareLogger *ConnAwareLogger } @@ -272,6 +274,12 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFa return err } + // If a single connection has connected with the current protocol, we know we know we don't have to fallback + // to a different protocol. + if e.tracker.HasConnectedWith(e.config.ProtocolSelector.Current()) { + return err + } + if !selectNextProtocol( connLog.Logger(), protocolFallback, @@ -461,6 +469,7 @@ func serveTunnel( nil, gracefulShutdownC, config.GracePeriod, + protocol, ) switch protocol { diff --git a/tunnelstate/conntracker.go b/tunnelstate/conntracker.go index 06fb176f..426ba483 100644 --- a/tunnelstate/conntracker.go +++ b/tunnelstate/conntracker.go @@ -10,20 +10,26 @@ import ( type ConnTracker struct { sync.RWMutex - isConnected map[int]bool - log *zerolog.Logger + // int is the connection Index + connectionInfo map[uint8]ConnectionInfo + log *zerolog.Logger +} + +type ConnectionInfo struct { + IsConnected bool + Protocol connection.Protocol } func NewConnTracker(log *zerolog.Logger) *ConnTracker { return &ConnTracker{ - isConnected: make(map[int]bool, 0), - log: log, + connectionInfo: make(map[uint8]ConnectionInfo, 0), + log: log, } } -func MockedConnTracker(mocked map[int]bool) *ConnTracker { +func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker { return &ConnTracker{ - isConnected: mocked, + connectionInfo: mocked, } } @@ -31,11 +37,17 @@ func (ct *ConnTracker) OnTunnelEvent(c connection.Event) { switch c.EventType { case connection.Connected: ct.Lock() - ct.isConnected[int(c.Index)] = true + ci := ConnectionInfo{ + IsConnected: true, + Protocol: c.Protocol, + } + ct.connectionInfo[c.Index] = ci ct.Unlock() case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering: ct.Lock() - ct.isConnected[int(c.Index)] = false + ci := ct.connectionInfo[c.Index] + ci.IsConnected = false + ct.connectionInfo[c.Index] = ci ct.Unlock() default: ct.log.Error().Msgf("Unknown connection event case %v", c) @@ -46,10 +58,23 @@ func (ct *ConnTracker) CountActiveConns() uint { ct.RLock() defer ct.RUnlock() active := uint(0) - for _, connected := range ct.isConnected { - if connected { + for _, ci := range ct.connectionInfo { + if ci.IsConnected { active++ } } return active } + +// HasConnectedWith checks if we've ever had a successful connection to the edge +// with said protocol. +func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool { + ct.RLock() + defer ct.RUnlock() + for _, ci := range ct.connectionInfo { + if ci.Protocol == protocol { + return true + } + } + return false +}