TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries

This commit is contained in:
cthuang 2020-10-14 14:42:00 +01:00
parent b5cdf3b2c7
commit a490443630
13 changed files with 632 additions and 159 deletions

View File

@ -232,15 +232,19 @@ func prepareTunnelConfig(
}
}
protocol, err := determineProtocol(c, namedTunnel)
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger)
if err != nil {
return nil, err
}
logger.Infof("Using protocol %s", protocol)
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName())
if err != nil {
logger.Errorf("unable to create TLS config to connect with edge: %s", err)
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
logger.Infof("Initial protocol %s", protocolSelector.Current())
edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
for _, p := range connection.ProtocolList {
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName())
if err != nil {
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
}
edgeTLSConfigs[p] = edgeTLSConfig
}
proxyConfig := &origin.ProxyConfig{
@ -252,7 +256,7 @@ func prepareTunnelConfig(
Tags: tags,
}
originClient := origin.NewClient(proxyConfig, logger)
transportConfig := &connection.Config{
connectionConfig := &connection.Config{
OriginClient: originClient,
GracePeriod: c.Duration("grace-period"),
ReplaceExisting: c.Bool("force"),
@ -270,7 +274,7 @@ func prepareTunnelConfig(
}
return &origin.TunnelConfig{
ConnectionConfig: transportConfig,
ConnectionConfig: connectionConfig,
ProxyConfig: proxyConfig,
BuildInfo: buildInfo,
ClientID: clientID,
@ -281,34 +285,20 @@ func prepareTunnelConfig(
IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"),
Logger: logger,
Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol),
Observer: connection.NewObserver(transportLogger, tunnelEventChan),
ReportedVersion: version,
Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(),
TLSConfig: toEdgeTLSConfig,
NamedTunnel: namedTunnel,
ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig,
TunnelEventChan: tunnelEventChan,
IngressRules: ingressRules,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
}, nil
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}
func determineProtocol(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) (connection.Protocol, error) {
if namedTunnel == nil {
return connection.H2mux, nil
}
http2Percentage, err := edgediscovery.HTTP2Percentage()
if err != nil {
return 0, err
}
protocol, ok := connection.SelectProtocol(c.String("protocol"), namedTunnel.Auth.AccountTag, http2Percentage)
if !ok {
return 0, fmt.Errorf("%s is not valid protocol. %s", c.String("protocol"), availableProtocol)
}
return protocol, nil
}

View File

@ -23,6 +23,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstore"
@ -30,7 +31,6 @@ import (
const (
credFileFlagAlias = "cred-file"
availableProtocol = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
)
var (
@ -90,7 +90,7 @@ var (
Name: "protocol",
Value: "h2mux",
Aliases: []string{"p"},
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol),
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage),
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
Hidden: true,
})

View File

@ -1,8 +1,6 @@
package connection
import (
"fmt"
"hash/fnv"
"io"
"net/http"
"strconv"
@ -13,10 +11,6 @@ import (
)
const (
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
)
@ -43,57 +37,6 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
return c.Hostname == ""
}
type Protocol int64
const (
H2mux Protocol = iota
HTTP2
)
func SelectProtocol(s string, accountTag string, http2Percentage uint32) (Protocol, bool) {
switch s {
case "h2mux":
return H2mux, true
case "http2":
return HTTP2, true
case "auto":
if tryHTTP2(accountTag, http2Percentage) {
return HTTP2, true
}
return H2mux, true
default:
return 0, false
}
}
func tryHTTP2(accountTag string, http2Percentage uint32) bool {
h := fnv.New32a()
h.Write([]byte(accountTag))
return h.Sum32()%100 < http2Percentage
}
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
func (p Protocol) String() string {
switch p {
case H2mux:
return "h2mux"
case HTTP2:
return "http2"
default:
return fmt.Sprintf("unknown protocol")
}
}
type OriginClient interface {
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
}

View File

@ -37,7 +37,7 @@ type HTTP2Connection struct {
connectedFuse ConnectedFuse
}
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) {
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection {
return &HTTP2Connection{
conn: conn,
server: &http2.Server{
@ -52,7 +52,7 @@ func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, named
connIndex: connIndex,
wg: &sync.WaitGroup{},
connectedFuse: connectedFuse,
}, nil
}
}
func (c *HTTP2Connection) Serve(ctx context.Context) {

View File

@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 {
}
// Metrics that can be collected without asking the edge
func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
func newTunnelMetrics() *tunnelMetrics {
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
@ -374,16 +374,12 @@ func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
[]string{"rpcName"},
)
prometheus.MustRegister(registerSuccess)
var muxerMetrics *muxerMetrics
if protocol == H2mux {
muxerMetrics = newMuxerMetrics()
}
return &tunnelMetrics{
timerRetries: timerRetries,
serverLocations: serverLocations,
oldServerLocations: make(map[string]string),
muxerMetrics: muxerMetrics,
muxerMetrics: newMuxerMetrics(),
tunnelsHA: NewTunnelsForHA(),
regSuccess: registerSuccess,
regFail: registerFail,

View File

@ -16,10 +16,10 @@ type Observer struct {
tunnelEventChan chan<- ui.TunnelEvent
}
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer {
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent) *Observer {
return &Observer{
logger,
newTunnelMetrics(protocol),
newTunnelMetrics(),
tunnelEventChan,
}
}

View File

@ -9,7 +9,7 @@ import (
)
// can only be called once
var m = newTunnelMetrics(H2mux)
var m = newTunnelMetrics()
func TestRegisterServerLocation(t *testing.T) {
tunnels := 20

179
connection/protocol.go Normal file
View File

@ -0,0 +1,179 @@
package connection
import (
"fmt"
"hash/fnv"
"sync"
"time"
"github.com/cloudflare/cloudflared/logger"
)
const (
AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
explicitHTTP2FallbackThreshold = -1
autoSelectFlag = "auto"
)
var (
ProtocolList = []Protocol{H2mux, HTTP2}
)
type Protocol int64
const (
H2mux Protocol = iota
HTTP2
)
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
// Fallback returns the fallback protocol and whether the protocol has a fallback
func (p Protocol) fallback() (Protocol, bool) {
switch p {
case H2mux:
return 0, false
case HTTP2:
return H2mux, true
default:
return 0, false
}
}
func (p Protocol) String() string {
switch p {
case H2mux:
return "h2mux"
case HTTP2:
return "http2"
default:
return fmt.Sprintf("unknown protocol")
}
}
type ProtocolSelector interface {
Current() Protocol
Fallback() (Protocol, bool)
}
type staticProtocolSelector struct {
current Protocol
}
func (s *staticProtocolSelector) Current() Protocol {
return s.current
}
func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
return 0, false
}
type autoProtocolSelector struct {
lock sync.RWMutex
current Protocol
switchThrehold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
logger logger.Service
}
func newAutoProtocolSelector(
current Protocol,
switchThrehold int32,
fetchFunc PercentageFetcher,
ttl time.Duration,
logger logger.Service,
) *autoProtocolSelector {
return &autoProtocolSelector{
current: current,
switchThrehold: switchThrehold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
logger: logger,
}
}
func (s *autoProtocolSelector) Current() Protocol {
s.lock.Lock()
defer s.lock.Unlock()
if time.Now().Before(s.refreshAfter) {
return s.current
}
percentage, err := s.fetchFunc()
if err != nil {
s.logger.Errorf("Failed to refresh protocol, err: %v", err)
return s.current
}
if s.switchThrehold < percentage {
s.current = HTTP2
} else {
s.current = H2mux
}
s.refreshAfter = time.Now().Add(s.ttl)
return s.current
}
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.current.fallback()
}
type PercentageFetcher func() (int32, error)
func NewProtocolSelector(protocolFlag string, namedTunnel *NamedTunnelConfig, fetchFunc PercentageFetcher, ttl time.Duration, logger logger.Service) (ProtocolSelector, error) {
if namedTunnel == nil {
return &staticProtocolSelector{
current: H2mux,
}, nil
}
if protocolFlag == H2mux.String() {
return &staticProtocolSelector{
current: H2mux,
}, nil
}
http2Percentage, err := fetchFunc()
if err != nil {
return nil, err
}
if protocolFlag == HTTP2.String() {
if http2Percentage < 0 {
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
}
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
threshold := switchThreshold(namedTunnel.Auth.AccountTag)
if threshold < http2Percentage {
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil
}
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, logger), nil
}
func switchThreshold(accountTag string) int32 {
h := fnv.New32a()
h.Write([]byte(accountTag))
return int32(h.Sum32() % 100)
}

220
connection/protocol_test.go Normal file
View File

@ -0,0 +1,220 @@
package connection
import (
"fmt"
"testing"
"time"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
const (
testNoTTL = 0
)
var (
testNamedTunnelConfig = &NamedTunnelConfig{
Auth: pogs.TunnelAuth{
AccountTag: "testAccountTag",
},
}
)
func mockFetcher(percentage int32) PercentageFetcher {
return func() (int32, error) {
return percentage, nil
}
}
func mockFetcherWithError() PercentageFetcher {
return func() (int32, error) {
return 0, fmt.Errorf("failed to fetch precentage")
}
}
type dynamicMockFetcher struct {
percentage int32
err error
}
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
}
}
func TestNewProtocolSelector(t *testing.T) {
tests := []struct {
name string
protocol string
expectedProtocol Protocol
hasFallback bool
expectedFallback Protocol
namedTunnelConfig *NamedTunnelConfig
fetchFunc PercentageFetcher
wantErr bool
}{
{
name: "classic tunnel",
protocol: "h2mux",
expectedProtocol: H2mux,
namedTunnelConfig: nil,
},
{
name: "named tunnel over h2mux",
protocol: "h2mux",
expectedProtocol: H2mux,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel over http2",
protocol: "http2",
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(0),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel http2 disabled",
protocol: "http2",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto all http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to h2mux",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(0),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to http2",
protocol: "auto",
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
},
{
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
name: "classic tunnel unknown protocol",
protocol: "unknown",
expectedProtocol: H2mux,
},
{
name: "named tunnel unknown protocol",
protocol: "unknown",
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
wantErr: true,
},
{
name: "named tunnel fetch error",
protocol: "unknown",
fetchFunc: mockFetcherWithError(),
namedTunnelConfig: testNamedTunnelConfig,
wantErr: true,
},
}
logger, _ := logger.New()
for _, test := range tests {
selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, logger)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
}
}
}
func TestAutoProtocolSelectorRefresh(t *testing.T) {
logger, _ := logger.New()
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
}
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
logger, _ := logger.New()
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
assert.Equal(t, H2mux, selector.Current())
}
func TestProtocolSelectorRefreshTTL(t *testing.T) {
logger, _ := logger.New()
fetcher := dynamicMockFetcher{percentage: 100}
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, logger)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
}

