diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 901d192c..25eaf2a6 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -238,7 +238,7 @@ func prepareTunnelConfig( log.Info().Msgf("Warp-routing is enabled") } - protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) if err != nil { return nil, ingress.Ingress{}, err } diff --git a/connection/protocol.go b/connection/protocol.go index 1eb37cc1..b1d86aa6 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -7,6 +7,8 @@ import ( "time" "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/edgediscovery" ) const ( @@ -24,7 +26,7 @@ const ( var ( // ProtocolList represents a list of supported protocols for communication with the edge. - ProtocolList = []Protocol{H2mux, HTTP2, QUIC} + ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp} ) type Protocol int64 @@ -36,6 +38,12 @@ const ( HTTP2 // QUIC is used only with named tunnels. QUIC + // HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to + // H2mux on HTTP2 failure to connect. + HTTP2Warp + //QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but + // dont' want HTTP2 to fallback to H2mux + QUICWarp ) // Fallback returns the fallback protocol and whether the protocol has a fallback @@ -45,8 +53,12 @@ func (p Protocol) fallback() (Protocol, bool) { return 0, false case HTTP2: return H2mux, true + case HTTP2Warp: + return 0, false case QUIC: return HTTP2, true + case QUICWarp: + return HTTP2Warp, true default: return 0, false } @@ -56,9 +68,9 @@ func (p Protocol) String() string { switch p { case H2mux: return "h2mux" - case HTTP2: + case HTTP2, HTTP2Warp: return "http2" - case QUIC: + case QUIC, QUICWarp: return "quic" default: return fmt.Sprintf("unknown protocol") @@ -71,11 +83,11 @@ func (p Protocol) TLSSettings() *TLSSettings { return &TLSSettings{ ServerName: edgeH2muxTLSServerName, } - case HTTP2: + case HTTP2, HTTP2Warp: return &TLSSettings{ ServerName: edgeH2TLSServerName, } - case QUIC: + case QUIC, QUICWarp: return &TLSSettings{ ServerName: edgeQUICServerName, NextProtos: []string{"argotunnel"}, @@ -108,29 +120,36 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) { } type autoProtocolSelector struct { - lock sync.RWMutex - current Protocol - switchThrehold int32 - fetchFunc PercentageFetcher - refreshAfter time.Time - ttl time.Duration - log *zerolog.Logger + lock sync.RWMutex + + current Protocol + + // protocolPool is desired protocols in the order of priority they should be picked in. + protocolPool []Protocol + + switchThreshold int32 + fetchFunc PercentageFetcher + refreshAfter time.Time + ttl time.Duration + log *zerolog.Logger } func newAutoProtocolSelector( current Protocol, - switchThrehold int32, + protocolPool []Protocol, + switchThreshold int32, fetchFunc PercentageFetcher, ttl time.Duration, log *zerolog.Logger, ) *autoProtocolSelector { return &autoProtocolSelector{ - current: current, - switchThrehold: switchThrehold, - fetchFunc: fetchFunc, - refreshAfter: time.Now().Add(ttl), - ttl: ttl, - log: log, + current: current, + protocolPool: protocolPool, + switchThreshold: switchThreshold, + fetchFunc: fetchFunc, + refreshAfter: time.Now().Add(ttl), + ttl: ttl, + log: log, } } @@ -141,28 +160,39 @@ func (s *autoProtocolSelector) Current() Protocol { return s.current } - percentage, err := s.fetchFunc() + protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold) if err != nil { s.log.Err(err).Msg("Failed to refresh protocol") return s.current } + s.current = protocol - if s.switchThrehold < percentage { - s.current = HTTP2 - } else { - s.current = H2mux - } s.refreshAfter = time.Now().Add(s.ttl) return s.current } +func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) { + protocolPercentages, err := fetchFunc() + if err != nil { + return 0, err + } + for _, protocol := range protocolPool { + protocolPercentage := protocolPercentages.GetPercentage(protocol.String()) + if protocolPercentage > switchThreshold { + return protocol, nil + } + } + + return protocolPool[len(protocolPool)-1], nil +} + func (s *autoProtocolSelector) Fallback() (Protocol, bool) { s.lock.RLock() defer s.lock.RUnlock() return s.current.fallback() } -type PercentageFetcher func() (int32, error) +type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) func NewProtocolSelector( protocolFlag string, @@ -179,22 +209,34 @@ func NewProtocolSelector( }, nil } - // warp routing cannot be served over h2mux connections - if warpRoutingEnabled { - if protocolFlag == H2mux.String() { - log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") - } - - if protocolFlag == QUIC.String() { - return &staticProtocolSelector{ - current: QUIC, - }, nil - } + threshold := switchThreshold(namedTunnel.Credentials.AccountTag) + fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold) + if err != nil { + log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") return &staticProtocolSelector{ current: HTTP2, }, nil } + if warpRoutingEnabled { + if protocolFlag == H2mux.String() || fetchedProtocol == H2mux { + log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") + protocolFlag = HTTP2.String() + fetchedProtocol = HTTP2Warp + } + return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) + } + return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) +} + +func selectNamedTunnelProtocols( + protocolFlag string, + fetchFunc PercentageFetcher, + ttl time.Duration, + log *zerolog.Logger, + threshold int32, + protocol Protocol, +) (ProtocolSelector, error) { if protocolFlag == H2mux.String() { return &staticProtocolSelector{ current: H2mux, @@ -202,31 +244,41 @@ func NewProtocolSelector( } if protocolFlag == QUIC.String() { - return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil } - http2Percentage, err := fetchFunc() - if err != nil { - log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") - return &staticProtocolSelector{ - current: HTTP2, - }, nil - } if protocolFlag == HTTP2.String() { - if http2Percentage < 0 { - return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil - } - return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil } if protocolFlag != autoSelectFlag { return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) } - threshold := switchThreshold(namedTunnel.Credentials.AccountTag) - if threshold < http2Percentage { - return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil + + return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil +} + +func selectWarpRoutingProtocols( + protocolFlag string, + fetchFunc PercentageFetcher, + ttl time.Duration, + log *zerolog.Logger, + threshold int32, + protocol Protocol, +) (ProtocolSelector, error) { + if protocolFlag == QUIC.String() { + return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil } - return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil + + if protocolFlag == HTTP2.String() { + return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + } + + if protocolFlag != autoSelectFlag { + return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) + } + + return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil } func switchThreshold(accountTag string) int32 { diff --git a/connection/protocol_test.go b/connection/protocol_test.go index b4a6299a..1139c657 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/stretchr/testify/assert" + + "github.com/cloudflare/cloudflared/edgediscovery" ) const ( @@ -21,29 +23,23 @@ var ( } ) -func mockFetcher(percentage int32) PercentageFetcher { - return func() (int32, error) { - return percentage, nil - } -} - -func mockFetcherWithError() PercentageFetcher { - return func() (int32, error) { - return 0, fmt.Errorf("failed to fetch precentage") +func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { + return func() (edgediscovery.ProtocolPercents, error) { + if getError { + return nil, fmt.Errorf("failed to fetch precentage") + } + return protocolPercent, nil } } type dynamicMockFetcher struct { - percentage int32 - err error + protocolPercents edgediscovery.ProtocolPercents + err error } func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { - return func() (int32, error) { - if dmf.err != nil { - return 0, dmf.err - } - return dmf.percentage, nil + return func() (edgediscovery.ProtocolPercents, error) { + return dmf.protocolPercents, dmf.err } } @@ -69,6 +65,7 @@ func TestNewProtocolSelector(t *testing.T) { name: "named tunnel over h2mux", protocol: "h2mux", expectedProtocol: H2mux, + fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, namedTunnelConfig: testNamedTunnelConfig, }, { @@ -77,28 +74,38 @@ func TestNewProtocolSelector(t *testing.T) { expectedProtocol: HTTP2, hasFallback: true, expectedFallback: H2mux, - fetchFunc: mockFetcher(0), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel http2 disabled", protocol: "http2", expectedProtocol: H2mux, - fetchFunc: mockFetcher(-1), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel quic disabled", + protocol: "quic", + expectedProtocol: HTTP2, + // Hasfallback true is because if http2 fails, then we further fallback to h2mux. + hasFallback: true, + expectedFallback: H2mux, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel auto all http2 disabled", protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(-1), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel auto to h2mux", protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(0), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), namedTunnelConfig: testNamedTunnelConfig, }, { @@ -107,36 +114,71 @@ func TestNewProtocolSelector(t *testing.T) { expectedProtocol: HTTP2, hasFallback: true, expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel auto to quic", + protocol: "auto", + expectedProtocol: QUIC, + hasFallback: true, + expectedFallback: HTTP2, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing requesting h2mux", protocol: "h2mux", - expectedProtocol: HTTP2, + expectedProtocol: HTTP2Warp, hasFallback: false, - expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", + protocol: "h2mux", + expectedProtocol: HTTP2Warp, + hasFallback: false, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing http2", protocol: "http2", - expectedProtocol: HTTP2, + expectedProtocol: HTTP2Warp, hasFallback: false, - expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing quic", + protocol: "quic", + expectedProtocol: QUICWarp, + hasFallback: true, + expectedFallback: HTTP2Warp, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing auto", protocol: "auto", - expectedProtocol: HTTP2, + expectedProtocol: HTTP2Warp, hasFallback: false, - expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing auto- quic", + protocol: "auto", + expectedProtocol: QUICWarp, + hasFallback: true, + expectedFallback: HTTP2Warp, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, @@ -149,14 +191,14 @@ func TestNewProtocolSelector(t *testing.T) { { name: "named tunnel unknown protocol", protocol: "unknown", - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), namedTunnelConfig: testNamedTunnelConfig, wantErr: true, }, { name: "named tunnel fetch error", - protocol: "unknown", - fetchFunc: mockFetcherWithError(), + protocol: "auto", + fetchFunc: mockFetcher(true), namedTunnelConfig: testNamedTunnelConfig, expectedProtocol: HTTP2, wantErr: false, @@ -164,18 +206,20 @@ func TestNewProtocolSelector(t *testing.T) { } for _, test := range tests { - selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) - if test.wantErr { - assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) - } else { - assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) - assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) - fallback, ok := selector.Fallback() - assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) - if test.hasFallback { - assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) + t.Run(test.name, func(t *testing.T) { + selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) + if test.wantErr { + assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) + } else { + assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) + assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) + fallback, ok := selector.Fallback() + assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) + if test.hasFallback { + assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) + } } - } + }) } } @@ -185,27 +229,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = -1 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 100 - assert.Equal(t, HTTP2, selector.Current()) + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} + assert.Equal(t, QUIC, selector.Current()) } func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { @@ -214,35 +258,36 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = -1 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = -1 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} assert.Equal(t, H2mux, selector.Current()) } func TestProtocolSelectorRefreshTTL(t *testing.T) { - fetcher := dynamicMockFetcher{percentage: 100} + fetcher := dynamicMockFetcher{} + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) assert.NoError(t, err) - assert.Equal(t, HTTP2, selector.Current()) + assert.Equal(t, QUIC, selector.Current()) - fetcher.percentage = 0 - assert.Equal(t, HTTP2, selector.Current()) + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} + assert.Equal(t, QUIC, selector.Current()) } diff --git a/edgediscovery/protocol.go b/edgediscovery/protocol.go index 8d9f5039..5bbb1e91 100644 --- a/edgediscovery/protocol.go +++ b/edgediscovery/protocol.go @@ -1,45 +1,50 @@ package edgediscovery import ( + "encoding/json" "fmt" "net" - "strconv" "strings" ) const ( - protocolRecord = "protocol.argotunnel.com" + protocolRecord = "protocol-v2.argotunnel.com" ) var ( errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) ) -func HTTP2Percentage() (int32, error) { +// ProtocolPercent represents a single Protocol Percentage combination. +type ProtocolPercent struct { + Protocol string `json:"protocol"` + Percentage int32 `json:"percentage"` +} + +// ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified. +type ProtocolPercents []ProtocolPercent + +// GetPercentage returns the threshold percentage of a single protocol. +func (p ProtocolPercents) GetPercentage(protocol string) int32 { + for _, protocolPercent := range p { + if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) { + return protocolPercent.Percentage + } + } + return 0 +} + +// ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection. +func ProtocolPercentage() (ProtocolPercents, error) { records, err := net.LookupTXT(protocolRecord) if err != nil { - return 0, err + return nil, err } if len(records) == 0 { - return 0, errNoProtocolRecord + return nil, errNoProtocolRecord } - return parseHTTP2Precentage(records[0]) -} - -// The record looks like http2=percentage -func parseHTTP2Precentage(record string) (int32, error) { - const key = "http2" - slices := strings.Split(record, "=") - if len(slices) != 2 { - return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record) - } - if slices[0] != key { - return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key) - } - percentage, err := strconv.ParseInt(slices[1], 10, 32) - if err != nil { - return 0, err - } - return int32(percentage), nil + var protocolsWithPercent ProtocolPercents + err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent) + return protocolsWithPercent, err } diff --git a/edgediscovery/protocol_test.go b/edgediscovery/protocol_test.go index 874ab6ee..37b9353f 100644 --- a/edgediscovery/protocol_test.go +++ b/edgediscovery/protocol_test.go @@ -6,75 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestHTTP2Percentage(t *testing.T) { - _, err := HTTP2Percentage() +func TestProtocolPercentage(t *testing.T) { + _, err := ProtocolPercentage() assert.NoError(t, err) } - -func TestParseHTTP2Precentage(t *testing.T) { - tests := []struct { - record string - percentage int32 - wantErr bool - }{ - { - record: "http2=-1", - percentage: -1, - wantErr: false, - }, - { - record: "http2=0", - percentage: 0, - wantErr: false, - }, - { - record: "http2=50", - percentage: 50, - wantErr: false, - }, - { - record: "http2=100", - percentage: 100, - wantErr: false, - }, - { - record: "http2=1000", - percentage: 1000, - wantErr: false, - }, - { - record: "http2=10.5", - wantErr: true, - }, - { - record: "http2=10 h2mux=90", - wantErr: true, - }, - { - record: "http2=ten", - wantErr: true, - }, - - { - record: "h2mux=100", - wantErr: true, - }, - { - record: "http2", - wantErr: true, - }, - { - record: "http2=", - wantErr: true, - }, - } - - for _, test := range tests { - p, err := parseHTTP2Precentage(test.record) - if test.wantErr { - assert.Error(t, err) - } else { - assert.Equal(t, test.percentage, p) - } - } -} diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index 63b547da..28c2f9ee 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -8,20 +8,18 @@ import ( "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/retry" ) type dynamicMockFetcher struct { - percentage int32 - err error + protocolPercents edgediscovery.ProtocolPercents + err error } func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { - return func() (int32, error) { - if dmf.err != nil { - return 0, dmf.err - } - return dmf.percentage, nil + return func() (edgediscovery.ProtocolPercents, error) { + return dmf.protocolPercents, dmf.err } } @@ -39,7 +37,7 @@ func TestWaitForBackoffFallback(t *testing.T) { }, } mockFetcher := dynamicMockFetcher{ - percentage: 0, + protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, } warpRoutingEnabled := false protocolSelector, err := connection.NewProtocolSelector(