TUN-3007: Implement named tunnel connection registration and unregistration.

Removed flag for using quick reconnect, this logic is now always enabled.
This commit is contained in:
Igor Postelnik 2020-06-25 13:25:39 -05:00
parent 932e383051
commit 2a3d486126
9 changed files with 248 additions and 141 deletions

View File

@ -106,7 +106,7 @@ func ssh(c *cli.Context) error {
wsConn := carrier.NewWSConnection(logger, false) wsConn := carrier.NewWSConnection(logger, false)
if c.NArg() > 0 || c.IsSet(sshURLFlag) { if c.NArg() > 0 || c.IsSet(sshURLFlag) {
localForwarder, err := config.ValidateUrl(c) localForwarder, err := config.ValidateUrl(c, true)
if err != nil { if err != nil {
logger.Errorf("Error validating origin URL: %s", err) logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL") return errors.Wrap(err, "error validating origin URL")

View File

@ -6,11 +6,12 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/cloudflare/cloudflared/validation"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc" "gopkg.in/urfave/cli.v2/altsrc"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/cloudflare/cloudflared/validation"
) )
var ( var (
@ -176,9 +177,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
// ValidateUrl will validate url flag correctness. It can be either from --url or argument // ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context) (string, error) { func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) {
var url = c.String("url") var url = c.String("url")
if c.NArg() > 0 { if allowFromArgs && c.NArg() > 0 {
if c.IsSet("url") { if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
} }

View File

@ -359,7 +359,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
defer wg.Done() defer wg.Done()
hello.StartHelloWorldServer(logger, helloListener, shutdownC) hello.StartHelloWorldServer(logger, helloListener, shutdownC)
}() }()
c.Set("url", "https://"+helloListener.Addr().String()) forceSetFlag(c, "url", "https://"+helloListener.Addr().String())
} }
if c.IsSet(sshServerFlag) { if c.IsSet(sshServerFlag) {
@ -409,7 +409,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
close(shutdownC) close(shutdownC)
} }
}() }()
c.Set("url", "ssh://"+localServerAddress) forceSetFlag(c, "url", "ssh://"+localServerAddress)
} }
url := c.String("url") url := c.String("url")
@ -453,7 +453,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
} }
errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler) errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler)
}() }()
c.Set("url", "http://"+listener.Addr().String()) forceSetFlag(c, "url", "http://"+listener.Addr().String())
} }
transportLogger, err := createLogger(c, true) transportLogger, err := createLogger(c, true)
@ -461,7 +461,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return errors.Wrap(err, "error setting up transport logger") return errors.Wrap(err, "error setting up transport logger")
} }
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger) tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, namedTunnel)
if err != nil { if err != nil {
return err return err
} }
@ -475,12 +475,21 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel) errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
}() }()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger)
} }
// forceSetFlag attempts to set the given flag value in the closest context that has it defined
func forceSetFlag(c *cli.Context, name, value string) {
for _, ctx := range c.Lineage() {
if err := ctx.Set(name, value); err == nil {
break
}
}
}
func Before(c *cli.Context) error { func Before(c *cli.Context) error {
logger, err := createLogger(c, false) logger, err := createLogger(c, false)
if err != nil { if err != nil {
@ -969,13 +978,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"}, EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-quick-reconnects",
Usage: "Test reestablishing connections with the new 'connection digest' flow.",
Value: true,
EnvVars: []string{"TUNNEL_USE_QUICK_RECONNECTS"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "dial-edge-timeout", Name: "dial-edge-timeout",
Usage: "Maximum wait time to set up a connection with the edge", Usage: "Maximum wait time to set up a connection with the edge",

View File

@ -158,7 +158,10 @@ func prepareTunnelConfig(
version string, version string,
logger logger.Service, logger logger.Service,
transportLogger logger.Service, transportLogger logger.Service,
namedTunnel *origin.NamedTunnelConfig,
) (*origin.TunnelConfig, error) { ) (*origin.TunnelConfig, error) {
compatibilityMode := namedTunnel == nil
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
logger.Errorf("Invalid hostname: %s", err) logger.Errorf("Invalid hostname: %s", err)
@ -181,7 +184,7 @@ func prepareTunnelConfig(
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
originURL, err := config.ValidateUrl(c) originURL, err := config.ValidateUrl(c, compatibilityMode)
if err != nil { if err != nil {
logger.Errorf("Error validating origin URL: %s", err) logger.Errorf("Error validating origin URL: %s", err)
return nil, errors.Wrap(err, "Error validating origin URL") return nil, errors.Wrap(err, "Error validating origin URL")
@ -254,6 +257,19 @@ func prepareTunnelConfig(
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
if namedTunnel != nil {
clientUUID, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "can't generate clientUUID")
}
namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:],
Features: []string{origin.FeatureSerializedHeaders},
Version: version,
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
}
}
return &origin.TunnelConfig{ return &origin.TunnelConfig{
BuildInfo: buildInfo, BuildInfo: buildInfo,
ClientID: clientID, ClientID: clientID,
@ -283,9 +299,10 @@ func prepareTunnelConfig(
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
Tags: tags, Tags: tags,
TlsConfig: toEdgeTLSConfig, TlsConfig: toEdgeTLSConfig,
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"), NamedTunnel: namedTunnel,
UseReconnectToken: c.Bool("use-reconnect-token"), ReplaceExisting: c.Bool("force"),
UseQuickReconnects: c.Bool("use-quick-reconnects"), // turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"),
}, nil }, nil
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -34,7 +35,7 @@ var (
Aliases: []string{"o"}, Aliases: []string{"o"},
Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'", Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'",
} }
forceFlag = &cli.StringFlag{ forceFlag = &cli.BoolFlag{
Name: "force", Name: "force",
Aliases: []string{"f"}, Aliases: []string{"f"},
Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " + Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " +
@ -148,9 +149,12 @@ func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, e
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath) return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
} }
auth := pogs.TunnelAuth{}
err = json.Unmarshal(body, &auth) var auth pogs.TunnelAuth
return &auth, errors.Wrap(err, "couldn't parse tunnel credentials from JSON") if err = json.Unmarshal(body, &auth); err != nil {
return nil, err
}
return &auth, nil
} }
func buildListCommand() *cli.Command { func buildListCommand() *cli.Command {
@ -325,6 +329,10 @@ func runTunnel(c *cli.Context) error {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
} }
id := c.Args().First() id := c.Args().First()
tunnelID, err := uuid.Parse(id)
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New() logger, err := logger.New()
if err != nil { if err != nil {
@ -340,5 +348,5 @@ func runTunnel(c *cli.Context) error {
return err return err
} }
logger.Debugf("Read credentials for %v", credentials.AccountTag) logger.Debugf("Read credentials for %v", credentials.AccountTag)
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id}) return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
} }

View File

@ -68,8 +68,6 @@ type Supervisor struct {
connDigest map[uint8][]byte connDigest map[uint8][]byte
bufferPool *buffer.Pool bufferPool *buffer.Pool
namedTunnel *NamedTunnelConfig
} }
type resolveResult struct { type resolveResult struct {
@ -82,7 +80,7 @@ type tunnelError struct {
err error err error
} }
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
var ( var (
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
err error err error
@ -95,6 +93,7 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
@ -104,7 +103,6 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
logger: config.Logger, logger: config.Logger,
connDigest: make(map[uint8][]byte), connDigest: make(map[uint8][]byte),
bufferPool: buffer.NewPool(512 * 1024), bufferPool: buffer.NewPool(512 * 1024),
namedTunnel: namedTunnel,
}, nil }, nil
} }
@ -229,17 +227,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
addr *net.TCPAddr addr *net.TCPAddr
err error err error
) )
const thisConnID = 0 const firstConnIndex = 0
defer func() { defer func() {
s.tunnelErrors <- tunnelError{index: thisConnID, addr: addr, err: err} s.tunnelErrors <- tunnelError{index: firstConnIndex, addr: addr, err: err}
}() }()
addr, err = s.edgeIPs.GetAddr(thisConnID) addr, err = s.edgeIPs.GetAddr(firstConnIndex)
if err != nil { if err != nil {
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
for s.unusedIPs() { for s.unusedIPs() {
@ -257,12 +255,12 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
if edgeErrors >= 2 { if edgeErrors >= 2 {
addr, err = s.edgeIPs.GetDifferentAddr(thisConnID) addr, err = s.edgeIPs.GetDifferentAddr(firstConnIndex)
if err != nil { if err != nil {
return return
} }
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
} }
} }

View File

@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
func TestRefreshAuthFail(t *testing.T) { func TestRefreshAuthFail(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) logger := logger.NewOutputWriter(logger.NewMockWriteManager())
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }

View File

@ -48,7 +48,6 @@ type registerRPCName string
const ( const (
register registerRPCName = "register" register registerRPCName = "register"
reconnect registerRPCName = "reconnect" reconnect registerRPCName = "reconnect"
unknown registerRPCName = "unknown"
) )
type TunnelConfig struct { type TunnelConfig struct {
@ -80,15 +79,15 @@ type TunnelConfig struct {
RunFromTerminal bool RunFromTerminal bool
Tags []tunnelpogs.Tag Tags []tunnelpogs.Tag
TlsConfig *tls.Config TlsConfig *tls.Config
UseDeclarativeTunnel bool
WSGI bool WSGI bool
// OriginUrl may not be used if a user specifies a unix socket. // OriginUrl may not be used if a user specifies a unix socket.
OriginUrl string OriginUrl string
// feature-flag to use new edge reconnect tokens // feature-flag to use new edge reconnect tokens
UseReconnectToken bool UseReconnectToken bool
// feature-flag for using ConnectionDigest
UseQuickReconnects bool NamedTunnel *NamedTunnelConfig
ReplaceExisting bool
} }
// ReconnectTunnelCredentialManager is invoked by functions in this file to // ReconnectTunnelCredentialManager is invoked by functions in this file to
@ -103,6 +102,8 @@ type ReconnectTunnelCredentialManager interface {
type dupConnRegisterTunnelError struct{} type dupConnRegisterTunnelError struct{}
var errDuplicationConnection = &dupConnRegisterTunnelError{}
func (e dupConnRegisterTunnelError) Error() string { func (e dupConnRegisterTunnelError) Error() string {
return "already connected to this server" return "already connected to this server"
} }
@ -171,21 +172,35 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
} }
} }
func (c *TunnelConfig) SupportedFeatures() []string { func (c *TunnelConfig) ConnectionOptions(originLocalAddr string) *tunnelpogs.ConnectionOptions {
basic := []string{FeatureSerializedHeaders} // attempt to parse out origin IP, but don't fail since it's informational field
if c.UseQuickReconnects { host, _, _ := net.SplitHostPort(originLocalAddr)
basic = append(basic, FeatureQuickReconnects) originIP := net.ParseIP(host)
return &tunnelpogs.ConnectionOptions{
Client: c.NamedTunnel.Client,
OriginLocalIP: originIP,
ReplaceExisting: c.ReplaceExisting,
CompressionQuality: uint8(c.CompressionQuality),
} }
return basic }
func (c *TunnelConfig) SupportedFeatures() []string {
features := []string{FeatureSerializedHeaders}
if c.NamedTunnel == nil {
features = append(features, FeatureQuickReconnects)
}
return features
} }
type NamedTunnelConfig struct { type NamedTunnelConfig struct {
Auth pogs.TunnelAuth Auth pogs.TunnelAuth
ID string ID uuid.UUID
Client pogs.ClientInfo
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error { func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
s, err := NewSupervisor(config, cloudflaredID, namedTunnel) s, err := NewSupervisor(config, cloudflaredID)
if err != nil { if err != nil {
return err return err
} }
@ -196,7 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager ReconnectTunnelCredentialManager,
config *TunnelConfig, config *TunnelConfig,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionID uint8, connectionIndex uint8,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
@ -219,7 +234,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager, credentialManager,
config, config,
config.Logger, config.Logger,
addr, connectionID, addr, connectionIndex,
connectedFuse, connectedFuse,
&backoff, &backoff,
cloudflaredUUID, cloudflaredUUID,
@ -228,7 +243,7 @@ func ServeTunnelLoop(ctx context.Context,
) )
if recoverable { if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok { if duration, ok := backoff.GetBackoffDuration(ctx); ok {
config.Logger.Infof("Retrying in %s seconds: connectionID: %d", duration, connectionID) config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration)
backoff.Backoff(ctx) backoff.Backoff(ctx)
continue continue
} }
@ -243,7 +258,7 @@ func ServeTunnel(
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionID uint8, connectionIndex uint8,
connectedFuse *h2mux.BooleanFuse, connectedFuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *BackoffHandler,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
@ -262,22 +277,18 @@ func ServeTunnel(
} }
}() }()
connectionTag := uint8ToString(connectionID) connectionTag := uint8ToString(connectionIndex)
// additional tags to send other than hostname which is set in cloudflared main package
tags := make(map[string]string)
tags["ha"] = connectionTag
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool) handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case connection.DialError: case connection.DialError:
logger.Errorf("Unable to dial edge: %s connectionID: %d", err, connectionID) logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
case h2mux.MuxerHandshakeError: case h2mux.MuxerHandshakeError:
logger.Errorf("Handshake failed with edge server: %s connectionID: %d", err, connectionID) logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
default: default:
logger.Errorf("Tunnel creation failure: %s connectionID: %d", err, connectionID) logger.Errorf("Connection %d failed: %s", connectionIndex, err)
return err, false return err, false
} }
return err, true return err, true
@ -293,20 +304,21 @@ func ServeTunnel(
} }
}() }()
if config.NamedTunnel != nil {
return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr)
}
if config.UseReconnectToken && connectedFuse.Value() { if config.UseReconnectToken && connectedFuse.Value() {
token, tokenErr := credentialManager.ReconnectToken() token, tokenErr := credentialManager.ReconnectToken()
eventDigest, eventDigestErr := credentialManager.EventDigest() eventDigest, eventDigestErr := credentialManager.EventDigest()
// if we have both credentials, we can reconnect // if we have both credentials, we can reconnect
if tokenErr == nil && eventDigestErr == nil { if tokenErr == nil && eventDigestErr == nil {
var connDigest []byte var connDigest []byte
if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil {
// check if we can use Quick Reconnects
if config.UseQuickReconnects {
if digest, connDigestErr := credentialManager.ConnDigest(connectionID); connDigestErr == nil {
connDigest = digest connDigest = digest
} }
}
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager) return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
} }
// log errors and proceed to RegisterTunnel // log errors and proceed to RegisterTunnel
if tokenErr != nil { if tokenErr != nil {
@ -316,7 +328,7 @@ func ServeTunnel(
logger.Errorf("Couldn't get event digest: %s", eventDigestErr) logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
} }
} }
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID) return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
@ -325,12 +337,15 @@ func ServeTunnel(
select { select {
case <-serveCtx.Done(): case <-serveCtx.Done():
// UnregisterTunnel blocks until the RPC call returns // UnregisterTunnel blocks until the RPC call returns
var err error
if connectedFuse.Value() { if connectedFuse.Value() {
err = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger) if config.NamedTunnel != nil {
_ = UnregisterConnection(ctx, handler.muxer, config)
} else {
_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
}
} }
handler.muxer.Shutdown() handler.muxer.Shutdown()
return err return nil
case <-updateMetricsTickC: case <-updateMetricsTickC:
handler.UpdateMetrics(connectionTag) handler.UpdateMetrics(connectionTag)
} }
@ -361,8 +376,6 @@ func ServeTunnel(
err = errGroup.Wait() err = errGroup.Wait()
if err != nil { if err != nil {
_ = newClientRegisterTunnelError(err, config.Metrics.regFail, unknown)
switch castedErr := err.(type) { switch castedErr := err.(type) {
case dupConnRegisterTunnelError: case dupConnRegisterTunnelError:
logger.Info("Already connected to this server, selecting a different one") logger.Info("Already connected to this server, selecting a different one")
@ -382,7 +395,7 @@ func ServeTunnel(
logger.Info("Muxer shutdown") logger.Info("Muxer shutdown")
return err, true return err, true
case *ReconnectSignal: case *ReconnectSignal:
logger.Infof("Restarting due to reconnect signal in %d seconds", castedErr.Delay) logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, castedErr.Delay)
castedErr.DelayBeforeReconnect() castedErr.DelayBeforeReconnect()
return err, true return err, true
default: default:
@ -393,6 +406,74 @@ func ServeTunnel(
return nil, true return nil, true
} }
func RegisterConnection(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
connectionIndex uint8,
originLocalAddr string,
) error {
const registerConnection = "registerConnection"
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
}
defer rpc.Close()
conn, err := rpc.RegisterConnection(
ctx,
config.NamedTunnel.Auth,
config.NamedTunnel.ID,
connectionIndex,
config.ConnectionOptions(originLocalAddr),
)
if err != nil {
if err.Error() == DuplicateConnectionError {
config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
return errDuplicationConnection
}
config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
return serverRegistrationErrorFromRPC(err)
}
config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
return nil
}
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
return &serverRegisterTunnelError{
cause: retryable.Unwrap(),
permanent: false,
}
}
return &serverRegisterTunnelError{
cause: err,
permanent: true,
}
}
func UnregisterConnection(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
) error {
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
}
defer rpc.Close()
return rpc.UnregisterConnection(ctx)
}
func RegisterTunnel( func RegisterTunnel(
ctx context.Context, ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager ReconnectTunnelCredentialManager,
@ -437,7 +518,7 @@ func ReconnectTunnel(
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
connectionID uint8, connectionID uint8,
originLocalIP string, originLocalAddr string,
uuid uuid.UUID, uuid uuid.UUID,
credentialManager ReconnectTunnelCredentialManager, credentialManager ReconnectTunnelCredentialManager,
) error { ) error {
@ -459,7 +540,7 @@ func ReconnectTunnel(
eventDigest, eventDigest,
connDigest, connDigest,
config.Hostname, config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid), config.RegistrationOptions(connectionID, originLocalAddr, uuid),
) )
if registrationErr := registration.DeserializeError(); registrationErr != nil { if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure // ReconnectTunnel RPC failure
@ -508,11 +589,11 @@ func processRegistrationSuccess(
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error { func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
if err.Error() == DuplicateConnectionError { if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc() metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
return dupConnRegisterTunnelError{} return errDuplicationConnection
} }
metrics.regFail.WithLabelValues("server_error", string(name)).Inc() metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
return serverRegisterTunnelError{ return serverRegisterTunnelError{
cause: fmt.Errorf("Server error: %s", err.Error()), cause: err,
permanent: err.IsPermanent(), permanent: err.IsPermanent(),
} }
} }

View File

@ -223,7 +223,7 @@ func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth Tu
return nil, newRPCError("unknown result which %d", result.Which()) return nil, newRPCError("unknown result which %d", result.Which())
} }
func (c TunnelServer_PogsClient) Unregister(ctx context.Context) error { func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client} client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error { promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil return nil