TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries

This commit is contained in:
cthuang 2020-10-14 14:42:00 +01:00
parent b5cdf3b2c7
commit a490443630
13 changed files with 632 additions and 159 deletions

View File

@ -232,15 +232,19 @@ func prepareTunnelConfig(
} }
} }
protocol, err := determineProtocol(c, namedTunnel) protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
logger.Infof("Using protocol %s", protocol) logger.Infof("Initial protocol %s", protocolSelector.Current())
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName())
if err != nil { edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
logger.Errorf("unable to create TLS config to connect with edge: %s", err) for _, p := range connection.ProtocolList {
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName())
if err != nil {
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
}
edgeTLSConfigs[p] = edgeTLSConfig
} }
proxyConfig := &origin.ProxyConfig{ proxyConfig := &origin.ProxyConfig{
@ -252,7 +256,7 @@ func prepareTunnelConfig(
Tags: tags, Tags: tags,
} }
originClient := origin.NewClient(proxyConfig, logger) originClient := origin.NewClient(proxyConfig, logger)
transportConfig := &connection.Config{ connectionConfig := &connection.Config{
OriginClient: originClient, OriginClient: originClient,
GracePeriod: c.Duration("grace-period"), GracePeriod: c.Duration("grace-period"),
ReplaceExisting: c.Bool("force"), ReplaceExisting: c.Bool("force"),
@ -270,7 +274,7 @@ func prepareTunnelConfig(
} }
return &origin.TunnelConfig{ return &origin.TunnelConfig{
ConnectionConfig: transportConfig, ConnectionConfig: connectionConfig,
ProxyConfig: proxyConfig, ProxyConfig: proxyConfig,
BuildInfo: buildInfo, BuildInfo: buildInfo,
ClientID: clientID, ClientID: clientID,
@ -281,34 +285,20 @@ func prepareTunnelConfig(
IsFreeTunnel: isFreeTunnel, IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"), LBPool: c.String("lb-pool"),
Logger: logger, Logger: logger,
Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol), Observer: connection.NewObserver(transportLogger, tunnelEventChan),
ReportedVersion: version, ReportedVersion: version,
Retries: c.Uint("retries"), Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
TLSConfig: toEdgeTLSConfig,
NamedTunnel: namedTunnel, NamedTunnel: namedTunnel,
ClassicTunnel: classicTunnel, ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig, MuxerConfig: muxerConfig,
TunnelEventChan: tunnelEventChan, TunnelEventChan: tunnelEventChan,
IngressRules: ingressRules, IngressRules: ingressRules,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
}, nil }, nil
} }
func isRunningFromTerminal() bool { func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd())) return terminal.IsTerminal(int(os.Stdout.Fd()))
} }
func determineProtocol(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) (connection.Protocol, error) {
if namedTunnel == nil {
return connection.H2mux, nil
}
http2Percentage, err := edgediscovery.HTTP2Percentage()
if err != nil {
return 0, err
}
protocol, ok := connection.SelectProtocol(c.String("protocol"), namedTunnel.Auth.AccountTag, http2Percentage)
if !ok {
return 0, fmt.Errorf("%s is not valid protocol. %s", c.String("protocol"), availableProtocol)
}
return protocol, nil
}

View File

@ -23,6 +23,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstore" "github.com/cloudflare/cloudflared/tunnelstore"
@ -30,7 +31,6 @@ import (
const ( const (
credFileFlagAlias = "cred-file" credFileFlagAlias = "cred-file"
availableProtocol = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
) )
var ( var (
@ -90,7 +90,7 @@ var (
Name: "protocol", Name: "protocol",
Value: "h2mux", Value: "h2mux",
Aliases: []string{"p"}, Aliases: []string{"p"},
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol), Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage),
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"}, EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
Hidden: true, Hidden: true,
}) })

View File

