TUN-5138: Switch to QUIC on auto protocol based on threshold
This commit is contained in:
parent
bccf4a63dc
commit
e445fd92f7
|
@ -238,7 +238,7 @@ func prepareTunnelConfig(
|
||||||
log.Info().Msgf("Warp-routing is enabled")
|
log.Info().Msgf("Warp-routing is enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log)
|
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ingress.Ingress{}, err
|
return nil, ingress.Ingress{}, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -24,7 +26,7 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ProtocolList represents a list of supported protocols for communication with the edge.
|
// ProtocolList represents a list of supported protocols for communication with the edge.
|
||||||
ProtocolList = []Protocol{H2mux, HTTP2, QUIC}
|
ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp}
|
||||||
)
|
)
|
||||||
|
|
||||||
type Protocol int64
|
type Protocol int64
|
||||||
|
@ -36,6 +38,12 @@ const (
|
||||||
HTTP2
|
HTTP2
|
||||||
// QUIC is used only with named tunnels.
|
// QUIC is used only with named tunnels.
|
||||||
QUIC
|
QUIC
|
||||||
|
// HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
|
||||||
|
// H2mux on HTTP2 failure to connect.
|
||||||
|
HTTP2Warp
|
||||||
|
//QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
|
||||||
|
// dont' want HTTP2 to fallback to H2mux
|
||||||
|
QUICWarp
|
||||||
)
|
)
|
||||||
|
|
||||||
// Fallback returns the fallback protocol and whether the protocol has a fallback
|
// Fallback returns the fallback protocol and whether the protocol has a fallback
|
||||||
|
@ -45,8 +53,12 @@ func (p Protocol) fallback() (Protocol, bool) {
|
||||||
return 0, false
|
return 0, false
|
||||||
case HTTP2:
|
case HTTP2:
|
||||||
return H2mux, true
|
return H2mux, true
|
||||||
|
case HTTP2Warp:
|
||||||
|
return 0, false
|
||||||
case QUIC:
|
case QUIC:
|
||||||
return HTTP2, true
|
return HTTP2, true
|
||||||
|
case QUICWarp:
|
||||||
|
return HTTP2Warp, true
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
@ -56,9 +68,9 @@ func (p Protocol) String() string {
|
||||||
switch p {
|
switch p {
|
||||||
case H2mux:
|
case H2mux:
|
||||||
return "h2mux"
|
return "h2mux"
|
||||||
case HTTP2:
|
case HTTP2, HTTP2Warp:
|
||||||
return "http2"
|
return "http2"
|
||||||
case QUIC:
|
case QUIC, QUICWarp:
|
||||||
return "quic"
|
return "quic"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("unknown protocol")
|
return fmt.Sprintf("unknown protocol")
|
||||||
|
@ -71,11 +83,11 @@ func (p Protocol) TLSSettings() *TLSSettings {
|
||||||
return &TLSSettings{
|
return &TLSSettings{
|
||||||
ServerName: edgeH2muxTLSServerName,
|
ServerName: edgeH2muxTLSServerName,
|
||||||
}
|
}
|
||||||
case HTTP2:
|
case HTTP2, HTTP2Warp:
|
||||||
return &TLSSettings{
|
return &TLSSettings{
|
||||||
ServerName: edgeH2TLSServerName,
|
ServerName: edgeH2TLSServerName,
|
||||||
}
|
}
|
||||||
case QUIC:
|
case QUIC, QUICWarp:
|
||||||
return &TLSSettings{
|
return &TLSSettings{
|
||||||
ServerName: edgeQUICServerName,
|
ServerName: edgeQUICServerName,
|
||||||
NextProtos: []string{"argotunnel"},
|
NextProtos: []string{"argotunnel"},
|
||||||
|
@ -108,29 +120,36 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type autoProtocolSelector struct {
|
type autoProtocolSelector struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
current Protocol
|
|
||||||
switchThrehold int32
|
current Protocol
|
||||||
fetchFunc PercentageFetcher
|
|
||||||
refreshAfter time.Time
|
// protocolPool is desired protocols in the order of priority they should be picked in.
|
||||||
ttl time.Duration
|
protocolPool []Protocol
|
||||||
log *zerolog.Logger
|
|
||||||
|
switchThreshold int32
|
||||||
|
fetchFunc PercentageFetcher
|
||||||
|
refreshAfter time.Time
|
||||||
|
ttl time.Duration
|
||||||
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAutoProtocolSelector(
|
func newAutoProtocolSelector(
|
||||||
current Protocol,
|
current Protocol,
|
||||||
switchThrehold int32,
|
protocolPool []Protocol,
|
||||||
|
switchThreshold int32,
|
||||||
fetchFunc PercentageFetcher,
|
fetchFunc PercentageFetcher,
|
||||||
ttl time.Duration,
|
ttl time.Duration,
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
) *autoProtocolSelector {
|
) *autoProtocolSelector {
|
||||||
return &autoProtocolSelector{
|
return &autoProtocolSelector{
|
||||||
current: current,
|
current: current,
|
||||||
switchThrehold: switchThrehold,
|
protocolPool: protocolPool,
|
||||||
fetchFunc: fetchFunc,
|
switchThreshold: switchThreshold,
|
||||||
refreshAfter: time.Now().Add(ttl),
|
fetchFunc: fetchFunc,
|
||||||
ttl: ttl,
|
refreshAfter: time.Now().Add(ttl),
|
||||||
log: log,
|
ttl: ttl,
|
||||||
|
log: log,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,28 +160,39 @@ func (s *autoProtocolSelector) Current() Protocol {
|
||||||
return s.current
|
return s.current
|
||||||
}
|
}
|
||||||
|
|
||||||
percentage, err := s.fetchFunc()
|
protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Err(err).Msg("Failed to refresh protocol")
|
s.log.Err(err).Msg("Failed to refresh protocol")
|
||||||
return s.current
|
return s.current
|
||||||
}
|
}
|
||||||
|
s.current = protocol
|
||||||
|
|
||||||
if s.switchThrehold < percentage {
|
|
||||||
s.current = HTTP2
|
|
||||||
} else {
|
|
||||||
s.current = H2mux
|
|
||||||
}
|
|
||||||
s.refreshAfter = time.Now().Add(s.ttl)
|
s.refreshAfter = time.Now().Add(s.ttl)
|
||||||
return s.current
|
return s.current
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) {
|
||||||
|
protocolPercentages, err := fetchFunc()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
for _, protocol := range protocolPool {
|
||||||
|
protocolPercentage := protocolPercentages.GetPercentage(protocol.String())
|
||||||
|
if protocolPercentage > switchThreshold {
|
||||||
|
return protocol, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return protocolPool[len(protocolPool)-1], nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
|
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
|
||||||
s.lock.RLock()
|
s.lock.RLock()
|
||||||
defer s.lock.RUnlock()
|
defer s.lock.RUnlock()
|
||||||
return s.current.fallback()
|
return s.current.fallback()
|
||||||
}
|
}
|
||||||
|
|
||||||
type PercentageFetcher func() (int32, error)
|
type PercentageFetcher func() (edgediscovery.ProtocolPercents, error)
|
||||||
|
|
||||||
func NewProtocolSelector(
|
func NewProtocolSelector(
|
||||||
protocolFlag string,
|
protocolFlag string,
|
||||||
|
@ -179,22 +209,34 @@ func NewProtocolSelector(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// warp routing cannot be served over h2mux connections
|
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
|
||||||
if warpRoutingEnabled {
|
fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold)
|
||||||
if protocolFlag == H2mux.String() {
|
if err != nil {
|
||||||
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
|
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
|
||||||
}
|
|
||||||
|
|
||||||
if protocolFlag == QUIC.String() {
|
|
||||||
return &staticProtocolSelector{
|
|
||||||
current: QUIC,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &staticProtocolSelector{
|
return &staticProtocolSelector{
|
||||||
current: HTTP2,
|
current: HTTP2,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
if warpRoutingEnabled {
|
||||||
|
if protocolFlag == H2mux.String() || fetchedProtocol == H2mux {
|
||||||
|
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
|
||||||
|
protocolFlag = HTTP2.String()
|
||||||
|
fetchedProtocol = HTTP2Warp
|
||||||
|
}
|
||||||
|
return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectNamedTunnelProtocols(
|
||||||
|
protocolFlag string,
|
||||||
|
fetchFunc PercentageFetcher,
|
||||||
|
ttl time.Duration,
|
||||||
|
log *zerolog.Logger,
|
||||||
|
threshold int32,
|
||||||
|
protocol Protocol,
|
||||||
|
) (ProtocolSelector, error) {
|
||||||
if protocolFlag == H2mux.String() {
|
if protocolFlag == H2mux.String() {
|
||||||
return &staticProtocolSelector{
|
return &staticProtocolSelector{
|
||||||
current: H2mux,
|
current: H2mux,
|
||||||
|
@ -202,31 +244,41 @@ func NewProtocolSelector(
|
||||||
}
|
}
|
||||||
|
|
||||||
if protocolFlag == QUIC.String() {
|
if protocolFlag == QUIC.String() {
|
||||||
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
http2Percentage, err := fetchFunc()
|
|
||||||
if err != nil {
|
|
||||||
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
|
|
||||||
return &staticProtocolSelector{
|
|
||||||
current: HTTP2,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
if protocolFlag == HTTP2.String() {
|
if protocolFlag == HTTP2.String() {
|
||||||
if http2Percentage < 0 {
|
return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
||||||
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
|
||||||
}
|
|
||||||
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if protocolFlag != autoSelectFlag {
|
if protocolFlag != autoSelectFlag {
|
||||||
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||||
}
|
}
|
||||||
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
|
|
||||||
if threshold < http2Percentage {
|
return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil
|
||||||
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil
|
}
|
||||||
|
|
||||||
|
func selectWarpRoutingProtocols(
|
||||||
|
protocolFlag string,
|
||||||
|
fetchFunc PercentageFetcher,
|
||||||
|
ttl time.Duration,
|
||||||
|
log *zerolog.Logger,
|
||||||
|
threshold int32,
|
||||||
|
protocol Protocol,
|
||||||
|
) (ProtocolSelector, error) {
|
||||||
|
if protocolFlag == QUIC.String() {
|
||||||
|
return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
||||||
}
|
}
|
||||||
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil
|
|
||||||
|
if protocolFlag == HTTP2.String() {
|
||||||
|
return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocolFlag != autoSelectFlag {
|
||||||
|
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func switchThreshold(accountTag string) int32 {
|
func switchThreshold(accountTag string) int32 {
|
||||||
|
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -21,29 +23,23 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func mockFetcher(percentage int32) PercentageFetcher {
|
func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher {
|
||||||
return func() (int32, error) {
|
return func() (edgediscovery.ProtocolPercents, error) {
|
||||||
return percentage, nil
|
if getError {
|
||||||
}
|
return nil, fmt.Errorf("failed to fetch precentage")
|
||||||
}
|
}
|
||||||
|
return protocolPercent, nil
|
||||||
func mockFetcherWithError() PercentageFetcher {
|
|
||||||
return func() (int32, error) {
|
|
||||||
return 0, fmt.Errorf("failed to fetch precentage")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type dynamicMockFetcher struct {
|
type dynamicMockFetcher struct {
|
||||||
percentage int32
|
protocolPercents edgediscovery.ProtocolPercents
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
|
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
|
||||||
return func() (int32, error) {
|
return func() (edgediscovery.ProtocolPercents, error) {
|
||||||
if dmf.err != nil {
|
return dmf.protocolPercents, dmf.err
|
||||||
return 0, dmf.err
|
|
||||||
}
|
|
||||||
return dmf.percentage, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,6 +65,7 @@ func TestNewProtocolSelector(t *testing.T) {
|
||||||
name: "named tunnel over h2mux",
|
name: "named tunnel over h2mux",
|
||||||
protocol: "h2mux",
|
protocol: "h2mux",
|
||||||
expectedProtocol: H2mux,
|
expectedProtocol: H2mux,
|
||||||
|
fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil },
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -77,28 +74,38 @@ func TestNewProtocolSelector(t *testing.T) {
|
||||||
expectedProtocol: HTTP2,
|
expectedProtocol: HTTP2,
|
||||||
hasFallback: true,
|
hasFallback: true,
|
||||||
expectedFallback: H2mux,
|
expectedFallback: H2mux,
|
||||||
fetchFunc: mockFetcher(0),
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "named tunnel http2 disabled",
|
name: "named tunnel http2 disabled",
|
||||||
protocol: "http2",
|
protocol: "http2",
|
||||||
expectedProtocol: H2mux,
|
expectedProtocol: H2mux,
|
||||||
fetchFunc: mockFetcher(-1),
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel quic disabled",
|
||||||
|
protocol: "quic",
|
||||||
|
expectedProtocol: HTTP2,
|
||||||
|
// Hasfallback true is because if http2 fails, then we further fallback to h2mux.
|
||||||
|
hasFallback: true,
|
||||||
|
expectedFallback: H2mux,
|
||||||
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "named tunnel auto all http2 disabled",
|
name: "named tunnel auto all http2 disabled",
|
||||||
protocol: "auto",
|
protocol: "auto",
|
||||||
expectedProtocol: H2mux,
|
expectedProtocol: H2mux,
|
||||||
fetchFunc: mockFetcher(-1),
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "named tunnel auto to h2mux",
|
name: "named tunnel auto to h2mux",
|
||||||
protocol: "auto",
|
protocol: "auto",
|
||||||
expectedProtocol: H2mux,
|
expectedProtocol: H2mux,
|
||||||
fetchFunc: mockFetcher(0),
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -107,36 +114,71 @@ func TestNewProtocolSelector(t *testing.T) {
|
||||||
expectedProtocol: HTTP2,
|
expectedProtocol: HTTP2,
|
||||||
hasFallback: true,
|
hasFallback: true,
|
||||||
expectedFallback: H2mux,
|
expectedFallback: H2mux,
|
||||||
fetchFunc: mockFetcher(100),
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel auto to quic",
|
||||||
|
protocol: "auto",
|
||||||
|
expectedProtocol: QUIC,
|
||||||
|
hasFallback: true,
|
||||||
|
expectedFallback: HTTP2,
|
||||||
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "warp routing requesting h2mux",
|
name: "warp routing requesting h2mux",
|
||||||
protocol: "h2mux",
|
protocol: "h2mux",
|
||||||
expectedProtocol: HTTP2,
|
expectedProtocol: HTTP2Warp,
|
||||||
hasFallback: false,
|
hasFallback: false,
|
||||||
expectedFallback: H2mux,
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
|
||||||
fetchFunc: mockFetcher(100),
|
warpRoutingEnabled: true,
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1",
|
||||||
|
protocol: "h2mux",
|
||||||
|
expectedProtocol: HTTP2Warp,
|
||||||
|
hasFallback: false,
|
||||||
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
|
||||||
warpRoutingEnabled: true,
|
warpRoutingEnabled: true,
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "warp routing http2",
|
name: "warp routing http2",
|
||||||
protocol: "http2",
|
protocol: "http2",
|
||||||
expectedProtocol: HTTP2,
|
expectedProtocol: HTTP2Warp,
|
||||||
hasFallback: false,
|
hasFallback: false,
|
||||||
expectedFallback: H2mux,
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
|
||||||
fetchFunc: mockFetcher(100),
|
warpRoutingEnabled: true,
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warp routing quic",
|
||||||
|
protocol: "quic",
|
||||||
|
expectedProtocol: QUICWarp,
|
||||||
|
hasFallback: true,
|
||||||
|
expectedFallback: HTTP2Warp,
|
||||||
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
|
||||||
warpRoutingEnabled: true,
|
warpRoutingEnabled: true,
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "warp routing auto",
|
name: "warp routing auto",
|
||||||
protocol: "auto",
|
protocol: "auto",
|
||||||
expectedProtocol: HTTP2,
|
expectedProtocol: HTTP2Warp,
|
||||||
hasFallback: false,
|
hasFallback: false,
|
||||||
expectedFallback: H2mux,
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
|
||||||
fetchFunc: mockFetcher(100),
|
warpRoutingEnabled: true,
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "warp routing auto- quic",
|
||||||
|
protocol: "auto",
|
||||||
|
expectedProtocol: QUICWarp,
|
||||||
|
hasFallback: true,
|
||||||
|
expectedFallback: HTTP2Warp,
|
||||||
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
|
||||||
warpRoutingEnabled: true,
|
warpRoutingEnabled: true,
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
},
|
},
|
||||||
|
@ -149,14 +191,14 @@ func TestNewProtocolSelector(t *testing.T) {
|
||||||
{
|
{
|
||||||
name: "named tunnel unknown protocol",
|
name: "named tunnel unknown protocol",
|
||||||
protocol: "unknown",
|
protocol: "unknown",
|
||||||
fetchFunc: mockFetcher(100),
|
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "named tunnel fetch error",
|
name: "named tunnel fetch error",
|
||||||
protocol: "unknown",
|
protocol: "auto",
|
||||||
fetchFunc: mockFetcherWithError(),
|
fetchFunc: mockFetcher(true),
|
||||||
namedTunnelConfig: testNamedTunnelConfig,
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
expectedProtocol: HTTP2,
|
expectedProtocol: HTTP2,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -164,18 +206,20 @@ func TestNewProtocolSelector(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
|
t.Run(test.name, func(t *testing.T) {
|
||||||
if test.wantErr {
|
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
|
||||||
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
|
if test.wantErr {
|
||||||
} else {
|
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
|
||||||
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
|
} else {
|
||||||
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
|
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
|
||||||
fallback, ok := selector.Fallback()
|
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
|
||||||
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
|
fallback, ok := selector.Fallback()
|
||||||
if test.hasFallback {
|
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
|
||||||
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
|
if test.hasFallback {
|
||||||
|
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,27 +229,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, H2mux, selector.Current())
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 100
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 0
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
|
||||||
assert.Equal(t, H2mux, selector.Current())
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 100
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.err = fmt.Errorf("failed to fetch")
|
fetcher.err = fmt.Errorf("failed to fetch")
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = -1
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
|
||||||
fetcher.err = nil
|
fetcher.err = nil
|
||||||
assert.Equal(t, H2mux, selector.Current())
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 0
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
|
||||||
assert.Equal(t, H2mux, selector.Current())
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 100
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, QUIC, selector.Current())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
|
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
|
||||||
|
@ -214,35 +258,36 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 100
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 0
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.err = fmt.Errorf("failed to fetch")
|
fetcher.err = fmt.Errorf("failed to fetch")
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = -1
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
|
||||||
fetcher.err = nil
|
fetcher.err = nil
|
||||||
assert.Equal(t, H2mux, selector.Current())
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 0
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 100
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = -1
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
|
||||||
assert.Equal(t, H2mux, selector.Current())
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProtocolSelectorRefreshTTL(t *testing.T) {
|
func TestProtocolSelectorRefreshTTL(t *testing.T) {
|
||||||
fetcher := dynamicMockFetcher{percentage: 100}
|
fetcher := dynamicMockFetcher{}
|
||||||
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
|
||||||
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
|
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, QUIC, selector.Current())
|
||||||
|
|
||||||
fetcher.percentage = 0
|
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}}
|
||||||
assert.Equal(t, HTTP2, selector.Current())
|
assert.Equal(t, QUIC, selector.Current())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,45 +1,50 @@
|
||||||
package edgediscovery
|
package edgediscovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
protocolRecord = "protocol.argotunnel.com"
|
protocolRecord = "protocol-v2.argotunnel.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord)
|
errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord)
|
||||||
)
|
)
|
||||||
|
|
||||||
func HTTP2Percentage() (int32, error) {
|
// ProtocolPercent represents a single Protocol Percentage combination.
|
||||||
|
type ProtocolPercent struct {
|
||||||
|
Protocol string `json:"protocol"`
|
||||||
|
Percentage int32 `json:"percentage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified.
|
||||||
|
type ProtocolPercents []ProtocolPercent
|
||||||
|
|
||||||
|
// GetPercentage returns the threshold percentage of a single protocol.
|
||||||
|
func (p ProtocolPercents) GetPercentage(protocol string) int32 {
|
||||||
|
for _, protocolPercent := range p {
|
||||||
|
if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) {
|
||||||
|
return protocolPercent.Percentage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection.
|
||||||
|
func ProtocolPercentage() (ProtocolPercents, error) {
|
||||||
records, err := net.LookupTXT(protocolRecord)
|
records, err := net.LookupTXT(protocolRecord)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(records) == 0 {
|
if len(records) == 0 {
|
||||||
return 0, errNoProtocolRecord
|
return nil, errNoProtocolRecord
|
||||||
}
|
}
|
||||||
return parseHTTP2Precentage(records[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// The record looks like http2=percentage
|
|
||||||
func parseHTTP2Precentage(record string) (int32, error) {
|
|
||||||
const key = "http2"
|
|
||||||
slices := strings.Split(record, "=")
|
|
||||||
if len(slices) != 2 {
|
|
||||||
return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record)
|
|
||||||
}
|
|
||||||
if slices[0] != key {
|
|
||||||
return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key)
|
|
||||||
}
|
|
||||||
percentage, err := strconv.ParseInt(slices[1], 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return int32(percentage), nil
|
|
||||||
|
|
||||||
|
var protocolsWithPercent ProtocolPercents
|
||||||
|
err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent)
|
||||||
|
return protocolsWithPercent, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,75 +6,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHTTP2Percentage(t *testing.T) {
|
func TestProtocolPercentage(t *testing.T) {
|
||||||
_, err := HTTP2Percentage()
|
_, err := ProtocolPercentage()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseHTTP2Precentage(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
record string
|
|
||||||
percentage int32
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
record: "http2=-1",
|
|
||||||
percentage: -1,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=0",
|
|
||||||
percentage: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=50",
|
|
||||||
percentage: 50,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=100",
|
|
||||||
percentage: 100,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=1000",
|
|
||||||
percentage: 1000,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=10.5",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=10 h2mux=90",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=ten",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
record: "h2mux=100",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
record: "http2=",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
p, err := parseHTTP2Precentage(test.record)
|
|
||||||
if test.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, test.percentage, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -8,20 +8,18 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/retry"
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dynamicMockFetcher struct {
|
type dynamicMockFetcher struct {
|
||||||
percentage int32
|
protocolPercents edgediscovery.ProtocolPercents
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
|
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
|
||||||
return func() (int32, error) {
|
return func() (edgediscovery.ProtocolPercents, error) {
|
||||||
if dmf.err != nil {
|
return dmf.protocolPercents, dmf.err
|
||||||
return 0, dmf.err
|
|
||||||
}
|
|
||||||
return dmf.percentage, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +37,7 @@ func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
mockFetcher := dynamicMockFetcher{
|
mockFetcher := dynamicMockFetcher{
|
||||||
percentage: 0,
|
protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}},
|
||||||
}
|
}
|
||||||
warpRoutingEnabled := false
|
warpRoutingEnabled := false
|
||||||
protocolSelector, err := connection.NewProtocolSelector(
|
protocolSelector, err := connection.NewProtocolSelector(
|
||||||
|
|
Loading…
Reference in New Issue