TUN-3007: Implement named tunnel connection registration and unregistration.

Removed flag for using quick reconnect, this logic is now always enabled.
This commit is contained in:
Igor Postelnik 2020-06-25 13:25:39 -05:00
parent 932e383051
commit 2a3d486126
9 changed files with 248 additions and 141 deletions

View File

@ -106,7 +106,7 @@ func ssh(c *cli.Context) error {
wsConn := carrier.NewWSConnection(logger, false) wsConn := carrier.NewWSConnection(logger, false)
if c.NArg() > 0 || c.IsSet(sshURLFlag) { if c.NArg() > 0 || c.IsSet(sshURLFlag) {
localForwarder, err := config.ValidateUrl(c) localForwarder, err := config.ValidateUrl(c, true)
if err != nil { if err != nil {
logger.Errorf("Error validating origin URL: %s", err) logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL") return errors.Wrap(err, "error validating origin URL")

View File

@ -6,11 +6,12 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/cloudflare/cloudflared/validation"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc" "gopkg.in/urfave/cli.v2/altsrc"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/cloudflare/cloudflared/validation"
) )
var ( var (
@ -176,9 +177,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
// ValidateUrl will validate url flag correctness. It can be either from --url or argument // ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context) (string, error) { func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) {
var url = c.String("url") var url = c.String("url")
if c.NArg() > 0 { if allowFromArgs && c.NArg() > 0 {
if c.IsSet("url") { if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
} }

View File

@ -359,7 +359,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
defer wg.Done() defer wg.Done()
hello.StartHelloWorldServer(logger, helloListener, shutdownC) hello.StartHelloWorldServer(logger, helloListener, shutdownC)
}() }()
c.Set("url", "https://"+helloListener.Addr().String()) forceSetFlag(c, "url", "https://"+helloListener.Addr().String())
} }
if c.IsSet(sshServerFlag) { if c.IsSet(sshServerFlag) {
@ -409,7 +409,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
close(shutdownC) close(shutdownC)
} }
}() }()
c.Set("url", "ssh://"+localServerAddress) forceSetFlag(c, "url", "ssh://"+localServerAddress)
} }
url := c.String("url") url := c.String("url")
@ -453,7 +453,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
} }
errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler) errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler)
}() }()
c.Set("url", "http://"+listener.Addr().String()) forceSetFlag(c, "url", "http://"+listener.Addr().String())
} }
transportLogger, err := createLogger(c, true) transportLogger, err := createLogger(c, true)
@ -461,7 +461,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return errors.Wrap(err, "error setting up transport logger") return errors.Wrap(err, "error setting up transport logger")
} }
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger) tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, namedTunnel)
if err != nil { if err != nil {
return err return err
} }
@ -475,12 +475,21 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel) errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
}() }()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger)
} }
// forceSetFlag attempts to set the given flag value in the closest context that has it defined
func forceSetFlag(c *cli.Context, name, value string) {
for _, ctx := range c.Lineage() {
if err := ctx.Set(name, value); err == nil {
break
}
}
}
func Before(c *cli.Context) error { func Before(c *cli.Context) error {
logger, err := createLogger(c, false) logger, err := createLogger(c, false)
if err != nil { if err != nil {
@ -969,13 +978,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"}, EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-quick-reconnects",
Usage: "Test reestablishing connections with the new 'connection digest' flow.",
Value: true,
EnvVars: []string{"TUNNEL_USE_QUICK_RECONNECTS"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "dial-edge-timeout", Name: "dial-edge-timeout",
Usage: "Maximum wait time to set up a connection with the edge", Usage: "Maximum wait time to set up a connection with the edge",

View File

@ -158,7 +158,10 @@ func prepareTunnelConfig(
version string, version string,
logger logger.Service, logger logger.Service,
transportLogger logger.Service, transportLogger logger.Service,
namedTunnel *origin.NamedTunnelConfig,
) (*origin.TunnelConfig, error) { ) (*origin.TunnelConfig, error) {
compatibilityMode := namedTunnel == nil
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
logger.Errorf("Invalid hostname: %s", err) logger.Errorf("Invalid hostname: %s", err)
@ -181,7 +184,7 @@ func prepareTunnelConfig(
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
originURL, err := config.ValidateUrl(c) originURL, err := config.ValidateUrl(c, compatibilityMode)
if err != nil { if err != nil {
logger.Errorf("Error validating origin URL: %s", err) logger.Errorf("Error validating origin URL: %s", err)
return nil, errors.Wrap(err, "Error validating origin URL") return nil, errors.Wrap(err, "Error validating origin URL")
@ -254,38 +257,52 @@ func prepareTunnelConfig(
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
if namedTunnel != nil {
clientUUID, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "can't generate clientUUID")
}
namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:],
Features: []string{origin.FeatureSerializedHeaders},
Version: version,
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
}
}
return &origin.TunnelConfig{ return &origin.TunnelConfig{
BuildInfo: buildInfo, BuildInfo: buildInfo,
ClientID: clientID, ClientID: clientID,
ClientTlsConfig: httpTransport.TLSClientConfig, ClientTlsConfig: httpTransport.TLSClientConfig,
CompressionQuality: c.Uint64("compression-quality"), CompressionQuality: c.Uint64("compression-quality"),
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
GracePeriod: c.Duration("grace-period"), GracePeriod: c.Duration("grace-period"),
HAConnections: c.Int("ha-connections"), HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport, HTTPTransport: httpTransport,
HeartbeatInterval: c.Duration("heartbeat-interval"), HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname, Hostname: hostname,
HTTPHostHeader: c.String("http-host-header"), HTTPHostHeader: c.String("http-host-header"),
IncidentLookup: origin.NewIncidentLookup(), IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel, IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"), LBPool: c.String("lb-pool"),
Logger: logger, Logger: logger,
TransportLogger: transportLogger, TransportLogger: transportLogger,
MaxHeartbeats: c.Uint64("heartbeat-count"), MaxHeartbeats: c.Uint64("heartbeat-count"),
Metrics: tunnelMetrics, Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"), MetricsUpdateFreq: c.Duration("metrics-update-freq"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"), NoChunkedEncoding: c.Bool("no-chunked-encoding"),
OriginCert: originCert, OriginCert: originCert,
OriginUrl: originURL, OriginUrl: originURL,
ReportedVersion: version, ReportedVersion: version,
Retries: c.Uint("retries"), Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
Tags: tags, Tags: tags,
TlsConfig: toEdgeTLSConfig, TlsConfig: toEdgeTLSConfig,
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"), NamedTunnel: namedTunnel,
UseReconnectToken: c.Bool("use-reconnect-token"), ReplaceExisting: c.Bool("force"),
UseQuickReconnects: c.Bool("use-quick-reconnects"), // turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"),
}, nil }, nil
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -34,7 +35,7 @@ var (
Aliases: []string{"o"}, Aliases: []string{"o"},
Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'", Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'",
} }
forceFlag = &cli.StringFlag{ forceFlag = &cli.BoolFlag{
Name: "force", Name: "force",
Aliases: []string{"f"}, Aliases: []string{"f"},
Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " + Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " +
@ -148,9 +149,12 @@ func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, e
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath) return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
} }
auth := pogs.TunnelAuth{}
err = json.Unmarshal(body, &auth) var auth pogs.TunnelAuth
return &auth, errors.Wrap(err, "couldn't parse tunnel credentials from JSON") if err = json.Unmarshal(body, &auth); err != nil {
return nil, err
}
return &auth, nil
} }
func buildListCommand() *cli.Command { func buildListCommand() *cli.Command {
@ -325,6 +329,10 @@ func runTunnel(c *cli.Context) error {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
} }
id := c.Args().First() id := c.Args().First()
tunnelID, err := uuid.Parse(id)
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New() logger, err := logger.New()
if err != nil { if err != nil {
@ -340,5 +348,5 @@ func runTunnel(c *cli.Context) error {
return err return err
} }
logger.Debugf("Read credentials for %v", credentials.AccountTag) logger.Debugf("Read credentials for %v", credentials.AccountTag)
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id}) return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
} }