@ -1,8 +1,6 @@
package connection package connection
import ( import (
"fmt"
"hash/fnv"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@ -13,10 +11,6 @@ import (
) )
const ( const (
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
) )
@ -43,57 +37,6 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
return c.Hostname == "" return c.Hostname == ""
} }
type Protocol int64
const (
H2mux Protocol = iota
HTTP2
)
func SelectProtocol(s string, accountTag string, http2Percentage uint32) (Protocol, bool) {
switch s {
case "h2mux":
return H2mux, true
case "http2":
return HTTP2, true
case "auto":
if tryHTTP2(accountTag, http2Percentage) {
return HTTP2, true
}
return H2mux, true
default:
return 0, false
}
}
func tryHTTP2(accountTag string, http2Percentage uint32) bool {
h := fnv.New32a()
h.Write([]byte(accountTag))
return h.Sum32()%100 < http2Percentage
}
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
func (p Protocol) String() string {
switch p {
case H2mux:
return "h2mux"
case HTTP2:
return "http2"
default:
return fmt.Sprintf("unknown protocol")
}
}
type OriginClient interface { type OriginClient interface {
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
} }

View File

@ -37,7 +37,7 @@ type HTTP2Connection struct {
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
} }
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) { func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection {
return &HTTP2Connection{ return &HTTP2Connection{
conn: conn, conn: conn,
server: &http2.Server{ server: &http2.Server{
@ -52,7 +52,7 @@ func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, named
connIndex: connIndex, connIndex: connIndex,
wg: &sync.WaitGroup{}, wg: &sync.WaitGroup{},
connectedFuse: connectedFuse, connectedFuse: connectedFuse,
}, nil }
} }
func (c *HTTP2Connection) Serve(ctx context.Context) { func (c *HTTP2Connection) Serve(ctx context.Context) {

View File

@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 {
} }
// Metrics that can be collected without asking the edge // Metrics that can be collected without asking the edge
func newTunnelMetrics(protocol Protocol) *tunnelMetrics { func newTunnelMetrics() *tunnelMetrics {
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: MetricsNamespace, Namespace: MetricsNamespace,
@ -374,16 +374,12 @@ func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
[]string{"rpcName"}, []string{"rpcName"},
) )
prometheus.MustRegister(registerSuccess) prometheus.MustRegister(registerSuccess)
var muxerMetrics *muxerMetrics
if protocol == H2mux {
muxerMetrics = newMuxerMetrics()
}
return &tunnelMetrics{ return &tunnelMetrics{
timerRetries: timerRetries, timerRetries: timerRetries,
serverLocations: serverLocations, serverLocations: serverLocations,
oldServerLocations: make(map[string]string), oldServerLocations: make(map[string]string),
muxerMetrics: muxerMetrics, muxerMetrics: newMuxerMetrics(),
tunnelsHA: NewTunnelsForHA(), tunnelsHA: NewTunnelsForHA(),
regSuccess: registerSuccess, regSuccess: registerSuccess,
regFail: registerFail, regFail: registerFail,

View File

@ -16,10 +16,10 @@ type Observer struct {
tunnelEventChan chan<- ui.TunnelEvent tunnelEventChan chan<- ui.TunnelEvent
} }
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer { func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent) *Observer {
return &Observer{ return &Observer{
logger, logger,
newTunnelMetrics(protocol), newTunnelMetrics(),
tunnelEventChan, tunnelEventChan,
} }
} }

View File

@ -9,7 +9,7 @@ import (
) )
// can only be called once // can only be called once
var m = newTunnelMetrics(H2mux) var m = newTunnelMetrics()
func TestRegisterServerLocation(t *testing.T) { func TestRegisterServerLocation(t *testing.T) {
tunnels := 20 tunnels := 20

179
connection/protocol.go Normal file
View File

@ -0,0 +1,179 @@
package connection
import (
"fmt"
"hash/fnv"
"sync"
"time"
"github.com/cloudflare/cloudflared/logger"
)
const (
AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
explicitHTTP2FallbackThreshold = -1
autoSelectFlag = "auto"
)
var (
ProtocolList = []Protocol{H2mux, HTTP2}
)
type Protocol int64
const (
H2mux Protocol = iota
HTTP2
)
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
// Fallback returns the fallback protocol and whether the protocol has a fallback
func (p Protocol) fallback() (Protocol, bool) {
switch p {
case H2mux:
return 0, false
case HTTP2:
return H2mux, true
default:
return 0, false
}
}
func (p Protocol) String() string {
switch p {
case H2mux:
return "h2mux"
case HTTP2:
return "http2"
default:
return fmt.Sprintf("unknown protocol")
}
}
type ProtocolSelector interface {
Current() Protocol
Fallback() (Protocol, bool)
}
type staticProtocolSelector struct {
current Protocol
}
func (s *staticProtocolSelector) Current() Protocol {
return s.current
}
func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
return 0, false
}
type autoProtocolSelector struct {
lock sync.RWMutex
current Protocol
switchThrehold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
logger logger.Service
}
func newAutoProtocolSelector(
current Protocol,
switchThrehold int32,
fetchFunc PercentageFetcher,
ttl time.Duration,
logger logger.Service,
) *autoProtocolSelector {
return &autoProtocolSelector{
current: current,
switchThrehold: switchThrehold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
logger: logger,
}
}
func (s *autoProtocolSelector) Current() Protocol {
s.lock.Lock()
defer s.lock.Unlock()
if time.Now().Before(s.refreshAfter) {
return s.current
}
percentage, err := s.fetchFunc()
if err != nil {
s.logger.Errorf("Failed to refresh protocol, err: %v", err)
return s.current
}
if s.switchThrehold < percentage {
s.current = HTTP2
} else {
s.current = H2mux
}
s.refreshAfter = time.Now().Add(s.ttl)
return s.current
}
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.current.fallback()
}
type PercentageFetcher func() (int32, error)
func NewProtocolSelector(protocolFlag string, namedTunnel *NamedTunnelConfig, fetchFunc PercentageFetcher, ttl time.Duration, logger logger.Service) (ProtocolSelector, error) {
if namedTunnel == nil {
return &staticProtocolSelector{
current: H2mux,
}, nil
}
if protocolFlag == H2mux.String() {
return &staticProtocolSelector{
current: H2mux,
}, nil
}
http2Percentage, err := fetchFunc()
if err != nil {
return nil, err
}
if protocolFlag == HTTP2.String() {
if http2Percentage < 0 {
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
}
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
threshold := switchThreshold(namedTunnel.Auth.AccountTag)
if threshold < http2Percentage {
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil
}
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, logger), nil
}
func switchThreshold(accountTag string) int32 {
h := fnv.New32a()
h.Write([]byte(accountTag))
return int32(h.Sum32() % 100)
}

220
connection/protocol_test.go Normal file
View File

@ -0,0 +1,220 @@
package connection
import (
"fmt"
"testing"
"time"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
const (
testNoTTL = 0
)
var (
testNamedTunnelConfig = &NamedTunnelConfig{
Auth: pogs.TunnelAuth{
AccountTag: "testAccountTag",
},
}
)
func mockFetcher(percentage int32) PercentageFetcher {
return func() (int32, error) {
return percentage, nil
}
}
func mockFetcherWithError() PercentageFetcher {
return func() (int32, error) {
return 0, fmt.Errorf("failed to fetch precentage")
}
}
type dynamicMockFetcher struct {
percentage int32
err error
}
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
}
}
func TestNewProtocolSelector(t *testing.T) {
tests := []struct {
name string
protocol string
expectedProtocol Protocol
hasFallback bool
expectedFallback Protocol
namedTunnelConfig *NamedTunnelConfig
fetchFunc PercentageFetcher
wantErr bool
}{
{
name: "classic tunnel",
protocol: "h2mux",
expectedProtocol: H2mux,
namedTunnelConfig: nil,
},
{
name: "named tunnel over h2mux",
protocol: "h2mux",
expectedProtocol: H2mux,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel over http2",
protocol: "http2",
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(0),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel http2 disabled",
protocol: "http2",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto all http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to h2mux",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(0),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to http2",
protocol: "auto",
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
},
{
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
name: "classic tunnel unknown protocol",
protocol: "unknown",
expectedProtocol: H2mux,
},
{
name: "named tunnel unknown protocol",
protocol: "unknown",
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
wantErr: true,
},
{
name: "named tunnel fetch error",
protocol: "unknown",
fetchFunc: mockFetcherWithError(),
namedTunnelConfig: testNamedTunnelConfig,
wantErr: true,
},
}
logger, _ := logger.New()
for _, test := range tests {
selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, logger)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
}
}
}
func TestAutoProtocolSelectorRefresh(t *testing.T) {
logger, _ := logger.New()
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
}
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
logger, _ := logger.New()
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
assert.Equal(t, H2mux, selector.Current())
}
func TestProtocolSelectorRefreshTTL(t *testing.T) {
logger, _ := logger.New()
fetcher := dynamicMockFetcher{percentage: 100}
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, logger)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
}

View File

@ -97,3 +97,11 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
func (b *BackoffHandler) Retries() int { func (b *BackoffHandler) Retries() int {
return int(b.retries) return int(b.retries)
} }
func (b *BackoffHandler) ReachedMaxRetries() bool {
return b.retries == b.MaxRetries
}
func (b *BackoffHandler) resetNow() {
b.resetDeadline = time.Now()
}

View File

@ -17,10 +17,10 @@ import (
) )
const ( const (
// SRV and TXT record resolution TTL
ResolveTTL = time.Hour
// Waiting time before retrying a failed tunnel connection // Waiting time before retrying a failed tunnel connection
tunnelRetryDuration = time.Second * 10 tunnelRetryDuration = time.Second * 10
// SRV record resolution TTL
resolveTTL = time.Hour
// Interval between registering new tunnels // Interval between registering new tunnels
registrationInterval = time.Second registrationInterval = time.Second
@ -43,8 +43,6 @@ type Supervisor struct {
cloudflaredUUID uuid.UUID cloudflaredUUID uuid.UUID
config *TunnelConfig config *TunnelConfig
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
lastResolve time.Time
resolverC chan resolveResult
tunnelErrors chan tunnelError tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{} tunnelsConnecting map[int]chan struct{}
// nextConnectedIndex and nextConnectedSignal are used to wait for all // nextConnectedIndex and nextConnectedSignal are used to wait for all
@ -58,10 +56,6 @@ type Supervisor struct {
useReconnectToken bool useReconnectToken bool
} }
type resolveResult struct {
err error
}
type tunnelError struct { type tunnelError struct {
index int index int
addr *net.TCPAddr addr *net.TCPAddr
@ -74,9 +68,9 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
err error err error
) )
if len(config.EdgeAddrs) > 0 { if len(config.EdgeAddrs) > 0 {
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs) edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
} else { } else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer) edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -93,14 +87,13 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
logger: config.Observer, logger: config.Logger,
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections), reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
useReconnectToken: useReconnectToken, useReconnectToken: useReconnectToken,
}, nil }, nil
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.config.Observer
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err return err
} }
@ -117,7 +110,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil { if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer refreshAuthBackoffTimer = timer
} else { } else {
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err) s.logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration) refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
} }
} }
@ -136,7 +129,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case tunnelError := <-s.tunnelErrors: case tunnelError := <-s.tunnelErrors:
tunnelsActive-- tunnelsActive--
if tunnelError.err != nil { if tunnelError.err != nil {
logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err) s.logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index) s.waitForNextTunnel(tunnelError.index)
@ -159,7 +152,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case <-refreshAuthBackoffTimer: case <-refreshAuthBackoffTimer:
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate) newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil { if err != nil {
logger.Errorf("supervisor: Authentication failed: %s", err) s.logger.Errorf("supervisor: Authentication failed: %s", err)
// Permanent failure. Leave the `select` without setting the // Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again. // channel to be non-null, so we'll never hit this case of the `select` again.
continue continue
@ -171,27 +164,15 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
// No more tunnels outstanding, clear backoff timer // No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod() backoff.SetGracePeriod()
} }
// DNS resolution returned
case result := <-s.resolverC:
s.lastResolve = time.Now()
s.resolverC = nil
if result.err == nil {
logger.Debug("supervisor: Service discovery refresh complete")
} else {
logger.Errorf("supervisor: Service discovery error: %s", result.err)
}
} }
} }
} }
// Returns nil if initialization succeeded, else the initialization error. // Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.logger
s.lastResolve = time.Now()
availableAddrs := int(s.edgeIPs.AvailableAddrs()) availableAddrs := int(s.edgeIPs.AvailableAddrs())
if s.config.HAConnections > availableAddrs { if s.config.HAConnections > availableAddrs {
logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) s.logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs s.config.HAConnections = availableAddrs
} }
@ -304,7 +285,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
return nil, err return nil, err
} }
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -62,13 +62,14 @@ type TunnelConfig struct {
ReportedVersion string ReportedVersion string
Retries uint Retries uint
RunFromTerminal bool RunFromTerminal bool
TLSConfig *tls.Config
NamedTunnel *connection.NamedTunnelConfig NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress IngressRules ingress.Ingress
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
} }
type muxerShutdownError struct{} type muxerShutdownError struct{}
@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
config *TunnelConfig, config *TunnelConfig,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionIndex uint8, connIndex uint8,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context,
haConnections.Inc() haConnections.Inc()
defer haConnections.Dec() defer haConnections.Dec()
backoff := BackoffHandler{MaxRetries: config.Retries} protocallFallback := &protocallFallback{
BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(),
false,
}
connectedFuse := h2mux.NewBooleanFuse() connectedFuse := h2mux.NewBooleanFuse()
go func() { go func() {
if connectedFuse.Await() { if connectedFuse.Await() {
@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context,
}() }()
// Ensure the above goroutine will terminate if we return without connecting // Ensure the above goroutine will terminate if we return without connecting
defer connectedFuse.Fuse(false) defer connectedFuse.Fuse(false)
// Each connection to keep its own copy of protocol, because individual connections might fallback
// to another protocol when a particular metal doesn't support new protocol
for { for {
err, recoverable := ServeTunnel( err, recoverable := ServeTunnel(
ctx, ctx,
credentialManager, credentialManager,
config, config,
addr, connectionIndex, addr,
connIndex,
connectedFuse, connectedFuse,
&backoff, protocallFallback,
cloudflaredUUID, cloudflaredUUID,
reconnectCh, reconnectCh,
protocallFallback.protocol,
) )
if recoverable { if !recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok { return err
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
backoff.Backoff(ctx)
continue
}
} }
err = waitForBackoff(ctx, protocallFallback, config, connIndex, err)
if err != nil {
return err
}
}
}
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries
type protocallFallback struct {
BackoffHandler
protocol connection.Protocol
inFallback bool
}
func (pf *protocallFallback) reset() {
pf.resetNow()
pf.inFallback = false
}
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
pf.resetNow()
pf.protocol = fallback
pf.inFallback = true
}
// Expect err to always be non nil
func waitForBackoff(
ctx context.Context,
protobackoff *protocallFallback,
config *TunnelConfig,
connIndex uint8,
err error,
) error {
duration, ok := protobackoff.GetBackoffDuration(ctx)
if !ok {
return err return err
} }
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err)
protobackoff.Backoff(ctx)
if protobackoff.ReachedMaxRetries() {
fallback, hasFallback := config.ProtocolSelector.Fallback()
if !hasFallback {
return err
}
// Already using fallback protocol, no point to retry
if protobackoff.protocol == fallback {
return err
}
config.Logger.Infof("Fallback to use %s", fallback)
protobackoff.fallback(fallback)
} else if !protobackoff.inFallback {
current := config.ProtocolSelector.Current()
if protobackoff.protocol != current {
protobackoff.protocol = current
config.Logger.Infof("Change protocol to %s", current)
}
}
return nil
} }
func ServeTunnel( func ServeTunnel(
@ -204,11 +270,12 @@ func ServeTunnel(
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
config *TunnelConfig, config *TunnelConfig,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionIndex uint8, connIndex uint8,
fuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *protocallFallback,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
protocol connection.Protocol,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -226,11 +293,11 @@ func ServeTunnel(
// If launch-ui flag is set, send disconnect msg // If launch-ui flag is set, send disconnect msg
if config.TunnelEventChan != nil { if config.TunnelEventChan != nil {
defer func() { defer func() {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected} config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected}
}() }()
} }
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr)
if err != nil { if err != nil {
return err, true return err, true
} }
@ -238,11 +305,11 @@ func ServeTunnel(
fuse: fuse, fuse: fuse,
backoff: backoff, backoff: backoff,
} }
if config.Protocol == connection.HTTP2 { if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh) return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
} }
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh) return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh)
} }
func ServeH2mux( func ServeH2mux(
@ -255,6 +322,7 @@ func ServeH2mux(
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
config.Logger.Debugf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer) handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
if err != nil { if err != nil {
@ -266,10 +334,10 @@ func ServeH2mux(
errGroup.Go(func() (err error) { errGroup.Go(func() (err error) {
if config.NamedTunnel != nil { if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
} }
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
}) })
errGroup.Go(listenReconnect(serveCtx, reconnectCh)) errGroup.Go(listenReconnect(serveCtx, reconnectCh))
@ -295,7 +363,7 @@ func ServeH2mux(
config.Logger.Info("Muxer shutdown") config.Logger.Info("Muxer shutdown")
return err, true return err, true
case *ReconnectSignal: case *ReconnectSignal:
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay) config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay)
err.DelayBeforeReconnect() err.DelayBeforeReconnect()
return err, true return err, true
default: default:
@ -319,10 +387,8 @@ func ServeHTTP2(
connectedFuse connection.ConnectedFuse, connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) config.Logger.Debugf("Connecting via http2")
if err != nil { server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
return err, false
}
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu
type connectedFuse struct { type connectedFuse struct {
fuse *h2mux.BooleanFuse fuse *h2mux.BooleanFuse
backoff *BackoffHandler backoff *protocallFallback
} }
func (cf *connectedFuse) Connected() { func (cf *connectedFuse) Connected() {
cf.fuse.Fuse(true) cf.fuse.Fuse(true)
cf.backoff.SetGracePeriod() cf.backoff.reset()
} }
func (cf *connectedFuse) IsConnected() bool { func (cf *connectedFuse) IsConnected() bool {

90
origin/tunnel_test.go Normal file
View File

@ -0,0 +1,90 @@
package origin
import (
"context"
"fmt"
"testing"
"time"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
type dynamicMockFetcher struct {
percentage int32
err error
}
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
}
}
func TestWaitForBackoffFallback(t *testing.T) {
maxRetries := uint(3)
backoff := BackoffHandler{
MaxRetries: maxRetries,
BaseTime: time.Millisecond * 10,
}
ctx := context.Background()
logger, err := logger.New()
assert.NoError(t, err)
resolveTTL := time.Duration(0)
namedTunnel := &connection.NamedTunnelConfig{
Auth: pogs.TunnelAuth{
AccountTag: "test-account",
},
}
mockFetcher := dynamicMockFetcher{
percentage: 0,
}
protocolSelector, err := connection.NewProtocolSelector(connection.HTTP2.String(), namedTunnel, mockFetcher.fetch(), resolveTTL, logger)
assert.NoError(t, err)
config := &TunnelConfig{
Logger: logger,
ProtocolSelector: protocolSelector,
}
connIndex := uint8(1)
initProtocol := protocolSelector.Current()
assert.Equal(t, connection.HTTP2, initProtocol)
protocallFallback := &protocallFallback{
backoff,
initProtocol,
false,
}
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
for i := 0; i < int(maxRetries-1); i++ {
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.NoError(t, err)
assert.Equal(t, initProtocol, protocallFallback.protocol)
}
// Retry fallback protocol
for i := 0; i < int(maxRetries); i++ {
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.NoError(t, err)
fallback, ok := protocolSelector.Fallback()
assert.True(t, ok)
assert.Equal(t, fallback, protocallFallback.protocol)
}
currentGlobalProtocol := protocolSelector.Current()
assert.Equal(t, initProtocol, currentGlobalProtocol)
// No protocol to fallback, return error
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.Error(t, err)
protocallFallback.reset()
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("New error"))
assert.NoError(t, err)
assert.Equal(t, initProtocol, protocallFallback.protocol)
}