TUN-3869: Improve reliability of graceful shutdown.

- Don't rely on edge to close connection on graceful shutdown in h2mux, start muxer shutdown from cloudflared.
- Don't retry failed connections after graceful shutdown has started.
- After graceful shutdown channel is closed we stop waiting for retry timer and don't try to restart tunnel loop.
- Use readonly channel for graceful shutdown in functions that only consume the signal
This commit is contained in:
Igor Postelnik 2021-02-04 18:07:49 -06:00
parent dbd90f270e
commit 0b16a473da
6 changed files with 95 additions and 83 deletions

View File

@ -30,7 +30,7 @@ type h2muxConnection struct {
connIndex uint8 connIndex uint8
observer *Observer observer *Observer
gracefulShutdownC chan struct{} gracefulShutdownC <-chan struct{}
stoppedGracefully bool stoppedGracefully bool
// newRPCClientFunc allows us to mock RPCs during testing // newRPCClientFunc allows us to mock RPCs during testing
@ -63,7 +63,7 @@ func NewH2muxConnection(
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
gracefulShutdownC chan struct{}, gracefulShutdownC <-chan struct{},
) (*h2muxConnection, error, bool) { ) (*h2muxConnection, error, bool) {
h := &h2muxConnection{ h := &h2muxConnection{
config: config, config: config,
@ -168,6 +168,7 @@ func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) { func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq) updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
var shutdownCompleted <-chan struct{}
for { for {
select { select {
case <-h.gracefulShutdownC: case <-h.gracefulShutdownC:
@ -176,6 +177,10 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
} }
h.stoppedGracefully = true h.stoppedGracefully = true
h.gracefulShutdownC = nil h.gracefulShutdownC = nil
shutdownCompleted = h.muxer.Shutdown()
case <-shutdownCompleted:
return
case <-ctx.Done(): case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns // UnregisterTunnel blocks until the RPC call returns
@ -183,6 +188,7 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
h.unregister(isNamedTunnel) h.unregister(isNamedTunnel)
} }
h.muxer.Shutdown() h.muxer.Shutdown()
// don't wait for shutdown to finish when context is closed, this is the hard termination path
return return
case <-updateMetricsTickC: case <-updateMetricsTickC:

View File

@ -39,7 +39,7 @@ type http2Connection struct {
activeRequestsWG sync.WaitGroup activeRequestsWG sync.WaitGroup
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
gracefulShutdownC chan struct{} gracefulShutdownC <-chan struct{}
stoppedGracefully bool stoppedGracefully bool
controlStreamErr error // result of running control stream handler controlStreamErr error // result of running control stream handler
} }
@ -52,7 +52,7 @@ func NewHTTP2Connection(
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
connectedFuse ConnectedFuse, connectedFuse ConnectedFuse,
gracefulShutdownC chan struct{}, gracefulShutdownC <-chan struct{},
) *http2Connection { ) *http2Connection {
return &http2Connection{ return &http2Connection{
conn: conn, conn: conn,

View File

@ -257,7 +257,8 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
unregistered: make(chan struct{}), unregistered: make(chan struct{}),
} }
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
http2Conn.gracefulShutdownC = make(chan struct{}) shutdownC := make(chan struct{})
http2Conn.gracefulShutdownC = shutdownC
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
@ -288,7 +289,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
} }
// signal graceful shutdown // signal graceful shutdown
close(http2Conn.gracefulShutdownC) close(shutdownC)
select { select {
case <-rpcClientFactory.unregistered: case <-rpcClientFactory.unregistered:

View File

@ -58,16 +58,18 @@ type Supervisor struct {
useReconnectToken bool useReconnectToken bool
reconnectCh chan ReconnectSignal reconnectCh chan ReconnectSignal
gracefulShutdownC chan struct{} gracefulShutdownC <-chan struct{}
} }
var errEarlyShutdown = errors.New("shutdown started")
type tunnelError struct { type tunnelError struct {
index int index int
addr *net.TCPAddr addr *net.TCPAddr
err error err error
} }
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
cloudflaredUUID, err := uuid.NewRandom() cloudflaredUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
@ -108,6 +110,9 @@ func (s *Supervisor) Run(
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
) error { ) error {
if err := s.initialize(ctx, connectedSignal); err != nil { if err := s.initialize(ctx, connectedSignal); err != nil {
if err == errEarlyShutdown {
return nil
}
return err return err
} }
var tunnelsWaiting []int var tunnelsWaiting []int
@ -130,6 +135,7 @@ func (s *Supervisor) Run(
} }
} }
shuttingDown := false
for { for {
select { select {
// Context cancelled // Context cancelled
@ -143,7 +149,7 @@ func (s *Supervisor) Run(
// (note that this may also be caused by context cancellation) // (note that this may also be caused by context cancellation)
case tunnelError := <-s.tunnelErrors: case tunnelError := <-s.tunnelErrors:
tunnelsActive-- tunnelsActive--
if tunnelError.err != nil { if tunnelError.err != nil && !shuttingDown {
s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated") s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index) s.waitForNextTunnel(tunnelError.index)
@ -181,6 +187,8 @@ func (s *Supervisor) Run(
// No more tunnels outstanding, clear backoff timer // No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod() backoff.SetGracePeriod()
} }
case <-s.gracefulShutdownC:
shuttingDown = true
} }
} }
} }
@ -203,6 +211,8 @@ func (s *Supervisor) initialize(
return ctx.Err() return ctx.Err()
case tunnelError := <-s.tunnelErrors: case tunnelError := <-s.tunnelErrors:
return tunnelError.err return tunnelError.err
case <-s.gracefulShutdownC:
return errEarlyShutdown
case <-connectedSignal.Wait(): case <-connectedSignal.Wait():
} }
// At least one successful connection, so start the rest // At least one successful connection, so start the rest

