diff --git a/origin/tunnel.go b/origin/tunnel.go index 7f8ff02c..bcfea720 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -185,12 +185,11 @@ func ServeTunnelLoop( case <-gracefulShutdownC: return nil case <-protocolFallback.BackoffTimer(): - var idleTimeoutError *quic.IdleTimeoutError if !selectNextProtocol( connLog.Logger(), protocolFallback, config.ProtocolSelector, - errors.As(err, &idleTimeoutError), + err, ) { return err } @@ -223,9 +222,13 @@ func selectNextProtocol( connLog *zerolog.Logger, protocolBackoff *protocolFallback, selector connection.ProtocolSelector, - isNetworkActivityTimeout bool, + cause error, ) bool { - if protocolBackoff.ReachedMaxRetries() || isNetworkActivityTimeout { + var idleTimeoutError *quic.IdleTimeoutError + isNetworkActivityTimeout := errors.As(cause, &idleTimeoutError) + _, hasFallback := selector.Fallback() + + if protocolBackoff.ReachedMaxRetries() || (hasFallback && isNetworkActivityTimeout) { fallback, hasFallback := selector.Fallback() if !hasFallback { return false diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index 58108240..870a5049 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/lucas-clemente/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -53,7 +54,7 @@ func TestWaitForBackoffFallback(t *testing.T) { initProtocol := protocolSelector.Current() assert.Equal(t, connection.HTTP2, initProtocol) - protocolFallback := &protocolFallback{ + protoFallback := &protocolFallback{ backoff, initProtocol, false, @@ -61,40 +62,63 @@ func TestWaitForBackoffFallback(t *testing.T) { // Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this for i := 0; i < int(maxRetries-1); i++ { - protocolFallback.BackoffTimer() // simulate retry - ok := selectNextProtocol(&log, protocolFallback, protocolSelector, false) + protoFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.True(t, ok) - assert.Equal(t, initProtocol, protocolFallback.protocol) + assert.Equal(t, initProtocol, protoFallback.protocol) } // Retry fallback protocol for i := 0; i < int(maxRetries); i++ { - protocolFallback.BackoffTimer() // simulate retry - ok := selectNextProtocol(&log, protocolFallback, protocolSelector, false) + protoFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.True(t, ok) fallback, ok := protocolSelector.Fallback() assert.True(t, ok) - assert.Equal(t, fallback, protocolFallback.protocol) + assert.Equal(t, fallback, protoFallback.protocol) } currentGlobalProtocol := protocolSelector.Current() assert.Equal(t, initProtocol, currentGlobalProtocol) // No protocol to fallback, return error - protocolFallback.BackoffTimer() // simulate retry - ok := selectNextProtocol(&log, protocolFallback, protocolSelector, false) + protoFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.False(t, ok) - protocolFallback.reset() - protocolFallback.BackoffTimer() // simulate retry - ok = selectNextProtocol(&log, protocolFallback, protocolSelector, false) + protoFallback.reset() + protoFallback.BackoffTimer() // simulate retry + ok = selectNextProtocol(&log, protoFallback, protocolSelector, nil) assert.True(t, ok) - assert.Equal(t, initProtocol, protocolFallback.protocol) + assert.Equal(t, initProtocol, protoFallback.protocol) - protocolFallback.reset() - protocolFallback.BackoffTimer() // simulate retry - ok = selectNextProtocol(&log, protocolFallback, protocolSelector, true) + protoFallback.reset() + protoFallback.BackoffTimer() // simulate retry + ok = selectNextProtocol(&log, protoFallback, protocolSelector, &quic.IdleTimeoutError{}) // Check that we get a true after the first try itself when this flag is true. This allows us to immediately - // switch protocols. + // switch protocols when there is a fallback. assert.True(t, ok) + + // But if there is no fallback available, then we exhaust the retries despite the type of error. + // The reason why there's no fallback available is because we pick a specific protocol instead of letting it be auto. + protocolSelector, err = connection.NewProtocolSelector( + "quic", + warpRoutingEnabled, + namedTunnel, + mockFetcher.fetch(), + resolveTTL, + &log, + ) + assert.NoError(t, err) + protoFallback = &protocolFallback{backoff, protocolSelector.Current(), false} + for i := 0; i < int(maxRetries-1); i++ { + protoFallback.BackoffTimer() // simulate retry + ok := selectNextProtocol(&log, protoFallback, protocolSelector, &quic.IdleTimeoutError{}) + assert.True(t, ok) + assert.Equal(t, connection.QUIC, protoFallback.protocol) + } + // And finally it fails as it should, with no fallback. + protoFallback.BackoffTimer() + ok = selectNextProtocol(&log, protoFallback, protocolSelector, &quic.IdleTimeoutError{}) + assert.False(t, ok) }