Revert "TUN-6617: Dont fallback to http2 if QUIC conn was successful."

This reverts commit 679a89c7df.
This commit is contained in:
Sudarsan Reddy 2022-08-11 20:27:22 +01:00
parent 68d370af19
commit d3fd581b7b
11 changed files with 87 additions and 133 deletions

View File

@ -2,7 +2,6 @@ package connection
import ( import (
"context" "context"
"fmt"
"io" "io"
"net" "net"
"time" "time"
@ -22,7 +21,6 @@ type controlStream struct {
namedTunnelProperties *NamedTunnelProperties namedTunnelProperties *NamedTunnelProperties
connIndex uint8 connIndex uint8
edgeAddress net.IP edgeAddress net.IP
protocol Protocol
newRPCClientFunc RPCClientFunc newRPCClientFunc RPCClientFunc
@ -53,7 +51,6 @@ func NewControlStream(
newRPCClientFunc RPCClientFunc, newRPCClientFunc RPCClientFunc,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
gracePeriod time.Duration, gracePeriod time.Duration,
protocol Protocol,
) ControlStreamHandler { ) ControlStreamHandler {
if newRPCClientFunc == nil { if newRPCClientFunc == nil {
newRPCClientFunc = newRegistrationRPCClient newRPCClientFunc = newRegistrationRPCClient
@ -67,7 +64,6 @@ func NewControlStream(
edgeAddress: edgeAddress, edgeAddress: edgeAddress,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
gracePeriod: gracePeriod, gracePeriod: gracePeriod,
protocol: protocol,
} }
} }
@ -84,9 +80,6 @@ func (c *controlStream) ServeControlStream(
rpcClient.Close() rpcClient.Close()
return err return err
} }
c.observer.logServerInfo(c.connIndex, registrationDetails.Location, c.edgeAddress, fmt.Sprintf("Connection %s registered", registrationDetails.UUID))
c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location)
c.connectedFuse.Connected() c.connectedFuse.Connected()
// if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration

View File

@ -5,7 +5,6 @@ type Event struct {
Index uint8 Index uint8
EventType Status EventType Status
Location string Location string
Protocol Protocol
URL string URL string
} }

View File

@ -43,7 +43,6 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
nil, nil,
nil, nil,
1*time.Second, 1*time.Second,
HTTP2,
) )
return NewHTTP2Connection( return NewHTTP2Connection(
cfdConn, cfdConn,
@ -367,7 +366,6 @@ func TestServeControlStream(t *testing.T) {
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
nil, nil,
1*time.Second, 1*time.Second,
HTTP2,
) )
http2Conn.controlStreamHandler = controlStream http2Conn.controlStreamHandler = controlStream
@ -419,7 +417,6 @@ func TestFailRegistration(t *testing.T) {
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
nil, nil,
1*time.Second, 1*time.Second,
HTTP2,
) )
http2Conn.controlStreamHandler = controlStream http2Conn.controlStreamHandler = controlStream
@ -467,7 +464,6 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
shutdownC, shutdownC,
1*time.Second, 1*time.Second,
HTTP2,
) )
http2Conn.controlStreamHandler = controlStream http2Conn.controlStreamHandler = controlStream

View File

