TUN-5138: Switch to QUIC on auto protocol based on threshold

This commit is contained in:
Sudarsan Reddy 2021-10-11 11:31:05 +01:00
parent bccf4a63dc
commit e445fd92f7
6 changed files with 248 additions and 216 deletions

View File

@ -238,7 +238,7 @@ func prepareTunnelConfig(
log.Info().Msgf("Warp-routing is enabled") log.Info().Msgf("Warp-routing is enabled")
} }
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, err return nil, ingress.Ingress{}, err
} }

View File

@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/edgediscovery"
) )
const ( const (
@ -24,7 +26,7 @@ const (
var ( var (
// ProtocolList represents a list of supported protocols for communication with the edge. // ProtocolList represents a list of supported protocols for communication with the edge.
ProtocolList = []Protocol{H2mux, HTTP2, QUIC} ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp}
) )
type Protocol int64 type Protocol int64
@ -36,6 +38,12 @@ const (
HTTP2 HTTP2
// QUIC is used only with named tunnels. // QUIC is used only with named tunnels.
QUIC QUIC
// HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
// H2mux on HTTP2 failure to connect.
HTTP2Warp
//QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
// dont' want HTTP2 to fallback to H2mux
QUICWarp
) )
// Fallback returns the fallback protocol and whether the protocol has a fallback // Fallback returns the fallback protocol and whether the protocol has a fallback
@ -45,8 +53,12 @@ func (p Protocol) fallback() (Protocol, bool) {
return 0, false return 0, false
case HTTP2: case HTTP2:
return H2mux, true return H2mux, true
case HTTP2Warp:
return 0, false
case QUIC: case QUIC:
return HTTP2, true return HTTP2, true
case QUICWarp:
return HTTP2Warp, true
default: default:
return 0, false return 0, false
} }
@ -56,9 +68,9 @@ func (p Protocol) String() string {
switch p { switch p {
case H2mux: case H2mux:
return "h2mux" return "h2mux"
case HTTP2: case HTTP2, HTTP2Warp:
return "http2" return "http2"
case QUIC: case QUIC, QUICWarp:
return "quic" return "quic"
default: default:
return fmt.Sprintf("unknown protocol") return fmt.Sprintf("unknown protocol")
@ -71,11 +83,11 @@ func (p Protocol) TLSSettings() *TLSSettings {
return &TLSSettings{ return &TLSSettings{
ServerName: edgeH2muxTLSServerName, ServerName: edgeH2muxTLSServerName,
} }
case HTTP2: case HTTP2, HTTP2Warp:
return &TLSSettings{ return &TLSSettings{
ServerName: edgeH2TLSServerName, ServerName: edgeH2TLSServerName,
} }
case QUIC: case QUIC, QUICWarp:
return &TLSSettings{ return &TLSSettings{
ServerName: edgeQUICServerName, ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"}, NextProtos: []string{"argotunnel"},
@ -108,29 +120,36 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
} }
type autoProtocolSelector struct { type autoProtocolSelector struct {
lock sync.RWMutex lock sync.RWMutex
current Protocol
switchThrehold int32 current Protocol
fetchFunc PercentageFetcher
refreshAfter time.Time // protocolPool is desired protocols in the order of priority they should be picked in.
ttl time.Duration protocolPool []Protocol
log *zerolog.Logger
switchThreshold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
} }
func newAutoProtocolSelector( func newAutoProtocolSelector(
current Protocol, current Protocol,
switchThrehold int32, protocolPool []Protocol,
switchThreshold int32,
fetchFunc PercentageFetcher, fetchFunc PercentageFetcher,
ttl time.Duration, ttl time.Duration,
log *zerolog.Logger, log *zerolog.Logger,
) *autoProtocolSelector { ) *autoProtocolSelector {
return &autoProtocolSelector{ return &autoProtocolSelector{
current: current, current: current,
switchThrehold: switchThrehold, protocolPool: protocolPool,
fetchFunc: fetchFunc, switchThreshold: switchThreshold,
refreshAfter: time.Now().Add(ttl), fetchFunc: fetchFunc,
ttl: ttl, refreshAfter: time.Now().Add(ttl),
log: log, ttl: ttl,
log: log,
} }
} }
@ -141,28 +160,39 @@ func (s *autoProtocolSelector) Current() Protocol {
return s.current return s.current
} }
percentage, err := s.fetchFunc() protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold)
if err != nil { if err != nil {
s.log.Err(err).Msg("Failed to refresh protocol") s.log.Err(err).Msg("Failed to refresh protocol")
return s.current return s.current
} }
s.current = protocol
if s.switchThrehold < percentage {
s.current = HTTP2
} else {
s.current = H2mux
}
s.refreshAfter = time.Now().Add(s.ttl) s.refreshAfter = time.Now().Add(s.ttl)
return s.current return s.current
} }
func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) {
protocolPercentages, err := fetchFunc()
if err != nil {
return 0, err
}
for _, protocol := range protocolPool {
protocolPercentage := protocolPercentages.GetPercentage(protocol.String())
if protocolPercentage > switchThreshold {
return protocol, nil
}
}
return protocolPool[len(protocolPool)-1], nil
}
func (s *autoProtocolSelector) Fallback() (Protocol, bool) { func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock() s.lock.RLock()
defer s.lock.RUnlock() defer s.lock.RUnlock()
return s.current.fallback() return s.current.fallback()
} }
type PercentageFetcher func() (int32, error) type PercentageFetcher func() (edgediscovery.ProtocolPercents, error)
func NewProtocolSelector( func NewProtocolSelector(
protocolFlag string, protocolFlag string,
@ -179,22 +209,34 @@ func NewProtocolSelector(
}, nil }, nil
} }
// warp routing cannot be served over h2mux connections threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
if warpRoutingEnabled { fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold)
if protocolFlag == H2mux.String() { if err != nil {
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
}
if protocolFlag == QUIC.String() {
return &staticProtocolSelector{
current: QUIC,
}, nil
}
return &staticProtocolSelector{ return &staticProtocolSelector{
current: HTTP2, current: HTTP2,
}, nil }, nil
} }
if warpRoutingEnabled {
if protocolFlag == H2mux.String() || fetchedProtocol == H2mux {
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
protocolFlag = HTTP2.String()
fetchedProtocol = HTTP2Warp
}
return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
}
return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
}
func selectNamedTunnelProtocols(
protocolFlag string,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
threshold int32,
protocol Protocol,
) (ProtocolSelector, error) {
if protocolFlag == H2mux.String() { if protocolFlag == H2mux.String() {
return &staticProtocolSelector{ return &staticProtocolSelector{
current: H2mux, current: H2mux,
@ -202,31 +244,41 @@ func NewProtocolSelector(
} }
if protocolFlag == QUIC.String() { if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
} }
http2Percentage, err := fetchFunc()
if err != nil {
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
return &staticProtocolSelector{
current: HTTP2,
}, nil
}
if protocolFlag == HTTP2.String() { if protocolFlag == HTTP2.String() {
if http2Percentage < 0 { return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
} }
if protocolFlag != autoSelectFlag { if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
} }
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
if threshold < http2Percentage { return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil }
func selectWarpRoutingProtocols(
protocolFlag string,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
threshold int32,
protocol Protocol,
) (ProtocolSelector, error) {
if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
} }
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil
if protocolFlag == HTTP2.String() {
return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil
} }
func switchThreshold(accountTag string) int32 { func switchThreshold(accountTag string) int32 {

View File

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/edgediscovery"
) )
const ( const (
@ -21,29 +23,23 @@ var (
} }
) )
func mockFetcher(percentage int32) PercentageFetcher { func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher {
return func() (int32, error) { return func() (edgediscovery.ProtocolPercents, error) {
return percentage, nil if getError {
} return nil, fmt.Errorf("failed to fetch precentage")
} }
return protocolPercent, nil
func mockFetcherWithError() PercentageFetcher {
return func() (int32, error) {
return 0, fmt.Errorf("failed to fetch precentage")
} }
} }
type dynamicMockFetcher struct { type dynamicMockFetcher struct {
percentage int32 protocolPercents edgediscovery.ProtocolPercents
err error err error
} }
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
return func() (int32, error) { return func() (edgediscovery.ProtocolPercents, error) {
if dmf.err != nil { return dmf.protocolPercents, dmf.err
return 0, dmf.err
}
return dmf.percentage, nil
} }
} }
@ -69,6 +65,7 @@ func TestNewProtocolSelector(t *testing.T) {
name: "named tunnel over h2mux", name: "named tunnel over h2mux",
protocol: "h2mux", protocol: "h2mux",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil },
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
@ -77,28 +74,38 @@ func TestNewProtocolSelector(t *testing.T) {
expectedProtocol: HTTP2, expectedProtocol: HTTP2,
hasFallback: true, hasFallback: true,
expectedFallback: H2mux, expectedFallback: H2mux,
fetchFunc: mockFetcher(0), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
name: "named tunnel http2 disabled", name: "named tunnel http2 disabled",
protocol: "http2", protocol: "http2",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel quic disabled",
protocol: "quic",
expectedProtocol: HTTP2,
// Hasfallback true is because if http2 fails, then we further fallback to h2mux.
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
name: "named tunnel auto all http2 disabled", name: "named tunnel auto all http2 disabled",
protocol: "auto", protocol: "auto",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
name: "named tunnel auto to h2mux", name: "named tunnel auto to h2mux",
protocol: "auto", protocol: "auto",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: mockFetcher(0), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
@ -107,36 +114,71 @@ func TestNewProtocolSelector(t *testing.T) {
expectedProtocol: HTTP2, expectedProtocol: HTTP2,
hasFallback: true, hasFallback: true,
expectedFallback: H2mux, expectedFallback: H2mux,
fetchFunc: mockFetcher(100), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to quic",
protocol: "auto",
expectedProtocol: QUIC,
hasFallback: true,
expectedFallback: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
name: "warp routing requesting h2mux", name: "warp routing requesting h2mux",
protocol: "h2mux", protocol: "h2mux",
expectedProtocol: HTTP2, expectedProtocol: HTTP2Warp,
hasFallback: false, hasFallback: false,
expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
fetchFunc: mockFetcher(100), warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1",
protocol: "h2mux",
expectedProtocol: HTTP2Warp,
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
name: "warp routing http2", name: "warp routing http2",
protocol: "http2", protocol: "http2",
expectedProtocol: HTTP2, expectedProtocol: HTTP2Warp,
hasFallback: false, hasFallback: false,
expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
fetchFunc: mockFetcher(100), warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing quic",
protocol: "quic",
expectedProtocol: QUICWarp,
hasFallback: true,
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
{ {
name: "warp routing auto", name: "warp routing auto",
protocol: "auto", protocol: "auto",
expectedProtocol: HTTP2, expectedProtocol: HTTP2Warp,
hasFallback: false, hasFallback: false,
expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
fetchFunc: mockFetcher(100), warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing auto- quic",
protocol: "auto",
expectedProtocol: QUICWarp,
hasFallback: true,
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
}, },
@ -149,14 +191,14 @@ func TestNewProtocolSelector(t *testing.T) {
{ {
name: "named tunnel unknown protocol", name: "named tunnel unknown protocol",
protocol: "unknown", protocol: "unknown",
fetchFunc: mockFetcher(100), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
wantErr: true, wantErr: true,
}, },
{ {
name: "named tunnel fetch error", name: "named tunnel fetch error",
protocol: "unknown", protocol: "auto",
fetchFunc: mockFetcherWithError(), fetchFunc: mockFetcher(true),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelConfig,
expectedProtocol: HTTP2, expectedProtocol: HTTP2,
wantErr: false, wantErr: false,
@ -164,18 +206,20 @@ func TestNewProtocolSelector(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) t.Run(test.name, func(t *testing.T) {
if test.wantErr { selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) if test.wantErr {
} else { assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) } else {
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback() assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) fallback, ok := selector.Fallback()
if test.hasFallback { assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
} }
} })
} }
} }
@ -185,27 +229,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch") fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.err = nil fetcher.err = nil
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, QUIC, selector.Current())
} }
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
@ -214,35 +258,36 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch") fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.err = nil fetcher.err = nil
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
} }
func TestProtocolSelectorRefreshTTL(t *testing.T) { func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{percentage: 100} fetcher := dynamicMockFetcher{}
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, QUIC, selector.Current())
fetcher.percentage = 0 fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, QUIC, selector.Current())
} }

