TUN-3753: Select http2 protocol when warp routing is enabled

This commit is contained in:
cthuang 2021-01-21 15:23:18 +00:00 committed by Nuno Diegues
parent 3b93914612
commit 2146f71b45
4 changed files with 69 additions and 21 deletions

View File

@ -231,7 +231,13 @@ func prepareTunnelConfig(
}
}
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log)
var warpRoutingService *ingress.WarpRoutingService
warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel)
if warpRoutingEnabled {
warpRoutingService = ingress.NewWarpRoutingService()
}
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log)
if err != nil {
return nil, ingress.Ingress{}, err
}
@ -246,11 +252,6 @@ func prepareTunnelConfig(
edgeTLSConfigs[p] = edgeTLSConfig
}
var warpRoutingService *ingress.WarpRoutingService
if isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel, protocolSelector.Current()) {
warpRoutingService = ingress.NewWarpRoutingService()
}
originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log)
connectionConfig := &connection.Config{
OriginProxy: originProxy,
@ -292,8 +293,8 @@ func prepareTunnelConfig(
}, ingressRules, nil
}
func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool, protocol connection.Protocol) bool {
return warpConfig.Enabled && isNamedTunnel && protocol == connection.HTTP2
func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool) bool {
return warpConfig.Enabled && isNamedTunnel
}
func isRunningFromTerminal() bool {

View File

@ -141,16 +141,29 @@ type PercentageFetcher func() (int32, error)
func NewProtocolSelector(
protocolFlag string,
warpRoutingEnabled bool,
namedTunnel *NamedTunnelConfig,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
) (ProtocolSelector, error) {
// Classic tunnel is only supported with h2mux
if namedTunnel == nil {
return &staticProtocolSelector{
current: H2mux,
}, nil
}
// warp routing can only be served over http2 connections
if warpRoutingEnabled {
if protocolFlag == H2mux.String() {
log.Warn().Msg("Warp routing is only supported by http2 protocol. Upgrading protocol to http2")
}
return &staticProtocolSelector{
current: HTTP2,
}, nil
}
if protocolFlag == H2mux.String() {
return &staticProtocolSelector{
current: H2mux,

View File

@ -9,7 +9,8 @@ import (
)
const (
testNoTTL = 0
testNoTTL = 0
noWarpRoutingEnabled = false
)
var (
@ -48,14 +49,15 @@ func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
func TestNewProtocolSelector(t *testing.T) {
tests := []struct {
name string
protocol string
expectedProtocol Protocol
hasFallback bool
expectedFallback Protocol
namedTunnelConfig *NamedTunnelConfig
fetchFunc PercentageFetcher
wantErr bool
name string
protocol string
expectedProtocol Protocol
hasFallback bool
expectedFallback Protocol
warpRoutingEnabled bool
namedTunnelConfig *NamedTunnelConfig
fetchFunc PercentageFetcher
wantErr bool
}{
{
name: "classic tunnel",
@ -108,6 +110,36 @@ func TestNewProtocolSelector(t *testing.T) {
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing requesting h2mux",
protocol: "h2mux",
expectedProtocol: HTTP2,
hasFallback: false,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing http2",
protocol: "http2",
expectedProtocol: HTTP2,
hasFallback: false,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing auto",
protocol: "auto",
expectedProtocol: HTTP2,
hasFallback: false,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
name: "classic tunnel unknown protocol",
@ -131,7 +163,7 @@ func TestNewProtocolSelector(t *testing.T) {
}
for _, test := range tests {
selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
@ -148,7 +180,7 @@ func TestNewProtocolSelector(t *testing.T) {
func TestAutoProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
@ -177,7 +209,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
@ -206,7 +238,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{percentage: 100}
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())

View File

@ -40,8 +40,10 @@ func TestWaitForBackoffFallback(t *testing.T) {
mockFetcher := dynamicMockFetcher{
percentage: 0,
}
warpRoutingEnabled := false
protocolSelector, err := connection.NewProtocolSelector(
connection.HTTP2.String(),
warpRoutingEnabled,
namedTunnel,
mockFetcher.fetch(),
resolveTTL,