From b5cdf3b2c70fcd5224f9f2cb7bd38b59733a7618 Mon Sep 17 00:00:00 2001 From: cthuang Date: Wed, 14 Oct 2020 11:28:07 +0100 Subject: [PATCH] TUN-3456: New protocol option auto to automatically select between http2 and h2mux --- cmd/cloudflared/tunnel/configuration.go | 23 ++++-- cmd/cloudflared/tunnel/subcommand_context.go | 6 +- cmd/cloudflared/tunnel/subcommands.go | 6 +- connection/connection.go | 33 ++++++-- edgediscovery/protocol.go | 45 +++++++++++ edgediscovery/protocol_test.go | 80 ++++++++++++++++++++ origin/tunnel.go | 2 +- 7 files changed, 176 insertions(+), 19 deletions(-) create mode 100644 edgediscovery/protocol.go create mode 100644 edgediscovery/protocol_test.go diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 14318713..822fa85a 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -11,6 +11,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" @@ -231,7 +232,11 @@ func prepareTunnelConfig( } } - protocol := determineProtocol(namedTunnel) + protocol, err := determineProtocol(c, namedTunnel) + if err != nil { + return nil, err + } + logger.Infof("Using protocol %s", protocol) toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName()) if err != nil { logger.Errorf("unable to create TLS config to connect with edge: %s", err) @@ -293,9 +298,17 @@ func isRunningFromTerminal() bool { return terminal.IsTerminal(int(os.Stdout.Fd())) } -func determineProtocol(namedTunnel *connection.NamedTunnelConfig) connection.Protocol { - if namedTunnel != nil { - return namedTunnel.Protocol +func determineProtocol(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) (connection.Protocol, error) { + if namedTunnel == nil { + return connection.H2mux, nil } - return connection.H2mux + http2Percentage, err := edgediscovery.HTTP2Percentage() + if err != nil { + return 0, err + } + protocol, ok := connection.SelectProtocol(c.String("protocol"), namedTunnel.Auth.AccountTag, http2Percentage) + if !ok { + return 0, fmt.Errorf("%s is not valid protocol. %s", c.String("protocol"), availableProtocol) + } + return protocol, nil } diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index ce04ab5a..6f3cca20 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -260,16 +260,12 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { return err } - protocol, ok := connection.ParseProtocol(sc.c.String("protocol")) - if !ok { - return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol) - } return StartServer( sc.c, version, shutdownC, graceShutdownC, - &connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol}, + &connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID}, sc.logger, sc.isUIEnabled, ) diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 959a022a..469ec60d 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -30,7 +30,7 @@ import ( const ( credFileFlagAlias = "cred-file" - availableProtocol = "Available protocols: http2, Go's implementation and h2mux, Cloudflare's implementation of HTTP/2." + availableProtocol = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux" ) var ( @@ -86,14 +86,14 @@ var ( Usage: "Allows you to delete a tunnel, even if it has active connections.", EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"}, } - selectProtocolFlag = &cli.StringFlag{ + selectProtocolFlag = altsrc.NewStringFlag(&cli.StringFlag{ Name: "protocol", Value: "h2mux", Aliases: []string{"p"}, Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol), EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"}, Hidden: true, - } + }) ) func buildCreateCommand() *cli.Command { diff --git a/connection/connection.go b/connection/connection.go index 5f7103b9..54df755f 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -1,6 +1,8 @@ package connection import ( + "fmt" + "hash/fnv" "io" "net/http" "strconv" @@ -25,10 +27,9 @@ type Config struct { } type NamedTunnelConfig struct { - Auth pogs.TunnelAuth - ID uuid.UUID - Client pogs.ClientInfo - Protocol Protocol + Auth pogs.TunnelAuth + ID uuid.UUID + Client pogs.ClientInfo } type ClassicTunnelConfig struct { @@ -49,17 +50,28 @@ const ( HTTP2 ) -func ParseProtocol(s string) (Protocol, bool) { +func SelectProtocol(s string, accountTag string, http2Percentage uint32) (Protocol, bool) { switch s { case "h2mux": return H2mux, true case "http2": return HTTP2, true + case "auto": + if tryHTTP2(accountTag, http2Percentage) { + return HTTP2, true + } + return H2mux, true default: return 0, false } } +func tryHTTP2(accountTag string, http2Percentage uint32) bool { + h := fnv.New32a() + h.Write([]byte(accountTag)) + return h.Sum32()%100 < http2Percentage +} + func (p Protocol) ServerName() string { switch p { case H2mux: @@ -71,6 +83,17 @@ func (p Protocol) ServerName() string { } } +func (p Protocol) String() string { + switch p { + case H2mux: + return "h2mux" + case HTTP2: + return "http2" + default: + return fmt.Sprintf("unknown protocol") + } +} + type OriginClient interface { Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error } diff --git a/edgediscovery/protocol.go b/edgediscovery/protocol.go new file mode 100644 index 00000000..8d9f5039 --- /dev/null +++ b/edgediscovery/protocol.go @@ -0,0 +1,45 @@ +package edgediscovery + +import ( + "fmt" + "net" + "strconv" + "strings" +) + +const ( + protocolRecord = "protocol.argotunnel.com" +) + +var ( + errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) +) + +func HTTP2Percentage() (int32, error) { + records, err := net.LookupTXT(protocolRecord) + if err != nil { + return 0, err + } + if len(records) == 0 { + return 0, errNoProtocolRecord + } + return parseHTTP2Precentage(records[0]) +} + +// The record looks like http2=percentage +func parseHTTP2Precentage(record string) (int32, error) { + const key = "http2" + slices := strings.Split(record, "=") + if len(slices) != 2 { + return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record) + } + if slices[0] != key { + return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key) + } + percentage, err := strconv.ParseInt(slices[1], 10, 32) + if err != nil { + return 0, err + } + return int32(percentage), nil + +} diff --git a/edgediscovery/protocol_test.go b/edgediscovery/protocol_test.go new file mode 100644 index 00000000..874ab6ee --- /dev/null +++ b/edgediscovery/protocol_test.go @@ -0,0 +1,80 @@ +package edgediscovery + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHTTP2Percentage(t *testing.T) { + _, err := HTTP2Percentage() + assert.NoError(t, err) +} + +func TestParseHTTP2Precentage(t *testing.T) { + tests := []struct { + record string + percentage int32 + wantErr bool + }{ + { + record: "http2=-1", + percentage: -1, + wantErr: false, + }, + { + record: "http2=0", + percentage: 0, + wantErr: false, + }, + { + record: "http2=50", + percentage: 50, + wantErr: false, + }, + { + record: "http2=100", + percentage: 100, + wantErr: false, + }, + { + record: "http2=1000", + percentage: 1000, + wantErr: false, + }, + { + record: "http2=10.5", + wantErr: true, + }, + { + record: "http2=10 h2mux=90", + wantErr: true, + }, + { + record: "http2=ten", + wantErr: true, + }, + + { + record: "h2mux=100", + wantErr: true, + }, + { + record: "http2", + wantErr: true, + }, + { + record: "http2=", + wantErr: true, + }, + } + + for _, test := range tests { + p, err := parseHTTP2Precentage(test.record) + if test.wantErr { + assert.Error(t, err) + } else { + assert.Equal(t, test.percentage, p) + } + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go index b98d2a47..d96432ea 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -238,7 +238,7 @@ func ServeTunnel( fuse: fuse, backoff: backoff, } - if config.NamedTunnel != nil && config.NamedTunnel.Protocol == connection.HTTP2 { + if config.Protocol == connection.HTTP2 { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh) }