TUN-8427: Fix BackoffHandler's internally shared clock structure

A clock structure was used to help support unit testing timetravel
but it is a globally shared object and is likely unsafe to share
across tests. Reordering of the tests seemed to have intermittent
failures for the TestWaitForBackoffFallback specifically on windows
builds.

Adjusting this to be a shim inside the BackoffHandler struct should
resolve shared object overrides in unit testing.

Additionally, added the reset retries functionality to be inline with
the ResetNow function of the BackoffHandler to align better with
expected functionality of the method.

Removes unused reconnectCredentialManager.
This commit is contained in:
Devin Carr 2024-05-23 09:48:34 -07:00
parent 2db00211f5
commit 8184bc457d
8 changed files with 83 additions and 335 deletions

View File

@ -6,17 +6,16 @@ import (
"time"
)
const (
DefaultBaseTime time.Duration = time.Second
)
// Redeclare time functions so they can be overridden in tests.
type clock struct {
type Clock struct {
Now func() time.Time
After func(d time.Duration) <-chan time.Time
}
var Clock = clock{
Now: time.Now,
After: time.After,
}
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
// The base time period is 1 second, doubling with each retry.
// After initial success, a grace period can be set to reset the backoff timer if
@ -25,15 +24,26 @@ var Clock = clock{
type BackoffHandler struct {
// MaxRetries sets the maximum number of retries to perform. The default value
// of 0 disables retry completely.
MaxRetries uint
maxRetries uint
// RetryForever caps the exponential backoff period according to MaxRetries
// but allows you to retry indefinitely.
RetryForever bool
retryForever bool
// BaseTime sets the initial backoff period.
BaseTime time.Duration
baseTime time.Duration
retries uint
resetDeadline time.Time
Clock Clock
}
func NewBackoff(maxRetries uint, baseTime time.Duration, retryForever bool) BackoffHandler {
return BackoffHandler{
maxRetries: maxRetries,
baseTime: baseTime,
retryForever: retryForever,
Clock: Clock{Now: time.Now, After: time.After},
}
}
func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duration, bool) {
@ -44,11 +54,11 @@ func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duratio
return time.Duration(0), false
default:
}
if !b.resetDeadline.IsZero() && Clock.Now().After(b.resetDeadline) {
if !b.resetDeadline.IsZero() && b.Clock.Now().After(b.resetDeadline) {
// b.retries would be set to 0 at this point
return time.Second, true
}
if b.retries >= b.MaxRetries && !b.RetryForever {
if b.retries >= b.maxRetries && !b.retryForever {
return time.Duration(0), false
}
maxTimeToWait := b.GetBaseTime() * 1 << (b.retries + 1)
@ -58,12 +68,12 @@ func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duratio
// BackoffTimer returns a channel that sends the current time when the exponential backoff timeout expires.
// Returns nil if the maximum number of retries have been used.
func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
if !b.resetDeadline.IsZero() && Clock.Now().After(b.resetDeadline) {
if !b.resetDeadline.IsZero() && b.Clock.Now().After(b.resetDeadline) {
b.retries = 0
b.resetDeadline = time.Time{}
}
if b.retries >= b.MaxRetries {
if !b.RetryForever {
if b.retries >= b.maxRetries {
if !b.retryForever {
return nil
}
} else {
@ -71,7 +81,7 @@ func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
}
maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries))
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
return Clock.After(timeToWait)
return b.Clock.After(timeToWait)
}
// Backoff is used to wait according to exponential backoff. Returns false if the
@ -94,16 +104,16 @@ func (b *BackoffHandler) Backoff(ctx context.Context) bool {
func (b *BackoffHandler) SetGracePeriod() time.Duration {
maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1)
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
b.resetDeadline = Clock.Now().Add(timeToWait)
b.resetDeadline = b.Clock.Now().Add(timeToWait)
return timeToWait
}
func (b BackoffHandler) GetBaseTime() time.Duration {
if b.BaseTime == 0 {
return time.Second
if b.baseTime == 0 {
return DefaultBaseTime
}
return b.BaseTime
return b.baseTime
}
// Retries returns the number of retries consumed so far.
@ -112,9 +122,10 @@ func (b *BackoffHandler) Retries() int {
}
func (b *BackoffHandler) ReachedMaxRetries() bool {
return b.retries == b.MaxRetries
return b.retries == b.maxRetries
}
func (b *BackoffHandler) ResetNow() {
b.resetDeadline = time.Now()
b.resetDeadline = b.Clock.Now()
b.retries = 0
}

View File

@ -13,10 +13,9 @@ func immediateTimeAfter(time.Duration) <-chan time.Time {
}
func TestBackoffRetries(t *testing.T) {
// make backoff return immediately
Clock.After = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
// make backoff return immediately
backoff := BackoffHandler{maxRetries: 3, Clock: Clock{time.Now, immediateTimeAfter}}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed immediately")
}
@ -32,10 +31,10 @@ func TestBackoffRetries(t *testing.T) {
}
func TestBackoffCancel(t *testing.T) {
// prevent backoff from returning normally
Clock.After = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
ctx, cancelFunc := context.WithCancel(context.Background())
backoff := BackoffHandler{MaxRetries: 3}
// prevent backoff from returning normally
after := func(time.Duration) <-chan time.Time { return make(chan time.Time) }
backoff := BackoffHandler{maxRetries: 3, Clock: Clock{time.Now, after}}
cancelFunc()
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after cancel")
@ -46,13 +45,12 @@ func TestBackoffCancel(t *testing.T) {
}
func TestBackoffGracePeriod(t *testing.T) {
ctx := context.Background()
currentTime := time.Now()
// make Clock.Now return whatever we like
Clock.Now = func() time.Time { return currentTime }
now := func() time.Time { return currentTime }
// make backoff return immediately
Clock.After = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 1}
backoff := BackoffHandler{maxRetries: 1, Clock: Clock{now, immediateTimeAfter}}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed immediately")
}
@ -70,10 +68,9 @@ func TestBackoffGracePeriod(t *testing.T) {
}
func TestGetMaxBackoffDurationRetries(t *testing.T) {
// make backoff return immediately
Clock.After = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
// make backoff return immediately
backoff := BackoffHandler{maxRetries: 3, Clock: Clock{time.Now, immediateTimeAfter}}
if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed immediately")
}
@ -95,10 +92,9 @@ func TestGetMaxBackoffDurationRetries(t *testing.T) {
}
func TestGetMaxBackoffDuration(t *testing.T) {
// make backoff return immediately
Clock.After = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
// make backoff return immediately
backoff := BackoffHandler{maxRetries: 3, Clock: Clock{time.Now, immediateTimeAfter}}
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
t.Fatalf("backoff (%s) didn't return < 2 seconds on first retry", duration)
}
@ -117,10 +113,9 @@ func TestGetMaxBackoffDuration(t *testing.T) {
}
func TestBackoffRetryForever(t *testing.T) {
// make backoff return immediately
Clock.After = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3, RetryForever: true}
// make backoff return immediately
backoff := BackoffHandler{maxRetries: 3, retryForever: true, Clock: Clock{time.Now, immediateTimeAfter}}
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
t.Fatalf("backoff (%s) didn't return < 2 seconds on first retry", duration)
}

View File

@ -1,138 +0,0 @@
package supervisor
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/retry"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
errJWTUnset = errors.New("JWT unset")
)
// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
// get/set parameters for ReconnectTunnel RPC calls.
type reconnectCredentialManager struct {
mu sync.RWMutex
jwt []byte
eventDigest map[uint8][]byte
connDigest map[uint8][]byte
authSuccess prometheus.Counter
authFail *prometheus.CounterVec
}
func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
authSuccess := prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "tunnel_authenticate_success",
Help: "Count of successful tunnel authenticate",
},
)
authFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "tunnel_authenticate_fail",
Help: "Count of tunnel authenticate errors by type",
},
[]string{"error"},
)
prometheus.MustRegister(authSuccess, authFail)
return &reconnectCredentialManager{
eventDigest: make(map[uint8][]byte, haConnections),
connDigest: make(map[uint8][]byte, haConnections),
authSuccess: authSuccess,
authFail: authFail,
}
}
func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
if cm.jwt == nil {
return nil, errJWTUnset
}
return cm.jwt, nil
}
func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.jwt = jwt
}
func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
digest, ok := cm.eventDigest[connID]
if !ok {
return nil, fmt.Errorf("no event digest for connection %v", connID)
}
return digest, nil
}
func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.eventDigest[connID] = digest
}
func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
digest, ok := cm.connDigest[connID]
if !ok {
return nil, fmt.Errorf("no connection digest for connection %v", connID)
}
return digest, nil
}
func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.connDigest[connID] = digest
}
func (cm *reconnectCredentialManager) RefreshAuth(
ctx context.Context,
backoff *retry.BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
cm.authFail.WithLabelValues(err.Error()).Inc()
if _, ok := backoff.GetMaxBackoffDuration(ctx); ok {
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
cm.SetReconnectToken(outcome.JWT())
cm.authSuccess.Inc()
return retry.Clock.After(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
duration := outcome.RefreshAfter()
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return retry.Clock.After(duration), nil
case tunnelpogs.AuthFail:
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return nil, outcome
default:
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
cm.authFail.WithLabelValues(err.Error()).Inc()
return nil, err
}
}

View File

@ -1,120 +0,0 @@
package supervisor
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/retry"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
func TestRefreshAuthBackoff(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration
retry.Clock.After = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure")
}
// authentication failures should consume the backoff
for i := uint(0); i < backoff.MaxRetries; i++ {
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
require.NoError(t, err)
require.NotNil(t, retryChan)
require.Greater(t, wait.Seconds(), 0.0)
require.Less(t, wait.Seconds(), float64((1<<(i+1))*time.Second))
}
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
require.Error(t, err)
require.Nil(t, retryChan)
// now we actually make contact with the remote server
_, _ = rcm.RefreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
})
// The backoff timer should have been reset. To confirm this, make timeNow
// return a value after the backoff timer's grace period
retry.Clock.Now = func() time.Time {
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
return time.Now().Add(expectedGracePeriod * 2)
}
_, ok := backoff.GetMaxBackoffDuration(context.Background())
require.True(t, ok)
}
func TestRefreshAuthSuccess(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration
retry.Clock.After = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
}
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := rcm.ReconnectToken()
assert.NoError(t, err)
assert.Equal(t, []byte("jwt"), token)
}
func TestRefreshAuthUnknown(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration
retry.Clock.After = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
}
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := rcm.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}
func TestRefreshAuthFail(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
}
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
token, err := rcm.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}

