TUN-5138: Switch to QUIC on auto protocol based on threshold
This commit is contained in:
		
							parent
							
								
									5a3c0fdffa
								
							
						
					
					
						commit
						ceb509ee98
					
				|  | @ -238,7 +238,7 @@ func prepareTunnelConfig( | |||
| 		log.Info().Msgf("Warp-routing is enabled") | ||||
| 	} | ||||
| 
 | ||||
| 	protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) | ||||
| 	protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) | ||||
| 	if err != nil { | ||||
| 		return nil, ingress.Ingress{}, err | ||||
| 	} | ||||
|  |  | |||
|  | @ -7,6 +7,8 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/rs/zerolog" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/edgediscovery" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -24,7 +26,7 @@ const ( | |||
| 
 | ||||
| var ( | ||||
| 	// ProtocolList represents a list of supported protocols for communication with the edge.
 | ||||
| 	ProtocolList = []Protocol{H2mux, HTTP2, QUIC} | ||||
| 	ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp} | ||||
| ) | ||||
| 
 | ||||
| type Protocol int64 | ||||
|  | @ -36,6 +38,12 @@ const ( | |||
| 	HTTP2 | ||||
| 	// QUIC is used only with named tunnels.
 | ||||
| 	QUIC | ||||
| 	// HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
 | ||||
| 	// H2mux on HTTP2 failure to connect.
 | ||||
| 	HTTP2Warp | ||||
| 	//QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
 | ||||
| 	// dont' want HTTP2 to fallback to H2mux
 | ||||
| 	QUICWarp | ||||
| ) | ||||
| 
 | ||||
| // Fallback returns the fallback protocol and whether the protocol has a fallback
 | ||||