View File

@ -113,7 +113,7 @@ func StartTunnelDaemon(
config *TunnelConfig, config *TunnelConfig,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
graceShutdownC chan struct{}, graceShutdownC <-chan struct{},
) error { ) error {
s, err := NewSupervisor(config, reconnectCh, graceShutdownC) s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
if err != nil { if err != nil {
@ -131,14 +131,14 @@ func ServeTunnelLoop(
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{}, gracefulShutdownC <-chan struct{},
) error { ) error {
haConnections.Inc() haConnections.Inc()
defer haConnections.Dec() defer haConnections.Dec()
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger() connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
protocallFallback := &protocallFallback{ protocolFallback := &protocolFallback{
BackoffHandler{MaxRetries: config.Retries}, BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(), config.ProtocolSelector.Current(),
false, false,
@ -162,82 +162,82 @@ func ServeTunnelLoop(
addr, addr,
connIndex, connIndex,
connectedFuse, connectedFuse,
protocallFallback, protocolFallback,
cloudflaredUUID, cloudflaredUUID,
reconnectCh, reconnectCh,
protocallFallback.protocol, protocolFallback.protocol,
gracefulShutdownC, gracefulShutdownC,
) )
if !recoverable { if !recoverable {
return err return err
} }
err = waitForBackoff(ctx, &connLog, protocallFallback, config, connIndex, err) config.Observer.SendReconnect(connIndex)
if err != nil {
duration, ok := protocolFallback.GetBackoffDuration(ctx)
if !ok {
return err
}
connLog.Info().Msgf("Retrying connection in %s seconds", duration)
select {
case <-ctx.Done():
return ctx.Err()
case <-gracefulShutdownC:
return nil
case <-protocolFallback.BackoffTimer():
if !selectNextProtocol(&connLog, protocolFallback, config.ProtocolSelector) {
return err return err
} }
} }
} }
}
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches // protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries // max retries
type protocallFallback struct { type protocolFallback struct {
BackoffHandler BackoffHandler
protocol connection.Protocol protocol connection.Protocol
inFallback bool inFallback bool
} }
func (pf *protocallFallback) reset() { func (pf *protocolFallback) reset() {
pf.resetNow() pf.resetNow()
pf.inFallback = false pf.inFallback = false
} }
func (pf *protocallFallback) fallback(fallback connection.Protocol) { func (pf *protocolFallback) fallback(fallback connection.Protocol) {
pf.resetNow() pf.resetNow()
pf.protocol = fallback pf.protocol = fallback
pf.inFallback = true pf.inFallback = true
} }
// Expect err to always be non nil // selectNextProtocol picks connection protocol for the next retry iteration,
func waitForBackoff( // returns true if it was able to pick the protocol, false if we are out of options and should stop retrying
ctx context.Context, func selectNextProtocol(
log *zerolog.Logger, connLog *zerolog.Logger,
protobackoff *protocallFallback, protocolBackoff *protocolFallback,
config *TunnelConfig, selector connection.ProtocolSelector,
connIndex uint8, ) bool {
err error, if protocolBackoff.ReachedMaxRetries() {
) error { fallback, hasFallback := selector.Fallback()
duration, ok := protobackoff.GetBackoffDuration(ctx)
if !ok {
return err
}
config.Observer.SendReconnect(connIndex)
log.Info().
Err(err).
Uint8(connection.LogFieldConnIndex, connIndex).
Msgf("Retrying connection in %s seconds", duration)
protobackoff.Backoff(ctx)
if protobackoff.ReachedMaxRetries() {
fallback, hasFallback := config.ProtocolSelector.Fallback()
if !hasFallback { if !hasFallback {
return err return false
} }
// Already using fallback protocol, no point to retry // Already using fallback protocol, no point to retry
if protobackoff.protocol == fallback { if protocolBackoff.protocol == fallback {
return err return false
} }
log.Info().Msgf("Fallback to use %s", fallback) connLog.Info().Msgf("Switching to fallback protocol %s", fallback)
protobackoff.fallback(fallback) protocolBackoff.fallback(fallback)
} else if !protobackoff.inFallback { } else if !protocolBackoff.inFallback {
current := config.ProtocolSelector.Current() current := selector.Current()
if protobackoff.protocol != current { if protocolBackoff.protocol != current {
protobackoff.protocol = current protocolBackoff.protocol = current
config.Log.Info().Msgf("Change protocol to %s", current) connLog.Info().Msgf("Changing protocol to %s", current)
} }
} }
return nil return true
} }
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown, // ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
@ -250,11 +250,11 @@ func ServeTunnel(
addr *net.TCPAddr, addr *net.TCPAddr,
connIndex uint8, connIndex uint8,
fuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
backoff *protocallFallback, backoff *protocolFallback,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
protocol connection.Protocol, protocol connection.Protocol,
gracefulShutdownC chan struct{}, gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -358,7 +358,7 @@ func ServeH2mux(
connectedFuse *connectedFuse, connectedFuse *connectedFuse,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{}, gracefulShutdownC <-chan struct{},
) error { ) error {
connLog.Debug().Msgf("Connecting via h2mux") connLog.Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
@ -404,7 +404,7 @@ func ServeHTTP2(
connIndex uint8, connIndex uint8,
connectedFuse connection.ConnectedFuse, connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{}, gracefulShutdownC <-chan struct{},
) error { ) error {
connLog.Debug().Msgf("Connecting via http2") connLog.Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection( h2conn := connection.NewHTTP2Connection(
@ -435,7 +435,7 @@ func ServeHTTP2(
return errGroup.Wait() return errGroup.Wait()
} }
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) error { func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
select { select {
case reconnect := <-reconnectCh: case reconnect := <-reconnectCh:
return reconnect return reconnect
@ -448,7 +448,7 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gr
type connectedFuse struct { type connectedFuse struct {
fuse *h2mux.BooleanFuse fuse *h2mux.BooleanFuse
backoff *protocallFallback backoff *protocolFallback
} }
func (cf *connectedFuse) Connected() { func (cf *connectedFuse) Connected() {

View File

@ -1,8 +1,6 @@
package origin package origin
import ( import (
"context"
"fmt"
"testing" "testing"
"time" "time"
@ -25,13 +23,13 @@ func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
return dmf.percentage, nil return dmf.percentage, nil
} }
} }
func TestWaitForBackoffFallback(t *testing.T) { func TestWaitForBackoffFallback(t *testing.T) {
maxRetries := uint(3) maxRetries := uint(3)
backoff := BackoffHandler{ backoff := BackoffHandler{
MaxRetries: maxRetries, MaxRetries: maxRetries,
BaseTime: time.Millisecond * 10, BaseTime: time.Millisecond * 10,
} }
ctx := context.Background()
log := zerolog.Nop() log := zerolog.Nop()
resolveTTL := time.Duration(0) resolveTTL := time.Duration(0)
namedTunnel := &connection.NamedTunnelConfig{ namedTunnel := &connection.NamedTunnelConfig{
@ -50,18 +48,11 @@ func TestWaitForBackoffFallback(t *testing.T) {
&log, &log,
) )
assert.NoError(t, err) assert.NoError(t, err)
config := &TunnelConfig{
Log: &log,
LogTransport: &log,
ProtocolSelector: protocolSelector,
Observer: connection.NewObserver(&log, &log, false),
}
connIndex := uint8(1)
initProtocol := protocolSelector.Current() initProtocol := protocolSelector.Current()
assert.Equal(t, connection.HTTP2, initProtocol) assert.Equal(t, connection.HTTP2, initProtocol)
protocallFallback := &protocallFallback{ protocolFallback := &protocolFallback{
backoff, backoff,
initProtocol, initProtocol,
false, false,
@ -69,29 +60,33 @@ func TestWaitForBackoffFallback(t *testing.T) {
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this // Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
for i := 0; i < int(maxRetries-1); i++ { for i := 0; i < int(maxRetries-1); i++ {
err := waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error")) protocolFallback.BackoffTimer() // simulate retry
assert.NoError(t, err) ok := selectNextProtocol(&log, protocolFallback, protocolSelector)
assert.Equal(t, initProtocol, protocallFallback.protocol) assert.True(t, ok)
assert.Equal(t, initProtocol, protocolFallback.protocol)
} }
// Retry fallback protocol // Retry fallback protocol
for i := 0; i < int(maxRetries); i++ { for i := 0; i < int(maxRetries); i++ {
err := waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error")) protocolFallback.BackoffTimer() // simulate retry
assert.NoError(t, err) ok := selectNextProtocol(&log, protocolFallback, protocolSelector)
assert.True(t, ok)
fallback, ok := protocolSelector.Fallback() fallback, ok := protocolSelector.Fallback()
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, fallback, protocallFallback.protocol) assert.Equal(t, fallback, protocolFallback.protocol)
} }
currentGlobalProtocol := protocolSelector.Current() currentGlobalProtocol := protocolSelector.Current()
assert.Equal(t, initProtocol, currentGlobalProtocol) assert.Equal(t, initProtocol, currentGlobalProtocol)
// No protocol to fallback, return error // No protocol to fallback, return error
err = waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error")) protocolFallback.BackoffTimer() // simulate retry
assert.Error(t, err) ok := selectNextProtocol(&log, protocolFallback, protocolSelector)
assert.False(t, ok)
protocallFallback.reset() protocolFallback.reset()
err = waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("new error")) protocolFallback.BackoffTimer() // simulate retry
assert.NoError(t, err) ok = selectNextProtocol(&log, protocolFallback, protocolSelector)
assert.Equal(t, initProtocol, protocallFallback.protocol) assert.True(t, ok)
assert.Equal(t, initProtocol, protocolFallback.protocol)
} }