TUN-5138: Switch to QUIC on auto protocol based on threshold
This commit is contained in:
		
							parent
							
								
									5a3c0fdffa
								
							
						
					
					
						commit
						ceb509ee98
					
				|  | @ -238,7 +238,7 @@ func prepareTunnelConfig( | ||||||
| 		log.Info().Msgf("Warp-routing is enabled") | 		log.Info().Msgf("Warp-routing is enabled") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) | 	protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, ingress.Ingress{}, err | 		return nil, ingress.Ingress{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -7,6 +7,8 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/rs/zerolog" | 	"github.com/rs/zerolog" | ||||||
|  | 
 | ||||||
|  | 	"github.com/cloudflare/cloudflared/edgediscovery" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -24,7 +26,7 @@ const ( | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	// ProtocolList represents a list of supported protocols for communication with the edge.
 | 	// ProtocolList represents a list of supported protocols for communication with the edge.
 | ||||||
| 	ProtocolList = []Protocol{H2mux, HTTP2, QUIC} | 	ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp} | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Protocol int64 | type Protocol int64 | ||||||
|  | @ -36,6 +38,12 @@ const ( | ||||||
| 	HTTP2 | 	HTTP2 | ||||||
| 	// QUIC is used only with named tunnels.
 | 	// QUIC is used only with named tunnels.
 | ||||||
| 	QUIC | 	QUIC | ||||||
|  | 	// HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
 | ||||||
|  | 	// H2mux on HTTP2 failure to connect.
 | ||||||
|  | 	HTTP2Warp | ||||||
|  | 	//QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
 | ||||||
|  | 	// dont' want HTTP2 to fallback to H2mux
 | ||||||
|  | 	QUICWarp | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Fallback returns the fallback protocol and whether the protocol has a fallback
 | // Fallback returns the fallback protocol and whether the protocol has a fallback
 | ||||||
|  | @ -45,8 +53,12 @@ func (p Protocol) fallback() (Protocol, bool) { | ||||||
| 		return 0, false | 		return 0, false | ||||||
| 	case HTTP2: | 	case HTTP2: | ||||||
| 		return H2mux, true | 		return H2mux, true | ||||||
|  | 	case HTTP2Warp: | ||||||
|  | 		return 0, false | ||||||
| 	case QUIC: | 	case QUIC: | ||||||
| 		return HTTP2, true | 		return HTTP2, true | ||||||
|  | 	case QUICWarp: | ||||||
|  | 		return HTTP2Warp, true | ||||||
| 	default: | 	default: | ||||||
| 		return 0, false | 		return 0, false | ||||||
| 	} | 	} | ||||||
|  | @ -56,9 +68,9 @@ func (p Protocol) String() string { | ||||||
| 	switch p { | 	switch p { | ||||||
| 	case H2mux: | 	case H2mux: | ||||||
| 		return "h2mux" | 		return "h2mux" | ||||||
| 	case HTTP2: | 	case HTTP2, HTTP2Warp: | ||||||
| 		return "http2" | 		return "http2" | ||||||
| 	case QUIC: | 	case QUIC, QUICWarp: | ||||||
| 		return "quic" | 		return "quic" | ||||||
| 	default: | 	default: | ||||||
| 		return fmt.Sprintf("unknown protocol") | 		return fmt.Sprintf("unknown protocol") | ||||||
|  | @ -71,11 +83,11 @@ func (p Protocol) TLSSettings() *TLSSettings { | ||||||
| 		return &TLSSettings{ | 		return &TLSSettings{ | ||||||
| 			ServerName: edgeH2muxTLSServerName, | 			ServerName: edgeH2muxTLSServerName, | ||||||
| 		} | 		} | ||||||
| 	case HTTP2: | 	case HTTP2, HTTP2Warp: | ||||||
| 		return &TLSSettings{ | 		return &TLSSettings{ | ||||||
| 			ServerName: edgeH2TLSServerName, | 			ServerName: edgeH2TLSServerName, | ||||||
| 		} | 		} | ||||||
| 	case QUIC: | 	case QUIC, QUICWarp: | ||||||
| 		return &TLSSettings{ | 		return &TLSSettings{ | ||||||
| 			ServerName: edgeQUICServerName, | 			ServerName: edgeQUICServerName, | ||||||
| 			NextProtos: []string{"argotunnel"}, | 			NextProtos: []string{"argotunnel"}, | ||||||
|  | @ -109,8 +121,13 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) { | ||||||
| 
 | 
 | ||||||
| type autoProtocolSelector struct { | type autoProtocolSelector struct { | ||||||
| 	lock sync.RWMutex | 	lock sync.RWMutex | ||||||
|  | 
 | ||||||
| 	current Protocol | 	current Protocol | ||||||
| 	switchThrehold int32 | 
 | ||||||
|  | 	// protocolPool is desired protocols in the order of priority they should be picked in.
 | ||||||
|  | 	protocolPool []Protocol | ||||||
|  | 
 | ||||||
|  | 	switchThreshold int32 | ||||||
| 	fetchFunc       PercentageFetcher | 	fetchFunc       PercentageFetcher | ||||||
| 	refreshAfter    time.Time | 	refreshAfter    time.Time | ||||||
| 	ttl             time.Duration | 	ttl             time.Duration | ||||||
|  | @ -119,14 +136,16 @@ type autoProtocolSelector struct { | ||||||
| 
 | 
 | ||||||
| func newAutoProtocolSelector( | func newAutoProtocolSelector( | ||||||
| 	current Protocol, | 	current Protocol, | ||||||
| 	switchThrehold int32, | 	protocolPool []Protocol, | ||||||
|  | 	switchThreshold int32, | ||||||
| 	fetchFunc PercentageFetcher, | 	fetchFunc PercentageFetcher, | ||||||
| 	ttl time.Duration, | 	ttl time.Duration, | ||||||
| 	log *zerolog.Logger, | 	log *zerolog.Logger, | ||||||
| ) *autoProtocolSelector { | ) *autoProtocolSelector { | ||||||
| 	return &autoProtocolSelector{ | 	return &autoProtocolSelector{ | ||||||
| 		current:         current, | 		current:         current, | ||||||
| 		switchThrehold: switchThrehold, | 		protocolPool:    protocolPool, | ||||||
|  | 		switchThreshold: switchThreshold, | ||||||
| 		fetchFunc:       fetchFunc, | 		fetchFunc:       fetchFunc, | ||||||
| 		refreshAfter:    time.Now().Add(ttl), | 		refreshAfter:    time.Now().Add(ttl), | ||||||
| 		ttl:             ttl, | 		ttl:             ttl, | ||||||
|  | @ -141,28 +160,39 @@ func (s *autoProtocolSelector) Current() Protocol { | ||||||
| 		return s.current | 		return s.current | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	percentage, err := s.fetchFunc() | 	protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.log.Err(err).Msg("Failed to refresh protocol") | 		s.log.Err(err).Msg("Failed to refresh protocol") | ||||||
| 		return s.current | 		return s.current | ||||||
| 	} | 	} | ||||||
|  | 	s.current = protocol | ||||||
| 
 | 
 | ||||||
| 	if s.switchThrehold < percentage { |  | ||||||
| 		s.current = HTTP2 |  | ||||||
| 	} else { |  | ||||||
| 		s.current = H2mux |  | ||||||
| 	} |  | ||||||
| 	s.refreshAfter = time.Now().Add(s.ttl) | 	s.refreshAfter = time.Now().Add(s.ttl) | ||||||
| 	return s.current | 	return s.current | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) { | ||||||
|  | 	protocolPercentages, err := fetchFunc() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, err | ||||||
|  | 	} | ||||||
|  | 	for _, protocol := range protocolPool { | ||||||
|  | 		protocolPercentage := protocolPercentages.GetPercentage(protocol.String()) | ||||||
|  | 		if protocolPercentage > switchThreshold { | ||||||
|  | 			return protocol, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return protocolPool[len(protocolPool)-1], nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (s *autoProtocolSelector) Fallback() (Protocol, bool) { | func (s *autoProtocolSelector) Fallback() (Protocol, bool) { | ||||||
| 	s.lock.RLock() | 	s.lock.RLock() | ||||||
| 	defer s.lock.RUnlock() | 	defer s.lock.RUnlock() | ||||||
| 	return s.current.fallback() | 	return s.current.fallback() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type PercentageFetcher func() (int32, error) | type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) | ||||||
| 
 | 
 | ||||||
| func NewProtocolSelector( | func NewProtocolSelector( | ||||||
| 	protocolFlag string, | 	protocolFlag string, | ||||||
|  | @ -179,22 +209,34 @@ func NewProtocolSelector( | ||||||
| 		}, nil | 		}, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// warp routing cannot be served over h2mux connections
 | 	threshold := switchThreshold(namedTunnel.Credentials.AccountTag) | ||||||
| 	if warpRoutingEnabled { | 	fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold) | ||||||
| 		if protocolFlag == H2mux.String() { | 	if err != nil { | ||||||
| 			log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") | 		log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if protocolFlag == QUIC.String() { |  | ||||||
| 			return &staticProtocolSelector{ |  | ||||||
| 				current: QUIC, |  | ||||||
| 			}, nil |  | ||||||
| 		} |  | ||||||
| 		return &staticProtocolSelector{ | 		return &staticProtocolSelector{ | ||||||
| 			current: HTTP2, | 			current: HTTP2, | ||||||
| 		}, nil | 		}, nil | ||||||
| 	} | 	} | ||||||
|  | 	if warpRoutingEnabled { | ||||||
|  | 		if protocolFlag == H2mux.String() || fetchedProtocol == H2mux { | ||||||
|  | 			log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") | ||||||
|  | 			protocolFlag = HTTP2.String() | ||||||
|  | 			fetchedProtocol = HTTP2Warp | ||||||
|  | 		} | ||||||
|  | 		return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
|  | 	return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func selectNamedTunnelProtocols( | ||||||
|  | 	protocolFlag string, | ||||||
|  | 	fetchFunc PercentageFetcher, | ||||||
|  | 	ttl time.Duration, | ||||||
|  | 	log *zerolog.Logger, | ||||||
|  | 	threshold int32, | ||||||
|  | 	protocol Protocol, | ||||||
|  | ) (ProtocolSelector, error) { | ||||||
| 	if protocolFlag == H2mux.String() { | 	if protocolFlag == H2mux.String() { | ||||||
| 		return &staticProtocolSelector{ | 		return &staticProtocolSelector{ | ||||||
| 			current: H2mux, | 			current: H2mux, | ||||||
|  | @ -202,31 +244,41 @@ func NewProtocolSelector( | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if protocolFlag == QUIC.String() { | 	if protocolFlag == QUIC.String() { | ||||||
| 		return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | 		return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	http2Percentage, err := fetchFunc() |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") |  | ||||||
| 		return &staticProtocolSelector{ |  | ||||||
| 			current: HTTP2, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	if protocolFlag == HTTP2.String() { | 	if protocolFlag == HTTP2.String() { | ||||||
| 		if http2Percentage < 0 { | 		return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||||
| 			return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil |  | ||||||
| 		} |  | ||||||
| 		return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if protocolFlag != autoSelectFlag { | 	if protocolFlag != autoSelectFlag { | ||||||
| 		return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) | 		return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) | ||||||
| 	} | 	} | ||||||
| 	threshold := switchThreshold(namedTunnel.Credentials.AccountTag) | 
 | ||||||
| 	if threshold < http2Percentage { | 	return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil | ||||||
| 		return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil | } | ||||||
|  | 
 | ||||||
|  | func selectWarpRoutingProtocols( | ||||||
|  | 	protocolFlag string, | ||||||
|  | 	fetchFunc PercentageFetcher, | ||||||
|  | 	ttl time.Duration, | ||||||
|  | 	log *zerolog.Logger, | ||||||
|  | 	threshold int32, | ||||||
|  | 	protocol Protocol, | ||||||
|  | ) (ProtocolSelector, error) { | ||||||
|  | 	if protocolFlag == QUIC.String() { | ||||||
|  | 		return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||||
| 	} | 	} | ||||||
| 	return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil | 
 | ||||||
|  | 	if protocolFlag == HTTP2.String() { | ||||||
|  | 		return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if protocolFlag != autoSelectFlag { | ||||||
|  | 		return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func switchThreshold(accountTag string) int32 { | func switchThreshold(accountTag string) int32 { | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
|  | 
 | ||||||
|  | 	"github.com/cloudflare/cloudflared/edgediscovery" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -21,29 +23,23 @@ var ( | ||||||
| 	} | 	} | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func mockFetcher(percentage int32) PercentageFetcher { | func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { | ||||||
| 	return func() (int32, error) { | 	return func() (edgediscovery.ProtocolPercents, error) { | ||||||
| 		return percentage, nil | 		if getError { | ||||||
|  | 			return nil, fmt.Errorf("failed to fetch precentage") | ||||||
| 		} | 		} | ||||||
| } | 		return protocolPercent, nil | ||||||
| 
 |  | ||||||
| func mockFetcherWithError() PercentageFetcher { |  | ||||||
| 	return func() (int32, error) { |  | ||||||
| 		return 0, fmt.Errorf("failed to fetch precentage") |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type dynamicMockFetcher struct { | type dynamicMockFetcher struct { | ||||||
| 	percentage int32 | 	protocolPercents edgediscovery.ProtocolPercents | ||||||
| 	err              error | 	err              error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { | func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { | ||||||
| 	return func() (int32, error) { | 	return func() (edgediscovery.ProtocolPercents, error) { | ||||||
| 		if dmf.err != nil { | 		return dmf.protocolPercents, dmf.err | ||||||
| 			return 0, dmf.err |  | ||||||
| 		} |  | ||||||
| 		return dmf.percentage, nil |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -69,6 +65,7 @@ func TestNewProtocolSelector(t *testing.T) { | ||||||
| 			name:              "named tunnel over h2mux", | 			name:              "named tunnel over h2mux", | ||||||
| 			protocol:          "h2mux", | 			protocol:          "h2mux", | ||||||
| 			expectedProtocol:  H2mux, | 			expectedProtocol:  H2mux, | ||||||
|  | 			fetchFunc:         func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
|  | @ -77,28 +74,38 @@ func TestNewProtocolSelector(t *testing.T) { | ||||||
| 			expectedProtocol:  HTTP2, | 			expectedProtocol:  HTTP2, | ||||||
| 			hasFallback:       true, | 			hasFallback:       true, | ||||||
| 			expectedFallback:  H2mux, | 			expectedFallback:  H2mux, | ||||||
| 			fetchFunc:         mockFetcher(0), | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:              "named tunnel http2 disabled", | 			name:              "named tunnel http2 disabled", | ||||||
| 			protocol:          "http2", | 			protocol:          "http2", | ||||||
| 			expectedProtocol:  H2mux, | 			expectedProtocol:  H2mux, | ||||||
| 			fetchFunc:         mockFetcher(-1), | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), | ||||||
|  | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:             "named tunnel quic disabled", | ||||||
|  | 			protocol:         "quic", | ||||||
|  | 			expectedProtocol: HTTP2, | ||||||
|  | 			// Hasfallback true is because if http2 fails, then we further fallback to h2mux.
 | ||||||
|  | 			hasFallback:       true, | ||||||
|  | 			expectedFallback:  H2mux, | ||||||
|  | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:              "named tunnel auto all http2 disabled", | 			name:              "named tunnel auto all http2 disabled", | ||||||
| 			protocol:          "auto", | 			protocol:          "auto", | ||||||
| 			expectedProtocol:  H2mux, | 			expectedProtocol:  H2mux, | ||||||
| 			fetchFunc:         mockFetcher(-1), | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:              "named tunnel auto to h2mux", | 			name:              "named tunnel auto to h2mux", | ||||||
| 			protocol:          "auto", | 			protocol:          "auto", | ||||||
| 			expectedProtocol:  H2mux, | 			expectedProtocol:  H2mux, | ||||||
| 			fetchFunc:         mockFetcher(0), | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
|  | @ -107,36 +114,71 @@ func TestNewProtocolSelector(t *testing.T) { | ||||||
| 			expectedProtocol:  HTTP2, | 			expectedProtocol:  HTTP2, | ||||||
| 			hasFallback:       true, | 			hasFallback:       true, | ||||||
| 			expectedFallback:  H2mux, | 			expectedFallback:  H2mux, | ||||||
| 			fetchFunc:         mockFetcher(100), | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||||
|  | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:              "named tunnel auto to quic", | ||||||
|  | 			protocol:          "auto", | ||||||
|  | 			expectedProtocol:  QUIC, | ||||||
|  | 			hasFallback:       true, | ||||||
|  | 			expectedFallback:  HTTP2, | ||||||
|  | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:               "warp routing requesting h2mux", | 			name:               "warp routing requesting h2mux", | ||||||
| 			protocol:           "h2mux", | 			protocol:           "h2mux", | ||||||
| 			expectedProtocol:   HTTP2, | 			expectedProtocol:   HTTP2Warp, | ||||||
| 			hasFallback:        false, | 			hasFallback:        false, | ||||||
| 			expectedFallback:   H2mux, | 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||||
| 			fetchFunc:          mockFetcher(100), | 			warpRoutingEnabled: true, | ||||||
|  | 			namedTunnelConfig:  testNamedTunnelConfig, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:               "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", | ||||||
|  | 			protocol:           "h2mux", | ||||||
|  | 			expectedProtocol:   HTTP2Warp, | ||||||
|  | 			hasFallback:        false, | ||||||
|  | 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), | ||||||
| 			warpRoutingEnabled: true, | 			warpRoutingEnabled: true, | ||||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | 			namedTunnelConfig:  testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:               "warp routing http2", | 			name:               "warp routing http2", | ||||||
| 			protocol:           "http2", | 			protocol:           "http2", | ||||||
| 			expectedProtocol:   HTTP2, | 			expectedProtocol:   HTTP2Warp, | ||||||
| 			hasFallback:        false, | 			hasFallback:        false, | ||||||
| 			expectedFallback:   H2mux, | 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||||
| 			fetchFunc:          mockFetcher(100), | 			warpRoutingEnabled: true, | ||||||
|  | 			namedTunnelConfig:  testNamedTunnelConfig, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:               "warp routing quic", | ||||||
|  | 			protocol:           "quic", | ||||||
|  | 			expectedProtocol:   QUICWarp, | ||||||
|  | 			hasFallback:        true, | ||||||
|  | 			expectedFallback:   HTTP2Warp, | ||||||
|  | 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), | ||||||
| 			warpRoutingEnabled: true, | 			warpRoutingEnabled: true, | ||||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | 			namedTunnelConfig:  testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:               "warp routing auto", | 			name:               "warp routing auto", | ||||||
| 			protocol:           "auto", | 			protocol:           "auto", | ||||||
| 			expectedProtocol:   HTTP2, | 			expectedProtocol:   HTTP2Warp, | ||||||
| 			hasFallback:        false, | 			hasFallback:        false, | ||||||
| 			expectedFallback:   H2mux, | 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||||
| 			fetchFunc:          mockFetcher(100), | 			warpRoutingEnabled: true, | ||||||
|  | 			namedTunnelConfig:  testNamedTunnelConfig, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:               "warp routing auto- quic", | ||||||
|  | 			protocol:           "auto", | ||||||
|  | 			expectedProtocol:   QUICWarp, | ||||||
|  | 			hasFallback:        true, | ||||||
|  | 			expectedFallback:   HTTP2Warp, | ||||||
|  | 			fetchFunc:          mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), | ||||||
| 			warpRoutingEnabled: true, | 			warpRoutingEnabled: true, | ||||||
| 			namedTunnelConfig:  testNamedTunnelConfig, | 			namedTunnelConfig:  testNamedTunnelConfig, | ||||||
| 		}, | 		}, | ||||||
|  | @ -149,14 +191,14 @@ func TestNewProtocolSelector(t *testing.T) { | ||||||
| 		{ | 		{ | ||||||
| 			name:              "named tunnel unknown protocol", | 			name:              "named tunnel unknown protocol", | ||||||
| 			protocol:          "unknown", | 			protocol:          "unknown", | ||||||
| 			fetchFunc:         mockFetcher(100), | 			fetchFunc:         mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 			wantErr:           true, | 			wantErr:           true, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:              "named tunnel fetch error", | 			name:              "named tunnel fetch error", | ||||||
| 			protocol:          "unknown", | 			protocol:          "auto", | ||||||
| 			fetchFunc:         mockFetcherWithError(), | 			fetchFunc:         mockFetcher(true), | ||||||
| 			namedTunnelConfig: testNamedTunnelConfig, | 			namedTunnelConfig: testNamedTunnelConfig, | ||||||
| 			expectedProtocol:  HTTP2, | 			expectedProtocol:  HTTP2, | ||||||
| 			wantErr:           false, | 			wantErr:           false, | ||||||
|  | @ -164,6 +206,7 @@ func TestNewProtocolSelector(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
|  | 		t.Run(test.name, func(t *testing.T) { | ||||||
| 			selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) | 			selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) | ||||||
| 			if test.wantErr { | 			if test.wantErr { | ||||||
| 				assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) | 				assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) | ||||||
|  | @ -176,6 +219,7 @@ func TestNewProtocolSelector(t *testing.T) { | ||||||
| 					assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) | 					assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -185,27 +229,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.Equal(t, H2mux, selector.Current()) | 	assert.Equal(t, H2mux, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 100 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 0 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||||
| 	assert.Equal(t, H2mux, selector.Current()) | 	assert.Equal(t, H2mux, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 100 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.err = fmt.Errorf("failed to fetch") | 	fetcher.err = fmt.Errorf("failed to fetch") | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = -1 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} | ||||||
| 	fetcher.err = nil | 	fetcher.err = nil | ||||||
| 	assert.Equal(t, H2mux, selector.Current()) | 	assert.Equal(t, H2mux, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 0 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||||
| 	assert.Equal(t, H2mux, selector.Current()) | 	assert.Equal(t, H2mux, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 100 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, QUIC, selector.Current()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { | func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { | ||||||
|  | @ -214,35 +258,36 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 100 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 0 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.err = fmt.Errorf("failed to fetch") | 	fetcher.err = fmt.Errorf("failed to fetch") | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = -1 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} | ||||||
| 	fetcher.err = nil | 	fetcher.err = nil | ||||||
| 	assert.Equal(t, H2mux, selector.Current()) | 	assert.Equal(t, H2mux, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 0 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 100 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, HTTP2, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = -1 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} | ||||||
| 	assert.Equal(t, H2mux, selector.Current()) | 	assert.Equal(t, H2mux, selector.Current()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProtocolSelectorRefreshTTL(t *testing.T) { | func TestProtocolSelectorRefreshTTL(t *testing.T) { | ||||||
| 	fetcher := dynamicMockFetcher{percentage: 100} | 	fetcher := dynamicMockFetcher{} | ||||||
|  | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} | ||||||
| 	selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) | 	selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, QUIC, selector.Current()) | ||||||
| 
 | 
 | ||||||
| 	fetcher.percentage = 0 | 	fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} | ||||||
| 	assert.Equal(t, HTTP2, selector.Current()) | 	assert.Equal(t, QUIC, selector.Current()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,45 +1,50 @@ | ||||||
| package edgediscovery | package edgediscovery | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"strconv" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	protocolRecord = "protocol.argotunnel.com" | 	protocolRecord = "protocol-v2.argotunnel.com" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
| 	errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) | 	errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func HTTP2Percentage() (int32, error) { | // ProtocolPercent represents a single Protocol Percentage combination.
 | ||||||
|  | type ProtocolPercent struct { | ||||||
|  | 	Protocol   string `json:"protocol"` | ||||||
|  | 	Percentage int32  `json:"percentage"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified.
 | ||||||
|  | type ProtocolPercents []ProtocolPercent | ||||||
|  | 
 | ||||||
|  | // GetPercentage returns the threshold percentage of a single protocol.
 | ||||||
|  | func (p ProtocolPercents) GetPercentage(protocol string) int32 { | ||||||
|  | 	for _, protocolPercent := range p { | ||||||
|  | 		if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) { | ||||||
|  | 			return protocolPercent.Percentage | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection.
 | ||||||
|  | func ProtocolPercentage() (ProtocolPercents, error) { | ||||||
| 	records, err := net.LookupTXT(protocolRecord) | 	records, err := net.LookupTXT(protocolRecord) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if len(records) == 0 { | 	if len(records) == 0 { | ||||||
| 		return 0, errNoProtocolRecord | 		return nil, errNoProtocolRecord | ||||||
| 	} | 	} | ||||||
| 	return parseHTTP2Precentage(records[0]) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // The record looks like http2=percentage
 |  | ||||||
| func parseHTTP2Precentage(record string) (int32, error) { |  | ||||||
| 	const key = "http2" |  | ||||||
| 	slices := strings.Split(record, "=") |  | ||||||
| 	if len(slices) != 2 { |  | ||||||
| 		return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record) |  | ||||||
| 	} |  | ||||||
| 	if slices[0] != key { |  | ||||||
| 		return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key) |  | ||||||
| 	} |  | ||||||
| 	percentage, err := strconv.ParseInt(slices[1], 10, 32) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return 0, err |  | ||||||
| 	} |  | ||||||
| 	return int32(percentage), nil |  | ||||||
| 
 | 
 | ||||||
|  | 	var protocolsWithPercent ProtocolPercents | ||||||
|  | 	err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent) | ||||||
|  | 	return protocolsWithPercent, err | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -6,75 +6,7 @@ import ( | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestHTTP2Percentage(t *testing.T) { | func TestProtocolPercentage(t *testing.T) { | ||||||
| 	_, err := HTTP2Percentage() | 	_, err := ProtocolPercentage() | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func TestParseHTTP2Precentage(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		record     string |  | ||||||
| 		percentage int32 |  | ||||||
| 		wantErr    bool |  | ||||||
| 	}{ |  | ||||||
| 		{ |  | ||||||
| 			record:     "http2=-1", |  | ||||||
| 			percentage: -1, |  | ||||||
| 			wantErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:     "http2=0", |  | ||||||
| 			percentage: 0, |  | ||||||
| 			wantErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:     "http2=50", |  | ||||||
| 			percentage: 50, |  | ||||||
| 			wantErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:     "http2=100", |  | ||||||
| 			percentage: 100, |  | ||||||
| 			wantErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:     "http2=1000", |  | ||||||
| 			percentage: 1000, |  | ||||||
| 			wantErr:    false, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:  "http2=10.5", |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:  "http2=10 h2mux=90", |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:  "http2=ten", |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 
 |  | ||||||
| 		{ |  | ||||||
| 			record:  "h2mux=100", |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:  "http2", |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 		{ |  | ||||||
| 			record:  "http2=", |  | ||||||
| 			wantErr: true, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	for _, test := range tests { |  | ||||||
| 		p, err := parseHTTP2Precentage(test.record) |  | ||||||
| 		if test.wantErr { |  | ||||||
| 			assert.Error(t, err) |  | ||||||
| 		} else { |  | ||||||
| 			assert.Equal(t, test.percentage, p) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -351,7 +351,7 @@ func serveTunnel( | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	switch protocol { | 	switch protocol { | ||||||
| 	case connection.QUIC: | 	case connection.QUIC, connection.QUICWarp: | ||||||
| 		connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries())) | 		connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries())) | ||||||
| 		return ServeQUIC(ctx, | 		return ServeQUIC(ctx, | ||||||
| 			addr.UDP, | 			addr.UDP, | ||||||
|  | @ -361,7 +361,7 @@ func serveTunnel( | ||||||
| 			reconnectCh, | 			reconnectCh, | ||||||
| 			gracefulShutdownC) | 			gracefulShutdownC) | ||||||
| 
 | 
 | ||||||
| 	case connection.HTTP2: | 	case connection.HTTP2, connection.HTTP2Warp: | ||||||
| 		edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) | 		edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge") | 			connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge") | ||||||
|  |  | ||||||
|  | @ -8,20 +8,18 @@ import ( | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflared/connection" | 	"github.com/cloudflare/cloudflared/connection" | ||||||
|  | 	"github.com/cloudflare/cloudflared/edgediscovery" | ||||||
| 	"github.com/cloudflare/cloudflared/retry" | 	"github.com/cloudflare/cloudflared/retry" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type dynamicMockFetcher struct { | type dynamicMockFetcher struct { | ||||||
| 	percentage int32 | 	protocolPercents edgediscovery.ProtocolPercents | ||||||
| 	err              error | 	err              error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { | func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { | ||||||
| 	return func() (int32, error) { | 	return func() (edgediscovery.ProtocolPercents, error) { | ||||||
| 		if dmf.err != nil { | 		return dmf.protocolPercents, dmf.err | ||||||
| 			return 0, dmf.err |  | ||||||
| 		} |  | ||||||
| 		return dmf.percentage, nil |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -39,7 +37,7 @@ func TestWaitForBackoffFallback(t *testing.T) { | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	mockFetcher := dynamicMockFetcher{ | 	mockFetcher := dynamicMockFetcher{ | ||||||
| 		percentage: 0, | 		protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, | ||||||
| 	} | 	} | ||||||
| 	warpRoutingEnabled := false | 	warpRoutingEnabled := false | ||||||
| 	protocolSelector, err := connection.NewProtocolSelector( | 	protocolSelector, err := connection.NewProtocolSelector( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue