diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 822fa85a..c44e0ecc 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -232,15 +232,19 @@ func prepareTunnelConfig( } } - protocol, err := determineProtocol(c, namedTunnel) + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger) if err != nil { return nil, err } - logger.Infof("Using protocol %s", protocol) - toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName()) - if err != nil { - logger.Errorf("unable to create TLS config to connect with edge: %s", err) - return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") + logger.Infof("Initial protocol %s", protocolSelector.Current()) + + edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList)) + for _, p := range connection.ProtocolList { + edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName()) + if err != nil { + return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") + } + edgeTLSConfigs[p] = edgeTLSConfig } proxyConfig := &origin.ProxyConfig{ @@ -252,7 +256,7 @@ func prepareTunnelConfig( Tags: tags, } originClient := origin.NewClient(proxyConfig, logger) - transportConfig := &connection.Config{ + connectionConfig := &connection.Config{ OriginClient: originClient, GracePeriod: c.Duration("grace-period"), ReplaceExisting: c.Bool("force"), @@ -270,7 +274,7 @@ func prepareTunnelConfig( } return &origin.TunnelConfig{ - ConnectionConfig: transportConfig, + ConnectionConfig: connectionConfig, ProxyConfig: proxyConfig, BuildInfo: buildInfo, ClientID: clientID, @@ -281,34 +285,20 @@ func prepareTunnelConfig( IsFreeTunnel: isFreeTunnel, LBPool: c.String("lb-pool"), Logger: logger, - Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol), + Observer: connection.NewObserver(transportLogger, tunnelEventChan), ReportedVersion: version, Retries: c.Uint("retries"), RunFromTerminal: isRunningFromTerminal(), - TLSConfig: toEdgeTLSConfig, NamedTunnel: namedTunnel, ClassicTunnel: classicTunnel, MuxerConfig: muxerConfig, TunnelEventChan: tunnelEventChan, IngressRules: ingressRules, + ProtocolSelector: protocolSelector, + EdgeTLSConfigs: edgeTLSConfigs, }, nil } func isRunningFromTerminal() bool { return terminal.IsTerminal(int(os.Stdout.Fd())) } - -func determineProtocol(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) (connection.Protocol, error) { - if namedTunnel == nil { - return connection.H2mux, nil - } - http2Percentage, err := edgediscovery.HTTP2Percentage() - if err != nil { - return 0, err - } - protocol, ok := connection.SelectProtocol(c.String("protocol"), namedTunnel.Auth.AccountTag, http2Percentage) - if !ok { - return 0, fmt.Errorf("%s is not valid protocol. %s", c.String("protocol"), availableProtocol) - } - return protocol, nil -} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 469ec60d..6980b599 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -23,6 +23,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelstore" @@ -30,7 +31,6 @@ import ( const ( credFileFlagAlias = "cred-file" - availableProtocol = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux" ) var ( @@ -90,7 +90,7 @@ var ( Name: "protocol", Value: "h2mux", Aliases: []string{"p"}, - Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol), + Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage), EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"}, Hidden: true, }) diff --git a/connection/connection.go b/connection/connection.go index 54df755f..bfc779df 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -1,8 +1,6 @@ package connection import ( - "fmt" - "hash/fnv" "io" "net/http" "strconv" @@ -13,10 +11,6 @@ import ( ) const ( - // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge - edgeH2muxTLSServerName = "cftunnel.com" - // edgeH2TLSServerName is the server name to establish http2 connection with edge - edgeH2TLSServerName = "h2.cftunnel.com" lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" ) @@ -43,57 +37,6 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool { return c.Hostname == "" } -type Protocol int64 - -const ( - H2mux Protocol = iota - HTTP2 -) - -func SelectProtocol(s string, accountTag string, http2Percentage uint32) (Protocol, bool) { - switch s { - case "h2mux": - return H2mux, true - case "http2": - return HTTP2, true - case "auto": - if tryHTTP2(accountTag, http2Percentage) { - return HTTP2, true - } - return H2mux, true - default: - return 0, false - } -} - -func tryHTTP2(accountTag string, http2Percentage uint32) bool { - h := fnv.New32a() - h.Write([]byte(accountTag)) - return h.Sum32()%100 < http2Percentage -} - -func (p Protocol) ServerName() string { - switch p { - case H2mux: - return edgeH2muxTLSServerName - case HTTP2: - return edgeH2TLSServerName - default: - return "" - } -} - -func (p Protocol) String() string { - switch p { - case H2mux: - return "h2mux" - case HTTP2: - return "http2" - default: - return fmt.Sprintf("unknown protocol") - } -} - type OriginClient interface { Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error } diff --git a/connection/http2.go b/connection/http2.go index 5b6e3754..a9417d23 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -37,7 +37,7 @@ type HTTP2Connection struct { connectedFuse ConnectedFuse } -func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) { +func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection { return &HTTP2Connection{ conn: conn, server: &http2.Server{ @@ -52,7 +52,7 @@ func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, named connIndex: connIndex, wg: &sync.WaitGroup{}, connectedFuse: connectedFuse, - }, nil + } } func (c *HTTP2Connection) Serve(ctx context.Context) { diff --git a/connection/metrics.go b/connection/metrics.go index 405611bc..0d86fd05 100644 --- a/connection/metrics.go +++ b/connection/metrics.go @@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 { } // Metrics that can be collected without asking the edge -func newTunnelMetrics(protocol Protocol) *tunnelMetrics { +func newTunnelMetrics() *tunnelMetrics { maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: MetricsNamespace, @@ -374,16 +374,12 @@ func newTunnelMetrics(protocol Protocol) *tunnelMetrics { []string{"rpcName"}, ) prometheus.MustRegister(registerSuccess) - var muxerMetrics *muxerMetrics - if protocol == H2mux { - muxerMetrics = newMuxerMetrics() - } return &tunnelMetrics{ timerRetries: timerRetries, serverLocations: serverLocations, oldServerLocations: make(map[string]string), - muxerMetrics: muxerMetrics, + muxerMetrics: newMuxerMetrics(), tunnelsHA: NewTunnelsForHA(), regSuccess: registerSuccess, regFail: registerFail, diff --git a/connection/observer.go b/connection/observer.go index 9dfb15aa..fbcf7b6c 100644 --- a/connection/observer.go +++ b/connection/observer.go @@ -16,10 +16,10 @@ type Observer struct { tunnelEventChan chan<- ui.TunnelEvent } -func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer { +func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent) *Observer { return &Observer{ logger, - newTunnelMetrics(protocol), + newTunnelMetrics(), tunnelEventChan, } } diff --git a/connection/observer_test.go b/connection/observer_test.go index aa47430e..6116cded 100644 --- a/connection/observer_test.go +++ b/connection/observer_test.go @@ -9,7 +9,7 @@ import ( ) // can only be called once -var m = newTunnelMetrics(H2mux) +var m = newTunnelMetrics() func TestRegisterServerLocation(t *testing.T) { tunnels := 20 diff --git a/connection/protocol.go b/connection/protocol.go new file mode 100644 index 00000000..9118a5c6 --- /dev/null +++ b/connection/protocol.go @@ -0,0 +1,179 @@ +package connection + +import ( + "fmt" + "hash/fnv" + "sync" + "time" + + "github.com/cloudflare/cloudflared/logger" +) + +const ( + AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux" + // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge + edgeH2muxTLSServerName = "cftunnel.com" + // edgeH2TLSServerName is the server name to establish http2 connection with edge + edgeH2TLSServerName = "h2.cftunnel.com" + // threshold to switch back to h2mux when the user intentionally pick --protocol http2 + explicitHTTP2FallbackThreshold = -1 + autoSelectFlag = "auto" +) + +var ( + ProtocolList = []Protocol{H2mux, HTTP2} +) + +type Protocol int64 + +const ( + H2mux Protocol = iota + HTTP2 +) + +func (p Protocol) ServerName() string { + switch p { + case H2mux: + return edgeH2muxTLSServerName + case HTTP2: + return edgeH2TLSServerName + default: + return "" + } +} + +// Fallback returns the fallback protocol and whether the protocol has a fallback +func (p Protocol) fallback() (Protocol, bool) { + switch p { + case H2mux: + return 0, false + case HTTP2: + return H2mux, true + default: + return 0, false + } +} + +func (p Protocol) String() string { + switch p { + case H2mux: + return "h2mux" + case HTTP2: + return "http2" + default: + return fmt.Sprintf("unknown protocol") + } +} + +type ProtocolSelector interface { + Current() Protocol + Fallback() (Protocol, bool) +} + +type staticProtocolSelector struct { + current Protocol +} + +func (s *staticProtocolSelector) Current() Protocol { + return s.current +} + +func (s *staticProtocolSelector) Fallback() (Protocol, bool) { + return 0, false +} + +type autoProtocolSelector struct { + lock sync.RWMutex + current Protocol + switchThrehold int32 + fetchFunc PercentageFetcher + refreshAfter time.Time + ttl time.Duration + logger logger.Service +} + +func newAutoProtocolSelector( + current Protocol, + switchThrehold int32, + fetchFunc PercentageFetcher, + ttl time.Duration, + logger logger.Service, +) *autoProtocolSelector { + return &autoProtocolSelector{ + current: current, + switchThrehold: switchThrehold, + fetchFunc: fetchFunc, + refreshAfter: time.Now().Add(ttl), + ttl: ttl, + logger: logger, + } +} + +func (s *autoProtocolSelector) Current() Protocol { + s.lock.Lock() + defer s.lock.Unlock() + if time.Now().Before(s.refreshAfter) { + return s.current + } + + percentage, err := s.fetchFunc() + if err != nil { + s.logger.Errorf("Failed to refresh protocol, err: %v", err) + return s.current + } + + if s.switchThrehold < percentage { + s.current = HTTP2 + } else { + s.current = H2mux + } + s.refreshAfter = time.Now().Add(s.ttl) + return s.current +} + +func (s *autoProtocolSelector) Fallback() (Protocol, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + return s.current.fallback() +} + +type PercentageFetcher func() (int32, error) + +func NewProtocolSelector(protocolFlag string, namedTunnel *NamedTunnelConfig, fetchFunc PercentageFetcher, ttl time.Duration, logger logger.Service) (ProtocolSelector, error) { + if namedTunnel == nil { + return &staticProtocolSelector{ + current: H2mux, + }, nil + } + if protocolFlag == H2mux.String() { + return &staticProtocolSelector{ + current: H2mux, + }, nil + } + + http2Percentage, err := fetchFunc() + if err != nil { + return nil, err + } + if protocolFlag == HTTP2.String() { + if http2Percentage < 0 { + return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil + } + return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil + } + + if protocolFlag != autoSelectFlag { + return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) + } + threshold := switchThreshold(namedTunnel.Auth.AccountTag) + if threshold < http2Percentage { + return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil + } + return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, logger), nil +} + +func switchThreshold(accountTag string) int32 { + h := fnv.New32a() + h.Write([]byte(accountTag)) + return int32(h.Sum32() % 100) +} diff --git a/connection/protocol_test.go b/connection/protocol_test.go new file mode 100644 index 00000000..3cfdf846 --- /dev/null +++ b/connection/protocol_test.go @@ -0,0 +1,220 @@ +package connection + +import ( + "fmt" + "testing" + "time" + + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/stretchr/testify/assert" +) + +const ( + testNoTTL = 0 +) + +var ( + testNamedTunnelConfig = &NamedTunnelConfig{ + Auth: pogs.TunnelAuth{ + AccountTag: "testAccountTag", + }, + } +) + +func mockFetcher(percentage int32) PercentageFetcher { + return func() (int32, error) { + return percentage, nil + } +} + +func mockFetcherWithError() PercentageFetcher { + return func() (int32, error) { + return 0, fmt.Errorf("failed to fetch precentage") + } +} + +type dynamicMockFetcher struct { + percentage int32 + err error +} + +func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { + return func() (int32, error) { + if dmf.err != nil { + return 0, dmf.err + } + return dmf.percentage, nil + } +} + +func TestNewProtocolSelector(t *testing.T) { + tests := []struct { + name string + protocol string + expectedProtocol Protocol + hasFallback bool + expectedFallback Protocol + namedTunnelConfig *NamedTunnelConfig + fetchFunc PercentageFetcher + wantErr bool + }{ + { + name: "classic tunnel", + protocol: "h2mux", + expectedProtocol: H2mux, + namedTunnelConfig: nil, + }, + { + name: "named tunnel over h2mux", + protocol: "h2mux", + expectedProtocol: H2mux, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel over http2", + protocol: "http2", + expectedProtocol: HTTP2, + hasFallback: true, + expectedFallback: H2mux, + fetchFunc: mockFetcher(0), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel http2 disabled", + protocol: "http2", + expectedProtocol: H2mux, + fetchFunc: mockFetcher(-1), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel auto all http2 disabled", + protocol: "auto", + expectedProtocol: H2mux, + fetchFunc: mockFetcher(-1), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel auto to h2mux", + protocol: "auto", + expectedProtocol: H2mux, + fetchFunc: mockFetcher(0), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel auto to http2", + protocol: "auto", + expectedProtocol: HTTP2, + hasFallback: true, + expectedFallback: H2mux, + fetchFunc: mockFetcher(100), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error + name: "classic tunnel unknown protocol", + protocol: "unknown", + expectedProtocol: H2mux, + }, + { + name: "named tunnel unknown protocol", + protocol: "unknown", + fetchFunc: mockFetcher(100), + namedTunnelConfig: testNamedTunnelConfig, + wantErr: true, + }, + { + name: "named tunnel fetch error", + protocol: "unknown", + fetchFunc: mockFetcherWithError(), + namedTunnelConfig: testNamedTunnelConfig, + wantErr: true, + }, + } + logger, _ := logger.New() + for _, test := range tests { + selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, logger) + if test.wantErr { + assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) + } else { + assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) + assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) + fallback, ok := selector.Fallback() + assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) + if test.hasFallback { + assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) + } + } + } +} + +func TestAutoProtocolSelectorRefresh(t *testing.T) { + logger, _ := logger.New() + fetcher := dynamicMockFetcher{} + selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger) + assert.NoError(t, err) + assert.Equal(t, H2mux, selector.Current()) + + fetcher.percentage = 100 + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = 0 + assert.Equal(t, H2mux, selector.Current()) + + fetcher.percentage = 100 + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.err = fmt.Errorf("failed to fetch") + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = -1 + fetcher.err = nil + assert.Equal(t, H2mux, selector.Current()) + + fetcher.percentage = 0 + assert.Equal(t, H2mux, selector.Current()) + + fetcher.percentage = 100 + assert.Equal(t, HTTP2, selector.Current()) +} + +func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { + logger, _ := logger.New() + fetcher := dynamicMockFetcher{} + selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger) + assert.NoError(t, err) + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = 100 + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = 0 + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.err = fmt.Errorf("failed to fetch") + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = -1 + fetcher.err = nil + assert.Equal(t, H2mux, selector.Current()) + + fetcher.percentage = 0 + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = 100 + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = -1 + assert.Equal(t, H2mux, selector.Current()) +} + +func TestProtocolSelectorRefreshTTL(t *testing.T) { + logger, _ := logger.New() + fetcher := dynamicMockFetcher{percentage: 100} + selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, logger) + assert.NoError(t, err) + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.percentage = 0 + assert.Equal(t, HTTP2, selector.Current()) +} diff --git a/origin/backoffhandler.go b/origin/backoffhandler.go index 8ff9752b..e99605e1 100644 --- a/origin/backoffhandler.go +++ b/origin/backoffhandler.go @@ -97,3 +97,11 @@ func (b BackoffHandler) GetBaseTime() time.Duration { func (b *BackoffHandler) Retries() int { return int(b.retries) } + +func (b *BackoffHandler) ReachedMaxRetries() bool { + return b.retries == b.MaxRetries +} + +func (b *BackoffHandler) resetNow() { + b.resetDeadline = time.Now() +} diff --git a/origin/supervisor.go b/origin/supervisor.go index af81378a..7759e103 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -17,10 +17,10 @@ import ( ) const ( + // SRV and TXT record resolution TTL + ResolveTTL = time.Hour // Waiting time before retrying a failed tunnel connection tunnelRetryDuration = time.Second * 10 - // SRV record resolution TTL - resolveTTL = time.Hour // Interval between registering new tunnels registrationInterval = time.Second @@ -43,8 +43,6 @@ type Supervisor struct { cloudflaredUUID uuid.UUID config *TunnelConfig edgeIPs *edgediscovery.Edge - lastResolve time.Time - resolverC chan resolveResult tunnelErrors chan tunnelError tunnelsConnecting map[int]chan struct{} // nextConnectedIndex and nextConnectedSignal are used to wait for all @@ -58,10 +56,6 @@ type Supervisor struct { useReconnectToken bool } -type resolveResult struct { - err error -} - type tunnelError struct { index int addr *net.TCPAddr @@ -74,9 +68,9 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor err error ) if len(config.EdgeAddrs) > 0 { - edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs) + edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs) } else { - edgeIPs, err = edgediscovery.ResolveEdge(config.Observer) + edgeIPs, err = edgediscovery.ResolveEdge(config.Logger) } if err != nil { return nil, err @@ -93,14 +87,13 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor edgeIPs: edgeIPs, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, - logger: config.Observer, + logger: config.Logger, reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections), useReconnectToken: useReconnectToken, }, nil } func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { - logger := s.config.Observer if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { return err } @@ -117,7 +110,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil { refreshAuthBackoffTimer = timer } else { - logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err) + s.logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err) refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration) } } @@ -136,7 +129,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re case tunnelError := <-s.tunnelErrors: tunnelsActive-- if tunnelError.err != nil { - logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err) + s.logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err) tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) s.waitForNextTunnel(tunnelError.index) @@ -159,7 +152,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re case <-refreshAuthBackoffTimer: newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate) if err != nil { - logger.Errorf("supervisor: Authentication failed: %s", err) + s.logger.Errorf("supervisor: Authentication failed: %s", err) // Permanent failure. Leave the `select` without setting the // channel to be non-null, so we'll never hit this case of the `select` again. continue @@ -171,27 +164,15 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re // No more tunnels outstanding, clear backoff timer backoff.SetGracePeriod() } - // DNS resolution returned - case result := <-s.resolverC: - s.lastResolve = time.Now() - s.resolverC = nil - if result.err == nil { - logger.Debug("supervisor: Service discovery refresh complete") - } else { - logger.Errorf("supervisor: Service discovery error: %s", result.err) - } } } } // Returns nil if initialization succeeded, else the initialization error. func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { - logger := s.logger - - s.lastResolve = time.Now() availableAddrs := int(s.edgeIPs.AvailableAddrs()) if s.config.HAConnections > availableAddrs { - logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) + s.logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) s.config.HAConnections = availableAddrs } @@ -304,7 +285,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) return nil, err } - edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP) if err != nil { return nil, err } diff --git a/origin/tunnel.go b/origin/tunnel.go index d96432ea..75e01f33 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -62,13 +62,14 @@ type TunnelConfig struct { ReportedVersion string Retries uint RunFromTerminal bool - TLSConfig *tls.Config - NamedTunnel *connection.NamedTunnelConfig - ClassicTunnel *connection.ClassicTunnelConfig - MuxerConfig *connection.MuxerConfig - TunnelEventChan chan ui.TunnelEvent - IngressRules ingress.Ingress + NamedTunnel *connection.NamedTunnelConfig + ClassicTunnel *connection.ClassicTunnelConfig + MuxerConfig *connection.MuxerConfig + TunnelEventChan chan ui.TunnelEvent + IngressRules ingress.Ingress + ProtocolSelector connection.ProtocolSelector + EdgeTLSConfigs map[connection.Protocol]*tls.Config } type muxerShutdownError struct{} @@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context, credentialManager *reconnectCredentialManager, config *TunnelConfig, addr *net.TCPAddr, - connectionIndex uint8, + connIndex uint8, connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, @@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context, haConnections.Inc() defer haConnections.Dec() - backoff := BackoffHandler{MaxRetries: config.Retries} + protocallFallback := &protocallFallback{ + BackoffHandler{MaxRetries: config.Retries}, + config.ProtocolSelector.Current(), + false, + } connectedFuse := h2mux.NewBooleanFuse() go func() { if connectedFuse.Await() { @@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context, }() // Ensure the above goroutine will terminate if we return without connecting defer connectedFuse.Fuse(false) + // Each connection to keep its own copy of protocol, because individual connections might fallback + // to another protocol when a particular metal doesn't support new protocol for { err, recoverable := ServeTunnel( ctx, credentialManager, config, - addr, connectionIndex, + addr, + connIndex, connectedFuse, - &backoff, + protocallFallback, cloudflaredUUID, reconnectCh, + protocallFallback.protocol, ) - if recoverable { - if duration, ok := backoff.GetBackoffDuration(ctx); ok { - if config.TunnelEventChan != nil { - config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting} - } - config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err) - backoff.Backoff(ctx) - continue - } + if !recoverable { + return err } + + err = waitForBackoff(ctx, protocallFallback, config, connIndex, err) + if err != nil { + return err + } + } +} + +// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches +// max retries +type protocallFallback struct { + BackoffHandler + protocol connection.Protocol + inFallback bool +} + +func (pf *protocallFallback) reset() { + pf.resetNow() + pf.inFallback = false +} + +func (pf *protocallFallback) fallback(fallback connection.Protocol) { + pf.resetNow() + pf.protocol = fallback + pf.inFallback = true +} + +// Expect err to always be non nil +func waitForBackoff( + ctx context.Context, + protobackoff *protocallFallback, + config *TunnelConfig, + connIndex uint8, + err error, +) error { + duration, ok := protobackoff.GetBackoffDuration(ctx) + if !ok { return err } + + if config.TunnelEventChan != nil { + config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting} + } + + config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err) + protobackoff.Backoff(ctx) + + if protobackoff.ReachedMaxRetries() { + fallback, hasFallback := config.ProtocolSelector.Fallback() + if !hasFallback { + return err + } + // Already using fallback protocol, no point to retry + if protobackoff.protocol == fallback { + return err + } + config.Logger.Infof("Fallback to use %s", fallback) + protobackoff.fallback(fallback) + } else if !protobackoff.inFallback { + current := config.ProtocolSelector.Current() + if protobackoff.protocol != current { + protobackoff.protocol = current + config.Logger.Infof("Change protocol to %s", current) + } + } + return nil } func ServeTunnel( @@ -204,11 +270,12 @@ func ServeTunnel( credentialManager *reconnectCredentialManager, config *TunnelConfig, addr *net.TCPAddr, - connectionIndex uint8, + connIndex uint8, fuse *h2mux.BooleanFuse, - backoff *BackoffHandler, + backoff *protocallFallback, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, + protocol connection.Protocol, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -226,11 +293,11 @@ func ServeTunnel( // If launch-ui flag is set, send disconnect msg if config.TunnelEventChan != nil { defer func() { - config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected} + config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected} }() } - edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr) if err != nil { return err, true } @@ -238,11 +305,11 @@ func ServeTunnel( fuse: fuse, backoff: backoff, } - if config.Protocol == connection.HTTP2 { + if protocol == connection.HTTP2 { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) - return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh) + return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh) } - return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh) + return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh) } func ServeH2mux( @@ -255,6 +322,7 @@ func ServeH2mux( cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { + config.Logger.Debugf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer) if err != nil { @@ -266,10 +334,10 @@ func ServeH2mux( errGroup.Go(func() (err error) { if config.NamedTunnel != nil { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) - return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) + return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) } registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) - return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) + return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) errGroup.Go(listenReconnect(serveCtx, reconnectCh)) @@ -295,7 +363,7 @@ func ServeH2mux( config.Logger.Info("Muxer shutdown") return err, true case *ReconnectSignal: - config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay) + config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay) err.DelayBeforeReconnect() return err, true default: @@ -319,10 +387,8 @@ func ServeHTTP2( connectedFuse connection.ConnectedFuse, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { - server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) - if err != nil { - return err, false - } + config.Logger.Debugf("Connecting via http2") + server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { @@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu type connectedFuse struct { fuse *h2mux.BooleanFuse - backoff *BackoffHandler + backoff *protocallFallback } func (cf *connectedFuse) Connected() { cf.fuse.Fuse(true) - cf.backoff.SetGracePeriod() + cf.backoff.reset() } func (cf *connectedFuse) IsConnected() bool { diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go new file mode 100644 index 00000000..e6660b82 --- /dev/null +++ b/origin/tunnel_test.go @@ -0,0 +1,90 @@ +package origin + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/stretchr/testify/assert" +) + +type dynamicMockFetcher struct { + percentage int32 + err error +} + +func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { + return func() (int32, error) { + if dmf.err != nil { + return 0, dmf.err + } + return dmf.percentage, nil + } +} +func TestWaitForBackoffFallback(t *testing.T) { + maxRetries := uint(3) + backoff := BackoffHandler{ + MaxRetries: maxRetries, + BaseTime: time.Millisecond * 10, + } + ctx := context.Background() + logger, err := logger.New() + assert.NoError(t, err) + resolveTTL := time.Duration(0) + namedTunnel := &connection.NamedTunnelConfig{ + Auth: pogs.TunnelAuth{ + AccountTag: "test-account", + }, + } + mockFetcher := dynamicMockFetcher{ + percentage: 0, + } + protocolSelector, err := connection.NewProtocolSelector(connection.HTTP2.String(), namedTunnel, mockFetcher.fetch(), resolveTTL, logger) + assert.NoError(t, err) + config := &TunnelConfig{ + Logger: logger, + ProtocolSelector: protocolSelector, + } + connIndex := uint8(1) + + initProtocol := protocolSelector.Current() + assert.Equal(t, connection.HTTP2, initProtocol) + + protocallFallback := &protocallFallback{ + backoff, + initProtocol, + false, + } + + // Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this + for i := 0; i < int(maxRetries-1); i++ { + err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error")) + assert.NoError(t, err) + assert.Equal(t, initProtocol, protocallFallback.protocol) + } + + // Retry fallback protocol + for i := 0; i < int(maxRetries); i++ { + err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error")) + assert.NoError(t, err) + fallback, ok := protocolSelector.Fallback() + assert.True(t, ok) + assert.Equal(t, fallback, protocallFallback.protocol) + } + + currentGlobalProtocol := protocolSelector.Current() + assert.Equal(t, initProtocol, currentGlobalProtocol) + + // No protocol to fallback, return error + err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error")) + assert.Error(t, err) + + protocallFallback.reset() + err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("New error")) + assert.NoError(t, err) + assert.Equal(t, initProtocol, protocallFallback.protocol) +}