View File

@ -1,45 +1,50 @@
package edgediscovery package edgediscovery
import ( import (
"encoding/json"
"fmt" "fmt"
"net" "net"
"strconv"
"strings" "strings"
) )
const ( const (
protocolRecord = "protocol.argotunnel.com" protocolRecord = "protocol-v2.argotunnel.com"
) )
var ( var (
errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord)
) )
func HTTP2Percentage() (int32, error) { // ProtocolPercent represents a single Protocol Percentage combination.
type ProtocolPercent struct {
Protocol string `json:"protocol"`
Percentage int32 `json:"percentage"`
}
// ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified.
type ProtocolPercents []ProtocolPercent
// GetPercentage returns the threshold percentage of a single protocol.
func (p ProtocolPercents) GetPercentage(protocol string) int32 {
for _, protocolPercent := range p {
if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) {
return protocolPercent.Percentage
}
}
return 0
}
// ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection.
func ProtocolPercentage() (ProtocolPercents, error) {
records, err := net.LookupTXT(protocolRecord) records, err := net.LookupTXT(protocolRecord)
if err != nil { if err != nil {
return 0, err return nil, err
} }
if len(records) == 0 { if len(records) == 0 {
return 0, errNoProtocolRecord return nil, errNoProtocolRecord
} }
return parseHTTP2Precentage(records[0])
}
// The record looks like http2=percentage
func parseHTTP2Precentage(record string) (int32, error) {
const key = "http2"
slices := strings.Split(record, "=")
if len(slices) != 2 {
return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record)
}
if slices[0] != key {
return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key)
}
percentage, err := strconv.ParseInt(slices[1], 10, 32)
if err != nil {
return 0, err
}
return int32(percentage), nil
var protocolsWithPercent ProtocolPercents
err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent)
return protocolsWithPercent, err
} }

