From 2822fbe3db87f8f239ceaf12ff519615d437a7ec Mon Sep 17 00:00:00 2001 From: Sudarsan Reddy Date: Wed, 13 Oct 2021 19:06:31 +0100 Subject: [PATCH] TUN-5249: Revert "TUN-5138: Switch to QUIC on auto protocol based on threshold" This reverts commit e445fd92f78fc83382974f884301dbd3768fea59 --- cmd/cloudflared/tunnel/configuration.go | 2 +- connection/protocol.go | 158 ++++++++-------------- connection/protocol_test.go | 167 +++++++++--------------- edgediscovery/protocol.go | 53 ++++---- edgediscovery/protocol_test.go | 72 +++++++++- origin/tunnel_test.go | 14 +- 6 files changed, 217 insertions(+), 249 deletions(-) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 25eaf2a6..901d192c 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -238,7 +238,7 @@ func prepareTunnelConfig( log.Info().Msgf("Warp-routing is enabled") } - protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) if err != nil { return nil, ingress.Ingress{}, err } diff --git a/connection/protocol.go b/connection/protocol.go index b1d86aa6..1eb37cc1 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -7,8 +7,6 @@ import ( "time" "github.com/rs/zerolog" - - "github.com/cloudflare/cloudflared/edgediscovery" ) const ( @@ -26,7 +24,7 @@ const ( var ( // ProtocolList represents a list of supported protocols for communication with the edge. - ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp} + ProtocolList = []Protocol{H2mux, HTTP2, QUIC} ) type Protocol int64 @@ -38,12 +36,6 @@ const ( HTTP2 // QUIC is used only with named tunnels. QUIC - // HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to - // H2mux on HTTP2 failure to connect. - HTTP2Warp - //QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but - // dont' want HTTP2 to fallback to H2mux - QUICWarp ) // Fallback returns the fallback protocol and whether the protocol has a fallback @@ -53,12 +45,8 @@ func (p Protocol) fallback() (Protocol, bool) { return 0, false case HTTP2: return H2mux, true - case HTTP2Warp: - return 0, false case QUIC: return HTTP2, true - case QUICWarp: - return HTTP2Warp, true default: return 0, false } @@ -68,9 +56,9 @@ func (p Protocol) String() string { switch p { case H2mux: return "h2mux" - case HTTP2, HTTP2Warp: + case HTTP2: return "http2" - case QUIC, QUICWarp: + case QUIC: return "quic" default: return fmt.Sprintf("unknown protocol") @@ -83,11 +71,11 @@ func (p Protocol) TLSSettings() *TLSSettings { return &TLSSettings{ ServerName: edgeH2muxTLSServerName, } - case HTTP2, HTTP2Warp: + case HTTP2: return &TLSSettings{ ServerName: edgeH2TLSServerName, } - case QUIC, QUICWarp: + case QUIC: return &TLSSettings{ ServerName: edgeQUICServerName, NextProtos: []string{"argotunnel"}, @@ -120,36 +108,29 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) { } type autoProtocolSelector struct { - lock sync.RWMutex - - current Protocol - - // protocolPool is desired protocols in the order of priority they should be picked in. - protocolPool []Protocol - - switchThreshold int32 - fetchFunc PercentageFetcher - refreshAfter time.Time - ttl time.Duration - log *zerolog.Logger + lock sync.RWMutex + current Protocol + switchThrehold int32 + fetchFunc PercentageFetcher + refreshAfter time.Time + ttl time.Duration + log *zerolog.Logger } func newAutoProtocolSelector( current Protocol, - protocolPool []Protocol, - switchThreshold int32, + switchThrehold int32, fetchFunc PercentageFetcher, ttl time.Duration, log *zerolog.Logger, ) *autoProtocolSelector { return &autoProtocolSelector{ - current: current, - protocolPool: protocolPool, - switchThreshold: switchThreshold, - fetchFunc: fetchFunc, - refreshAfter: time.Now().Add(ttl), - ttl: ttl, - log: log, + current: current, + switchThrehold: switchThrehold, + fetchFunc: fetchFunc, + refreshAfter: time.Now().Add(ttl), + ttl: ttl, + log: log, } } @@ -160,39 +141,28 @@ func (s *autoProtocolSelector) Current() Protocol { return s.current } - protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold) + percentage, err := s.fetchFunc() if err != nil { s.log.Err(err).Msg("Failed to refresh protocol") return s.current } - s.current = protocol + if s.switchThrehold < percentage { + s.current = HTTP2 + } else { + s.current = H2mux + } s.refreshAfter = time.Now().Add(s.ttl) return s.current } -func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) { - protocolPercentages, err := fetchFunc() - if err != nil { - return 0, err - } - for _, protocol := range protocolPool { - protocolPercentage := protocolPercentages.GetPercentage(protocol.String()) - if protocolPercentage > switchThreshold { - return protocol, nil - } - } - - return protocolPool[len(protocolPool)-1], nil -} - func (s *autoProtocolSelector) Fallback() (Protocol, bool) { s.lock.RLock() defer s.lock.RUnlock() return s.current.fallback() } -type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) +type PercentageFetcher func() (int32, error) func NewProtocolSelector( protocolFlag string, @@ -209,34 +179,22 @@ func NewProtocolSelector( }, nil } - threshold := switchThreshold(namedTunnel.Credentials.AccountTag) - fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold) - if err != nil { - log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") + // warp routing cannot be served over h2mux connections + if warpRoutingEnabled { + if protocolFlag == H2mux.String() { + log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") + } + + if protocolFlag == QUIC.String() { + return &staticProtocolSelector{ + current: QUIC, + }, nil + } return &staticProtocolSelector{ current: HTTP2, }, nil } - if warpRoutingEnabled { - if protocolFlag == H2mux.String() || fetchedProtocol == H2mux { - log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") - protocolFlag = HTTP2.String() - fetchedProtocol = HTTP2Warp - } - return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) - } - return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) -} - -func selectNamedTunnelProtocols( - protocolFlag string, - fetchFunc PercentageFetcher, - ttl time.Duration, - log *zerolog.Logger, - threshold int32, - protocol Protocol, -) (ProtocolSelector, error) { if protocolFlag == H2mux.String() { return &staticProtocolSelector{ current: H2mux, @@ -244,41 +202,31 @@ func selectNamedTunnelProtocols( } if protocolFlag == QUIC.String() { - return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil } + http2Percentage, err := fetchFunc() + if err != nil { + log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") + return &staticProtocolSelector{ + current: HTTP2, + }, nil + } if protocolFlag == HTTP2.String() { - return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + if http2Percentage < 0 { + return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + } + return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil } if protocolFlag != autoSelectFlag { return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) } - - return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil -} - -func selectWarpRoutingProtocols( - protocolFlag string, - fetchFunc PercentageFetcher, - ttl time.Duration, - log *zerolog.Logger, - threshold int32, - protocol Protocol, -) (ProtocolSelector, error) { - if protocolFlag == QUIC.String() { - return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + threshold := switchThreshold(namedTunnel.Credentials.AccountTag) + if threshold < http2Percentage { + return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil } - - if protocolFlag == HTTP2.String() { - return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil - } - - if protocolFlag != autoSelectFlag { - return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) - } - - return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil + return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil } func switchThreshold(accountTag string) int32 { diff --git a/connection/protocol_test.go b/connection/protocol_test.go index 1139c657..b4a6299a 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -6,8 +6,6 @@ import ( "time" "github.com/stretchr/testify/assert" - - "github.com/cloudflare/cloudflared/edgediscovery" ) const ( @@ -23,23 +21,29 @@ var ( } ) -func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { - return func() (edgediscovery.ProtocolPercents, error) { - if getError { - return nil, fmt.Errorf("failed to fetch precentage") - } - return protocolPercent, nil +func mockFetcher(percentage int32) PercentageFetcher { + return func() (int32, error) { + return percentage, nil + } +} + +func mockFetcherWithError() PercentageFetcher { + return func() (int32, error) { + return 0, fmt.Errorf("failed to fetch precentage") } } type dynamicMockFetcher struct { - protocolPercents edgediscovery.ProtocolPercents - err error + percentage int32 + err error } func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { - return func() (edgediscovery.ProtocolPercents, error) { - return dmf.protocolPercents, dmf.err + return func() (int32, error) { + if dmf.err != nil { + return 0, dmf.err + } + return dmf.percentage, nil } } @@ -65,7 +69,6 @@ func TestNewProtocolSelector(t *testing.T) { name: "named tunnel over h2mux", protocol: "h2mux", expectedProtocol: H2mux, - fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, namedTunnelConfig: testNamedTunnelConfig, }, { @@ -74,38 +77,28 @@ func TestNewProtocolSelector(t *testing.T) { expectedProtocol: HTTP2, hasFallback: true, expectedFallback: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), + fetchFunc: mockFetcher(0), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel http2 disabled", protocol: "http2", expectedProtocol: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, - }, - { - name: "named tunnel quic disabled", - protocol: "quic", - expectedProtocol: HTTP2, - // Hasfallback true is because if http2 fails, then we further fallback to h2mux. - hasFallback: true, - expectedFallback: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), + fetchFunc: mockFetcher(-1), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel auto all http2 disabled", protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), + fetchFunc: mockFetcher(-1), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel auto to h2mux", protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), + fetchFunc: mockFetcher(0), namedTunnelConfig: testNamedTunnelConfig, }, { @@ -114,71 +107,36 @@ func TestNewProtocolSelector(t *testing.T) { expectedProtocol: HTTP2, hasFallback: true, expectedFallback: H2mux, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, - }, - { - name: "named tunnel auto to quic", - protocol: "auto", - expectedProtocol: QUIC, - hasFallback: true, - expectedFallback: HTTP2, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), + fetchFunc: mockFetcher(100), namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing requesting h2mux", protocol: "h2mux", - expectedProtocol: HTTP2Warp, + expectedProtocol: HTTP2, hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, - }, - { - name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", - protocol: "h2mux", - expectedProtocol: HTTP2Warp, - hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing http2", protocol: "http2", - expectedProtocol: HTTP2Warp, + expectedProtocol: HTTP2, hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, - }, - { - name: "warp routing quic", - protocol: "quic", - expectedProtocol: QUICWarp, - hasFallback: true, - expectedFallback: HTTP2Warp, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing auto", protocol: "auto", - expectedProtocol: HTTP2Warp, + expectedProtocol: HTTP2, hasFallback: false, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, - }, - { - name: "warp routing auto- quic", - protocol: "auto", - expectedProtocol: QUICWarp, - hasFallback: true, - expectedFallback: HTTP2Warp, - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, @@ -191,14 +149,14 @@ func TestNewProtocolSelector(t *testing.T) { { name: "named tunnel unknown protocol", protocol: "unknown", - fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + fetchFunc: mockFetcher(100), namedTunnelConfig: testNamedTunnelConfig, wantErr: true, }, { name: "named tunnel fetch error", - protocol: "auto", - fetchFunc: mockFetcher(true), + protocol: "unknown", + fetchFunc: mockFetcherWithError(), namedTunnelConfig: testNamedTunnelConfig, expectedProtocol: HTTP2, wantErr: false, @@ -206,20 +164,18 @@ func TestNewProtocolSelector(t *testing.T) { } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) - if test.wantErr { - assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) - } else { - assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) - assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) - fallback, ok := selector.Fallback() - assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) - if test.hasFallback { - assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) - } + selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) + if test.wantErr { + assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) + } else { + assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) + assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) + fallback, ok := selector.Fallback() + assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) + if test.hasFallback { + assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) } - }) + } } } @@ -229,27 +185,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} + fetcher.percentage = 100 assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} + fetcher.percentage = 0 assert.Equal(t, H2mux, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} + fetcher.percentage = 100 assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} + fetcher.percentage = -1 fetcher.err = nil assert.Equal(t, H2mux, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} + fetcher.percentage = 0 assert.Equal(t, H2mux, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} - assert.Equal(t, QUIC, selector.Current()) + fetcher.percentage = 100 + assert.Equal(t, HTTP2, selector.Current()) } func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { @@ -258,36 +214,35 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} + fetcher.percentage = 100 assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} + fetcher.percentage = 0 assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} + fetcher.percentage = -1 fetcher.err = nil assert.Equal(t, H2mux, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} + fetcher.percentage = 0 assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} + fetcher.percentage = 100 assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} + fetcher.percentage = -1 assert.Equal(t, H2mux, selector.Current()) } func TestProtocolSelectorRefreshTTL(t *testing.T) { - fetcher := dynamicMockFetcher{} - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} + fetcher := dynamicMockFetcher{percentage: 100} selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) assert.NoError(t, err) - assert.Equal(t, QUIC, selector.Current()) + assert.Equal(t, HTTP2, selector.Current()) - fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} - assert.Equal(t, QUIC, selector.Current()) + fetcher.percentage = 0 + assert.Equal(t, HTTP2, selector.Current()) } diff --git a/edgediscovery/protocol.go b/edgediscovery/protocol.go index 5bbb1e91..8d9f5039 100644 --- a/edgediscovery/protocol.go +++ b/edgediscovery/protocol.go @@ -1,50 +1,45 @@ package edgediscovery import ( - "encoding/json" "fmt" "net" + "strconv" "strings" ) const ( - protocolRecord = "protocol-v2.argotunnel.com" + protocolRecord = "protocol.argotunnel.com" ) var ( errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) ) -// ProtocolPercent represents a single Protocol Percentage combination. -type ProtocolPercent struct { - Protocol string `json:"protocol"` - Percentage int32 `json:"percentage"` -} - -// ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified. -type ProtocolPercents []ProtocolPercent - -// GetPercentage returns the threshold percentage of a single protocol. -func (p ProtocolPercents) GetPercentage(protocol string) int32 { - for _, protocolPercent := range p { - if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) { - return protocolPercent.Percentage - } - } - return 0 -} - -// ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection. -func ProtocolPercentage() (ProtocolPercents, error) { +func HTTP2Percentage() (int32, error) { records, err := net.LookupTXT(protocolRecord) if err != nil { - return nil, err + return 0, err } if len(records) == 0 { - return nil, errNoProtocolRecord + return 0, errNoProtocolRecord } - - var protocolsWithPercent ProtocolPercents - err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent) - return protocolsWithPercent, err + return parseHTTP2Precentage(records[0]) +} + +// The record looks like http2=percentage +func parseHTTP2Precentage(record string) (int32, error) { + const key = "http2" + slices := strings.Split(record, "=") + if len(slices) != 2 { + return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record) + } + if slices[0] != key { + return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key) + } + percentage, err := strconv.ParseInt(slices[1], 10, 32) + if err != nil { + return 0, err + } + return int32(percentage), nil + } diff --git a/edgediscovery/protocol_test.go b/edgediscovery/protocol_test.go index 37b9353f..874ab6ee 100644 --- a/edgediscovery/protocol_test.go +++ b/edgediscovery/protocol_test.go @@ -6,7 +6,75 @@ import ( "github.com/stretchr/testify/assert" ) -func TestProtocolPercentage(t *testing.T) { - _, err := ProtocolPercentage() +func TestHTTP2Percentage(t *testing.T) { + _, err := HTTP2Percentage() assert.NoError(t, err) } + +func TestParseHTTP2Precentage(t *testing.T) { + tests := []struct { + record string + percentage int32 + wantErr bool + }{ + { + record: "http2=-1", + percentage: -1, + wantErr: false, + }, + { + record: "http2=0", + percentage: 0, + wantErr: false, + }, + { + record: "http2=50", + percentage: 50, + wantErr: false, + }, + { + record: "http2=100", + percentage: 100, + wantErr: false, + }, + { + record: "http2=1000", + percentage: 1000, + wantErr: false, + }, + { + record: "http2=10.5", + wantErr: true, + }, + { + record: "http2=10 h2mux=90", + wantErr: true, + }, + { + record: "http2=ten", + wantErr: true, + }, + + { + record: "h2mux=100", + wantErr: true, + }, + { + record: "http2", + wantErr: true, + }, + { + record: "http2=", + wantErr: true, + }, + } + + for _, test := range tests { + p, err := parseHTTP2Precentage(test.record) + if test.wantErr { + assert.Error(t, err) + } else { + assert.Equal(t, test.percentage, p) + } + } +} diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index 28c2f9ee..63b547da 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -8,18 +8,20 @@ import ( "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/retry" ) type dynamicMockFetcher struct { - protocolPercents edgediscovery.ProtocolPercents - err error + percentage int32 + err error } func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { - return func() (edgediscovery.ProtocolPercents, error) { - return dmf.protocolPercents, dmf.err + return func() (int32, error) { + if dmf.err != nil { + return 0, dmf.err + } + return dmf.percentage, nil } } @@ -37,7 +39,7 @@ func TestWaitForBackoffFallback(t *testing.T) { }, } mockFetcher := dynamicMockFetcher{ - protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, + percentage: 0, } warpRoutingEnabled := false protocolSelector, err := connection.NewProtocolSelector(