package connection import ( "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/edgediscovery" ) const ( testNoTTL = 0 noWarpRoutingEnabled = false ) var ( testNamedTunnelProperties = &NamedTunnelProperties{ Credentials: Credentials{ AccountTag: "testAccountTag", }, } ) func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { return func() (edgediscovery.ProtocolPercents, error) { if getError { return nil, fmt.Errorf("failed to fetch percentage") } return protocolPercent, nil } } type dynamicMockFetcher struct { protocolPercents edgediscovery.ProtocolPercents err error } func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { return func() (edgediscovery.ProtocolPercents, error) { return dmf.protocolPercents, dmf.err } } func TestNewProtocolSelector(t *testing.T) { tests := []struct { name string protocol string expectedProtocol Protocol hasFallback bool expectedFallback Protocol warpRoutingEnabled bool namedTunnelConfig *NamedTunnelProperties fetchFunc PercentageFetcher wantErr bool }{ { name: "classic tunnel", protocol: "h2mux", expectedProtocol: H2mux, namedTunnelConfig: nil, }, { name: "named tunnel over h2mux", protocol: "h2mux", expectedProtocol: H2mux, fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel over http2", protocol: "http2", expectedProtocol: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel http2 disabled still gets http2 because it is manually picked", protocol: "http2", expectedProtocol: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic disabled still gets quic because it is manually picked", protocol: "quic", expectedProtocol: QUIC, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic and http2 disabled", protocol: AutoSelectFlag, expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic disabled", protocol: AutoSelectFlag, expectedProtocol: HTTP2, // Hasfallback true is because if http2 fails, then we further fallback to h2mux. hasFallback: true, expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto all http2 disabled", protocol: AutoSelectFlag, expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to h2mux", protocol: AutoSelectFlag, expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to http2", protocol: AutoSelectFlag, expectedProtocol: HTTP2, hasFallback: true, expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to quic", protocol: AutoSelectFlag, expectedProtocol: QUIC, hasFallback: true, expectedFallback: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing requesting h2mux", protocol: "h2mux", expectedProtocol: HTTP2Warp, hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", protocol: "h2mux", expectedProtocol: HTTP2Warp, hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing http2", protocol: "http2", expectedProtocol: HTTP2Warp, hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing quic", protocol: AutoSelectFlag, expectedProtocol: QUICWarp, hasFallback: true, expectedFallback: HTTP2Warp, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing auto", protocol: AutoSelectFlag, expectedProtocol: HTTP2Warp, hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing auto- quic", protocol: AutoSelectFlag, expectedProtocol: QUICWarp, hasFallback: true, expectedFallback: HTTP2Warp, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelProperties, }, { // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error name: "classic tunnel unknown protocol", protocol: "unknown", expectedProtocol: H2mux, }, { name: "named tunnel unknown protocol", protocol: "unknown", fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), namedTunnelConfig: testNamedTunnelProperties, wantErr: true, }, { name: "named tunnel fetch error", protocol: AutoSelectFlag, fetchFunc: mockFetcher(true), namedTunnelConfig: testNamedTunnelProperties, expectedProtocol: HTTP2, wantErr: false, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log, false) if test.wantErr { assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) } else { assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) fallback, ok := selector.Fallback() assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) if test.hasFallback { assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) } } }) } } func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log, false) assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, H2mux, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil assert.Equal(t, H2mux, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, H2mux, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} assert.Equal(t, QUIC, selector.Current()) } func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} // Since the user chooses http2 on purpose, we always stick to it. selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log, false) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} assert.Equal(t, HTTP2, selector.Current()) } func TestProtocolSelectorRefreshTTL(t *testing.T) { fetcher := dynamicMockFetcher{} fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log, false) assert.NoError(t, err) assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} assert.Equal(t, QUIC, selector.Current()) }