View File

@ -6,75 +6,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHTTP2Percentage(t *testing.T) { func TestProtocolPercentage(t *testing.T) {
_, err := HTTP2Percentage() _, err := ProtocolPercentage()
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestParseHTTP2Precentage(t *testing.T) {
tests := []struct {
record string
percentage int32
wantErr bool
}{
{
record: "http2=-1",
percentage: -1,
wantErr: false,
},
{
record: "http2=0",
percentage: 0,
wantErr: false,
},
{
record: "http2=50",
percentage: 50,
wantErr: false,
},
{
record: "http2=100",
percentage: 100,
wantErr: false,
},
{
record: "http2=1000",
percentage: 1000,
wantErr: false,
},
{
record: "http2=10.5",
wantErr: true,
},
{
record: "http2=10 h2mux=90",
wantErr: true,
},
{
record: "http2=ten",
wantErr: true,
},
{
record: "h2mux=100",
wantErr: true,
},
{
record: "http2",
wantErr: true,
},
{
record: "http2=",
wantErr: true,
},
}
for _, test := range tests {
p, err := parseHTTP2Precentage(test.record)
if test.wantErr {
assert.Error(t, err)
} else {
assert.Equal(t, test.percentage, p)
}
}
}

View File

@ -8,20 +8,18 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
) )
type dynamicMockFetcher struct { type dynamicMockFetcher struct {
percentage int32 protocolPercents edgediscovery.ProtocolPercents
err error err error
} }
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
return func() (int32, error) { return func() (edgediscovery.ProtocolPercents, error) {
if dmf.err != nil { return dmf.protocolPercents, dmf.err
return 0, dmf.err
}
return dmf.percentage, nil
} }
} }
@ -39,7 +37,7 @@ func TestWaitForBackoffFallback(t *testing.T) {
}, },
} }
mockFetcher := dynamicMockFetcher{ mockFetcher := dynamicMockFetcher{
percentage: 0, protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}},
} }
warpRoutingEnabled := false warpRoutingEnabled := false
protocolSelector, err := connection.NewProtocolSelector( protocolSelector, err := connection.NewProtocolSelector(