TUN-3869: Improve reliability of graceful shutdown.
- Don't rely on edge to close connection on graceful shutdown in h2mux, start muxer shutdown from cloudflared. - Don't retry failed connections after graceful shutdown has started. - After graceful shutdown channel is closed we stop waiting for retry timer and don't try to restart tunnel loop. - Use readonly channel for graceful shutdown in functions that only consume the signal
This commit is contained in:
parent
dbd90f270e
commit
0b16a473da
|
@ -30,7 +30,7 @@ type h2muxConnection struct {
|
||||||
connIndex uint8
|
connIndex uint8
|
||||||
|
|
||||||
observer *Observer
|
observer *Observer
|
||||||
gracefulShutdownC chan struct{}
|
gracefulShutdownC <-chan struct{}
|
||||||
stoppedGracefully bool
|
stoppedGracefully bool
|
||||||
|
|
||||||
// newRPCClientFunc allows us to mock RPCs during testing
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
|
@ -63,7 +63,7 @@ func NewH2muxConnection(
|
||||||
edgeConn net.Conn,
|
edgeConn net.Conn,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
gracefulShutdownC chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) (*h2muxConnection, error, bool) {
|
) (*h2muxConnection, error, bool) {
|
||||||
h := &h2muxConnection{
|
h := &h2muxConnection{
|
||||||
config: config,
|
config: config,
|
||||||
|
@ -168,6 +168,7 @@ func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
|
||||||
|
|
||||||
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
|
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
|
||||||
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
|
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
|
||||||
|
var shutdownCompleted <-chan struct{}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-h.gracefulShutdownC:
|
case <-h.gracefulShutdownC:
|
||||||
|
@ -176,6 +177,10 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
|
||||||
}
|
}
|
||||||
h.stoppedGracefully = true
|
h.stoppedGracefully = true
|
||||||
h.gracefulShutdownC = nil
|
h.gracefulShutdownC = nil
|
||||||
|
shutdownCompleted = h.muxer.Shutdown()
|
||||||
|
|
||||||
|
case <-shutdownCompleted:
|
||||||
|
return
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// UnregisterTunnel blocks until the RPC call returns
|
// UnregisterTunnel blocks until the RPC call returns
|
||||||
|
@ -183,6 +188,7 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
|
||||||
h.unregister(isNamedTunnel)
|
h.unregister(isNamedTunnel)
|
||||||
}
|
}
|
||||||
h.muxer.Shutdown()
|
h.muxer.Shutdown()
|
||||||
|
// don't wait for shutdown to finish when context is closed, this is the hard termination path
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-updateMetricsTickC:
|
case <-updateMetricsTickC:
|
||||||
|
|
|
@ -39,7 +39,7 @@ type http2Connection struct {
|
||||||
|
|
||||||
activeRequestsWG sync.WaitGroup
|
activeRequestsWG sync.WaitGroup
|
||||||
connectedFuse ConnectedFuse
|
connectedFuse ConnectedFuse
|
||||||
gracefulShutdownC chan struct{}
|
gracefulShutdownC <-chan struct{}
|
||||||
stoppedGracefully bool
|
stoppedGracefully bool
|
||||||
controlStreamErr error // result of running control stream handler
|
controlStreamErr error // result of running control stream handler
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func NewHTTP2Connection(
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse ConnectedFuse,
|
connectedFuse ConnectedFuse,
|
||||||
gracefulShutdownC chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) *http2Connection {
|
) *http2Connection {
|
||||||
return &http2Connection{
|
return &http2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
|
|
@ -257,7 +257,8 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
unregistered: make(chan struct{}),
|
unregistered: make(chan struct{}),
|
||||||
}
|
}
|
||||||
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
||||||
http2Conn.gracefulShutdownC = make(chan struct{})
|
shutdownC := make(chan struct{})
|
||||||
|
http2Conn.gracefulShutdownC = shutdownC
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
@ -288,7 +289,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// signal graceful shutdown
|
// signal graceful shutdown
|
||||||
close(http2Conn.gracefulShutdownC)
|
close(shutdownC)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-rpcClientFactory.unregistered:
|
case <-rpcClientFactory.unregistered:
|
||||||
|
|
|
@ -58,16 +58,18 @@ type Supervisor struct {
|
||||||
useReconnectToken bool
|
useReconnectToken bool
|
||||||
|
|
||||||
reconnectCh chan ReconnectSignal
|
reconnectCh chan ReconnectSignal
|
||||||
gracefulShutdownC chan struct{}
|
gracefulShutdownC <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errEarlyShutdown = errors.New("shutdown started")
|
||||||
|
|
||||||
type tunnelError struct {
|
type tunnelError struct {
|
||||||
index int
|
index int
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) {
|
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
|
||||||
cloudflaredUUID, err := uuid.NewRandom()
|
cloudflaredUUID, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
|
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
|
||||||
|
@ -108,6 +110,9 @@ func (s *Supervisor) Run(
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
) error {
|
) error {
|
||||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
|
if err == errEarlyShutdown {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
|
@ -130,6 +135,7 @@ func (s *Supervisor) Run(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shuttingDown := false
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// Context cancelled
|
// Context cancelled
|
||||||
|
@ -143,7 +149,7 @@ func (s *Supervisor) Run(
|
||||||
// (note that this may also be caused by context cancellation)
|
// (note that this may also be caused by context cancellation)
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
tunnelsActive--
|
tunnelsActive--
|
||||||
if tunnelError.err != nil {
|
if tunnelError.err != nil && !shuttingDown {
|
||||||
s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
||||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
s.waitForNextTunnel(tunnelError.index)
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
|
@ -181,6 +187,8 @@ func (s *Supervisor) Run(
|
||||||
// No more tunnels outstanding, clear backoff timer
|
// No more tunnels outstanding, clear backoff timer
|
||||||
backoff.SetGracePeriod()
|
backoff.SetGracePeriod()
|
||||||
}
|
}
|
||||||
|
case <-s.gracefulShutdownC:
|
||||||
|
shuttingDown = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -203,6 +211,8 @@ func (s *Supervisor) initialize(
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
return tunnelError.err
|
return tunnelError.err
|
||||||
|
case <-s.gracefulShutdownC:
|
||||||
|
return errEarlyShutdown
|
||||||
case <-connectedSignal.Wait():
|
case <-connectedSignal.Wait():
|
||||||
}
|
}
|
||||||
// At least one successful connection, so start the rest
|
// At least one successful connection, so start the rest
|
||||||
|
|
102
origin/tunnel.go
102
origin/tunnel.go
|
@ -113,7 +113,7 @@ func StartTunnelDaemon(
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
graceShutdownC chan struct{},
|
graceShutdownC <-chan struct{},
|
||||||
) error {
|
) error {
|
||||||
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
|
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -131,14 +131,14 @@ func ServeTunnelLoop(
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
gracefulShutdownC chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) error {
|
) error {
|
||||||
haConnections.Inc()
|
haConnections.Inc()
|
||||||
defer haConnections.Dec()
|
defer haConnections.Dec()
|
||||||
|
|
||||||
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
|
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
|
||||||
|
|
||||||
protocallFallback := &protocallFallback{
|
protocolFallback := &protocolFallback{
|
||||||
BackoffHandler{MaxRetries: config.Retries},
|
BackoffHandler{MaxRetries: config.Retries},
|
||||||
config.ProtocolSelector.Current(),
|
config.ProtocolSelector.Current(),
|
||||||
false,
|
false,
|
||||||
|
@ -162,82 +162,82 @@ func ServeTunnelLoop(
|
||||||
addr,
|
addr,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
protocallFallback,
|
protocolFallback,
|
||||||
cloudflaredUUID,
|
cloudflaredUUID,
|
||||||
reconnectCh,
|
reconnectCh,
|
||||||
protocallFallback.protocol,
|
protocolFallback.protocol,
|
||||||
gracefulShutdownC,
|
gracefulShutdownC,
|
||||||
)
|
)
|
||||||
if !recoverable {
|
if !recoverable {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = waitForBackoff(ctx, &connLog, protocallFallback, config, connIndex, err)
|
config.Observer.SendReconnect(connIndex)
|
||||||
if err != nil {
|
|
||||||
|
duration, ok := protocolFallback.GetBackoffDuration(ctx)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
connLog.Info().Msgf("Retrying connection in %s seconds", duration)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-gracefulShutdownC:
|
||||||
|
return nil
|
||||||
|
case <-protocolFallback.BackoffTimer():
|
||||||
|
if !selectNextProtocol(&connLog, protocolFallback, config.ProtocolSelector) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
|
// protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
|
||||||
// max retries
|
// max retries
|
||||||
type protocallFallback struct {
|
type protocolFallback struct {
|
||||||
BackoffHandler
|
BackoffHandler
|
||||||
protocol connection.Protocol
|
protocol connection.Protocol
|
||||||
inFallback bool
|
inFallback bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pf *protocallFallback) reset() {
|
func (pf *protocolFallback) reset() {
|
||||||
pf.resetNow()
|
pf.resetNow()
|
||||||
pf.inFallback = false
|
pf.inFallback = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
|
func (pf *protocolFallback) fallback(fallback connection.Protocol) {
|
||||||
pf.resetNow()
|
pf.resetNow()
|
||||||
pf.protocol = fallback
|
pf.protocol = fallback
|
||||||
pf.inFallback = true
|
pf.inFallback = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expect err to always be non nil
|
// selectNextProtocol picks connection protocol for the next retry iteration,
|
||||||
func waitForBackoff(
|
// returns true if it was able to pick the protocol, false if we are out of options and should stop retrying
|
||||||
ctx context.Context,
|
func selectNextProtocol(
|
||||||
log *zerolog.Logger,
|
connLog *zerolog.Logger,
|
||||||
protobackoff *protocallFallback,
|
protocolBackoff *protocolFallback,
|
||||||
config *TunnelConfig,
|
selector connection.ProtocolSelector,
|
||||||
connIndex uint8,
|
) bool {
|
||||||
err error,
|
if protocolBackoff.ReachedMaxRetries() {
|
||||||
) error {
|
fallback, hasFallback := selector.Fallback()
|
||||||
duration, ok := protobackoff.GetBackoffDuration(ctx)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Observer.SendReconnect(connIndex)
|
|
||||||
log.Info().
|
|
||||||
Err(err).
|
|
||||||
Uint8(connection.LogFieldConnIndex, connIndex).
|
|
||||||
Msgf("Retrying connection in %s seconds", duration)
|
|
||||||
protobackoff.Backoff(ctx)
|
|
||||||
|
|
||||||
if protobackoff.ReachedMaxRetries() {
|
|
||||||
fallback, hasFallback := config.ProtocolSelector.Fallback()
|
|
||||||
if !hasFallback {
|
if !hasFallback {
|
||||||
return err
|
return false
|
||||||
}
|
}
|
||||||
// Already using fallback protocol, no point to retry
|
// Already using fallback protocol, no point to retry
|
||||||
if protobackoff.protocol == fallback {
|
if protocolBackoff.protocol == fallback {
|
||||||
return err
|
return false
|
||||||
}
|
}
|
||||||
log.Info().Msgf("Fallback to use %s", fallback)
|
connLog.Info().Msgf("Switching to fallback protocol %s", fallback)
|
||||||
protobackoff.fallback(fallback)
|
protocolBackoff.fallback(fallback)
|
||||||
} else if !protobackoff.inFallback {
|
} else if !protocolBackoff.inFallback {
|
||||||
current := config.ProtocolSelector.Current()
|
current := selector.Current()
|
||||||
if protobackoff.protocol != current {
|
if protocolBackoff.protocol != current {
|
||||||
protobackoff.protocol = current
|
protocolBackoff.protocol = current
|
||||||
config.Log.Info().Msgf("Change protocol to %s", current)
|
connLog.Info().Msgf("Changing protocol to %s", current)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
|
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
|
||||||
|
@ -250,11 +250,11 @@ func ServeTunnel(
|
||||||
addr *net.TCPAddr,
|
addr *net.TCPAddr,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
fuse *h2mux.BooleanFuse,
|
fuse *h2mux.BooleanFuse,
|
||||||
backoff *protocallFallback,
|
backoff *protocolFallback,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
protocol connection.Protocol,
|
protocol connection.Protocol,
|
||||||
gracefulShutdownC chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -358,7 +358,7 @@ func ServeH2mux(
|
||||||
connectedFuse *connectedFuse,
|
connectedFuse *connectedFuse,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
gracefulShutdownC chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) error {
|
) error {
|
||||||
connLog.Debug().Msgf("Connecting via h2mux")
|
connLog.Debug().Msgf("Connecting via h2mux")
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
|
@ -404,7 +404,7 @@ func ServeHTTP2(
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse connection.ConnectedFuse,
|
connectedFuse connection.ConnectedFuse,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
gracefulShutdownC chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) error {
|
) error {
|
||||||
connLog.Debug().Msgf("Connecting via http2")
|
connLog.Debug().Msgf("Connecting via http2")
|
||||||
h2conn := connection.NewHTTP2Connection(
|
h2conn := connection.NewHTTP2Connection(
|
||||||
|
@ -435,7 +435,7 @@ func ServeHTTP2(
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) error {
|
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
|
||||||
select {
|
select {
|
||||||
case reconnect := <-reconnectCh:
|
case reconnect := <-reconnectCh:
|
||||||
return reconnect
|
return reconnect
|
||||||
|
@ -448,7 +448,7 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gr
|
||||||
|
|
||||||
type connectedFuse struct {
|
type connectedFuse struct {
|
||||||
fuse *h2mux.BooleanFuse
|
fuse *h2mux.BooleanFuse
|
||||||
backoff *protocallFallback
|
backoff *protocolFallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cf *connectedFuse) Connected() {
|
func (cf *connectedFuse) Connected() {
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package origin
|
package origin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,13 +23,13 @@ func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
|
||||||
return dmf.percentage, nil
|
return dmf.percentage, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWaitForBackoffFallback(t *testing.T) {
|
func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
maxRetries := uint(3)
|
maxRetries := uint(3)
|
||||||
backoff := BackoffHandler{
|
backoff := BackoffHandler{
|
||||||
MaxRetries: maxRetries,
|
MaxRetries: maxRetries,
|
||||||
BaseTime: time.Millisecond * 10,
|
BaseTime: time.Millisecond * 10,
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
resolveTTL := time.Duration(0)
|
resolveTTL := time.Duration(0)
|
||||||
namedTunnel := &connection.NamedTunnelConfig{
|
namedTunnel := &connection.NamedTunnelConfig{
|
||||||
|
@ -50,18 +48,11 @@ func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
&log,
|
&log,
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
config := &TunnelConfig{
|
|
||||||
Log: &log,
|
|
||||||
LogTransport: &log,
|
|
||||||
ProtocolSelector: protocolSelector,
|
|
||||||
Observer: connection.NewObserver(&log, &log, false),
|
|
||||||
}
|
|
||||||
connIndex := uint8(1)
|
|
||||||
|
|
||||||
initProtocol := protocolSelector.Current()
|
initProtocol := protocolSelector.Current()
|
||||||
assert.Equal(t, connection.HTTP2, initProtocol)
|
assert.Equal(t, connection.HTTP2, initProtocol)
|
||||||
|
|
||||||
protocallFallback := &protocallFallback{
|
protocolFallback := &protocolFallback{
|
||||||
backoff,
|
backoff,
|
||||||
initProtocol,
|
initProtocol,
|
||||||
false,
|
false,
|
||||||
|
@ -69,29 +60,33 @@ func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
|
|
||||||
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
|
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
|
||||||
for i := 0; i < int(maxRetries-1); i++ {
|
for i := 0; i < int(maxRetries-1); i++ {
|
||||||
err := waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error"))
|
protocolFallback.BackoffTimer() // simulate retry
|
||||||
assert.NoError(t, err)
|
ok := selectNextProtocol(&log, protocolFallback, protocolSelector)
|
||||||
assert.Equal(t, initProtocol, protocallFallback.protocol)
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, initProtocol, protocolFallback.protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retry fallback protocol
|
// Retry fallback protocol
|
||||||
for i := 0; i < int(maxRetries); i++ {
|
for i := 0; i < int(maxRetries); i++ {
|
||||||
err := waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error"))
|
protocolFallback.BackoffTimer() // simulate retry
|
||||||
assert.NoError(t, err)
|
ok := selectNextProtocol(&log, protocolFallback, protocolSelector)
|
||||||
|
assert.True(t, ok)
|
||||||
fallback, ok := protocolSelector.Fallback()
|
fallback, ok := protocolSelector.Fallback()
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, fallback, protocallFallback.protocol)
|
assert.Equal(t, fallback, protocolFallback.protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
currentGlobalProtocol := protocolSelector.Current()
|
currentGlobalProtocol := protocolSelector.Current()
|
||||||
assert.Equal(t, initProtocol, currentGlobalProtocol)
|
assert.Equal(t, initProtocol, currentGlobalProtocol)
|
||||||
|
|
||||||
// No protocol to fallback, return error
|
// No protocol to fallback, return error
|
||||||
err = waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("some error"))
|
protocolFallback.BackoffTimer() // simulate retry
|
||||||
assert.Error(t, err)
|
ok := selectNextProtocol(&log, protocolFallback, protocolSelector)
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
protocallFallback.reset()
|
protocolFallback.reset()
|
||||||
err = waitForBackoff(ctx, &log, protocallFallback, config, connIndex, fmt.Errorf("new error"))
|
protocolFallback.BackoffTimer() // simulate retry
|
||||||
assert.NoError(t, err)
|
ok = selectNextProtocol(&log, protocolFallback, protocolSelector)
|
||||||
assert.Equal(t, initProtocol, protocallFallback.protocol)
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, initProtocol, protocolFallback.protocol)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue