diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 80c37145..8012f9af 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -21,11 +21,10 @@ import ( "golang.org/x/crypto/ssh/terminal" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" - "github.com/cloudflare/cloudflared/edgediscovery/allregions" - "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" @@ -218,34 +217,13 @@ func prepareTunnelConfig( transportProtocol = connection.QUIC.String() } - protocolFetcher := edgediscovery.ProtocolPercentage - - features := append(c.StringSlice("features"), defaultFeatures...) + features := dedup(append(c.StringSlice("features"), defaultFeatures...)) if needPQ { features = append(features, supervisor.FeaturePostQuantum) } - if c.IsSet(TunnelTokenFlag) { - if transportProtocol == connection.AutoSelectFlag { - protocolFetcher = func() (edgediscovery.ProtocolPercents, error) { - // If the Tunnel is remotely managed and no protocol is set, we prefer QUIC, but still allow fall-back. - preferQuic := []edgediscovery.ProtocolPercent{ - { - Protocol: connection.QUIC.String(), - Percentage: 100, - }, - { - Protocol: connection.HTTP2.String(), - Percentage: 100, - }, - } - return preferQuic, nil - } - } - log.Info().Msg("Will be fetching remotely managed configuration from Cloudflare API. Defaulting to protocol: quic") - } namedTunnel.Client = tunnelpogs.ClientInfo{ ClientID: clientID[:], - Features: dedup(features), + Features: features, Version: info.Version(), Arch: info.OSArch(), } @@ -268,7 +246,7 @@ func prepareTunnelConfig( } } - protocolSelector, err := connection.NewProtocolSelector(transportProtocol, cfg.WarpRouting.Enabled, namedTunnel, protocolFetcher, supervisor.ResolveTTL, log, c.Bool("post-quantum")) + protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), c.Bool("post-quantum"), edgediscovery.ProtocolPercentage, connection.ResolveTTL, log) if err != nil { return nil, nil, err } diff --git a/connection/protocol.go b/connection/protocol.go index ddf9bb76..21907155 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -1,7 +1,6 @@ package connection import ( - "errors" "fmt" "hash/fnv" "sync" @@ -13,7 +12,7 @@ import ( ) const ( - AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge; 'h2mux' - Cloudflare's implementation of HTTP/2, deprecated" + AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge" // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge edgeH2muxTLSServerName = "cftunnel.com" // edgeH2TLSServerName is the server name to establish http2 connection with edge @@ -21,43 +20,31 @@ const ( // edgeQUICServerName is the server name to establish quic connection with edge. edgeQUICServerName = "quic.cftunnel.com" AutoSelectFlag = "auto" + // SRV and TXT record resolution TTL + ResolveTTL = time.Hour ) var ( // ProtocolList represents a list of supported protocols for communication with the edge. - ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp} + ProtocolList = []Protocol{HTTP2, QUIC} ) type Protocol int64 const ( - // H2mux protocol can be used both with Classic and Named Tunnels. . - H2mux Protocol = iota - // HTTP2 is used only with named tunnels. It's more efficient than H2mux for L4 proxying. - HTTP2 - // QUIC is used only with named tunnels. + // HTTP2 using golang HTTP2 library for edge connections. + HTTP2 Protocol = iota + // QUIC using quic-go for edge connections. QUIC - // HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to - // H2mux on HTTP2 failure to connect. - HTTP2Warp - //QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but - // don't want HTTP2 to fallback to H2mux - QUICWarp ) // Fallback returns the fallback protocol and whether the protocol has a fallback func (p Protocol) fallback() (Protocol, bool) { switch p { - case H2mux: - return 0, false case HTTP2: - return H2mux, true - case HTTP2Warp: return 0, false case QUIC: return HTTP2, true - case QUICWarp: - return HTTP2Warp, true default: return 0, false } @@ -65,11 +52,9 @@ func (p Protocol) fallback() (Protocol, bool) { func (p Protocol) String() string { switch p { - case H2mux: - return "h2mux" - case HTTP2, HTTP2Warp: + case HTTP2: return "http2" - case QUIC, QUICWarp: + case QUIC: return "quic" default: return fmt.Sprintf("unknown protocol") @@ -78,15 +63,11 @@ func (p Protocol) String() string { func (p Protocol) TLSSettings() *TLSSettings { switch p { - case H2mux: - return &TLSSettings{ - ServerName: edgeH2muxTLSServerName, - } - case HTTP2, HTTP2Warp: + case HTTP2: return &TLSSettings{ ServerName: edgeH2TLSServerName, } - case QUIC, QUICWarp: + case QUIC: return &TLSSettings{ ServerName: edgeQUICServerName, NextProtos: []string{"argotunnel"}, @@ -106,6 +87,7 @@ type ProtocolSelector interface { Fallback() (Protocol, bool) } +// staticProtocolSelector will not provide a different protocol for Fallback type staticProtocolSelector struct { current Protocol } @@ -115,10 +97,11 @@ func (s *staticProtocolSelector) Current() Protocol { } func (s *staticProtocolSelector) Fallback() (Protocol, bool) { - return 0, false + return s.current, false } -type autoProtocolSelector struct { +// remoteProtocolSelector will fetch a list of remote protocols to provide for edge discovery +type remoteProtocolSelector struct { lock sync.RWMutex current Protocol @@ -127,23 +110,21 @@ type autoProtocolSelector struct { protocolPool []Protocol switchThreshold int32 - fetchFunc PercentageFetcher + fetchFunc edgediscovery.PercentageFetcher refreshAfter time.Time ttl time.Duration log *zerolog.Logger - needPQ bool } -func newAutoProtocolSelector( +func newRemoteProtocolSelector( current Protocol, protocolPool []Protocol, switchThreshold int32, - fetchFunc PercentageFetcher, + fetchFunc edgediscovery.PercentageFetcher, ttl time.Duration, log *zerolog.Logger, - needPQ bool, -) *autoProtocolSelector { - return &autoProtocolSelector{ +) *remoteProtocolSelector { + return &remoteProtocolSelector{ current: current, protocolPool: protocolPool, switchThreshold: switchThreshold, @@ -151,11 +132,10 @@ func newAutoProtocolSelector( refreshAfter: time.Now().Add(ttl), ttl: ttl, log: log, - needPQ: needPQ, } } -func (s *autoProtocolSelector) Current() Protocol { +func (s *remoteProtocolSelector) Current() Protocol { s.lock.Lock() defer s.lock.Unlock() if time.Now().Before(s.refreshAfter) { @@ -173,7 +153,13 @@ func (s *autoProtocolSelector) Current() Protocol { return s.current } -func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) { +func (s *remoteProtocolSelector) Fallback() (Protocol, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + return s.current.fallback() +} + +func getProtocol(protocolPool []Protocol, fetchFunc edgediscovery.PercentageFetcher, switchThreshold int32) (Protocol, error) { protocolPercentages, err := fetchFunc() if err != nil { return 0, err @@ -188,109 +174,74 @@ func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThr return protocolPool[len(protocolPool)-1], nil } -func (s *autoProtocolSelector) Fallback() (Protocol, bool) { +// defaultProtocolSelector will allow for a protocol to have a fallback +type defaultProtocolSelector struct { + lock sync.RWMutex + current Protocol +} + +func newDefaultProtocolSelector( + current Protocol, +) *defaultProtocolSelector { + return &defaultProtocolSelector{ + current: current, + } +} + +func (s *defaultProtocolSelector) Current() Protocol { + s.lock.Lock() + defer s.lock.Unlock() + return s.current +} + +func (s *defaultProtocolSelector) Fallback() (Protocol, bool) { s.lock.RLock() defer s.lock.RUnlock() - if s.needPQ { - return 0, false - } return s.current.fallback() } -type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) - func NewProtocolSelector( protocolFlag string, - warpRoutingEnabled bool, - namedTunnel *NamedTunnelProperties, - fetchFunc PercentageFetcher, - ttl time.Duration, - log *zerolog.Logger, + accountTag string, + tunnelTokenProvided bool, needPQ bool, + protocolFetcher edgediscovery.PercentageFetcher, + resolveTTL time.Duration, + log *zerolog.Logger, ) (ProtocolSelector, error) { - // Classic tunnel is only supported with h2mux - if namedTunnel == nil { - if needPQ { - return nil, errors.New("Classic tunnel does not support post-quantum") - } - + // With --post-quantum, we force quic + if needPQ { return &staticProtocolSelector{ - current: H2mux, + current: QUIC, }, nil } - threshold := switchThreshold(namedTunnel.Credentials.AccountTag) - fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold) - if err != nil && protocolFlag == "auto" { - log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can attempt `--protocol quic` instead.") - if needPQ { - return nil, errors.New("http2 does not support post-quantum") - } - return &staticProtocolSelector{ - current: HTTP2, - }, nil - } - if warpRoutingEnabled { - if protocolFlag == H2mux.String() || fetchedProtocol == H2mux { - log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") - protocolFlag = HTTP2.String() - fetchedProtocol = HTTP2Warp - } - return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol, needPQ) + // When a --token is provided, we want to start with QUIC but have fallback to HTTP2 + if tunnelTokenProvided { + return newDefaultProtocolSelector(QUIC), nil } - return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol, needPQ) -} + threshold := switchThreshold(accountTag) + fetchedProtocol, err := getProtocol(ProtocolList, protocolFetcher, threshold) + if err != nil { + log.Warn().Msg("Unable to lookup protocol percentage.") + // Falling through here since 'auto' is handled in the switch and failing + // to do the protocol lookup isn't a failure since it can be triggered again + // after the TTL. + } -func selectNamedTunnelProtocols( - protocolFlag string, - fetchFunc PercentageFetcher, - ttl time.Duration, - log *zerolog.Logger, - threshold int32, - protocol Protocol, - needPQ bool, -) (ProtocolSelector, error) { // If the user picks a protocol, then we stick to it no matter what. switch protocolFlag { - case H2mux.String(): - return &staticProtocolSelector{current: H2mux}, nil + case "h2mux": + // Any users still requesting h2mux will be upgraded to http2 instead + log.Warn().Msg("h2mux is no longer a supported protocol: upgrading edge connection to http2. Please remove '--protocol h2mux' from runtime arguments to remove this warning.") + return &staticProtocolSelector{current: HTTP2}, nil case QUIC.String(): return &staticProtocolSelector{current: QUIC}, nil case HTTP2.String(): return &staticProtocolSelector{current: HTTP2}, nil - } - - // If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and - // fallback on failures. - if protocolFlag == AutoSelectFlag { - return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log, needPQ), nil - } - - return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) -} - -func selectWarpRoutingProtocols( - protocolFlag string, - fetchFunc PercentageFetcher, - ttl time.Duration, - log *zerolog.Logger, - threshold int32, - protocol Protocol, - needPQ bool, -) (ProtocolSelector, error) { - // If the user picks a protocol, then we stick to it no matter what. - switch protocolFlag { - case QUIC.String(): - return &staticProtocolSelector{current: QUICWarp}, nil - case HTTP2.String(): - return &staticProtocolSelector{current: HTTP2Warp}, nil - } - - // If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and - // fallback on failures. - if protocolFlag == AutoSelectFlag { - return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log, needPQ), nil + case AutoSelectFlag: + return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil } return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) diff --git a/connection/protocol_test.go b/connection/protocol_test.go index b3bda44a..12d238c5 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -3,7 +3,6 @@ package connection import ( "fmt" "testing" - "time" "github.com/stretchr/testify/assert" @@ -11,19 +10,11 @@ import ( ) const ( - testNoTTL = 0 - noWarpRoutingEnabled = false + testNoTTL = 0 + testAccountTag = "testAccountTag" ) -var ( - testNamedTunnelProperties = &NamedTunnelProperties{ - Credentials: Credentials{ - AccountTag: "testAccountTag", - }, - } -) - -func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { +func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) edgediscovery.PercentageFetcher { return func() (edgediscovery.ProtocolPercents, error) { if getError { return nil, fmt.Errorf("failed to fetch percentage") @@ -37,7 +28,7 @@ type dynamicMockFetcher struct { err error } -func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { +func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher { return func() (edgediscovery.ProtocolPercents, error) { return dmf.protocolPercents, dmf.err } @@ -45,181 +36,58 @@ func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { func TestNewProtocolSelector(t *testing.T) { tests := []struct { - name string - protocol string - expectedProtocol Protocol - hasFallback bool - expectedFallback Protocol - warpRoutingEnabled bool - namedTunnelConfig *NamedTunnelProperties - fetchFunc PercentageFetcher - wantErr bool + name string + protocol string + tunnelTokenProvided bool + needPQ bool + expectedProtocol Protocol + hasFallback bool + expectedFallback Protocol + wantErr bool }{ { - name: "classic tunnel", - protocol: "h2mux", - expectedProtocol: H2mux, - namedTunnelConfig: nil, + name: "named tunnel with unknown protocol", + protocol: "unknown", + wantErr: true, }, { - name: "named tunnel over h2mux", - protocol: "h2mux", - expectedProtocol: H2mux, - fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "named tunnel over http2", - protocol: "http2", - expectedProtocol: HTTP2, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "named tunnel http2 disabled still gets http2 because it is manually picked", - protocol: "http2", - expectedProtocol: HTTP2, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "named tunnel quic disabled still gets quic because it is manually picked", - protocol: "quic", - expectedProtocol: QUIC, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "named tunnel quic and http2 disabled", - protocol: AutoSelectFlag, - expectedProtocol: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "named tunnel quic disabled", - protocol: AutoSelectFlag, + name: "named tunnel with h2mux: force to http2", + protocol: "h2mux", expectedProtocol: HTTP2, - // Hasfallback true is because if http2 fails, then we further fallback to h2mux. - hasFallback: true, - expectedFallback: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelProperties, }, { - name: "named tunnel auto all http2 disabled", - protocol: AutoSelectFlag, - expectedProtocol: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelProperties, + name: "named tunnel with http2: no fallback", + protocol: "http2", + expectedProtocol: HTTP2, }, { - name: "named tunnel auto to h2mux", - protocol: AutoSelectFlag, - expectedProtocol: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), - namedTunnelConfig: testNamedTunnelProperties, + name: "named tunnel with auto: quic", + protocol: AutoSelectFlag, + expectedProtocol: QUIC, + hasFallback: true, + expectedFallback: HTTP2, }, { - name: "named tunnel auto to http2", - protocol: AutoSelectFlag, - expectedProtocol: HTTP2, - hasFallback: true, - expectedFallback: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelProperties, + name: "named tunnel (post quantum)", + protocol: AutoSelectFlag, + needPQ: true, + expectedProtocol: QUIC, }, { - name: "named tunnel auto to quic", - protocol: AutoSelectFlag, - expectedProtocol: QUIC, - hasFallback: true, - expectedFallback: HTTP2, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "warp routing requesting h2mux", - protocol: "h2mux", - expectedProtocol: HTTP2Warp, - hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", - protocol: "h2mux", - expectedProtocol: HTTP2Warp, - hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "warp routing http2", - protocol: "http2", - expectedProtocol: HTTP2Warp, - hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "warp routing quic", - protocol: AutoSelectFlag, - expectedProtocol: QUICWarp, - hasFallback: true, - expectedFallback: HTTP2Warp, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "warp routing auto", - protocol: AutoSelectFlag, - expectedProtocol: HTTP2Warp, - hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - name: "warp routing auto- quic", - protocol: AutoSelectFlag, - expectedProtocol: QUICWarp, - hasFallback: true, - expectedFallback: HTTP2Warp, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelProperties, - }, - { - // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error - name: "classic tunnel unknown protocol", - protocol: "unknown", - expectedProtocol: H2mux, - }, - { - name: "named tunnel unknown protocol", - protocol: "unknown", - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelProperties, - wantErr: true, - }, - { - name: "named tunnel fetch error", - protocol: AutoSelectFlag, - fetchFunc: mockFetcher(true), - namedTunnelConfig: testNamedTunnelProperties, - expectedProtocol: HTTP2, - wantErr: false, + name: "named tunnel (post quantum) w/http2", + protocol: "http2", + needPQ: true, + expectedProtocol: QUIC, }, } + fetcher := dynamicMockFetcher{ + protocolPercents: edgediscovery.ProtocolPercents{}, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log, false) + selector, err := NewProtocolSelector(test.protocol, testAccountTag, test.tunnelTokenProvided, test.needPQ, fetcher.fetch(), ResolveTTL, &log) if test.wantErr { assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) } else { @@ -237,15 +105,15 @@ func TestNewProtocolSelector(t *testing.T) { func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} - selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log, false) + selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) - assert.Equal(t, H2mux, selector.Current()) + assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} - assert.Equal(t, H2mux, selector.Current()) + assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) @@ -255,10 +123,10 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil - assert.Equal(t, H2mux, selector.Current()) + assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} - assert.Equal(t, H2mux, selector.Current()) + assert.Equal(t, QUIC, selector.Current()) fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} assert.Equal(t, QUIC, selector.Current()) @@ -267,7 +135,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} // Since the user chooses http2 on purpose, we always stick to it. - selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log, false) + selector, err := NewProtocolSelector(HTTP2.String(), testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) @@ -294,13 +162,12 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { assert.Equal(t, HTTP2, selector.Current()) } -func TestProtocolSelectorRefreshTTL(t *testing.T) { +func TestAutoProtocolSelectorNoRefreshWithToken(t *testing.T) { fetcher := dynamicMockFetcher{} - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} - selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log, false) + selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, true, false, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, QUIC, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, QUIC, selector.Current()) } diff --git a/connection/rpc.go b/connection/rpc.go index 384be0dc..a4a13e82 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -295,7 +295,7 @@ func (h *h2muxConnection) registerNamedTunnel( return err } h.observer.logServerInfo(h.connIndex, registrationDetails.Location, nil, fmt.Sprintf("Connection %s registered", registrationDetails.UUID)) - h.observer.sendConnectedEvent(h.connIndex, H2mux, registrationDetails.Location) + h.observer.sendConnectedEvent(h.connIndex, 0, registrationDetails.Location) return nil } diff --git a/edgediscovery/protocol.go b/edgediscovery/protocol.go index 5bbb1e91..2427294b 100644 --- a/edgediscovery/protocol.go +++ b/edgediscovery/protocol.go @@ -15,6 +15,8 @@ var ( errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) ) +type PercentageFetcher func() (ProtocolPercents, error) + // ProtocolPercent represents a single Protocol Percentage combination. type ProtocolPercent struct { Protocol string `json:"protocol"` diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 67f93b92..5723a820 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -19,8 +19,6 @@ import ( ) const ( - // SRV and TXT record resolution TTL - ResolveTTL = time.Hour // Waiting time before retrying a failed tunnel connection tunnelRetryDuration = time.Second * 10 // Interval between registering new tunnels diff --git a/supervisor/supervisor_test.go b/supervisor/supervisor_test.go deleted file mode 100644 index 70e0f547..00000000 --- a/supervisor/supervisor_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package supervisor - -import ( - "context" - "testing" - - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - - "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/edgediscovery" - "github.com/cloudflare/cloudflared/edgediscovery/allregions" - "github.com/cloudflare/cloudflared/signal" -) - -type mockProtocolSelector struct { - protocols []connection.Protocol - index int -} - -func (m *mockProtocolSelector) Current() connection.Protocol { - return m.protocols[m.index] -} - -func (m *mockProtocolSelector) Fallback() (connection.Protocol, bool) { - m.index++ - if m.index == len(m.protocols) { - return m.protocols[len(m.protocols)-1], false - } - - return m.protocols[m.index], true -} - -type mockEdgeTunnelServer struct { - config *TunnelConfig -} - -func (m *mockEdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error { - // This is to mock the first connection falling back because of connectivity issues. - protocolFallback.protocol, _ = m.config.ProtocolSelector.Fallback() - connectedSignal.Notify() - return nil -} - -// Test to check if initialize sets all the different connections to the same protocol should the first -// tunnel fall back. -func Test_Initialize_Same_Protocol(t *testing.T) { - edgeIPs, err := edgediscovery.ResolveEdge(&zerolog.Logger{}, "us", allregions.Auto) - assert.Nil(t, err) - s := Supervisor{ - edgeIPs: edgeIPs, - config: &TunnelConfig{ - ProtocolSelector: &mockProtocolSelector{protocols: []connection.Protocol{connection.QUIC, connection.HTTP2, connection.H2mux}}, - }, - tunnelsProtocolFallback: make(map[int]*protocolFallback), - edgeTunnelServer: &mockEdgeTunnelServer{ - config: &TunnelConfig{ - ProtocolSelector: &mockProtocolSelector{protocols: []connection.Protocol{connection.QUIC, connection.HTTP2, connection.H2mux}}, - }, - }, - } - - ctx := context.Background() - connectedSignal := signal.New(make(chan struct{})) - s.initialize(ctx, connectedSignal) - - // Make sure we fell back to http2 as the mock Serve is wont to do. - assert.Equal(t, s.tunnelsProtocolFallback[0].protocol, connection.HTTP2) - - // Ensure all the protocols we set to try are the same as what the first tunnel has fallen back to. - for _, protocolFallback := range s.tunnelsProtocolFallback { - assert.Equal(t, protocolFallback.protocol, s.tunnelsProtocolFallback[0].protocol) - } -} diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 1c22daf1..ac659426 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -487,7 +487,7 @@ func (e *EdgeTunnelServer) serveConnection( ) switch protocol { - case connection.QUIC, connection.QUICWarp: + case connection.QUIC: connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) return e.serveQUIC(ctx, addr.UDP, @@ -496,7 +496,7 @@ func (e *EdgeTunnelServer) serveConnection( controlStream, connIndex) - case connection.HTTP2, connection.HTTP2Warp: + case connection.HTTP2: edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) if err != nil { connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") diff --git a/supervisor/tunnel_test.go b/supervisor/tunnel_test.go index 3f9ae62c..9a0e9653 100644 --- a/supervisor/tunnel_test.go +++ b/supervisor/tunnel_test.go @@ -18,7 +18,7 @@ type dynamicMockFetcher struct { err error } -func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { +func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher { return func() (edgediscovery.ProtocolPercents, error) { return dmf.protocolPercents, dmf.err } @@ -32,24 +32,22 @@ func TestWaitForBackoffFallback(t *testing.T) { } log := zerolog.Nop() resolveTTL := time.Duration(0) - namedTunnel := &connection.NamedTunnelProperties{} mockFetcher := dynamicMockFetcher{ - protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, + protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}, } - warpRoutingEnabled := false protocolSelector, err := connection.NewProtocolSelector( "auto", - warpRoutingEnabled, - namedTunnel, + "", + false, + false, mockFetcher.fetch(), resolveTTL, &log, - false, ) assert.NoError(t, err) initProtocol := protocolSelector.Current() - assert.Equal(t, connection.HTTP2, initProtocol) + assert.Equal(t, connection.QUIC, initProtocol) protoFallback := &protocolFallback{ backoff, @@ -100,12 +98,12 @@ func TestWaitForBackoffFallback(t *testing.T) { // The reason why there's no fallback available is because we pick a specific protocol instead of letting it be auto. protocolSelector, err = connection.NewProtocolSelector( "quic", - warpRoutingEnabled, - namedTunnel, + "", + false, + false, mockFetcher.fetch(), resolveTTL, &log, - false, ) assert.NoError(t, err) protoFallback = &protocolFallback{backoff, protocolSelector.Current(), false}