TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries
This commit is contained in:
parent
b5cdf3b2c7
commit
a490443630
|
@ -232,16 +232,20 @@ func prepareTunnelConfig(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := determineProtocol(c, namedTunnel)
|
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logger.Infof("Using protocol %s", protocol)
|
logger.Infof("Initial protocol %s", protocolSelector.Current())
|
||||||
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName())
|
|
||||||
|
edgeTLSConfigs := make(map[connection.Protocol]*tls.Config, len(connection.ProtocolList))
|
||||||
|
for _, p := range connection.ProtocolList {
|
||||||
|
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("unable to create TLS config to connect with edge: %s", err)
|
|
||||||
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||||
}
|
}
|
||||||
|
edgeTLSConfigs[p] = edgeTLSConfig
|
||||||
|
}
|
||||||
|
|
||||||
proxyConfig := &origin.ProxyConfig{
|
proxyConfig := &origin.ProxyConfig{
|
||||||
Client: httpTransport,
|
Client: httpTransport,
|
||||||
|
@ -252,7 +256,7 @@ func prepareTunnelConfig(
|
||||||
Tags: tags,
|
Tags: tags,
|
||||||
}
|
}
|
||||||
originClient := origin.NewClient(proxyConfig, logger)
|
originClient := origin.NewClient(proxyConfig, logger)
|
||||||
transportConfig := &connection.Config{
|
connectionConfig := &connection.Config{
|
||||||
OriginClient: originClient,
|
OriginClient: originClient,
|
||||||
GracePeriod: c.Duration("grace-period"),
|
GracePeriod: c.Duration("grace-period"),
|
||||||
ReplaceExisting: c.Bool("force"),
|
ReplaceExisting: c.Bool("force"),
|
||||||
|
@ -270,7 +274,7 @@ func prepareTunnelConfig(
|
||||||
}
|
}
|
||||||
|
|
||||||
return &origin.TunnelConfig{
|
return &origin.TunnelConfig{
|
||||||
ConnectionConfig: transportConfig,
|
ConnectionConfig: connectionConfig,
|
||||||
ProxyConfig: proxyConfig,
|
ProxyConfig: proxyConfig,
|
||||||
BuildInfo: buildInfo,
|
BuildInfo: buildInfo,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
@ -281,34 +285,20 @@ func prepareTunnelConfig(
|
||||||
IsFreeTunnel: isFreeTunnel,
|
IsFreeTunnel: isFreeTunnel,
|
||||||
LBPool: c.String("lb-pool"),
|
LBPool: c.String("lb-pool"),
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol),
|
Observer: connection.NewObserver(transportLogger, tunnelEventChan),
|
||||||
ReportedVersion: version,
|
ReportedVersion: version,
|
||||||
Retries: c.Uint("retries"),
|
Retries: c.Uint("retries"),
|
||||||
RunFromTerminal: isRunningFromTerminal(),
|
RunFromTerminal: isRunningFromTerminal(),
|
||||||
TLSConfig: toEdgeTLSConfig,
|
|
||||||
NamedTunnel: namedTunnel,
|
NamedTunnel: namedTunnel,
|
||||||
ClassicTunnel: classicTunnel,
|
ClassicTunnel: classicTunnel,
|
||||||
MuxerConfig: muxerConfig,
|
MuxerConfig: muxerConfig,
|
||||||
TunnelEventChan: tunnelEventChan,
|
TunnelEventChan: tunnelEventChan,
|
||||||
IngressRules: ingressRules,
|
IngressRules: ingressRules,
|
||||||
|
ProtocolSelector: protocolSelector,
|
||||||
|
EdgeTLSConfigs: edgeTLSConfigs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRunningFromTerminal() bool {
|
func isRunningFromTerminal() bool {
|
||||||
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func determineProtocol(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) (connection.Protocol, error) {
|
|
||||||
if namedTunnel == nil {
|
|
||||||
return connection.H2mux, nil
|
|
||||||
}
|
|
||||||
http2Percentage, err := edgediscovery.HTTP2Percentage()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
protocol, ok := connection.SelectProtocol(c.String("protocol"), namedTunnel.Auth.AccountTag, http2Percentage)
|
|
||||||
if !ok {
|
|
||||||
return 0, fmt.Errorf("%s is not valid protocol. %s", c.String("protocol"), availableProtocol)
|
|
||||||
}
|
|
||||||
return protocol, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
|
@ -30,7 +31,6 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
credFileFlagAlias = "cred-file"
|
credFileFlagAlias = "cred-file"
|
||||||
availableProtocol = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -90,7 +90,7 @@ var (
|
||||||
Name: "protocol",
|
Name: "protocol",
|
||||||
Value: "h2mux",
|
Value: "h2mux",
|
||||||
Aliases: []string{"p"},
|
Aliases: []string{"p"},
|
||||||
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol),
|
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage),
|
||||||
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
|
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"hash/fnv"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -13,10 +11,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
|
||||||
edgeH2muxTLSServerName = "cftunnel.com"
|
|
||||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
|
||||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
|
||||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,57 +37,6 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
|
||||||
return c.Hostname == ""
|
return c.Hostname == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type Protocol int64
|
|
||||||
|
|
||||||
const (
|
|
||||||
H2mux Protocol = iota
|
|
||||||
HTTP2
|
|
||||||
)
|
|
||||||
|
|
||||||
func SelectProtocol(s string, accountTag string, http2Percentage uint32) (Protocol, bool) {
|
|
||||||
switch s {
|
|
||||||
case "h2mux":
|
|
||||||
return H2mux, true
|
|
||||||
case "http2":
|
|
||||||
return HTTP2, true
|
|
||||||
case "auto":
|
|
||||||
if tryHTTP2(accountTag, http2Percentage) {
|
|
||||||
return HTTP2, true
|
|
||||||
}
|
|
||||||
return H2mux, true
|
|
||||||
default:
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func tryHTTP2(accountTag string, http2Percentage uint32) bool {
|
|
||||||
h := fnv.New32a()
|
|
||||||
h.Write([]byte(accountTag))
|
|
||||||
return h.Sum32()%100 < http2Percentage
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Protocol) ServerName() string {
|
|
||||||
switch p {
|
|
||||||
case H2mux:
|
|
||||||
return edgeH2muxTLSServerName
|
|
||||||
case HTTP2:
|
|
||||||
return edgeH2TLSServerName
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Protocol) String() string {
|
|
||||||
switch p {
|
|
||||||
case H2mux:
|
|
||||||
return "h2mux"
|
|
||||||
case HTTP2:
|
|
||||||
return "http2"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("unknown protocol")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type OriginClient interface {
|
type OriginClient interface {
|
||||||
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ type HTTP2Connection struct {
|
||||||
connectedFuse ConnectedFuse
|
connectedFuse ConnectedFuse
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) {
|
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection {
|
||||||
return &HTTP2Connection{
|
return &HTTP2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
server: &http2.Server{
|
server: &http2.Server{
|
||||||
|
@ -52,7 +52,7 @@ func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, named
|
||||||
connIndex: connIndex,
|
connIndex: connIndex,
|
||||||
wg: &sync.WaitGroup{},
|
wg: &sync.WaitGroup{},
|
||||||
connectedFuse: connectedFuse,
|
connectedFuse: connectedFuse,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) Serve(ctx context.Context) {
|
func (c *HTTP2Connection) Serve(ctx context.Context) {
|
||||||
|
|
|
@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Metrics that can be collected without asking the edge
|
// Metrics that can be collected without asking the edge
|
||||||
func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
|
func newTunnelMetrics() *tunnelMetrics {
|
||||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Namespace: MetricsNamespace,
|
Namespace: MetricsNamespace,
|
||||||
|
@ -374,16 +374,12 @@ func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
|
||||||
[]string{"rpcName"},
|
[]string{"rpcName"},
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(registerSuccess)
|
prometheus.MustRegister(registerSuccess)
|
||||||
var muxerMetrics *muxerMetrics
|
|
||||||
if protocol == H2mux {
|
|
||||||
muxerMetrics = newMuxerMetrics()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tunnelMetrics{
|
return &tunnelMetrics{
|
||||||
timerRetries: timerRetries,
|
timerRetries: timerRetries,
|
||||||
serverLocations: serverLocations,
|
serverLocations: serverLocations,
|
||||||
oldServerLocations: make(map[string]string),
|
oldServerLocations: make(map[string]string),
|
||||||
muxerMetrics: muxerMetrics,
|
muxerMetrics: newMuxerMetrics(),
|
||||||
tunnelsHA: NewTunnelsForHA(),
|
tunnelsHA: NewTunnelsForHA(),
|
||||||
regSuccess: registerSuccess,
|
regSuccess: registerSuccess,
|
||||||
regFail: registerFail,
|
regFail: registerFail,
|
||||||
|
|
|
@ -16,10 +16,10 @@ type Observer struct {
|
||||||
tunnelEventChan chan<- ui.TunnelEvent
|
tunnelEventChan chan<- ui.TunnelEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer {
|
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent) *Observer {
|
||||||
return &Observer{
|
return &Observer{
|
||||||
logger,
|
logger,
|
||||||
newTunnelMetrics(protocol),
|
newTunnelMetrics(),
|
||||||
tunnelEventChan,
|
tunnelEventChan,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// can only be called once
|
// can only be called once
|
||||||
var m = newTunnelMetrics(H2mux)
|
var m = newTunnelMetrics()
|
||||||
|
|
||||||
func TestRegisterServerLocation(t *testing.T) {
|
func TestRegisterServerLocation(t *testing.T) {
|
||||||
tunnels := 20
|
tunnels := 20
|
||||||
|
|
|
@ -0,0 +1,179 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
|
||||||
|
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
||||||
|
edgeH2muxTLSServerName = "cftunnel.com"
|
||||||
|
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||||
|
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||||
|
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
|
||||||
|
explicitHTTP2FallbackThreshold = -1
|
||||||
|
autoSelectFlag = "auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ProtocolList = []Protocol{H2mux, HTTP2}
|
||||||
|
)
|
||||||
|
|
||||||
|
type Protocol int64
|
||||||
|
|
||||||
|
const (
|
||||||
|
H2mux Protocol = iota
|
||||||
|
HTTP2
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p Protocol) ServerName() string {
|
||||||
|
switch p {
|
||||||
|
case H2mux:
|
||||||
|
return edgeH2muxTLSServerName
|
||||||
|
case HTTP2:
|
||||||
|
return edgeH2TLSServerName
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback returns the fallback protocol and whether the protocol has a fallback
|
||||||
|
func (p Protocol) fallback() (Protocol, bool) {
|
||||||
|
switch p {
|
||||||
|
case H2mux:
|
||||||
|
return 0, false
|
||||||
|
case HTTP2:
|
||||||
|
return H2mux, true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Protocol) String() string {
|
||||||
|
switch p {
|
||||||
|
case H2mux:
|
||||||
|
return "h2mux"
|
||||||
|
case HTTP2:
|
||||||
|
return "http2"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown protocol")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProtocolSelector interface {
|
||||||
|
Current() Protocol
|
||||||
|
Fallback() (Protocol, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type staticProtocolSelector struct {
|
||||||
|
current Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staticProtocolSelector) Current() Protocol {
|
||||||
|
return s.current
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
type autoProtocolSelector struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
current Protocol
|
||||||
|
switchThrehold int32
|
||||||
|
fetchFunc PercentageFetcher
|
||||||
|
refreshAfter time.Time
|
||||||
|
ttl time.Duration
|
||||||
|
logger logger.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAutoProtocolSelector(
|
||||||
|
current Protocol,
|
||||||
|
switchThrehold int32,
|
||||||
|
fetchFunc PercentageFetcher,
|
||||||
|
ttl time.Duration,
|
||||||
|
logger logger.Service,
|
||||||
|
) *autoProtocolSelector {
|
||||||
|
return &autoProtocolSelector{
|
||||||
|
current: current,
|
||||||
|
switchThrehold: switchThrehold,
|
||||||
|
fetchFunc: fetchFunc,
|
||||||
|
refreshAfter: time.Now().Add(ttl),
|
||||||
|
ttl: ttl,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *autoProtocolSelector) Current() Protocol {
|
||||||
|
s.lock.Lock()
|
||||||
|
defer s.lock.Unlock()
|
||||||
|
if time.Now().Before(s.refreshAfter) {
|
||||||
|
return s.current
|
||||||
|
}
|
||||||
|
|
||||||
|
percentage, err := s.fetchFunc()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Failed to refresh protocol, err: %v", err)
|
||||||
|
return s.current
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.switchThrehold < percentage {
|
||||||
|
s.current = HTTP2
|
||||||
|
} else {
|
||||||
|
s.current = H2mux
|
||||||
|
}
|
||||||
|
s.refreshAfter = time.Now().Add(s.ttl)
|
||||||
|
return s.current
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
|
||||||
|
s.lock.RLock()
|
||||||
|
defer s.lock.RUnlock()
|
||||||
|
return s.current.fallback()
|
||||||
|
}
|
||||||
|
|
||||||
|
type PercentageFetcher func() (int32, error)
|
||||||
|
|
||||||
|
func NewProtocolSelector(protocolFlag string, namedTunnel *NamedTunnelConfig, fetchFunc PercentageFetcher, ttl time.Duration, logger logger.Service) (ProtocolSelector, error) {
|
||||||
|
if namedTunnel == nil {
|
||||||
|
return &staticProtocolSelector{
|
||||||
|
current: H2mux,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
if protocolFlag == H2mux.String() {
|
||||||
|
return &staticProtocolSelector{
|
||||||
|
current: H2mux,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
http2Percentage, err := fetchFunc()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if protocolFlag == HTTP2.String() {
|
||||||
|
if http2Percentage < 0 {
|
||||||
|
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
|
||||||
|
}
|
||||||
|
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocolFlag != autoSelectFlag {
|
||||||
|
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||||
|
}
|
||||||
|
threshold := switchThreshold(namedTunnel.Auth.AccountTag)
|
||||||
|
if threshold < http2Percentage {
|
||||||
|
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil
|
||||||
|
}
|
||||||
|
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, logger), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func switchThreshold(accountTag string) int32 {
|
||||||
|
h := fnv.New32a()
|
||||||
|
h.Write([]byte(accountTag))
|
||||||
|
return int32(h.Sum32() % 100)
|
||||||
|
}
|
|
@ -0,0 +1,220 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testNoTTL = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testNamedTunnelConfig = &NamedTunnelConfig{
|
||||||
|
Auth: pogs.TunnelAuth{
|
||||||
|
AccountTag: "testAccountTag",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockFetcher(percentage int32) PercentageFetcher {
|
||||||
|
return func() (int32, error) {
|
||||||
|
return percentage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockFetcherWithError() PercentageFetcher {
|
||||||
|
return func() (int32, error) {
|
||||||
|
return 0, fmt.Errorf("failed to fetch precentage")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dynamicMockFetcher struct {
|
||||||
|
percentage int32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
|
||||||
|
return func() (int32, error) {
|
||||||
|
if dmf.err != nil {
|
||||||
|
return 0, dmf.err
|
||||||
|
}
|
||||||
|
return dmf.percentage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProtocolSelector(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
protocol string
|
||||||
|
expectedProtocol Protocol
|
||||||
|
hasFallback bool
|
||||||
|
expectedFallback Protocol
|
||||||
|
namedTunnelConfig *NamedTunnelConfig
|
||||||
|
fetchFunc PercentageFetcher
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "classic tunnel",
|
||||||
|
protocol: "h2mux",
|
||||||
|
expectedProtocol: H2mux,
|
||||||
|
namedTunnelConfig: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel over h2mux",
|
||||||
|
protocol: "h2mux",
|
||||||
|
expectedProtocol: H2mux,
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel over http2",
|
||||||
|
protocol: "http2",
|
||||||
|
expectedProtocol: HTTP2,
|
||||||
|
hasFallback: true,
|
||||||
|
expectedFallback: H2mux,
|
||||||
|
fetchFunc: mockFetcher(0),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel http2 disabled",
|
||||||
|
protocol: "http2",
|
||||||
|
expectedProtocol: H2mux,
|
||||||
|
fetchFunc: mockFetcher(-1),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel auto all http2 disabled",
|
||||||
|
protocol: "auto",
|
||||||
|
expectedProtocol: H2mux,
|
||||||
|
fetchFunc: mockFetcher(-1),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel auto to h2mux",
|
||||||
|
protocol: "auto",
|
||||||
|
expectedProtocol: H2mux,
|
||||||
|
fetchFunc: mockFetcher(0),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel auto to http2",
|
||||||
|
protocol: "auto",
|
||||||
|
expectedProtocol: HTTP2,
|
||||||
|
hasFallback: true,
|
||||||
|
expectedFallback: H2mux,
|
||||||
|
fetchFunc: mockFetcher(100),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
|
||||||
|
name: "classic tunnel unknown protocol",
|
||||||
|
protocol: "unknown",
|
||||||
|
expectedProtocol: H2mux,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel unknown protocol",
|
||||||
|
protocol: "unknown",
|
||||||
|
fetchFunc: mockFetcher(100),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "named tunnel fetch error",
|
||||||
|
protocol: "unknown",
|
||||||
|
fetchFunc: mockFetcherWithError(),
|
||||||
|
namedTunnelConfig: testNamedTunnelConfig,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
logger, _ := logger.New()
|
||||||
|
for _, test := range tests {
|
||||||
|
selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, logger)
|
||||||
|
if test.wantErr {
|
||||||
|
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
|
||||||
|
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
|
||||||
|
fallback, ok := selector.Fallback()
|
||||||
|
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
|
||||||
|
if test.hasFallback {
|
||||||
|
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoProtocolSelectorRefresh(t *testing.T) {
|
||||||
|
logger, _ := logger.New()
|
||||||
|
fetcher := dynamicMockFetcher{}
|
||||||
|
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 100
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 0
|
||||||
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 100
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.err = fmt.Errorf("failed to fetch")
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = -1
|
||||||
|
fetcher.err = nil
|
||||||
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 0
|
||||||
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 100
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
|
||||||
|
logger, _ := logger.New()
|
||||||
|
fetcher := dynamicMockFetcher{}
|
||||||
|
selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 100
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 0
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.err = fmt.Errorf("failed to fetch")
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = -1
|
||||||
|
fetcher.err = nil
|
||||||
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 0
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 100
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = -1
|
||||||
|
assert.Equal(t, H2mux, selector.Current())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocolSelectorRefreshTTL(t *testing.T) {
|
||||||
|
logger, _ := logger.New()
|
||||||
|
fetcher := dynamicMockFetcher{percentage: 100}
|
||||||
|
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
|
||||||
|
fetcher.percentage = 0
|
||||||
|
assert.Equal(t, HTTP2, selector.Current())
|
||||||
|
}
|
|
@ -97,3 +97,11 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
|
||||||
func (b *BackoffHandler) Retries() int {
|
func (b *BackoffHandler) Retries() int {
|
||||||
return int(b.retries)
|
return int(b.retries)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BackoffHandler) ReachedMaxRetries() bool {
|
||||||
|
return b.retries == b.MaxRetries
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BackoffHandler) resetNow() {
|
||||||
|
b.resetDeadline = time.Now()
|
||||||
|
}
|
||||||
|
|
|
@ -17,10 +17,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// SRV and TXT record resolution TTL
|
||||||
|
ResolveTTL = time.Hour
|
||||||
// Waiting time before retrying a failed tunnel connection
|
// Waiting time before retrying a failed tunnel connection
|
||||||
tunnelRetryDuration = time.Second * 10
|
tunnelRetryDuration = time.Second * 10
|
||||||
// SRV record resolution TTL
|
|
||||||
resolveTTL = time.Hour
|
|
||||||
// Interval between registering new tunnels
|
// Interval between registering new tunnels
|
||||||
registrationInterval = time.Second
|
registrationInterval = time.Second
|
||||||
|
|
||||||
|
@ -43,8 +43,6 @@ type Supervisor struct {
|
||||||
cloudflaredUUID uuid.UUID
|
cloudflaredUUID uuid.UUID
|
||||||
config *TunnelConfig
|
config *TunnelConfig
|
||||||
edgeIPs *edgediscovery.Edge
|
edgeIPs *edgediscovery.Edge
|
||||||
lastResolve time.Time
|
|
||||||
resolverC chan resolveResult
|
|
||||||
tunnelErrors chan tunnelError
|
tunnelErrors chan tunnelError
|
||||||
tunnelsConnecting map[int]chan struct{}
|
tunnelsConnecting map[int]chan struct{}
|
||||||
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
||||||
|
@ -58,10 +56,6 @@ type Supervisor struct {
|
||||||
useReconnectToken bool
|
useReconnectToken bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type resolveResult struct {
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type tunnelError struct {
|
type tunnelError struct {
|
||||||
index int
|
index int
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
|
@ -74,9 +68,9 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if len(config.EdgeAddrs) > 0 {
|
if len(config.EdgeAddrs) > 0 {
|
||||||
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
|
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
|
||||||
} else {
|
} else {
|
||||||
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
|
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -93,14 +87,13 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||||
edgeIPs: edgeIPs,
|
edgeIPs: edgeIPs,
|
||||||
tunnelErrors: make(chan tunnelError),
|
tunnelErrors: make(chan tunnelError),
|
||||||
tunnelsConnecting: map[int]chan struct{}{},
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
logger: config.Observer,
|
logger: config.Logger,
|
||||||
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
||||||
useReconnectToken: useReconnectToken,
|
useReconnectToken: useReconnectToken,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||||
logger := s.config.Observer
|
|
||||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -117,7 +110,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||||
refreshAuthBackoffTimer = timer
|
refreshAuthBackoffTimer = timer
|
||||||
} else {
|
} else {
|
||||||
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
s.logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
||||||
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,7 +129,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
tunnelsActive--
|
tunnelsActive--
|
||||||
if tunnelError.err != nil {
|
if tunnelError.err != nil {
|
||||||
logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
|
s.logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
|
||||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
s.waitForNextTunnel(tunnelError.index)
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
|
|
||||||
|
@ -159,7 +152,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
case <-refreshAuthBackoffTimer:
|
case <-refreshAuthBackoffTimer:
|
||||||
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("supervisor: Authentication failed: %s", err)
|
s.logger.Errorf("supervisor: Authentication failed: %s", err)
|
||||||
// Permanent failure. Leave the `select` without setting the
|
// Permanent failure. Leave the `select` without setting the
|
||||||
// channel to be non-null, so we'll never hit this case of the `select` again.
|
// channel to be non-null, so we'll never hit this case of the `select` again.
|
||||||
continue
|
continue
|
||||||
|
@ -171,27 +164,15 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
// No more tunnels outstanding, clear backoff timer
|
// No more tunnels outstanding, clear backoff timer
|
||||||
backoff.SetGracePeriod()
|
backoff.SetGracePeriod()
|
||||||
}
|
}
|
||||||
// DNS resolution returned
|
|
||||||
case result := <-s.resolverC:
|
|
||||||
s.lastResolve = time.Now()
|
|
||||||
s.resolverC = nil
|
|
||||||
if result.err == nil {
|
|
||||||
logger.Debug("supervisor: Service discovery refresh complete")
|
|
||||||
} else {
|
|
||||||
logger.Errorf("supervisor: Service discovery error: %s", result.err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns nil if initialization succeeded, else the initialization error.
|
// Returns nil if initialization succeeded, else the initialization error.
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||||
logger := s.logger
|
|
||||||
|
|
||||||
s.lastResolve = time.Now()
|
|
||||||
availableAddrs := int(s.edgeIPs.AvailableAddrs())
|
availableAddrs := int(s.edgeIPs.AvailableAddrs())
|
||||||
if s.config.HAConnections > availableAddrs {
|
if s.config.HAConnections > availableAddrs {
|
||||||
logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
s.logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
||||||
s.config.HAConnections = availableAddrs
|
s.config.HAConnections = availableAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -304,7 +285,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
128
origin/tunnel.go
128
origin/tunnel.go
|
@ -62,13 +62,14 @@ type TunnelConfig struct {
|
||||||
ReportedVersion string
|
ReportedVersion string
|
||||||
Retries uint
|
Retries uint
|
||||||
RunFromTerminal bool
|
RunFromTerminal bool
|
||||||
TLSConfig *tls.Config
|
|
||||||
|
|
||||||
NamedTunnel *connection.NamedTunnelConfig
|
NamedTunnel *connection.NamedTunnelConfig
|
||||||
ClassicTunnel *connection.ClassicTunnelConfig
|
ClassicTunnel *connection.ClassicTunnelConfig
|
||||||
MuxerConfig *connection.MuxerConfig
|
MuxerConfig *connection.MuxerConfig
|
||||||
TunnelEventChan chan ui.TunnelEvent
|
TunnelEventChan chan ui.TunnelEvent
|
||||||
IngressRules ingress.Ingress
|
IngressRules ingress.Ingress
|
||||||
|
ProtocolSelector connection.ProtocolSelector
|
||||||
|
EdgeTLSConfigs map[connection.Protocol]*tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxerShutdownError struct{}
|
type muxerShutdownError struct{}
|
||||||
|
@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
credentialManager *reconnectCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
addr *net.TCPAddr,
|
addr *net.TCPAddr,
|
||||||
connectionIndex uint8,
|
connIndex uint8,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
|
@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
haConnections.Inc()
|
haConnections.Inc()
|
||||||
defer haConnections.Dec()
|
defer haConnections.Dec()
|
||||||
|
|
||||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
protocallFallback := &protocallFallback{
|
||||||
|
BackoffHandler{MaxRetries: config.Retries},
|
||||||
|
config.ProtocolSelector.Current(),
|
||||||
|
false,
|
||||||
|
}
|
||||||
connectedFuse := h2mux.NewBooleanFuse()
|
connectedFuse := h2mux.NewBooleanFuse()
|
||||||
go func() {
|
go func() {
|
||||||
if connectedFuse.Await() {
|
if connectedFuse.Await() {
|
||||||
|
@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
}()
|
}()
|
||||||
// Ensure the above goroutine will terminate if we return without connecting
|
// Ensure the above goroutine will terminate if we return without connecting
|
||||||
defer connectedFuse.Fuse(false)
|
defer connectedFuse.Fuse(false)
|
||||||
|
// Each connection to keep its own copy of protocol, because individual connections might fallback
|
||||||
|
// to another protocol when a particular metal doesn't support new protocol
|
||||||
for {
|
for {
|
||||||
err, recoverable := ServeTunnel(
|
err, recoverable := ServeTunnel(
|
||||||
ctx,
|
ctx,
|
||||||
credentialManager,
|
credentialManager,
|
||||||
config,
|
config,
|
||||||
addr, connectionIndex,
|
addr,
|
||||||
|
connIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
&backoff,
|
protocallFallback,
|
||||||
cloudflaredUUID,
|
cloudflaredUUID,
|
||||||
reconnectCh,
|
reconnectCh,
|
||||||
|
protocallFallback.protocol,
|
||||||
)
|
)
|
||||||
if recoverable {
|
if !recoverable {
|
||||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
|
||||||
if config.TunnelEventChan != nil {
|
|
||||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
|
|
||||||
}
|
|
||||||
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
|
|
||||||
backoff.Backoff(ctx)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = waitForBackoff(ctx, protocallFallback, config, connIndex, err)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
|
||||||
|
// max retries
|
||||||
|
type protocallFallback struct {
|
||||||
|
BackoffHandler
|
||||||
|
protocol connection.Protocol
|
||||||
|
inFallback bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *protocallFallback) reset() {
|
||||||
|
pf.resetNow()
|
||||||
|
pf.inFallback = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
|
||||||
|
pf.resetNow()
|
||||||
|
pf.protocol = fallback
|
||||||
|
pf.inFallback = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect err to always be non nil
|
||||||
|
func waitForBackoff(
|
||||||
|
ctx context.Context,
|
||||||
|
protobackoff *protocallFallback,
|
||||||
|
config *TunnelConfig,
|
||||||
|
connIndex uint8,
|
||||||
|
err error,
|
||||||
|
) error {
|
||||||
|
duration, ok := protobackoff.GetBackoffDuration(ctx)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.TunnelEventChan != nil {
|
||||||
|
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err)
|
||||||
|
protobackoff.Backoff(ctx)
|
||||||
|
|
||||||
|
if protobackoff.ReachedMaxRetries() {
|
||||||
|
fallback, hasFallback := config.ProtocolSelector.Fallback()
|
||||||
|
if !hasFallback {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Already using fallback protocol, no point to retry
|
||||||
|
if protobackoff.protocol == fallback {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
config.Logger.Infof("Fallback to use %s", fallback)
|
||||||
|
protobackoff.fallback(fallback)
|
||||||
|
} else if !protobackoff.inFallback {
|
||||||
|
current := config.ProtocolSelector.Current()
|
||||||
|
if protobackoff.protocol != current {
|
||||||
|
protobackoff.protocol = current
|
||||||
|
config.Logger.Infof("Change protocol to %s", current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnel(
|
func ServeTunnel(
|
||||||
|
@ -204,11 +270,12 @@ func ServeTunnel(
|
||||||
credentialManager *reconnectCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
addr *net.TCPAddr,
|
addr *net.TCPAddr,
|
||||||
connectionIndex uint8,
|
connIndex uint8,
|
||||||
fuse *h2mux.BooleanFuse,
|
fuse *h2mux.BooleanFuse,
|
||||||
backoff *BackoffHandler,
|
backoff *protocallFallback,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
|
protocol connection.Protocol,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -226,11 +293,11 @@ func ServeTunnel(
|
||||||
// If launch-ui flag is set, send disconnect msg
|
// If launch-ui flag is set, send disconnect msg
|
||||||
if config.TunnelEventChan != nil {
|
if config.TunnelEventChan != nil {
|
||||||
defer func() {
|
defer func() {
|
||||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected}
|
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
|
@ -238,11 +305,11 @@ func ServeTunnel(
|
||||||
fuse: fuse,
|
fuse: fuse,
|
||||||
backoff: backoff,
|
backoff: backoff,
|
||||||
}
|
}
|
||||||
if config.Protocol == connection.HTTP2 {
|
if protocol == connection.HTTP2 {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
||||||
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
|
return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
|
||||||
}
|
}
|
||||||
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
|
return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeH2mux(
|
func ServeH2mux(
|
||||||
|
@ -255,6 +322,7 @@ func ServeH2mux(
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
|
config.Logger.Debugf("Connecting via h2mux")
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
|
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -266,10 +334,10 @@ func ServeH2mux(
|
||||||
errGroup.Go(func() (err error) {
|
errGroup.Go(func() (err error) {
|
||||||
if config.NamedTunnel != nil {
|
if config.NamedTunnel != nil {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
||||||
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
||||||
}
|
}
|
||||||
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||||
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||||
|
@ -295,7 +363,7 @@ func ServeH2mux(
|
||||||
config.Logger.Info("Muxer shutdown")
|
config.Logger.Info("Muxer shutdown")
|
||||||
return err, true
|
return err, true
|
||||||
case *ReconnectSignal:
|
case *ReconnectSignal:
|
||||||
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
|
config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay)
|
||||||
err.DelayBeforeReconnect()
|
err.DelayBeforeReconnect()
|
||||||
return err, true
|
return err, true
|
||||||
default:
|
default:
|
||||||
|
@ -319,10 +387,8 @@ func ServeHTTP2(
|
||||||
connectedFuse connection.ConnectedFuse,
|
connectedFuse connection.ConnectedFuse,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
|
config.Logger.Debugf("Connecting via http2")
|
||||||
if err != nil {
|
server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
|
||||||
return err, false
|
|
||||||
}
|
|
||||||
|
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
|
@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu
|
||||||
|
|
||||||
type connectedFuse struct {
|
type connectedFuse struct {
|
||||||
fuse *h2mux.BooleanFuse
|
fuse *h2mux.BooleanFuse
|
||||||
backoff *BackoffHandler
|
backoff *protocallFallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cf *connectedFuse) Connected() {
|
func (cf *connectedFuse) Connected() {
|
||||||
cf.fuse.Fuse(true)
|
cf.fuse.Fuse(true)
|
||||||
cf.backoff.SetGracePeriod()
|
cf.backoff.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cf *connectedFuse) IsConnected() bool {
|
func (cf *connectedFuse) IsConnected() bool {
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dynamicMockFetcher struct {
|
||||||
|
percentage int32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
|
||||||
|
return func() (int32, error) {
|
||||||
|
if dmf.err != nil {
|
||||||
|
return 0, dmf.err
|
||||||
|
}
|
||||||
|
return dmf.percentage, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
|
maxRetries := uint(3)
|
||||||
|
backoff := BackoffHandler{
|
||||||
|
MaxRetries: maxRetries,
|
||||||
|
BaseTime: time.Millisecond * 10,
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
logger, err := logger.New()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
resolveTTL := time.Duration(0)
|
||||||
|
namedTunnel := &connection.NamedTunnelConfig{
|
||||||
|
Auth: pogs.TunnelAuth{
|
||||||
|
AccountTag: "test-account",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mockFetcher := dynamicMockFetcher{
|
||||||
|
percentage: 0,
|
||||||
|
}
|
||||||
|
protocolSelector, err := connection.NewProtocolSelector(connection.HTTP2.String(), namedTunnel, mockFetcher.fetch(), resolveTTL, logger)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
config := &TunnelConfig{
|
||||||
|
Logger: logger,
|
||||||
|
ProtocolSelector: protocolSelector,
|
||||||
|
}
|
||||||
|
connIndex := uint8(1)
|
||||||
|
|
||||||
|
initProtocol := protocolSelector.Current()
|
||||||
|
assert.Equal(t, connection.HTTP2, initProtocol)
|
||||||
|
|
||||||
|
protocallFallback := &protocallFallback{
|
||||||
|
backoff,
|
||||||
|
initProtocol,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
|
||||||
|
for i := 0; i < int(maxRetries-1); i++ {
|
||||||
|
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, initProtocol, protocallFallback.protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry fallback protocol
|
||||||
|
for i := 0; i < int(maxRetries); i++ {
|
||||||
|
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
fallback, ok := protocolSelector.Fallback()
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, fallback, protocallFallback.protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentGlobalProtocol := protocolSelector.Current()
|
||||||
|
assert.Equal(t, initProtocol, currentGlobalProtocol)
|
||||||
|
|
||||||
|
// No protocol to fallback, return error
|
||||||
|
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
protocallFallback.reset()
|
||||||
|
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("New error"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, initProtocol, protocallFallback.protocol)
|
||||||
|
}
|
Loading…
Reference in New Issue