View File

@ -49,8 +49,6 @@ type Supervisor struct {
log *ConnAwareLogger
logTransport *zerolog.Logger
reconnectCredentialManager *reconnectCredentialManager
reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{}
}
@ -76,8 +74,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
return nil, err
}
reconnectCredentialManager := newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections)
tracker := tunnelstate.NewConnTracker(config.Log)
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
@ -87,7 +83,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
edgeTunnelServer := EdgeTunnelServer{
config: config,
orchestrator: orchestrator,
credentialManager: reconnectCredentialManager,
edgeAddrs: edgeIPs,
edgeAddrHandler: edgeAddrHandler,
edgeBindAddr: edgeBindAddr,
@ -107,7 +102,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
tunnelsProtocolFallback: map[int]*protocolFallback{},
log: log,
logTransport: config.LogTransport,
reconnectCredentialManager: reconnectCredentialManager,
reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC,
}, nil
@ -138,7 +132,7 @@ func (s *Supervisor) Run(
var tunnelsWaiting []int
tunnelsActive := s.config.HAConnections
backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
var backoffTimer <-chan time.Time
shuttingDown := false
@ -212,7 +206,7 @@ func (s *Supervisor) initialize(
s.config.HAConnections = availableAddrs
}
s.tunnelsProtocolFallback[0] = &protocolFallback{
retry.BackoffHandler{MaxRetries: s.config.Retries, RetryForever: true},
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
s.config.ProtocolSelector.Current(),
false,
}
@ -234,7 +228,7 @@ func (s *Supervisor) initialize(
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
s.tunnelsProtocolFallback[i] = &protocolFallback{
retry.BackoffHandler{MaxRetries: s.config.Retries, RetryForever: true},
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
// Set the protocol we know the first tunnel connected with.
s.tunnelsProtocolFallback[0].protocol,
false,

View File

@ -203,7 +203,6 @@ func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsN
type EdgeTunnelServer struct {
config *TunnelConfig
orchestrator *orchestration.Orchestrator
credentialManager *reconnectCredentialManager
edgeAddrHandler EdgeAddrHandler
edgeAddrs *edgediscovery.Edge
edgeBindAddr net.IP

View File

@ -24,14 +24,18 @@ func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher {
}
}
func immediateTimeAfter(time.Duration) <-chan time.Time {
c := make(chan time.Time, 1)
c <- time.Now()
return c
}
func TestWaitForBackoffFallback(t *testing.T) {
maxRetries := uint(3)
backoff := retry.BackoffHandler{
MaxRetries: maxRetries,
BaseTime: time.Millisecond * 10,
}
backoff := retry.NewBackoff(maxRetries, 40*time.Millisecond, false)
backoff.Clock.After = immediateTimeAfter
log := zerolog.Nop()
resolveTTL := time.Duration(0)
resolveTTL := 10 * time.Second
mockFetcher := dynamicMockFetcher{
protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}},
}
@ -64,21 +68,23 @@ func TestWaitForBackoffFallback(t *testing.T) {
}
// Retry fallback protocol
for i := 0; i < int(maxRetries); i++ {
protoFallback.BackoffTimer() // simulate retry
ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil)
assert.True(t, ok)
fallback, ok := protocolSelector.Fallback()
assert.True(t, ok)
assert.Equal(t, fallback, protoFallback.protocol)
}
assert.Equal(t, connection.HTTP2, protoFallback.protocol)
currentGlobalProtocol := protocolSelector.Current()
assert.Equal(t, initProtocol, currentGlobalProtocol)
// Simulate max retries again (retries reset after protocol switch)
for i := 0; i < int(maxRetries); i++ {
protoFallback.BackoffTimer()
}
// No protocol to fallback, return error
protoFallback.BackoffTimer() // simulate retry
ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil)
ok = selectNextProtocol(&log, protoFallback, protocolSelector, nil)
assert.False(t, ok)
protoFallback.reset()

View File

@ -94,9 +94,10 @@ func errDeleteTokenFailed(lockFilePath string) error {
// newLock will get a new file lock
func newLock(path string) *lock {
lockPath := path + ".lock"
backoff := retry.NewBackoff(uint(7), retry.DefaultBaseTime, false)
return &lock{
lockFilePath: lockPath,
backoff: &retry.BackoffHandler{MaxRetries: 7},
backoff: &backoff,
sigHandler: &signalHandler{
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
},