View File

@ -68,8 +68,6 @@ type Supervisor struct {
connDigest map[uint8][]byte connDigest map[uint8][]byte
bufferPool *buffer.Pool bufferPool *buffer.Pool
namedTunnel *NamedTunnelConfig
} }
type resolveResult struct { type resolveResult struct {
@ -82,7 +80,7 @@ type tunnelError struct {
err error err error
} }
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
var ( var (
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
err error err error
@ -95,6 +93,7 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
@ -104,7 +103,6 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
logger: config.Logger, logger: config.Logger,
connDigest: make(map[uint8][]byte), connDigest: make(map[uint8][]byte),
bufferPool: buffer.NewPool(512 * 1024), bufferPool: buffer.NewPool(512 * 1024),
namedTunnel: namedTunnel,
}, nil }, nil
} }
@ -229,17 +227,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
addr *net.TCPAddr addr *net.TCPAddr
err error err error
) )
const thisConnID = 0 const firstConnIndex = 0
defer func() { defer func() {
s.tunnelErrors <- tunnelError{index: thisConnID, addr: addr, err: err} s.tunnelErrors <- tunnelError{index: firstConnIndex, addr: addr, err: err}
}() }()
addr, err = s.edgeIPs.GetAddr(thisConnID) addr, err = s.edgeIPs.GetAddr(firstConnIndex)
if err != nil { if err != nil {
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
for s.unusedIPs() { for s.unusedIPs() {
@ -257,12 +255,12 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
if edgeErrors >= 2 { if edgeErrors >= 2 {
addr, err = s.edgeIPs.GetDifferentAddr(thisConnID) addr, err = s.edgeIPs.GetDifferentAddr(firstConnIndex)
if err != nil { if err != nil {
return return
} }
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
} }
} }

View File

@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
func TestRefreshAuthFail(t *testing.T) { func TestRefreshAuthFail(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) logger := logger.NewOutputWriter(logger.NewMockWriteManager())
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }

View File

@ -48,47 +48,46 @@ type registerRPCName string
const ( const (
register registerRPCName = "register" register registerRPCName = "register"
reconnect registerRPCName = "reconnect" reconnect registerRPCName = "reconnect"
unknown registerRPCName = "unknown"
) )
type TunnelConfig struct { type TunnelConfig struct {
BuildInfo *buildinfo.BuildInfo BuildInfo *buildinfo.BuildInfo
ClientID string ClientID string
ClientTlsConfig *tls.Config ClientTlsConfig *tls.Config
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
CompressionQuality uint64 CompressionQuality uint64
EdgeAddrs []string EdgeAddrs []string
GracePeriod time.Duration GracePeriod time.Duration
HAConnections int HAConnections int
HTTPTransport http.RoundTripper HTTPTransport http.RoundTripper
HeartbeatInterval time.Duration HeartbeatInterval time.Duration
Hostname string Hostname string
HTTPHostHeader string HTTPHostHeader string
IncidentLookup IncidentLookup IncidentLookup IncidentLookup
IsAutoupdated bool IsAutoupdated bool
IsFreeTunnel bool IsFreeTunnel bool
LBPool string LBPool string
Logger logger.Service Logger logger.Service
TransportLogger logger.Service TransportLogger logger.Service
MaxHeartbeats uint64 MaxHeartbeats uint64
Metrics *TunnelMetrics Metrics *TunnelMetrics
MetricsUpdateFreq time.Duration MetricsUpdateFreq time.Duration
NoChunkedEncoding bool NoChunkedEncoding bool
OriginCert []byte OriginCert []byte
ReportedVersion string ReportedVersion string
Retries uint Retries uint
RunFromTerminal bool RunFromTerminal bool
Tags []tunnelpogs.Tag Tags []tunnelpogs.Tag
TlsConfig *tls.Config TlsConfig *tls.Config
UseDeclarativeTunnel bool WSGI bool
WSGI bool
// OriginUrl may not be used if a user specifies a unix socket. // OriginUrl may not be used if a user specifies a unix socket.
OriginUrl string OriginUrl string
// feature-flag to use new edge reconnect tokens // feature-flag to use new edge reconnect tokens
UseReconnectToken bool UseReconnectToken bool
// feature-flag for using ConnectionDigest
UseQuickReconnects bool NamedTunnel *NamedTunnelConfig
ReplaceExisting bool
} }
// ReconnectTunnelCredentialManager is invoked by functions in this file to // ReconnectTunnelCredentialManager is invoked by functions in this file to
@ -103,6 +102,8 @@ type ReconnectTunnelCredentialManager interface {
type dupConnRegisterTunnelError struct{} type dupConnRegisterTunnelError struct{}
var errDuplicationConnection = &dupConnRegisterTunnelError{}
func (e dupConnRegisterTunnelError) Error() string { func (e dupConnRegisterTunnelError) Error() string {
return "already connected to this server" return "already connected to this server"
} }
@ -171,21 +172,35 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
} }
} }
func (c *TunnelConfig) SupportedFeatures() []string { func (c *TunnelConfig) ConnectionOptions(originLocalAddr string) *tunnelpogs.ConnectionOptions {
basic := []string{FeatureSerializedHeaders} // attempt to parse out origin IP, but don't fail since it's informational field
if c.UseQuickReconnects { host, _, _ := net.SplitHostPort(originLocalAddr)
basic = append(basic, FeatureQuickReconnects) originIP := net.ParseIP(host)
return &tunnelpogs.ConnectionOptions{
Client: c.NamedTunnel.Client,
OriginLocalIP: originIP,
ReplaceExisting: c.ReplaceExisting,
CompressionQuality: uint8(c.CompressionQuality),
} }
return basic }
func (c *TunnelConfig) SupportedFeatures() []string {
features := []string{FeatureSerializedHeaders}
if c.NamedTunnel == nil {
features = append(features, FeatureQuickReconnects)
}
return features
} }
type NamedTunnelConfig struct { type NamedTunnelConfig struct {
Auth pogs.TunnelAuth Auth pogs.TunnelAuth
ID string ID uuid.UUID
Client pogs.ClientInfo
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error { func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
s, err := NewSupervisor(config, cloudflaredID, namedTunnel) s, err := NewSupervisor(config, cloudflaredID)
if err != nil { if err != nil {
return err return err
} }
@ -196,7 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager ReconnectTunnelCredentialManager,
config *TunnelConfig, config *TunnelConfig,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionID uint8, connectionIndex uint8,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
@ -219,7 +234,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager, credentialManager,
config, config,
config.Logger, config.Logger,
addr, connectionID, addr, connectionIndex,
connectedFuse, connectedFuse,
&backoff, &backoff,
cloudflaredUUID, cloudflaredUUID,
@ -228,7 +243,7 @@ func ServeTunnelLoop(ctx context.Context,
) )
if recoverable { if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok { if duration, ok := backoff.GetBackoffDuration(ctx); ok {
config.Logger.Infof("Retrying in %s seconds: connectionID: %d", duration, connectionID) config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration)
backoff.Backoff(ctx) backoff.Backoff(ctx)
continue continue
} }
@ -243,7 +258,7 @@ func ServeTunnel(
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionID uint8, connectionIndex uint8,
connectedFuse *h2mux.BooleanFuse, connectedFuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *BackoffHandler,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
@ -262,22 +277,18 @@ func ServeTunnel(
} }
}() }()
connectionTag := uint8ToString(connectionID) connectionTag := uint8ToString(connectionIndex)
// additional tags to send other than hostname which is set in cloudflared main package
tags := make(map[string]string)
tags["ha"] = connectionTag
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool) handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case connection.DialError: case connection.DialError:
logger.Errorf("Unable to dial edge: %s connectionID: %d", err, connectionID) logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
case h2mux.MuxerHandshakeError: case h2mux.MuxerHandshakeError:
logger.Errorf("Handshake failed with edge server: %s connectionID: %d", err, connectionID) logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
default: default:
logger.Errorf("Tunnel creation failure: %s connectionID: %d", err, connectionID) logger.Errorf("Connection %d failed: %s", connectionIndex, err)
return err, false return err, false
} }
return err, true return err, true
@ -293,20 +304,21 @@ func ServeTunnel(
} }
}() }()
if config.NamedTunnel != nil {
return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr)
}
if config.UseReconnectToken && connectedFuse.Value() { if config.UseReconnectToken && connectedFuse.Value() {
token, tokenErr := credentialManager.ReconnectToken() token, tokenErr := credentialManager.ReconnectToken()
eventDigest, eventDigestErr := credentialManager.EventDigest() eventDigest, eventDigestErr := credentialManager.EventDigest()
// if we have both credentials, we can reconnect // if we have both credentials, we can reconnect
if tokenErr == nil && eventDigestErr == nil { if tokenErr == nil && eventDigestErr == nil {
var connDigest []byte var connDigest []byte
if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil {
// check if we can use Quick Reconnects connDigest = digest
if config.UseQuickReconnects {
if digest, connDigestErr := credentialManager.ConnDigest(connectionID); connDigestErr == nil {
connDigest = digest
}
} }
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager)
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
} }
// log errors and proceed to RegisterTunnel // log errors and proceed to RegisterTunnel
if tokenErr != nil { if tokenErr != nil {
@ -316,7 +328,7 @@ func ServeTunnel(
logger.Errorf("Couldn't get event digest: %s", eventDigestErr) logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
} }
} }
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID) return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
@ -325,12 +337,15 @@ func ServeTunnel(
select { select {
case <-serveCtx.Done(): case <-serveCtx.Done():
// UnregisterTunnel blocks until the RPC call returns // UnregisterTunnel blocks until the RPC call returns
var err error
if connectedFuse.Value() { if connectedFuse.Value() {
err = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger) if config.NamedTunnel != nil {
_ = UnregisterConnection(ctx, handler.muxer, config)
} else {
_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
}
} }
handler.muxer.Shutdown() handler.muxer.Shutdown()
return err return nil
case <-updateMetricsTickC: case <-updateMetricsTickC:
handler.UpdateMetrics(connectionTag) handler.UpdateMetrics(connectionTag)
} }
@ -361,8 +376,6 @@ func ServeTunnel(
err = errGroup.Wait() err = errGroup.Wait()
if err != nil { if err != nil {
_ = newClientRegisterTunnelError(err, config.Metrics.regFail, unknown)
switch castedErr := err.(type) { switch castedErr := err.(type) {
case dupConnRegisterTunnelError: case dupConnRegisterTunnelError:
logger.Info("Already connected to this server, selecting a different one") logger.Info("Already connected to this server, selecting a different one")
@ -382,7 +395,7 @@ func ServeTunnel(
logger.Info("Muxer shutdown") logger.Info("Muxer shutdown")
return err, true return err, true
case *ReconnectSignal: case *ReconnectSignal:
logger.Infof("Restarting due to reconnect signal in %d seconds", castedErr.Delay) logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, castedErr.Delay)
castedErr.DelayBeforeReconnect() castedErr.DelayBeforeReconnect()
return err, true return err, true
default: default:
@ -393,6 +406,74 @@ func ServeTunnel(
return nil, true return nil, true
} }
func RegisterConnection(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
connectionIndex uint8,
originLocalAddr string,
) error {
const registerConnection = "registerConnection"
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
}
defer rpc.Close()
conn, err := rpc.RegisterConnection(
ctx,
config.NamedTunnel.Auth,
config.NamedTunnel.ID,
connectionIndex,
config.ConnectionOptions(originLocalAddr),
)
if err != nil {
if err.Error() == DuplicateConnectionError {
config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
return errDuplicationConnection
}
config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
return serverRegistrationErrorFromRPC(err)
}
config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
return nil
}
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
return &serverRegisterTunnelError{
cause: retryable.Unwrap(),
permanent: false,
}
}
return &serverRegisterTunnelError{
cause: err,
permanent: true,
}
}
func UnregisterConnection(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
) error {
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
}
defer rpc.Close()
return rpc.UnregisterConnection(ctx)
}
func RegisterTunnel( func RegisterTunnel(
ctx context.Context, ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager ReconnectTunnelCredentialManager,
@ -437,7 +518,7 @@ func ReconnectTunnel(
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
connectionID uint8, connectionID uint8,
originLocalIP string, originLocalAddr string,
uuid uuid.UUID, uuid uuid.UUID,
credentialManager ReconnectTunnelCredentialManager, credentialManager ReconnectTunnelCredentialManager,
) error { ) error {
@ -459,7 +540,7 @@ func ReconnectTunnel(
eventDigest, eventDigest,
connDigest, connDigest,
config.Hostname, config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid), config.RegistrationOptions(connectionID, originLocalAddr, uuid),
) )
if registrationErr := registration.DeserializeError(); registrationErr != nil { if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure // ReconnectTunnel RPC failure
@ -508,11 +589,11 @@ func processRegistrationSuccess(
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error { func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
if err.Error() == DuplicateConnectionError { if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc() metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
return dupConnRegisterTunnelError{} return errDuplicationConnection
} }
metrics.regFail.WithLabelValues("server_error", string(name)).Inc() metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
return serverRegisterTunnelError{ return serverRegisterTunnelError{
cause: fmt.Errorf("Server error: %s", err.Error()), cause: err,
permanent: err.IsPermanent(), permanent: err.IsPermanent(),
} }
} }

View File

@ -223,7 +223,7 @@ func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth Tu
return nil, newRPCError("unknown result which %d", result.Which()) return nil, newRPCError("unknown result which %d", result.Which())
} }
func (c TunnelServer_PogsClient) Unregister(ctx context.Context) error { func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client} client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error { promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil return nil