From 0b16a473daeddc88dc14ba657514c7e51e9718b1 Mon Sep 17 00:00:00 2001 From: Igor Postelnik Date: Thu, 4 Feb 2021 18:07:49 -0600 Subject: [PATCH] TUN-3869: Improve reliability of graceful shutdown. - Don't rely on edge to close connection on graceful shutdown in h2mux, start muxer shutdown from cloudflared. - Don't retry failed connections after graceful shutdown has started. - After graceful shutdown channel is closed we stop waiting for retry timer and don't try to restart tunnel loop. - Use readonly channel for graceful shutdown in functions that only consume the signal --- connection/h2mux.go | 10 +++- connection/http2.go | 4 +- connection/http2_test.go | 5 +- origin/supervisor.go | 16 ++++-- origin/tunnel.go | 102 +++++++++++++++++++-------------------- origin/tunnel_test.go | 41 +++++++--------- 6 files changed, 95 insertions(+), 83 deletions(-) diff --git a/connection/h2mux.go b/connection/h2mux.go index 09d9b556..5d9ec068 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -30,7 +30,7 @@ type h2muxConnection struct { connIndex uint8 observer *Observer - gracefulShutdownC chan struct{} + gracefulShutdownC <-chan struct{} stoppedGracefully bool // newRPCClientFunc allows us to mock RPCs during testing @@ -63,7 +63,7 @@ func NewH2muxConnection( edgeConn net.Conn, connIndex uint8, observer *Observer, - gracefulShutdownC chan struct{}, + gracefulShutdownC <-chan struct{}, ) (*h2muxConnection, error, bool) { h := &h2muxConnection{ config: config, @@ -168,6 +168,7 @@ func (h *h2muxConnection) serveMuxer(ctx context.Context) error { func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) { updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq) + var shutdownCompleted <-chan struct{} for { select { case <-h.gracefulShutdownC: @@ -176,6 +177,10 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect } h.stoppedGracefully = true h.gracefulShutdownC = nil + shutdownCompleted = h.muxer.Shutdown() + + case <-shutdownCompleted: + return case <-ctx.Done(): // UnregisterTunnel blocks until the RPC call returns @@ -183,6 +188,7 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect h.unregister(isNamedTunnel) } h.muxer.Shutdown() + // don't wait for shutdown to finish when context is closed, this is the hard termination path return case <-updateMetricsTickC: diff --git a/connection/http2.go b/connection/http2.go index 9c4e4522..ab303294 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -39,7 +39,7 @@ type http2Connection struct { activeRequestsWG sync.WaitGroup connectedFuse ConnectedFuse - gracefulShutdownC chan struct{} + gracefulShutdownC <-chan struct{} stoppedGracefully bool controlStreamErr error // result of running control stream handler } @@ -52,7 +52,7 @@ func NewHTTP2Connection( observer *Observer, connIndex uint8, connectedFuse ConnectedFuse, - gracefulShutdownC chan struct{}, + gracefulShutdownC <-chan struct{}, ) *http2Connection { return &http2Connection{ conn: conn, diff --git a/connection/http2_test.go b/connection/http2_test.go index 96b705cf..13f9da00 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -257,7 +257,8 @@ func TestGracefulShutdownHTTP2(t *testing.T) { unregistered: make(chan struct{}), } http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient - http2Conn.gracefulShutdownC = make(chan struct{}) + shutdownC := make(chan struct{}) + http2Conn.gracefulShutdownC = shutdownC ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup @@ -288,7 +289,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { } // signal graceful shutdown - close(http2Conn.gracefulShutdownC) + close(shutdownC) select { case <-rpcClientFactory.unregistered: diff --git a/origin/supervisor.go b/origin/supervisor.go index 8397e2fa..39394e95 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -58,16 +58,18 @@ type Supervisor struct { useReconnectToken bool reconnectCh chan ReconnectSignal - gracefulShutdownC chan struct{} + gracefulShutdownC <-chan struct{} } +var errEarlyShutdown = errors.New("shutdown started") + type tunnelError struct { index int addr *net.TCPAddr err error } -func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) { +func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { cloudflaredUUID, err := uuid.NewRandom() if err != nil { return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) @@ -108,6 +110,9 @@ func (s *Supervisor) Run( connectedSignal *signal.Signal, ) error { if err := s.initialize(ctx, connectedSignal); err != nil { + if err == errEarlyShutdown { + return nil + } return err } var tunnelsWaiting []int @@ -130,6 +135,7 @@ func (s *Supervisor) Run( } } + shuttingDown := false for { select { // Context cancelled @@ -143,7 +149,7 @@ func (s *Supervisor) Run( // (note that this may also be caused by context cancellation) case tunnelError := <-s.tunnelErrors: tunnelsActive-- - if tunnelError.err != nil { + if tunnelError.err != nil && !shuttingDown { s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated") tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) s.waitForNextTunnel(tunnelError.index) @@ -181,6 +187,8 @@ func (s *Supervisor) Run( // No more tunnels outstanding, clear backoff timer backoff.SetGracePeriod() } + case <-s.gracefulShutdownC: + shuttingDown = true } } } @@ -203,6 +211,8 @@ func (s *Supervisor) initialize( return ctx.Err() case tunnelError := <-s.tunnelErrors: return tunnelError.err + case <-s.gracefulShutdownC: + return errEarlyShutdown case <-connectedSignal.Wait(): } // At least one successful connection, so start the rest diff --git a/origin/tunnel.go b/origin/tunnel.go index 7802cfc3..260942ad 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -113,7 +113,7 @@ func StartTunnelDaemon( config *TunnelConfig, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal, - graceShutdownC chan struct{}, + graceShutdownC <-chan struct{}, ) error { s, err := NewSupervisor(config, reconnectCh, graceShutdownC) if err != nil { @@ -131,14 +131,14 @@ func ServeTunnelLoop( connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, - gracefulShutdownC chan struct{}, + gracefulShutdownC <-chan struct{}, ) error { haConnections.Inc() defer haConnections.Dec() connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger() - protocallFallback := &protocallFallback{ + protocolFallback := &protocolFallback{ BackoffHandler{MaxRetries: config.Retries}, config.ProtocolSelector.Current(), false, @@ -162,82 +162,82 @@ func ServeTunnelLoop( addr, connIndex, connectedFuse, - protocallFallback, + protocolFallback, cloudflaredUUID, reconnectCh, - protocallFallback.protocol, + protocolFallback.protocol, gracefulShutdownC, ) if !recoverable { return err } - err = waitForBackoff(ctx, &connLog, protocallFallback, config, connIndex, err) - if err != nil { + config.Observer.SendReconnect(connIndex) + + duration, ok := protocolFallback.GetBackoffDuration(ctx) + if !ok { return err } + connLog.Info().Msgf("Retrying connection in %s seconds", duration) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-gracefulShutdownC: + return nil + case <-protocolFallback.BackoffTimer(): + if !selectNextProtocol(&connLog, protocolFallback, config.ProtocolSelector) { + return err + } + } } } -// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches +// protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches // max retries -type protocallFallback struct { +type protocolFallback struct { BackoffHandler protocol connection.Protocol inFallback bool } -func (pf *protocallFallback) reset() { +func (pf *protocolFallback) reset() { pf.resetNow() pf.inFallback = false } -func (pf *protocallFallback) fallback(fallback connection.Protocol) { +func (pf *protocolFallback) fallback(fallback connection.Protocol) { pf.resetNow() pf.protocol = fallback pf.inFallback = true } -// Expect err to always be non nil -func waitForBackoff( - ctx context.Context, - log *zerolog.Logger, - protobackoff *protocallFallback, - config *TunnelConfig, - connIndex uint8, - err error, -) error { - duration, ok := protobackoff.GetBackoffDuration(ctx) - if !ok { - return err - } - - config.Observer.SendReconnect(connIndex) - log.Info(). - Err(err). - Uint8(connection.LogFieldConnIndex, connIndex). - Msgf("Retrying connection in %s seconds", duration) - protobackoff.Backoff(ctx) - - if protobackoff.ReachedMaxRetries() { - fallback, hasFallback := config.ProtocolSelector.Fallback() +// selectNextProtocol picks connection protocol for the next retry iteration, +// returns true if it was able to pick the protocol, false if we are out of options and should stop retrying +func selectNextProtocol( + connLog *zerolog.Logger, + protocolBackoff *protocolFallback, + selector connection.ProtocolSelector, +) bool { + if protocolBackoff.ReachedMaxRetries() { + fallback, hasFallback := selector.Fallback() if !hasFallback { - return err + return false } // Already using fallback protocol, no point to retry - if protobackoff.protocol == fallback { - return err + if protocolBackoff.protocol == fallback { + return false } - log.Info().Msgf("Fallback to use %s", fallback) - protobackoff.fallback(fallback) - } else if !protobackoff.inFallback { - current := config.ProtocolSelector.Current() - if protobackoff.protocol != current { - protobackoff.protocol = current - config.Log.Info().Msgf("Change protocol to %s", current) + connLog.Info().Msgf("Switching to fallback protocol %s", fallback) + protocolBackoff.fallback(fallback) + } else if !protocolBackoff.inFallback { + current := selector.Current() + if protocolBackoff.protocol != current { + protocolBackoff.protocol = current + connLog.Info().Msgf("Changing protocol to %s", current) } } - return nil + return true } // ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown, @@ -250,11 +250,11 @@ func ServeTunnel( addr *net.TCPAddr, connIndex uint8, fuse *h2mux.BooleanFuse, - backoff *protocallFallback, + backoff *protocolFallback, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, protocol connection.Protocol, - gracefulShutdownC chan struct{}, + gracefulShutdownC <-chan struct{}, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -358,7 +358,7 @@ func ServeH2mux( connectedFuse *connectedFuse, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, - gracefulShutdownC chan struct{}, + gracefulShutdownC <-chan struct{}, ) error { connLog.Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors @@ -404,7 +404,7 @@ func ServeHTTP2( connIndex uint8, connectedFuse connection.ConnectedFuse, reconnectCh chan ReconnectSignal, - gracefulShutdownC chan struct{}, + gracefulShutdownC <-chan struct{}, ) error { connLog.Debug().Msgf("Connecting via http2") h2conn := connection.NewHTTP2Connection( @@ -435,7 +435,7 @@ func ServeHTTP2( return errGroup.Wait() } -func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) error { +func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error { select { case reconnect := <-reconnectCh: return reconnect @@ -448,7 +448,7 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gr type connectedFuse struct { fuse *h2mux.BooleanFuse - backoff *protocallFallback + backoff *protocolFallback } func (cf *connectedFuse) Connected() { diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index 2484b44a..b7e284ab 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -1,8 +1,6 @@ package origin import ( - "context" - "fmt" "testing" "time" @@ -25,13 +23,13 @@ func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { return dmf.percentage, nil } } + func TestWaitForBackoffFallback(t *testing.T) { maxRetries := uint(3) backoff := BackoffHandler{ MaxRetries: maxRetries, BaseTime: time.Millisecond * 10, } - ctx := context.Background() log := zerolog.Nop() resolveTTL := time.Duration(0) namedTunnel := &connection.NamedTunnelConfig{ @@ -50,18 +48,11 @@ func TestWaitForBackoffFallback(t *testing.T) { &log, ) assert.NoError(t, err) - config := &TunnelConfig{ - Log: &log, - LogTransport: &log, - ProtocolSelector: protocolSelector, - Observer: connection.NewObserver(&log, &log, false), - } - connIndex := uint8(1) initProtocol := protocolSelector.Current() assert.Equal(t, connection.HTTP2, initProtocol) - protocallFallback := &protocallFallback{ + protocolFallback := &protocolFallback{ backoff, initProtocol, false, @@ -69,29 +60,33 @@ func TestWaitForBackoffFallback(t *testing.T) { // Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this for i := 0; i < int(maxRetries-1); i++ { - err := waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error")) - assert.NoError(t, err) - assert.Equal(t, initProtocol, protocallFallback.protocol) + protocolFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protocolFallback, protocolSelector) + assert.True(t, ok) + assert.Equal(t, initProtocol, protocolFallback.protocol) } // Retry fallback protocol for i := 0; i < int(maxRetries); i++ { - err := waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error")) - assert.NoError(t, err) + protocolFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protocolFallback, protocolSelector) + assert.True(t, ok) fallback, ok := protocolSelector.Fallback() assert.True(t, ok) - assert.Equal(t, fallback, protocallFallback.protocol) + assert.Equal(t, fallback, protocolFallback.protocol) } currentGlobalProtocol := protocolSelector.Current() assert.Equal(t, initProtocol, currentGlobalProtocol) // No protocol to fallback, return error - err = waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error")) - assert.Error(t, err) + protocolFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protocolFallback, protocolSelector) + assert.False(t, ok) - protocallFallback.reset() - err = waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("new error")) - assert.NoError(t, err) - assert.Equal(t, initProtocol, protocallFallback.protocol) + protocolFallback.reset() + protocolFallback.BackoffTimer() // simulate retry + ok = selectNextProtocol(&log, protocolFallback, protocolSelector) + assert.True(t, ok) + assert.Equal(t, initProtocol, protocolFallback.protocol) }