@ -55,8 +55,8 @@ func (o *Observer) sendRegisteringEvent(connIndex uint8) {
o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel}) o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel})
} }
func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string) { func (o *Observer) sendConnectedEvent(connIndex uint8, location string) {
o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location}) o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location})
} }
func (o *Observer) SendURL(url string) { func (o *Observer) SendURL(url string) {

View File

@ -81,63 +81,63 @@ func TestQUICServer(t *testing.T) {
}, },
expectedResponse: []byte("OK"), expectedResponse: []byte("OK"),
}, },
{ //{
desc: "test http body request streaming", // desc: "test http body request streaming",
dest: "/slow_echo_body", // dest: "/slow_echo_body",
connectionType: quicpogs.ConnectionTypeHTTP, // connectionType: quicpogs.ConnectionTypeHTTP,
metadata: []quicpogs.Metadata{ // metadata: []quicpogs.Metadata{
{ // {
Key: "HttpHeader:Cf-Ray", // Key: "HttpHeader:Cf-Ray",
Val: "123123123", // Val: "123123123",
}, // },
{ // {
Key: "HttpHost", // Key: "HttpHost",
Val: "cf.host", // Val: "cf.host",
}, // },
{ // {
Key: "HttpMethod", // Key: "HttpMethod",
Val: "POST", // Val: "POST",
}, // },
{ // {
Key: "HttpHeader:Content-Length", // Key: "HttpHeader:Content-Length",
Val: "24", // Val: "24",
}, // },
}, // },
message: []byte("This is the message body"), // message: []byte("This is the message body"),
expectedResponse: []byte("This is the message body"), // expectedResponse: []byte("This is the message body"),
}, //},
{ //{
desc: "test ws proxy", // desc: "test ws proxy",
dest: "/ws/echo", // dest: "/ws/echo",
connectionType: quicpogs.ConnectionTypeWebsocket, // connectionType: quicpogs.ConnectionTypeWebsocket,
metadata: []quicpogs.Metadata{ // metadata: []quicpogs.Metadata{
{ // {
Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade", // Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade",
Val: "Websocket", // Val: "Websocket",
}, // },
{ // {
Key: "HttpHeader:Another-Header", // Key: "HttpHeader:Another-Header",
Val: "Misc", // Val: "Misc",
}, // },
{ // {
Key: "HttpHost", // Key: "HttpHost",
Val: "cf.host", // Val: "cf.host",
}, // },
{ // {
Key: "HttpMethod", // Key: "HttpMethod",
Val: "get", // Val: "get",
}, // },
}, // },
message: wsBuf.Bytes(), // message: wsBuf.Bytes(),
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, // expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
}, //},
{ //{
desc: "test tcp proxy", // desc: "test tcp proxy",
connectionType: quicpogs.ConnectionTypeTCP, // connectionType: quicpogs.ConnectionTypeTCP,
metadata: []quicpogs.Metadata{}, // metadata: []quicpogs.Metadata{},
message: []byte("Here is some tcp data"), // message: []byte("Here is some tcp data"),
expectedResponse: []byte("Here is some tcp data"), // expectedResponse: []byte("Here is some tcp data"),
}, //},
} }
for _, test := range tests { for _, test := range tests {

View File

@ -2,6 +2,7 @@ package connection
import ( import (
"context" "context"
"fmt"
"io" "io"
"net" "net"
"time" "time"
@ -116,6 +117,9 @@ func (rsc *registrationServerClient) RegisterConnection(
observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
observer.logServerInfo(connIndex, conn.Location, edgeAddress, fmt.Sprintf("Connection %s registered", conn.UUID))
observer.sendConnectedEvent(connIndex, conn.Location)
return conn, nil return conn, nil
} }

View File

@ -14,7 +14,7 @@ import (
func TestReadyServer_makeResponse(t *testing.T) { func TestReadyServer_makeResponse(t *testing.T) {
type fields struct { type fields struct {
isConnected map[uint8]tunnelstate.ConnectionInfo isConnected map[int]bool
} }
tests := []struct { tests := []struct {
name string name string
@ -25,11 +25,11 @@ func TestReadyServer_makeResponse(t *testing.T) {
{ {
name: "One connection online => HTTP 200", name: "One connection online => HTTP 200",
fields: fields{ fields: fields{
isConnected: map[uint8]tunnelstate.ConnectionInfo{ isConnected: map[int]bool{
0: {IsConnected: false}, 0: false,
1: {IsConnected: false}, 1: false,
2: {IsConnected: true}, 2: true,
3: {IsConnected: false}, 3: false,
}, },
}, },
wantOK: true, wantOK: true,
@ -38,11 +38,11 @@ func TestReadyServer_makeResponse(t *testing.T) {
{ {
name: "No connections online => no HTTP 200", name: "No connections online => no HTTP 200",
fields: fields{ fields: fields{
isConnected: map[uint8]tunnelstate.ConnectionInfo{ isConnected: map[int]bool{
0: {IsConnected: false}, 0: false,
1: {IsConnected: false}, 1: false,
2: {IsConnected: false}, 2: false,
3: {IsConnected: false}, 3: false,
}, },
}, },
wantReadyConnections: 0, wantReadyConnections: 0,

View File

@ -12,9 +12,9 @@ type ConnAwareLogger struct {
logger *zerolog.Logger logger *zerolog.Logger
} }
func NewConnAwareLogger(logger *zerolog.Logger, tracker *tunnelstate.ConnTracker, observer *connection.Observer) *ConnAwareLogger { func NewConnAwareLogger(logger *zerolog.Logger, observer *connection.Observer) *ConnAwareLogger {
connAwareLogger := &ConnAwareLogger{ connAwareLogger := &ConnAwareLogger{
tracker: tracker, tracker: tunnelstate.NewConnTracker(logger),
logger: logger, logger: logger,
} }

View File

@ -19,7 +19,6 @@ import (
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstate"
) )
const ( const (
@ -89,9 +88,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
} }
reconnectCredentialManager := newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections) reconnectCredentialManager := newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections)
log := NewConnAwareLogger(config.Log, config.Observer)
tracker := tunnelstate.NewConnTracker(config.Log)
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
var edgeAddrHandler EdgeAddrHandler var edgeAddrHandler EdgeAddrHandler
if isStaticEdge { // static edge addresses if isStaticEdge { // static edge addresses
@ -109,7 +106,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
credentialManager: reconnectCredentialManager, credentialManager: reconnectCredentialManager,
edgeAddrs: edgeIPs, edgeAddrs: edgeIPs,
edgeAddrHandler: edgeAddrHandler, edgeAddrHandler: edgeAddrHandler,
tracker: tracker,
reconnectCh: reconnectCh, reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
connAwareLogger: log, connAwareLogger: log,

View File

@ -26,7 +26,6 @@ import (
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstate"
) )
const ( const (
@ -191,7 +190,6 @@ type EdgeTunnelServer struct {
edgeAddrs *edgediscovery.Edge edgeAddrs *edgediscovery.Edge
reconnectCh chan ReconnectSignal reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{} gracefulShutdownC <-chan struct{}
tracker *tunnelstate.ConnTracker
connAwareLogger *ConnAwareLogger connAwareLogger *ConnAwareLogger
} }
@ -274,12 +272,6 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFa
return err return err
} }
// If a single connection has connected with the current protocol, we know we know we don't have to fallback
// to a different protocol.
if e.tracker.HasConnectedWith(e.config.ProtocolSelector.Current()) {
return err
}
if !selectNextProtocol( if !selectNextProtocol(
connLog.Logger(), connLog.Logger(),
protocolFallback, protocolFallback,
@ -469,7 +461,6 @@ func serveTunnel(
nil, nil,
gracefulShutdownC, gracefulShutdownC,
config.GracePeriod, config.GracePeriod,
protocol,
) )
switch protocol { switch protocol {

View File

@ -10,26 +10,20 @@ import (
type ConnTracker struct { type ConnTracker struct {
sync.RWMutex sync.RWMutex
// int is the connection Index isConnected map[int]bool
connectionInfo map[uint8]ConnectionInfo log *zerolog.Logger
log *zerolog.Logger
}
type ConnectionInfo struct {
IsConnected bool
Protocol connection.Protocol
} }
func NewConnTracker(log *zerolog.Logger) *ConnTracker { func NewConnTracker(log *zerolog.Logger) *ConnTracker {
return &ConnTracker{ return &ConnTracker{
connectionInfo: make(map[uint8]ConnectionInfo, 0), isConnected: make(map[int]bool, 0),
log: log, log: log,
} }
} }
func MockedConnTracker(mocked map[uint8]ConnectionInfo) *ConnTracker { func MockedConnTracker(mocked map[int]bool) *ConnTracker {
return &ConnTracker{ return &ConnTracker{
connectionInfo: mocked, isConnected: mocked,
} }
} }
@ -37,17 +31,11 @@ func (ct *ConnTracker) OnTunnelEvent(c connection.Event) {
switch c.EventType { switch c.EventType {
case connection.Connected: case connection.Connected:
ct.Lock() ct.Lock()
ci := ConnectionInfo{ ct.isConnected[int(c.Index)] = true
IsConnected: true,
Protocol: c.Protocol,
}
ct.connectionInfo[c.Index] = ci
ct.Unlock() ct.Unlock()
case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering: case connection.Disconnected, connection.Reconnecting, connection.RegisteringTunnel, connection.Unregistering:
ct.Lock() ct.Lock()
ci := ct.connectionInfo[c.Index] ct.isConnected[int(c.Index)] = false
ci.IsConnected = false
ct.connectionInfo[c.Index] = ci
ct.Unlock() ct.Unlock()
default: default:
ct.log.Error().Msgf("Unknown connection event case %v", c) ct.log.Error().Msgf("Unknown connection event case %v", c)
@ -58,23 +46,10 @@ func (ct *ConnTracker) CountActiveConns() uint {
ct.RLock() ct.RLock()
defer ct.RUnlock() defer ct.RUnlock()
active := uint(0) active := uint(0)
for _, ci := range ct.connectionInfo { for _, connected := range ct.isConnected {
if ci.IsConnected { if connected {
active++ active++
} }
} }
return active return active
} }
// HasConnectedWith checks if we've ever had a successful connection to the edge
// with said protocol.
func (ct *ConnTracker) HasConnectedWith(protocol connection.Protocol) bool {
ct.RLock()
defer ct.RUnlock()
for _, ci := range ct.connectionInfo {
if ci.Protocol == protocol {
return true
}
}
return false
}