From dd540af695448ea1bfef28762060fb98455d455d Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Fri, 17 Jun 2022 17:24:37 -0700 Subject: [PATCH] TUN-6388: Fix first tunnel connection not retrying --- connection/control.go | 6 ++- connection/http2_test.go | 5 +++ connection/rpc.go | 6 ++- edgediscovery/edgediscovery.go | 15 ++++--- supervisor/supervisor.go | 73 +++++++++++++++++++++++++++------- supervisor/tunnel.go | 14 +++---- 6 files changed, 89 insertions(+), 30 deletions(-) diff --git a/connection/control.go b/connection/control.go index 1b28ceb1..a7d49772 100644 --- a/connection/control.go +++ b/connection/control.go @@ -3,6 +3,7 @@ package connection import ( "context" "io" + "net" "time" "github.com/rs/zerolog" @@ -19,6 +20,7 @@ type controlStream struct { connectedFuse ConnectedFuse namedTunnelProperties *NamedTunnelProperties connIndex uint8 + edgeAddress net.IP newRPCClientFunc RPCClientFunc @@ -45,6 +47,7 @@ func NewControlStream( connectedFuse ConnectedFuse, namedTunnelConfig *NamedTunnelProperties, connIndex uint8, + edgeAddress net.IP, newRPCClientFunc RPCClientFunc, gracefulShutdownC <-chan struct{}, gracePeriod time.Duration, @@ -58,6 +61,7 @@ func NewControlStream( namedTunnelProperties: namedTunnelConfig, newRPCClientFunc: newRPCClientFunc, connIndex: connIndex, + edgeAddress: edgeAddress, gracefulShutdownC: gracefulShutdownC, gracePeriod: gracePeriod, } @@ -71,7 +75,7 @@ func (c *controlStream) ServeControlStream( ) error { rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) - registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.observer) + registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.edgeAddress, c.observer) if err != nil { rpcClient.Close() return err diff --git a/connection/http2_test.go b/connection/http2_test.go index 18e688eb..82368d67 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -41,6 +41,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { connIndex, nil, nil, + nil, 1*time.Second, ) return NewHTTP2Connection( @@ -176,6 +177,7 @@ func (mc mockNamedTunnelRPCClient) RegisterConnection( properties *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, + edgeAddress net.IP, observer *Observer, ) (*tunnelpogs.ConnectionDetails, error) { if mc.shouldFail != nil { @@ -360,6 +362,7 @@ func TestServeControlStream(t *testing.T) { mockConnectedFuse{}, &NamedTunnelProperties{}, 1, + nil, rpcClientFactory.newMockRPCClient, nil, 1*time.Second, @@ -410,6 +413,7 @@ func TestFailRegistration(t *testing.T) { mockConnectedFuse{}, &NamedTunnelProperties{}, http2Conn.connIndex, + nil, rpcClientFactory.newMockRPCClient, nil, 1*time.Second, @@ -456,6 +460,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { mockConnectedFuse{}, &NamedTunnelProperties{}, http2Conn.connIndex, + nil, rpcClientFactory.newMockRPCClient, shutdownC, 1*time.Second, diff --git a/connection/rpc.go b/connection/rpc.go index a4444ea3..b288a0f8 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -58,6 +58,7 @@ type NamedTunnelRPCClient interface { config *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, + edgeAddress net.IP, observer *Observer, ) (*tunnelpogs.ConnectionDetails, error) SendLocalConfiguration( @@ -95,6 +96,7 @@ func (rsc *registrationServerClient) RegisterConnection( properties *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, + edgeAddress net.IP, observer *Observer, ) (*tunnelpogs.ConnectionDetails, error) { conn, err := rsc.client.RegisterConnection( @@ -115,7 +117,7 @@ func (rsc *registrationServerClient) RegisterConnection( observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() - observer.logServerInfo(connIndex, conn.Location, options.OriginLocalIP, fmt.Sprintf("Connection %s registered", conn.UUID)) + observer.logServerInfo(connIndex, conn.Location, edgeAddress, fmt.Sprintf("Connection %s registered", conn.UUID)) observer.sendConnectedEvent(connIndex, conn.Location) return conn, nil @@ -291,7 +293,7 @@ func (h *h2muxConnection) registerNamedTunnel( rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log) defer rpcClient.Close() - if _, err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { + if _, err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, nil, h.observer); err != nil { return err } return nil diff --git a/edgediscovery/edgediscovery.go b/edgediscovery/edgediscovery.go index b5f338e8..88788b10 100644 --- a/edgediscovery/edgediscovery.go +++ b/edgediscovery/edgediscovery.go @@ -1,7 +1,6 @@ package edgediscovery import ( - "fmt" "sync" "github.com/rs/zerolog" @@ -14,7 +13,13 @@ const ( LogFieldIPAddress = "ip" ) -var ErrNoAddressesLeft = fmt.Errorf("there are no free edge addresses left") +var errNoAddressesLeft = ErrNoAddressesLeft{} + +type ErrNoAddressesLeft struct{} + +func (e ErrNoAddressesLeft) Error() string { + return "there are no free edge addresses left to resolve to" +} // Edge finds addresses on the Cloudflare edge and hands them out to connections. type Edge struct { @@ -62,7 +67,7 @@ func (ed *Edge) GetAddrForRPC() (*allregions.EdgeAddr, error) { defer ed.Unlock() addr := ed.regions.GetAnyAddress() if addr == nil { - return nil, ErrNoAddressesLeft + return nil, errNoAddressesLeft } return addr, nil } @@ -83,7 +88,7 @@ func (ed *Edge) GetAddr(connIndex int) (*allregions.EdgeAddr, error) { addr := ed.regions.GetUnusedAddr(nil, connIndex) if addr == nil { log.Debug().Msg("edgediscovery - GetAddr: No addresses left to give proxy connection") - return nil, ErrNoAddressesLeft + return nil, errNoAddressesLeft } log = ed.log.With(). Int(LogFieldConnIndex, connIndex). @@ -107,7 +112,7 @@ func (ed *Edge) GetDifferentAddr(connIndex int, hasConnectivityError bool) (*all if addr == nil { log.Debug().Msg("edgediscovery - GetDifferentAddr: No addresses left to give proxy connection") // note: if oldAddr were not nil, it will become available on the next iteration - return nil, ErrNoAddressesLeft + return nil, errNoAddressesLeft } log = ed.log.With(). Int(LogFieldConnIndex, connIndex). diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 4bec3f56..c3f08844 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -4,9 +4,11 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/google/uuid" + "github.com/lucas-clemente/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/connection" @@ -37,13 +39,14 @@ const ( // Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and // reconnects them if they disconnect. type Supervisor struct { - cloudflaredUUID uuid.UUID - config *TunnelConfig - orchestrator *orchestration.Orchestrator - edgeIPs *edgediscovery.Edge - edgeTunnelServer EdgeTunnelServer - tunnelErrors chan tunnelError - tunnelsConnecting map[int]chan struct{} + cloudflaredUUID uuid.UUID + config *TunnelConfig + orchestrator *orchestration.Orchestrator + edgeIPs *edgediscovery.Edge + edgeTunnelServer EdgeTunnelServer + tunnelErrors chan tunnelError + tunnelsConnecting map[int]chan struct{} + tunnelsProtocolFallback map[int]*protocolFallback // nextConnectedIndex and nextConnectedSignal are used to wait for all // currently-connecting tunnels to finish connecting so we can reset backoff timer nextConnectedIndex int @@ -72,8 +75,10 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) } + isStaticEdge := len(config.EdgeAddrs) > 0 + var edgeIPs *edgediscovery.Edge - if len(config.EdgeAddrs) > 0 { + if isStaticEdge { // static edge addresses edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs) } else { edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion) @@ -86,7 +91,9 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato log := NewConnAwareLogger(config.Log, config.Observer) var edgeAddrHandler EdgeAddrHandler - if config.EdgeIPVersion == allregions.IPv6Only || config.EdgeIPVersion == allregions.Auto { + if isStaticEdge { // static edge addresses + edgeAddrHandler = &IPAddrFallback{} + } else if config.EdgeIPVersion == allregions.IPv6Only || config.EdgeIPVersion == allregions.Auto { edgeAddrHandler = &IPAddrFallback{} } else { // IPv4Only edgeAddrHandler = &DefaultAddrFallback{} @@ -117,6 +124,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato edgeTunnelServer: edgeTunnelServer, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, + tunnelsProtocolFallback: map[int]*protocolFallback{}, log: log, logTransport: config.LogTransport, reconnectCredentialManager: reconnectCredentialManager, @@ -178,6 +186,10 @@ func (s *Supervisor) Run( tunnelsActive++ continue } + // Make sure we don't continue if there is no more fallback allowed + if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry { + continue + } s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated") tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) s.waitForNextTunnel(tunnelError.index) @@ -232,6 +244,11 @@ func (s *Supervisor) initialize( s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) s.config.HAConnections = availableAddrs } + s.tunnelsProtocolFallback[0] = &protocolFallback{ + retry.BackoffHandler{MaxRetries: s.config.Retries}, + s.config.ProtocolSelector.Current(), + false, + } go s.startFirstTunnel(ctx, connectedSignal) @@ -249,6 +266,11 @@ func (s *Supervisor) initialize( // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { + s.tunnelsProtocolFallback[i] = &protocolFallback{ + retry.BackoffHandler{MaxRetries: s.config.Retries}, + s.config.ProtocolSelector.Current(), + false, + } ch := signal.New(make(chan struct{})) go s.startTunnel(ctx, i, ch) time.Sleep(registrationInterval) @@ -266,21 +288,44 @@ func (s *Supervisor) startFirstTunnel( err error ) const firstConnIndex = 0 + isStaticEdge := len(s.config.EdgeAddrs) > 0 defer func() { s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err} }() - err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, connectedSignal) - // If the first tunnel disconnects, keep restarting it. - for s.unusedIPs() { + for { + err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal) if ctx.Err() != nil { return } if err == nil { return } - err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, connectedSignal) + // Make sure we don't continue if there is no more fallback allowed + if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry { + return + } + // Try again for Unauthorized errors because we hope them to be + // transient due to edge propagation lag on new Tunnels. + if strings.Contains(err.Error(), "Unauthorized") { + continue + } + switch err.(type) { + case edgediscovery.ErrNoAddressesLeft: + // If your provided addresses are not available, we will keep trying regardless. + if !isStaticEdge { + return + } + case connection.DupConnRegisterTunnelError, + *quic.IdleTimeoutError, + edgediscovery.DialError, + *connection.EdgeQuicDialError: + // Try again for these types of errors + default: + // Uncaught errors should bail startup + return + } } } @@ -298,7 +343,7 @@ func (s *Supervisor) startTunnel( s.tunnelErrors <- tunnelError{index: index, err: err} }() - err = s.edgeTunnelServer.Serve(ctx, uint8(index), connectedSignal) + err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal) } func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index fcbace54..a92113e9 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -194,15 +194,10 @@ type EdgeTunnelServer struct { connAwareLogger *ConnAwareLogger } -func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, connectedSignal *signal.Signal) error { +func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error { haConnections.Inc() defer haConnections.Dec() - protocolFallback := &protocolFallback{ - retry.BackoffHandler{MaxRetries: e.config.Retries}, - e.config.ProtocolSelector.Current(), - false, - } connectedFuse := h2mux.NewBooleanFuse() go func() { if connectedFuse.Await() { @@ -214,7 +209,7 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, connectedS // Fetch IP address to associated connection index addr, err := e.edgeAddrs.GetAddr(int(connIndex)) - switch err { + switch err.(type) { case nil: // no error case edgediscovery.ErrNoAddressesLeft: return err @@ -262,7 +257,9 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, connectedS // establishing a connection to the edge and if so, rotate the IP address. yes, hasConnectivityError := e.edgeAddrHandler.ShouldGetNewAddress(err) if yes { - e.edgeAddrs.GetDifferentAddr(int(connIndex), hasConnectivityError) + if _, err := e.edgeAddrs.GetDifferentAddr(int(connIndex), hasConnectivityError); err != nil { + return err + } } select { @@ -461,6 +458,7 @@ func serveTunnel( connectedFuse, config.NamedTunnel, connIndex, + addr.UDP.IP, nil, gracefulShutdownC, config.GracePeriod,