diff --git a/connection/quic.go b/connection/quic.go index 41266c2d..eb7513e8 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -49,12 +49,12 @@ func NewQUICConnection( ) (*QUICConnection, error) { session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { - return nil, errors.Wrap(err, "failed to dial to edge") + return nil, fmt.Errorf("failed to dial to edge: %w", err) } registrationStream, err := session.OpenStream() if err != nil { - return nil, errors.Wrap(err, "failed to open a registration stream") + return nil, fmt.Errorf("failed to open a registration stream: %w", err) } err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false) @@ -82,7 +82,7 @@ func (q *QUICConnection) Serve(ctx context.Context) error { if errors.Is(err, context.Canceled) { return nil } - return errors.Wrap(err, "failed to accept QUIC stream") + return fmt.Errorf("failed to accept QUIC stream: %w", err) } go func() { defer stream.Close() diff --git a/origin/tunnel.go b/origin/tunnel.go index 7d8581cd..3727ce50 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -190,7 +190,13 @@ func ServeTunnelLoop( case <-gracefulShutdownC: return nil case <-protocolFallback.BackoffTimer(): - if !selectNextProtocol(&connLog, protocolFallback, config.ProtocolSelector) { + var idleTimeoutError *quic.IdleTimeoutError + if !selectNextProtocol( + &connLog, + protocolFallback, + config.ProtocolSelector, + errors.As(err, &idleTimeoutError), + ) { return err } } @@ -222,8 +228,9 @@ func selectNextProtocol( connLog *zerolog.Logger, protocolBackoff *protocolFallback, selector connection.ProtocolSelector, + isNetworkActivityTimeout bool, ) bool { - if protocolBackoff.ReachedMaxRetries() { + if protocolBackoff.ReachedMaxRetries() || isNetworkActivityTimeout { fallback, hasFallback := selector.Fallback() if !hasFallback { return false diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index 28c2f9ee..90898abb 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -62,7 +62,7 @@ func TestWaitForBackoffFallback(t *testing.T) { // Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this for i := 0; i < int(maxRetries-1); i++ { protocolFallback.BackoffTimer() // simulate retry - ok := selectNextProtocol(&log, protocolFallback, protocolSelector) + ok := selectNextProtocol(&log, protocolFallback, protocolSelector, false) assert.True(t, ok) assert.Equal(t, initProtocol, protocolFallback.protocol) } @@ -70,7 +70,7 @@ func TestWaitForBackoffFallback(t *testing.T) { // Retry fallback protocol for i := 0; i < int(maxRetries); i++ { protocolFallback.BackoffTimer() // simulate retry - ok := selectNextProtocol(&log, protocolFallback, protocolSelector) + ok := selectNextProtocol(&log, protocolFallback, protocolSelector, false) assert.True(t, ok) fallback, ok := protocolSelector.Fallback() assert.True(t, ok) @@ -82,12 +82,19 @@ func TestWaitForBackoffFallback(t *testing.T) { // No protocol to fallback, return error protocolFallback.BackoffTimer() // simulate retry - ok := selectNextProtocol(&log, protocolFallback, protocolSelector) + ok := selectNextProtocol(&log, protocolFallback, protocolSelector, false) assert.False(t, ok) protocolFallback.reset() protocolFallback.BackoffTimer() // simulate retry - ok = selectNextProtocol(&log, protocolFallback, protocolSelector) + ok = selectNextProtocol(&log, protocolFallback, protocolSelector, false) assert.True(t, ok) assert.Equal(t, initProtocol, protocolFallback.protocol) + + protocolFallback.reset() + protocolFallback.BackoffTimer() // simulate retry + ok = selectNextProtocol(&log, protocolFallback, protocolSelector, true) + // Check that we get a true after the first try itself when this flag is true. This allows us to immediately + // switch protocols. + assert.True(t, ok) }