View File

@ -97,3 +97,11 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
func (b *BackoffHandler) Retries() int {
return int(b.retries)
}
func (b *BackoffHandler) ReachedMaxRetries() bool {
return b.retries == b.MaxRetries
}
func (b *BackoffHandler) resetNow() {
b.resetDeadline = time.Now()
}

View File

@ -17,10 +17,10 @@ import (
)
const (
// SRV and TXT record resolution TTL
ResolveTTL = time.Hour
// Waiting time before retrying a failed tunnel connection
tunnelRetryDuration = time.Second * 10
// SRV record resolution TTL
resolveTTL = time.Hour
// Interval between registering new tunnels
registrationInterval = time.Second
@ -43,8 +43,6 @@ type Supervisor struct {
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs *edgediscovery.Edge
lastResolve time.Time
resolverC chan resolveResult
tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{}
// nextConnectedIndex and nextConnectedSignal are used to wait for all
@ -58,10 +56,6 @@ type Supervisor struct {
useReconnectToken bool
}
type resolveResult struct {
err error
}
type tunnelError struct {
index int
addr *net.TCPAddr
@ -74,9 +68,9 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
err error
)
if len(config.EdgeAddrs) > 0 {
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
} else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
}
if err != nil {
return nil, err
@ -93,14 +87,13 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Observer,
logger: config.Logger,
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
useReconnectToken: useReconnectToken,
}, nil
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.config.Observer
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err
}
@ -117,7 +110,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer
} else {
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
s.logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
}
@ -136,7 +129,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
if tunnelError.err != nil {
logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
s.logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
@ -159,7 +152,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case <-refreshAuthBackoffTimer:
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
logger.Errorf("supervisor: Authentication failed: %s", err)
s.logger.Errorf("supervisor: Authentication failed: %s", err)
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
@ -171,27 +164,15 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
// No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod()
}
// DNS resolution returned
case result := <-s.resolverC:
s.lastResolve = time.Now()
s.resolverC = nil
if result.err == nil {
logger.Debug("supervisor: Service discovery refresh complete")
} else {
logger.Errorf("supervisor: Service discovery error: %s", result.err)
}
}
}
}
// Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.logger
s.lastResolve = time.Now()
availableAddrs := int(s.edgeIPs.AvailableAddrs())
if s.config.HAConnections > availableAddrs {
logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs
}
@ -304,7 +285,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
return nil, err
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP)
if err != nil {
return nil, err
}

View File

@ -62,13 +62,14 @@ type TunnelConfig struct {
ReportedVersion string
Retries uint
RunFromTerminal bool
TLSConfig *tls.Config
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
}
type muxerShutdownError struct{}
@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionIndex uint8,
connIndex uint8,
connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context,
haConnections.Inc()
defer haConnections.Dec()
backoff := BackoffHandler{MaxRetries: config.Retries}
protocallFallback := &protocallFallback{
BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(),
false,
}
connectedFuse := h2mux.NewBooleanFuse()
go func() {
if connectedFuse.Await() {
@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context,
}()
// Ensure the above goroutine will terminate if we return without connecting
defer connectedFuse.Fuse(false)
// Each connection to keep its own copy of protocol, because individual connections might fallback
// to another protocol when a particular metal doesn't support new protocol
for {
err, recoverable := ServeTunnel(
ctx,
credentialManager,
config,
addr, connectionIndex,
addr,
connIndex,
connectedFuse,
&backoff,
protocallFallback,
cloudflaredUUID,
reconnectCh,
protocallFallback.protocol,
)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
backoff.Backoff(ctx)
continue
}
if !recoverable {
return err
}
err = waitForBackoff(ctx, protocallFallback, config, connIndex, err)
if err != nil {
return err
}
}
}
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries
type protocallFallback struct {
BackoffHandler
protocol connection.Protocol
inFallback bool
}
func (pf *protocallFallback) reset() {
pf.resetNow()
pf.inFallback = false
}
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
pf.resetNow()
pf.protocol = fallback
pf.inFallback = true
}
// Expect err to always be non nil
func waitForBackoff(
ctx context.Context,
protobackoff *protocallFallback,
config *TunnelConfig,
connIndex uint8,
err error,
) error {
duration, ok := protobackoff.GetBackoffDuration(ctx)
if !ok {
return err
}
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err)
protobackoff.Backoff(ctx)
if protobackoff.ReachedMaxRetries() {
fallback, hasFallback := config.ProtocolSelector.Fallback()
if !hasFallback {
return err
}
// Already using fallback protocol, no point to retry
if protobackoff.protocol == fallback {
return err
}
config.Logger.Infof("Fallback to use %s", fallback)
protobackoff.fallback(fallback)
} else if !protobackoff.inFallback {
current := config.ProtocolSelector.Current()
if protobackoff.protocol != current {
protobackoff.protocol = current
config.Logger.Infof("Change protocol to %s", current)
}
}
return nil
}
func ServeTunnel(
@ -204,11 +270,12 @@ func ServeTunnel(
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionIndex uint8,
connIndex uint8,
fuse *h2mux.BooleanFuse,
backoff *BackoffHandler,
backoff *protocallFallback,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
protocol connection.Protocol,
) (err error, recoverable bool) {
// Treat panics as recoverable errors
defer func() {
@ -226,11 +293,11 @@ func ServeTunnel(
// If launch-ui flag is set, send disconnect msg
if config.TunnelEventChan != nil {
defer func() {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected}
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected}
}()
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr)
if err != nil {
return err, true
}
@ -238,11 +305,11 @@ func ServeTunnel(
fuse: fuse,
backoff: backoff,
}
if config.Protocol == connection.HTTP2 {
if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
}
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh)
}
func ServeH2mux(
@ -255,6 +322,7 @@ func ServeH2mux(
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
config.Logger.Debugf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
if err != nil {
@ -266,10 +334,10 @@ func ServeH2mux(
errGroup.Go(func() (err error) {
if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
}
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
})
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
@ -295,7 +363,7 @@ func ServeH2mux(
config.Logger.Info("Muxer shutdown")
return err, true
case *ReconnectSignal:
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay)
err.DelayBeforeReconnect()
return err, true
default:
@ -319,10 +387,8 @@ func ServeHTTP2(
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
if err != nil {
return err, false
}
config.Logger.Debugf("Connecting via http2")
server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu
type connectedFuse struct {
fuse *h2mux.BooleanFuse
backoff *BackoffHandler
backoff *protocallFallback
}
func (cf *connectedFuse) Connected() {
cf.fuse.Fuse(true)
cf.backoff.SetGracePeriod()
cf.backoff.reset()
}
func (cf *connectedFuse) IsConnected() bool {

90
origin/tunnel_test.go Normal file
View File

@ -0,0 +1,90 @@
package origin
import (
"context"
"fmt"
"testing"
"time"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
type dynamicMockFetcher struct {
percentage int32
err error
}
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
}
}
func TestWaitForBackoffFallback(t *testing.T) {
maxRetries := uint(3)
backoff := BackoffHandler{
MaxRetries: maxRetries,
BaseTime: time.Millisecond * 10,
}
ctx := context.Background()
logger, err := logger.New()
assert.NoError(t, err)
resolveTTL := time.Duration(0)
namedTunnel := &connection.NamedTunnelConfig{
Auth: pogs.TunnelAuth{
AccountTag: "test-account",
},
}
mockFetcher := dynamicMockFetcher{
percentage: 0,
}
protocolSelector, err := connection.NewProtocolSelector(connection.HTTP2.String(), namedTunnel, mockFetcher.fetch(), resolveTTL, logger)
assert.NoError(t, err)
config := &TunnelConfig{
Logger: logger,
ProtocolSelector: protocolSelector,
}
connIndex := uint8(1)
initProtocol := protocolSelector.Current()
assert.Equal(t, connection.HTTP2, initProtocol)
protocallFallback := &protocallFallback{
backoff,
initProtocol,
false,
}
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
for i := 0; i < int(maxRetries-1); i++ {
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.NoError(t, err)
assert.Equal(t, initProtocol, protocallFallback.protocol)
}
// Retry fallback protocol
for i := 0; i < int(maxRetries); i++ {
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.NoError(t, err)
fallback, ok := protocolSelector.Fallback()
assert.True(t, ok)
assert.Equal(t, fallback, protocallFallback.protocol)
}
currentGlobalProtocol := protocolSelector.Current()
assert.Equal(t, initProtocol, currentGlobalProtocol)
// No protocol to fallback, return error
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.Error(t, err)
protocallFallback.reset()
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("New error"))
assert.NoError(t, err)
assert.Equal(t, initProtocol, protocallFallback.protocol)
}