TUN-6388: Fix first tunnel connection not retrying

This commit is contained in:
Devin Carr 2022-06-17 17:24:37 -07:00
parent e921ab35d5
commit dd540af695
6 changed files with 89 additions and 30 deletions

View File

@ -3,6 +3,7 @@ package connection
import ( import (
"context" "context"
"io" "io"
"net"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -19,6 +20,7 @@ type controlStream struct {
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
namedTunnelProperties *NamedTunnelProperties namedTunnelProperties *NamedTunnelProperties
connIndex uint8 connIndex uint8
edgeAddress net.IP
newRPCClientFunc RPCClientFunc newRPCClientFunc RPCClientFunc
@ -45,6 +47,7 @@ func NewControlStream(
connectedFuse ConnectedFuse, connectedFuse ConnectedFuse,
namedTunnelConfig *NamedTunnelProperties, namedTunnelConfig *NamedTunnelProperties,
connIndex uint8, connIndex uint8,
edgeAddress net.IP,
newRPCClientFunc RPCClientFunc, newRPCClientFunc RPCClientFunc,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
gracePeriod time.Duration, gracePeriod time.Duration,
@ -58,6 +61,7 @@ func NewControlStream(
namedTunnelProperties: namedTunnelConfig, namedTunnelProperties: namedTunnelConfig,
newRPCClientFunc: newRPCClientFunc, newRPCClientFunc: newRPCClientFunc,
connIndex: connIndex, connIndex: connIndex,
edgeAddress: edgeAddress,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
gracePeriod: gracePeriod, gracePeriod: gracePeriod,
} }
@ -71,7 +75,7 @@ func (c *controlStream) ServeControlStream(
) error { ) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.observer) registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.edgeAddress, c.observer)
if err != nil { if err != nil {
rpcClient.Close() rpcClient.Close()
return err return err

View File

@ -41,6 +41,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
connIndex, connIndex,
nil, nil,
nil, nil,
nil,
1*time.Second, 1*time.Second,
) )
return NewHTTP2Connection( return NewHTTP2Connection(
@ -176,6 +177,7 @@ func (mc mockNamedTunnelRPCClient) RegisterConnection(
properties *NamedTunnelProperties, properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
edgeAddress net.IP,
observer *Observer, observer *Observer,
) (*tunnelpogs.ConnectionDetails, error) { ) (*tunnelpogs.ConnectionDetails, error) {
if mc.shouldFail != nil { if mc.shouldFail != nil {
@ -360,6 +362,7 @@ func TestServeControlStream(t *testing.T) {
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &NamedTunnelProperties{},
1, 1,
nil,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
nil, nil,
1*time.Second, 1*time.Second,
@ -410,6 +413,7 @@ func TestFailRegistration(t *testing.T) {
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &NamedTunnelProperties{},
http2Conn.connIndex, http2Conn.connIndex,
nil,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
nil, nil,
1*time.Second, 1*time.Second,
@ -456,6 +460,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &NamedTunnelProperties{},
http2Conn.connIndex, http2Conn.connIndex,
nil,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
shutdownC, shutdownC,
1*time.Second, 1*time.Second,

View File

@ -58,6 +58,7 @@ type NamedTunnelRPCClient interface {
config *NamedTunnelProperties, config *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
edgeAddress net.IP,
observer *Observer, observer *Observer,
) (*tunnelpogs.ConnectionDetails, error) ) (*tunnelpogs.ConnectionDetails, error)
SendLocalConfiguration( SendLocalConfiguration(
@ -95,6 +96,7 @@ func (rsc *registrationServerClient) RegisterConnection(
properties *NamedTunnelProperties, properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
edgeAddress net.IP,
observer *Observer, observer *Observer,
) (*tunnelpogs.ConnectionDetails, error) { ) (*tunnelpogs.ConnectionDetails, error) {
conn, err := rsc.client.RegisterConnection( conn, err := rsc.client.RegisterConnection(
@ -115,7 +117,7 @@ func (rsc *registrationServerClient) RegisterConnection(
observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
observer.logServerInfo(connIndex, conn.Location, options.OriginLocalIP, fmt.Sprintf("Connection %s registered", conn.UUID)) observer.logServerInfo(connIndex, conn.Location, edgeAddress, fmt.Sprintf("Connection %s registered", conn.UUID))
observer.sendConnectedEvent(connIndex, conn.Location) observer.sendConnectedEvent(connIndex, conn.Location)
return conn, nil return conn, nil
@ -291,7 +293,7 @@ func (h *h2muxConnection) registerNamedTunnel(
rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log) rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
defer rpcClient.Close() defer rpcClient.Close()
if _, err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { if _, err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, nil, h.observer); err != nil {
return err return err
} }
return nil return nil

View File

@ -1,7 +1,6 @@
package edgediscovery package edgediscovery
import ( import (
"fmt"
"sync" "sync"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -14,7 +13,13 @@ const (
LogFieldIPAddress = "ip" LogFieldIPAddress = "ip"
) )
var ErrNoAddressesLeft = fmt.Errorf("there are no free edge addresses left") var errNoAddressesLeft = ErrNoAddressesLeft{}
type ErrNoAddressesLeft struct{}
func (e ErrNoAddressesLeft) Error() string {
return "there are no free edge addresses left to resolve to"
}
// Edge finds addresses on the Cloudflare edge and hands them out to connections. // Edge finds addresses on the Cloudflare edge and hands them out to connections.
type Edge struct { type Edge struct {
@ -62,7 +67,7 @@ func (ed *Edge) GetAddrForRPC() (*allregions.EdgeAddr, error) {
defer ed.Unlock() defer ed.Unlock()
addr := ed.regions.GetAnyAddress() addr := ed.regions.GetAnyAddress()
if addr == nil { if addr == nil {
return nil, ErrNoAddressesLeft return nil, errNoAddressesLeft
} }
return addr, nil return addr, nil
} }
@ -83,7 +88,7 @@ func (ed *Edge) GetAddr(connIndex int) (*allregions.EdgeAddr, error) {
addr := ed.regions.GetUnusedAddr(nil, connIndex) addr := ed.regions.GetUnusedAddr(nil, connIndex)
if addr == nil { if addr == nil {
log.Debug().Msg("edgediscovery - GetAddr: No addresses left to give proxy connection") log.Debug().Msg("edgediscovery - GetAddr: No addresses left to give proxy connection")
return nil, ErrNoAddressesLeft return nil, errNoAddressesLeft
} }
log = ed.log.With(). log = ed.log.With().
Int(LogFieldConnIndex, connIndex). Int(LogFieldConnIndex, connIndex).
@ -107,7 +112,7 @@ func (ed *Edge) GetDifferentAddr(connIndex int, hasConnectivityError bool) (*all
if addr == nil { if addr == nil {
log.Debug().Msg("edgediscovery - GetDifferentAddr: No addresses left to give proxy connection") log.Debug().Msg("edgediscovery - GetDifferentAddr: No addresses left to give proxy connection")
// note: if oldAddr were not nil, it will become available on the next iteration // note: if oldAddr were not nil, it will become available on the next iteration
return nil, ErrNoAddressesLeft return nil, errNoAddressesLeft
} }
log = ed.log.With(). log = ed.log.With().
Int(LogFieldConnIndex, connIndex). Int(LogFieldConnIndex, connIndex).

View File

@ -4,9 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
@ -37,13 +39,14 @@ const (
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and // Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
// reconnects them if they disconnect. // reconnects them if they disconnect.
type Supervisor struct { type Supervisor struct {
cloudflaredUUID uuid.UUID cloudflaredUUID uuid.UUID
config *TunnelConfig config *TunnelConfig
orchestrator *orchestration.Orchestrator orchestrator *orchestration.Orchestrator
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
edgeTunnelServer EdgeTunnelServer edgeTunnelServer EdgeTunnelServer
tunnelErrors chan tunnelError tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{} tunnelsConnecting map[int]chan struct{}
tunnelsProtocolFallback map[int]*protocolFallback
// nextConnectedIndex and nextConnectedSignal are used to wait for all // nextConnectedIndex and nextConnectedSignal are used to wait for all
// currently-connecting tunnels to finish connecting so we can reset backoff timer // currently-connecting tunnels to finish connecting so we can reset backoff timer
nextConnectedIndex int nextConnectedIndex int
@ -72,8 +75,10 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
} }
isStaticEdge := len(config.EdgeAddrs) > 0
var edgeIPs *edgediscovery.Edge var edgeIPs *edgediscovery.Edge
if len(config.EdgeAddrs) > 0 { if isStaticEdge { // static edge addresses
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs) edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
} else { } else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion) edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
@ -86,7 +91,9 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
log := NewConnAwareLogger(config.Log, config.Observer) log := NewConnAwareLogger(config.Log, config.Observer)
var edgeAddrHandler EdgeAddrHandler var edgeAddrHandler EdgeAddrHandler
if config.EdgeIPVersion == allregions.IPv6Only || config.EdgeIPVersion == allregions.Auto { if isStaticEdge { // static edge addresses
edgeAddrHandler = &IPAddrFallback{}
} else if config.EdgeIPVersion == allregions.IPv6Only || config.EdgeIPVersion == allregions.Auto {
edgeAddrHandler = &IPAddrFallback{} edgeAddrHandler = &IPAddrFallback{}
} else { // IPv4Only } else { // IPv4Only
edgeAddrHandler = &DefaultAddrFallback{} edgeAddrHandler = &DefaultAddrFallback{}
@ -117,6 +124,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
edgeTunnelServer: edgeTunnelServer, edgeTunnelServer: edgeTunnelServer,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
tunnelsProtocolFallback: map[int]*protocolFallback{},
log: log, log: log,
logTransport: config.LogTransport, logTransport: config.LogTransport,
reconnectCredentialManager: reconnectCredentialManager, reconnectCredentialManager: reconnectCredentialManager,
@ -178,6 +186,10 @@ func (s *Supervisor) Run(
tunnelsActive++ tunnelsActive++
continue continue
} }
// Make sure we don't continue if there is no more fallback allowed
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
continue
}
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated") s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index) s.waitForNextTunnel(tunnelError.index)
@ -232,6 +244,11 @@ func (s *Supervisor) initialize(
s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs s.config.HAConnections = availableAddrs
} }
s.tunnelsProtocolFallback[0] = &protocolFallback{
retry.BackoffHandler{MaxRetries: s.config.Retries},
s.config.ProtocolSelector.Current(),
false,
}
go s.startFirstTunnel(ctx, connectedSignal) go s.startFirstTunnel(ctx, connectedSignal)
@ -249,6 +266,11 @@ func (s *Supervisor) initialize(
// At least one successful connection, so start the rest // At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ { for i := 1; i < s.config.HAConnections; i++ {
s.tunnelsProtocolFallback[i] = &protocolFallback{
retry.BackoffHandler{MaxRetries: s.config.Retries},
s.config.ProtocolSelector.Current(),
false,
}
ch := signal.New(make(chan struct{})) ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch) go s.startTunnel(ctx, i, ch)
time.Sleep(registrationInterval) time.Sleep(registrationInterval)
@ -266,21 +288,44 @@ func (s *Supervisor) startFirstTunnel(
err error err error
) )
const firstConnIndex = 0 const firstConnIndex = 0
isStaticEdge := len(s.config.EdgeAddrs) > 0
defer func() { defer func() {
s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err} s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
}() }()
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, connectedSignal)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
for s.unusedIPs() { for {
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
if err == nil { if err == nil {
return return
} }
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, connectedSignal) // Make sure we don't continue if there is no more fallback allowed
if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
return
}
// Try again for Unauthorized errors because we hope them to be
// transient due to edge propagation lag on new Tunnels.
if strings.Contains(err.Error(), "Unauthorized") {
continue
}
switch err.(type) {
case edgediscovery.ErrNoAddressesLeft:
// If your provided addresses are not available, we will keep trying regardless.
if !isStaticEdge {
return
}
case connection.DupConnRegisterTunnelError,
*quic.IdleTimeoutError,
edgediscovery.DialError,
*connection.EdgeQuicDialError:
// Try again for these types of errors
default:
// Uncaught errors should bail startup
return
}
} }
} }
@ -298,7 +343,7 @@ func (s *Supervisor) startTunnel(
s.tunnelErrors <- tunnelError{index: index, err: err} s.tunnelErrors <- tunnelError{index: index, err: err}
}() }()
err = s.edgeTunnelServer.Serve(ctx, uint8(index), connectedSignal) err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
} }
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {

View File

@ -194,15 +194,10 @@ type EdgeTunnelServer struct {
connAwareLogger *ConnAwareLogger connAwareLogger *ConnAwareLogger
} }
func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, connectedSignal *signal.Signal) error { func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
haConnections.Inc() haConnections.Inc()
defer haConnections.Dec() defer haConnections.Dec()
protocolFallback := &protocolFallback{
retry.BackoffHandler{MaxRetries: e.config.Retries},
e.config.ProtocolSelector.Current(),
false,
}
connectedFuse := h2mux.NewBooleanFuse() connectedFuse := h2mux.NewBooleanFuse()
go func() { go func() {
if connectedFuse.Await() { if connectedFuse.Await() {
@ -214,7 +209,7 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, connectedS
// Fetch IP address to associated connection index // Fetch IP address to associated connection index
addr, err := e.edgeAddrs.GetAddr(int(connIndex)) addr, err := e.edgeAddrs.GetAddr(int(connIndex))
switch err { switch err.(type) {
case nil: // no error case nil: // no error
case edgediscovery.ErrNoAddressesLeft: case edgediscovery.ErrNoAddressesLeft:
return err return err
@ -262,7 +257,9 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, connectedS
// establishing a connection to the edge and if so, rotate the IP address. // establishing a connection to the edge and if so, rotate the IP address.
yes, hasConnectivityError := e.edgeAddrHandler.ShouldGetNewAddress(err) yes, hasConnectivityError := e.edgeAddrHandler.ShouldGetNewAddress(err)
if yes { if yes {
e.edgeAddrs.GetDifferentAddr(int(connIndex), hasConnectivityError) if _, err := e.edgeAddrs.GetDifferentAddr(int(connIndex), hasConnectivityError); err != nil {
return err
}
} }
select { select {
@ -461,6 +458,7 @@ func serveTunnel(
connectedFuse, connectedFuse,
config.NamedTunnel, config.NamedTunnel,
connIndex, connIndex,
addr.UDP.IP,
nil, nil,
gracefulShutdownC, gracefulShutdownC,
config.GracePeriod, config.GracePeriod,