package supervisor import ( "testing" "time" "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/retry" ) type dynamicMockFetcher struct { protocolPercents edgediscovery.ProtocolPercents err error } func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher { return func() (edgediscovery.ProtocolPercents, error) { return dmf.protocolPercents, dmf.err } } func immediateTimeAfter(time.Duration) <-chan time.Time { c := make(chan time.Time, 1) c <- time.Now() return c } func TestWaitForBackoffFallback(t *testing.T) { maxRetries := uint(3) backoff := retry.NewBackoff(maxRetries, 40*time.Millisecond, false) backoff.Clock.After = immediateTimeAfter log := zerolog.Nop() resolveTTL := 10 * time.Second mockFetcher := dynamicMockFetcher{ protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}, } protocolSelector, err := connection.NewProtocolSelector( "auto", "", false, false, mockFetcher.fetch(), resolveTTL, &log, ) assert.NoError(t, err) initProtocol := protocolSelector.Current() assert.Equal(t, connection.QUIC, initProtocol) protoFallback := &protocolFallback{ backoff, initProtocol, false, } // Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this for i := 0; i < int(maxRetries-1); i++ { protoFallback.BackoffTimer() // simulate retry ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.True(t, ok) assert.Equal(t, initProtocol, protoFallback.protocol) } // Retry fallback protocol protoFallback.BackoffTimer() // simulate retry ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.True(t, ok) fallback, ok := protocolSelector.Fallback() assert.True(t, ok) assert.Equal(t, fallback, protoFallback.protocol) assert.Equal(t, connection.HTTP2, protoFallback.protocol) currentGlobalProtocol := protocolSelector.Current() assert.Equal(t, initProtocol, currentGlobalProtocol) // Simulate max retries again (retries reset after protocol switch) for i := 0; i < int(maxRetries); i++ { protoFallback.BackoffTimer() } // No protocol to fallback, return error ok = selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.False(t, ok) protoFallback.reset() protoFallback.BackoffTimer() // simulate retry ok = selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.True(t, ok) assert.Equal(t, initProtocol, protoFallback.protocol) protoFallback.reset() protoFallback.BackoffTimer() // simulate retry ok = selectNextProtocol(&log, protoFallback, protocolSelector, &quic.IdleTimeoutError{}) // Check that we get a true after the first try itself when this flag is true. This allows us to immediately // switch protocols when there is a fallback. assert.True(t, ok) // But if there is no fallback available, then we exhaust the retries despite the type of error. // The reason why there's no fallback available is because we pick a specific protocol instead of letting it be auto. protocolSelector, err = connection.NewProtocolSelector( "quic", "", false, false, mockFetcher.fetch(), resolveTTL, &log, ) assert.NoError(t, err) protoFallback = &protocolFallback{backoff, protocolSelector.Current(), false} for i := 0; i < int(maxRetries-1); i++ { protoFallback.BackoffTimer() // simulate retry ok := selectNextProtocol(&log, protoFallback, protocolSelector, &quic.IdleTimeoutError{}) assert.True(t, ok) assert.Equal(t, connection.QUIC, protoFallback.protocol) } // And finally it fails as it should, with no fallback. protoFallback.BackoffTimer() ok = selectNextProtocol(&log, protoFallback, protocolSelector, &quic.IdleTimeoutError{}) assert.False(t, ok) }