|  | @ -45,8 +53,12 @@ func (p Protocol) fallback() (Protocol, bool) { | |||
| 		return 0, false | ||||
| 	case HTTP2: | ||||
| 		return H2mux, true | ||||
| 	case HTTP2Warp: | ||||
| 		return 0, false | ||||
| 	case QUIC: | ||||
| 		return HTTP2, true | ||||
| 	case QUICWarp: | ||||
| 		return HTTP2Warp, true | ||||
| 	default: | ||||
| 		return 0, false | ||||
| 	} | ||||
|  | @ -56,9 +68,9 @@ func (p Protocol) String() string { | |||
| 	switch p { | ||||
| 	case H2mux: | ||||
| 		return "h2mux" | ||||
| 	case HTTP2: | ||||
| 	case HTTP2, HTTP2Warp: | ||||
| 		return "http2" | ||||
| 	case QUIC: | ||||
| 	case QUIC, QUICWarp: | ||||
| 		return "quic" | ||||
| 	default: | ||||
| 		return fmt.Sprintf("unknown protocol") | ||||
|  | @ -71,11 +83,11 @@ func (p Protocol) TLSSettings() *TLSSettings { | |||
| 		return &TLSSettings{ | ||||
| 			ServerName: edgeH2muxTLSServerName, | ||||
| 		} | ||||
| 	case HTTP2: | ||||
| 	case HTTP2, HTTP2Warp: | ||||
| 		return &TLSSettings{ | ||||
| 			ServerName: edgeH2TLSServerName, | ||||
| 		} | ||||
| 	case QUIC: | ||||
| 	case QUIC, QUICWarp: | ||||
| 		return &TLSSettings{ | ||||
| 			ServerName: edgeQUICServerName, | ||||
| 			NextProtos: []string{"argotunnel"}, | ||||
|  | @ -109,8 +121,13 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) { | |||
| 
 | ||||
| type autoProtocolSelector struct { | ||||
| 	lock sync.RWMutex | ||||
| 
 | ||||
| 	current Protocol | ||||
| 	switchThrehold int32 | ||||
| 
 | ||||
| 	// protocolPool is desired protocols in the order of priority they should be picked in.
 | ||||
| 	protocolPool []Protocol | ||||
| 
 | ||||
| 	switchThreshold int32 | ||||
| 	fetchFunc       PercentageFetcher | ||||
| 	refreshAfter    time.Time | ||||
| 	ttl             time.Duration | ||||
|  | @ -119,14 +136,16 @@ type autoProtocolSelector struct { | |||
| 
 | ||||
| func newAutoProtocolSelector( | ||||
| 	current Protocol, | ||||
| 	switchThrehold int32, | ||||
| 	protocolPool []Protocol, | ||||
| 	switchThreshold int32, | ||||
| 	fetchFunc PercentageFetcher, | ||||
| 	ttl time.Duration, | ||||
| 	log *zerolog.Logger, | ||||
| ) *autoProtocolSelector { | ||||
| 	return &autoProtocolSelector{ | ||||
| 		current:         current, | ||||
| 		switchThrehold: switchThrehold, | ||||
| 		protocolPool:    protocolPool, | ||||
| 		switchThreshold: switchThreshold, | ||||
| 		fetchFunc:       fetchFunc, | ||||
| 		refreshAfter:    time.Now().Add(ttl), | ||||
| 		ttl:             ttl, | ||||
|  | @ -141,28 +160,39 @@ func (s *autoProtocolSelector) Current() Protocol { | |||
| 		return s.current | ||||
| 	} | ||||
| 
 | ||||
| 	percentage, err := s.fetchFunc() | ||||
| 	protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold) | ||||
| 	if err != nil { | ||||
| 		s.log.Err(err).Msg("Failed to refresh protocol") | ||||
| 		return s.current | ||||
| 	} | ||||
| 	s.current = protocol | ||||
| 
 | ||||
| 	if s.switchThrehold < percentage { | ||||
| 		s.current = HTTP2 | ||||
| 	} else { | ||||
| 		s.current = H2mux | ||||
| 	} | ||||
| 	s.refreshAfter = time.Now().Add(s.ttl) | ||||
| 	return s.current | ||||
| } | ||||
| 
 | ||||
| func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) { | ||||
| 	protocolPercentages, err := fetchFunc() | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	for _, protocol := range protocolPool { | ||||
| 		protocolPercentage := protocolPercentages.GetPercentage(protocol.String()) | ||||
| 		if protocolPercentage > switchThreshold { | ||||
| 			return protocol, nil | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return protocolPool[len(protocolPool)-1], nil | ||||
| } | ||||
| 
 | ||||
| func (s *autoProtocolSelector) Fallback() (Protocol, bool) { | ||||
| 	s.lock.RLock() | ||||
| 	defer s.lock.RUnlock() | ||||
| 	return s.current.fallback() | ||||
| } | ||||
| 
 | ||||
| type PercentageFetcher func() (int32, error) | ||||
| type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) | ||||
| 
 | ||||
| func NewProtocolSelector( | ||||
| 	protocolFlag string, | ||||
|  | @ -179,22 +209,34 @@ func NewProtocolSelector( | |||
| 		}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// warp routing cannot be served over h2mux connections
 | ||||
| 	if warpRoutingEnabled { | ||||
| 		if protocolFlag == H2mux.String() { | ||||
| 			log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") | ||||
| 		} | ||||
| 
 | ||||
| 		if protocolFlag == QUIC.String() { | ||||
| 			return &staticProtocolSelector{ | ||||
| 				current: QUIC, | ||||
| 			}, nil | ||||
| 		} | ||||
| 	threshold := switchThreshold(namedTunnel.Credentials.AccountTag) | ||||
| 	fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold) | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") | ||||
| 		return &staticProtocolSelector{ | ||||
| 			current: HTTP2, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	if warpRoutingEnabled { | ||||
| 		if protocolFlag == H2mux.String() || fetchedProtocol == H2mux { | ||||
| 			log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") | ||||
| 			protocolFlag = HTTP2.String() | ||||
| 			fetchedProtocol = HTTP2Warp | ||||
| 		} | ||||
| 		return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) | ||||
| 	} | ||||
| 
 | ||||
| 	return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) | ||||
| } | ||||
| 
 | ||||
| func selectNamedTunnelProtocols( | ||||
| 	protocolFlag string, | ||||
| 	fetchFunc PercentageFetcher, | ||||
| 	ttl time.Duration, | ||||
| 	log *zerolog.Logger, | ||||
| 	threshold int32, | ||||
| 	protocol Protocol, | ||||
| ) (ProtocolSelector, error) { | ||||
| 	if protocolFlag == H2mux.String() { | ||||
| 		return &staticProtocolSelector{ | ||||
| 			current: H2mux, | ||||
|  | @ -202,31 +244,41 @@ func NewProtocolSelector( | |||
| 	} | ||||
| 
 | ||||
| 	if protocolFlag == QUIC.String() { | ||||
| 		return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 		return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 	} | ||||
| 
 | ||||
| 	http2Percentage, err := fetchFunc() | ||||
| 	if err != nil { | ||||
| 		log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") | ||||
| 		return &staticProtocolSelector{ | ||||
| 			current: HTTP2, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	if protocolFlag == HTTP2.String() { | ||||
| 		if http2Percentage < 0 { | ||||
| 			return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 		} | ||||
| 		return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 		return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 	} | ||||
| 
 | ||||
| 	if protocolFlag != autoSelectFlag { | ||||
| 		return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) | ||||
| 	} | ||||
| 	threshold := switchThreshold(namedTunnel.Credentials.AccountTag) | ||||
| 	if threshold < http2Percentage { | ||||
| 		return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil | ||||
| 
 | ||||
| 	return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil | ||||
| } | ||||
| 
 | ||||
| func selectWarpRoutingProtocols( | ||||
| 	protocolFlag string, | ||||
| 	fetchFunc PercentageFetcher, | ||||
| 	ttl time.Duration, | ||||
| 	log *zerolog.Logger, | ||||
| 	threshold int32, | ||||
| 	protocol Protocol, | ||||
| ) (ProtocolSelector, error) { | ||||
| 	if protocolFlag == QUIC.String() { | ||||
| 		return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 	} | ||||
| 	return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil | ||||
| 
 | ||||
| 	if protocolFlag == HTTP2.String() { | ||||
| 		return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||
| 	} | ||||
| 
 | ||||
| 	if protocolFlag != autoSelectFlag { | ||||
| 		return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) | ||||
| 	} | ||||
| 
 | ||||
| 	return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil | ||||
| } | ||||
| 
 | ||||
| func switchThreshold(accountTag string) int32 { | ||||
|  |  | |||
|  | @ -6,6 +6,8 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/edgediscovery" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
|  | @ -21,29 +23,23 @@ var ( | |||
| 	} | ||||
| ) | ||||
| 
 | ||||
| func mockFetcher(percentage int32) PercentageFetcher { | ||||
| 	return func() (int32, error) { | ||||
| 		return percentage, nil | ||||
| func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { | ||||
| 	return func() (edgediscovery.ProtocolPercents, error) { | ||||
| 		if getError { | ||||
| 			return nil, fmt.Errorf("failed to fetch precentage") | ||||
| 		} | ||||
| } | ||||
| 
 | ||||
| func mockFetcherWithError() PercentageFetcher { | ||||
| 	return func() (int32, error) { | ||||
| 		return 0, fmt.Errorf("failed to fetch precentage") | ||||
| 		return protocolPercent, nil | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type dynamicMockFetcher struct { | ||||
| 	percentage int32 | ||||
| 	protocolPercents edgediscovery.ProtocolPercents | ||||
| 	err              error | ||||
| } | ||||
| 
 | ||||
| func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { | ||||
| 	return func() (int32, error) { | ||||
| 		if dmf.err != nil { | ||||
| 			return 0, dmf.err | ||||
| 		} | ||||
| 		return dmf.percentage, nil | ||||
| 	return func() (edgediscovery.ProtocolPercents, error) { | ||||
| 		return dmf.protocolPercents, dmf.err | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -69,6 +65,7 @@ func TestNewProtocolSelector(t *testing.T) { | |||
| 			name:              "named tunnel over h2mux", | ||||
| 			protocol:          "h2mux", | ||||
| 			expectedProtocol:  H2mux, | ||||
| 			fetchFunc:         func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
|  | @ -77,28 +74,38 @@ func TestNewProtocolSelector(t *testing.T) { | |||
| 			expectedProtocol:  HTTP2, | ||||
| 			hasFallback:       true, | ||||
| 			expectedFallback:  H2mux, | ||||
| 			fetchFunc:         mockFetcher(0), | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:              "named tunnel http2 disabled", | ||||
| 			protocol:          "http2", | ||||
| 			expectedProtocol:  H2mux, | ||||
| 			fetchFunc:         mockFetcher(-1), | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:             "named tunnel quic disabled", | ||||
| 			protocol:         "quic", | ||||
| 			expectedProtocol: HTTP2, | ||||
| 			// Hasfallback true is because if http2 fails, then we further fallback to h2mux.
 | ||||
| 			hasFallback:       true, | ||||
| 			expectedFallback:  H2mux, | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:              "named tunnel auto all http2 disabled", | ||||
| 			protocol:          "auto", | ||||
| 			expectedProtocol:  H2mux, | ||||
| 			fetchFunc:         mockFetcher(-1), | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:              "named tunnel auto to h2mux", | ||||
| 			protocol:          "auto", | ||||
| 			expectedProtocol:  H2mux, | ||||
| 			fetchFunc:         mockFetcher(0), | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
|  | @ -107,36 +114,71 @@ func TestNewProtocolSelector(t *testing.T) { | |||
| 			expectedProtocol:  HTTP2, | ||||
| 			hasFallback:       true, | ||||
| 			expectedFallback:  H2mux, | ||||
| 			fetchFunc:         mockFetcher(100), | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:              "named tunnel auto to quic", | ||||
| 			protocol:          "auto", | ||||
| 			expectedProtocol:  QUIC, | ||||
| 			hasFallback:       true, | ||||
| 			expectedFallback:  HTTP2, | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:               "warp routing requesting h2mux", | ||||
| 			protocol:           "h2mux", | ||||
| 			expectedProtocol:   HTTP2, | ||||
| 			expectedProtocol:   HTTP2Warp, | ||||
| 			hasFallback:        false, | ||||
| 			expectedFallback:   H2mux, | ||||
| 			fetchFunc:          mockFetcher(100), | ||||
| 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||
| 			warpRoutingEnabled: true, | ||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:               "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", | ||||
| 			protocol:           "h2mux", | ||||
| 			expectedProtocol:   HTTP2Warp, | ||||
| 			hasFallback:        false, | ||||
| 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), | ||||
| 			warpRoutingEnabled: true, | ||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:               "warp routing http2", | ||||
| 			protocol:           "http2", | ||||
| 			expectedProtocol:   HTTP2, | ||||
| 			expectedProtocol:   HTTP2Warp, | ||||
| 			hasFallback:        false, | ||||
| 			expectedFallback:   H2mux, | ||||
| 			fetchFunc:          mockFetcher(100), | ||||
| 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||
| 			warpRoutingEnabled: true, | ||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:               "warp routing quic", | ||||
| 			protocol:           "quic", | ||||
| 			expectedProtocol:   QUICWarp, | ||||
| 			hasFallback:        true, | ||||
| 			expectedFallback:   HTTP2Warp, | ||||
| 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), | ||||
| 			warpRoutingEnabled: true, | ||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:               "warp routing auto", | ||||
| 			protocol:           "auto", | ||||
| 			expectedProtocol:   HTTP2, | ||||
| 			expectedProtocol:   HTTP2Warp, | ||||
| 			hasFallback:        false, | ||||
| 			expectedFallback:   H2mux, | ||||
| 			fetchFunc:          mockFetcher(100), | ||||
| 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||
| 			warpRoutingEnabled: true, | ||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:               "warp routing auto- quic", | ||||
| 			protocol:           "auto", | ||||
| 			expectedProtocol:   QUICWarp, | ||||
| 			hasFallback:        true, | ||||
| 			expectedFallback:   HTTP2Warp, | ||||
| 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), | ||||
| 			warpRoutingEnabled: true, | ||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | ||||
| 		}, | ||||
|  | @ -149,14 +191,14 @@ func TestNewProtocolSelector(t *testing.T) { | |||
| 		{ | ||||
| 			name:              "named tunnel unknown protocol", | ||||
| 			protocol:          "unknown", | ||||
| 			fetchFunc:         mockFetcher(100), | ||||
| 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 			wantErr:           true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:              "named tunnel fetch error", | ||||
| 			protocol:          "unknown", | ||||
| 			fetchFunc:         mockFetcherWithError(), | ||||
| 			protocol:          "auto", | ||||
| 			fetchFunc:         mockFetcher(true), | ||||
| 			namedTunnelConfig: testNamedTunnelConfig, | ||||
| 			expectedProtocol:  HTTP2, | ||||
| 			wantErr:           false, | ||||
|  | @ -164,6 +206,7 @@ func TestNewProtocolSelector(t *testing.T) { | |||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		t.Run(test.name, func(t *testing.T) { | ||||
| 			selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) | ||||
| 			if test.wantErr { | ||||
| 				assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) | ||||
|  | @ -176,6 +219,7 @@ func TestNewProtocolSelector(t *testing.T) { | |||
| 					assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -185,27 +229,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { | |||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, H2mux, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 100 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 0 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||
| 	assert.Equal(t, H2mux, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 100 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.err = fmt.Errorf("failed to fetch") | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = -1 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} | ||||
| 	fetcher.err = nil | ||||
| 	assert.Equal(t, H2mux, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 0 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||
| 	assert.Equal(t, H2mux, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 100 | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} | ||||
| 	assert.Equal(t, QUIC, selector.Current()) | ||||
| } | ||||
| 
 | ||||
| func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { | ||||
|  | @ -214,35 +258,36 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { | |||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 100 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 0 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.err = fmt.Errorf("failed to fetch") | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = -1 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} | ||||
| 	fetcher.err = nil | ||||
| 	assert.Equal(t, H2mux, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 0 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 100 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = -1 | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} | ||||
| 	assert.Equal(t, H2mux, selector.Current()) | ||||
| } | ||||
| 
 | ||||
| func TestProtocolSelectorRefreshTTL(t *testing.T) { | ||||
| 	fetcher := dynamicMockFetcher{percentage: 100} | ||||
| 	fetcher := dynamicMockFetcher{} | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} | ||||
| 	selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 	assert.Equal(t, QUIC, selector.Current()) | ||||
| 
 | ||||
| 	fetcher.percentage = 0 | ||||
| 	assert.Equal(t, HTTP2, selector.Current()) | ||||
| 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} | ||||
| 	assert.Equal(t, QUIC, selector.Current()) | ||||
| } | ||||
|  |  | |||
|  | @ -1,45 +1,50 @@ | |||
| package edgediscovery | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	protocolRecord = "protocol.argotunnel.com" | ||||
| 	protocolRecord = "protocol-v2.argotunnel.com" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) | ||||
| ) | ||||
| 
 | ||||
| func HTTP2Percentage() (int32, error) { | ||||
| // ProtocolPercent represents a single Protocol Percentage combination.
 | ||||
| type ProtocolPercent struct { | ||||
| 	Protocol   string `json:"protocol"` | ||||
| 	Percentage int32  `json:"percentage"` | ||||
| } | ||||
| 
 | ||||
| // ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified.
 | ||||
| type ProtocolPercents []ProtocolPercent | ||||
| 
 | ||||
| // GetPercentage returns the threshold percentage of a single protocol.
 | ||||
| func (p ProtocolPercents) GetPercentage(protocol string) int32 { | ||||
| 	for _, protocolPercent := range p { | ||||
| 		if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) { | ||||
| 			return protocolPercent.Percentage | ||||
| 		} | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
| 
 | ||||
| // ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection.
 | ||||
| func ProtocolPercentage() (ProtocolPercents, error) { | ||||
| 	records, err := net.LookupTXT(protocolRecord) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(records) == 0 { | ||||
| 		return 0, errNoProtocolRecord | ||||
| 		return nil, errNoProtocolRecord | ||||
| 	} | ||||
| 	return parseHTTP2Precentage(records[0]) | ||||
| } | ||||
| 
 | ||||
| // The record looks like http2=percentage
 | ||||
| func parseHTTP2Precentage(record string) (int32, error) { | ||||
| 	const key = "http2" | ||||
| 	slices := strings.Split(record, "=") | ||||
| 	if len(slices) != 2 { | ||||
| 		return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record) | ||||
| 	} | ||||
| 	if slices[0] != key { | ||||
| 		return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key) | ||||
| 	} | ||||
| 	percentage, err := strconv.ParseInt(slices[1], 10, 32) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	return int32(percentage), nil | ||||
| 
 | ||||
| 	var protocolsWithPercent ProtocolPercents | ||||
| 	err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent) | ||||
| 	return protocolsWithPercent, err | ||||
| } | ||||
|  |  | |||
|  | @ -6,75 +6,7 @@ import ( | |||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestHTTP2Percentage(t *testing.T) { | ||||
| 	_, err := HTTP2Percentage() | ||||
| func TestProtocolPercentage(t *testing.T) { | ||||
| 	_, err := ProtocolPercentage() | ||||
| 	assert.NoError(t, err) | ||||
| } | ||||
| 
 | ||||
| func TestParseHTTP2Precentage(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		record     string | ||||
| 		percentage int32 | ||||
| 		wantErr    bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			record:     "http2=-1", | ||||
| 			percentage: -1, | ||||
| 			wantErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:     "http2=0", | ||||
| 			percentage: 0, | ||||
| 			wantErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:     "http2=50", | ||||
| 			percentage: 50, | ||||
| 			wantErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:     "http2=100", | ||||
| 			percentage: 100, | ||||
| 			wantErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:     "http2=1000", | ||||
| 			percentage: 1000, | ||||
| 			wantErr:    false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:  "http2=10.5", | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:  "http2=10 h2mux=90", | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:  "http2=ten", | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 
 | ||||
| 		{ | ||||
| 			record:  "h2mux=100", | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:  "http2", | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			record:  "http2=", | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		p, err := parseHTTP2Precentage(test.record) | ||||
| 		if test.wantErr { | ||||
| 			assert.Error(t, err) | ||||
| 		} else { | ||||
| 			assert.Equal(t, test.percentage, p) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -351,7 +351,7 @@ func serveTunnel( | |||
| 	) | ||||
| 
 | ||||
| 	switch protocol { | ||||
| 	case connection.QUIC: | ||||
| 	case connection.QUIC, connection.QUICWarp: | ||||
| 		connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries())) | ||||
| 		return ServeQUIC(ctx, | ||||
| 			addr.UDP, | ||||
|  | @ -361,7 +361,7 @@ func serveTunnel( | |||
| 			reconnectCh, | ||||
| 			gracefulShutdownC) | ||||
| 
 | ||||
| 	case connection.HTTP2: | ||||
| 	case connection.HTTP2, connection.HTTP2Warp: | ||||
| 		edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) | ||||
| 		if err != nil { | ||||
| 			connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge") | ||||
|  |  | |||
|  | @ -8,20 +8,18 @@ import ( | |||
| 	"github.com/stretchr/testify/assert" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/connection" | ||||
| 	"github.com/cloudflare/cloudflared/edgediscovery" | ||||
| 	"github.com/cloudflare/cloudflared/retry" | ||||
| ) | ||||
| 
 | ||||
| type dynamicMockFetcher struct { | ||||
| 	percentage int32 | ||||
| 	protocolPercents edgediscovery.ProtocolPercents | ||||
| 	err              error | ||||
| } | ||||
| 
 | ||||
| func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { | ||||
| 	return func() (int32, error) { | ||||
| 		if dmf.err != nil { | ||||
| 			return 0, dmf.err | ||||
| 		} | ||||
| 		return dmf.percentage, nil | ||||
| 	return func() (edgediscovery.ProtocolPercents, error) { | ||||
| 		return dmf.protocolPercents, dmf.err | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -39,7 +37,7 @@ func TestWaitForBackoffFallback(t *testing.T) { | |||
| 		}, | ||||
| 	} | ||||
| 	mockFetcher := dynamicMockFetcher{ | ||||
| 		percentage: 0, | ||||
| 		protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, | ||||
| 	} | ||||
| 	warpRoutingEnabled := false | ||||
| 	protocolSelector, err := connection.NewProtocolSelector( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue