diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 5cbb7a0b..9bf43d25 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -231,7 +231,13 @@ func prepareTunnelConfig( } } - protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) + var warpRoutingService *ingress.WarpRoutingService + warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel) + if warpRoutingEnabled { + warpRoutingService = ingress.NewWarpRoutingService() + } + + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) if err != nil { return nil, ingress.Ingress{}, err } @@ -246,11 +252,6 @@ func prepareTunnelConfig( edgeTLSConfigs[p] = edgeTLSConfig } - var warpRoutingService *ingress.WarpRoutingService - if isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel, protocolSelector.Current()) { - warpRoutingService = ingress.NewWarpRoutingService() - } - originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log) connectionConfig := &connection.Config{ OriginProxy: originProxy, @@ -292,8 +293,8 @@ func prepareTunnelConfig( }, ingressRules, nil } -func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool, protocol connection.Protocol) bool { - return warpConfig.Enabled && isNamedTunnel && protocol == connection.HTTP2 +func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool) bool { + return warpConfig.Enabled && isNamedTunnel } func isRunningFromTerminal() bool { diff --git a/connection/protocol.go b/connection/protocol.go index fd7a95f4..2f6dea92 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -141,16 +141,29 @@ type PercentageFetcher func() (int32, error) func NewProtocolSelector( protocolFlag string, + warpRoutingEnabled bool, namedTunnel *NamedTunnelConfig, fetchFunc PercentageFetcher, ttl time.Duration, log *zerolog.Logger, ) (ProtocolSelector, error) { + // Classic tunnel is only supported with h2mux if namedTunnel == nil { return &staticProtocolSelector{ current: H2mux, }, nil } + + // warp routing can only be served over http2 connections + if warpRoutingEnabled { + if protocolFlag == H2mux.String() { + log.Warn().Msg("Warp routing is only supported by http2 protocol. Upgrading protocol to http2") + } + return &staticProtocolSelector{ + current: HTTP2, + }, nil + } + if protocolFlag == H2mux.String() { return &staticProtocolSelector{ current: H2mux, diff --git a/connection/protocol_test.go b/connection/protocol_test.go index 2100a32e..a119ec42 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -9,7 +9,8 @@ import ( ) const ( - testNoTTL = 0 + testNoTTL = 0 + noWarpRoutingEnabled = false ) var ( @@ -48,14 +49,15 @@ func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { func TestNewProtocolSelector(t *testing.T) { tests := []struct { - name string - protocol string - expectedProtocol Protocol - hasFallback bool - expectedFallback Protocol - namedTunnelConfig *NamedTunnelConfig - fetchFunc PercentageFetcher - wantErr bool + name string + protocol string + expectedProtocol Protocol + hasFallback bool + expectedFallback Protocol + warpRoutingEnabled bool + namedTunnelConfig *NamedTunnelConfig + fetchFunc PercentageFetcher + wantErr bool }{ { name: "classic tunnel", @@ -108,6 +110,36 @@ func TestNewProtocolSelector(t *testing.T) { fetchFunc: mockFetcher(100), namedTunnelConfig: testNamedTunnelConfig, }, + { + name: "warp routing requesting h2mux", + protocol: "h2mux", + expectedProtocol: HTTP2, + hasFallback: false, + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing http2", + protocol: "http2", + expectedProtocol: HTTP2, + hasFallback: false, + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing auto", + protocol: "auto", + expectedProtocol: HTTP2, + hasFallback: false, + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, { // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error name: "classic tunnel unknown protocol", @@ -131,7 +163,7 @@ func TestNewProtocolSelector(t *testing.T) { } for _, test := range tests { - selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) + selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) if test.wantErr { assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) } else { @@ -148,7 +180,7 @@ func TestNewProtocolSelector(t *testing.T) { func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} - selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) + selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) @@ -177,7 +209,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} - selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) + selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) @@ -206,7 +238,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { func TestProtocolSelectorRefreshTTL(t *testing.T) { fetcher := dynamicMockFetcher{percentage: 100} - selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) + selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index b7e284ab..d5696c41 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -40,8 +40,10 @@ func TestWaitForBackoffFallback(t *testing.T) { mockFetcher := dynamicMockFetcher{ percentage: 0, } + warpRoutingEnabled := false protocolSelector, err := connection.NewProtocolSelector( connection.HTTP2.String(), + warpRoutingEnabled, namedTunnel, mockFetcher.fetch(), resolveTTL,