From 2a3d486126449a7b9228c5a93ac1e0691bb0e00f Mon Sep 17 00:00:00 2001 From: Igor Postelnik Date: Thu, 25 Jun 2020 13:25:39 -0500 Subject: [PATCH] TUN-3007: Implement named tunnel connection registration and unregistration. Removed flag for using quick reconnect, this logic is now always enabled. --- cmd/cloudflared/access/carrier.go | 2 +- cmd/cloudflared/config/configuration.go | 7 +- cmd/cloudflared/tunnel/cmd.go | 26 +-- cmd/cloudflared/tunnel/configuration.go | 81 +++++---- cmd/cloudflared/tunnel/subcommands.go | 18 +- origin/supervisor.go | 18 +- origin/supervisor_test.go | 8 +- origin/tunnel.go | 227 ++++++++++++++++-------- tunnelrpc/pogs/connectionrpc.go | 2 +- 9 files changed, 248 insertions(+), 141 deletions(-) diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index 29e9b385..d503b064 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -106,7 +106,7 @@ func ssh(c *cli.Context) error { wsConn := carrier.NewWSConnection(logger, false) if c.NArg() > 0 || c.IsSet(sshURLFlag) { - localForwarder, err := config.ValidateUrl(c) + localForwarder, err := config.ValidateUrl(c, true) if err != nil { logger.Errorf("Error validating origin URL: %s", err) return errors.Wrap(err, "error validating origin URL") diff --git a/cmd/cloudflared/config/configuration.go b/cmd/cloudflared/config/configuration.go index ed72f3c0..76c39fbf 100644 --- a/cmd/cloudflared/config/configuration.go +++ b/cmd/cloudflared/config/configuration.go @@ -6,11 +6,12 @@ import ( "path/filepath" "runtime" - "github.com/cloudflare/cloudflared/validation" homedir "github.com/mitchellh/go-homedir" "gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2/altsrc" "gopkg.in/yaml.v2" + + "github.com/cloudflare/cloudflared/validation" ) var ( @@ -176,9 +177,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) { // ValidateUrl will validate url flag correctness. It can be either from --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument -func ValidateUrl(c *cli.Context) (string, error) { +func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) { var url = c.String("url") - if c.NArg() > 0 { + if allowFromArgs && c.NArg() > 0 { if c.IsSet("url") { return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") } diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index f3b8476d..347e243d 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -359,7 +359,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan defer wg.Done() hello.StartHelloWorldServer(logger, helloListener, shutdownC) }() - c.Set("url", "https://"+helloListener.Addr().String()) + forceSetFlag(c, "url", "https://"+helloListener.Addr().String()) } if c.IsSet(sshServerFlag) { @@ -409,7 +409,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan close(shutdownC) } }() - c.Set("url", "ssh://"+localServerAddress) + forceSetFlag(c, "url", "ssh://"+localServerAddress) } url := c.String("url") @@ -453,7 +453,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan } errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler) }() - c.Set("url", "http://"+listener.Addr().String()) + forceSetFlag(c, "url", "http://"+listener.Addr().String()) } transportLogger, err := createLogger(c, true) @@ -461,7 +461,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan return errors.Wrap(err, "error setting up transport logger") } - tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger) + tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, namedTunnel) if err != nil { return err } @@ -475,12 +475,21 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan wg.Add(1) go func() { defer wg.Done() - errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel) + errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh) }() return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger) } +// forceSetFlag attempts to set the given flag value in the closest context that has it defined +func forceSetFlag(c *cli.Context, name, value string) { + for _, ctx := range c.Lineage() { + if err := ctx.Set(name, value); err == nil { + break + } + } +} + func Before(c *cli.Context) error { logger, err := createLogger(c, false) if err != nil { @@ -969,13 +978,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"}, Hidden: true, }), - altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: "use-quick-reconnects", - Usage: "Test reestablishing connections with the new 'connection digest' flow.", - Value: true, - EnvVars: []string{"TUNNEL_USE_QUICK_RECONNECTS"}, - Hidden: true, - }), altsrc.NewDurationFlag(&cli.DurationFlag{ Name: "dial-edge-timeout", Usage: "Maximum wait time to set up a connection with the edge", diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 274eafa0..86d06600 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -158,7 +158,10 @@ func prepareTunnelConfig( version string, logger logger.Service, transportLogger logger.Service, + namedTunnel *origin.NamedTunnelConfig, ) (*origin.TunnelConfig, error) { + compatibilityMode := namedTunnel == nil + hostname, err := validation.ValidateHostname(c.String("hostname")) if err != nil { logger.Errorf("Invalid hostname: %s", err) @@ -181,7 +184,7 @@ func prepareTunnelConfig( tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) - originURL, err := config.ValidateUrl(c) + originURL, err := config.ValidateUrl(c, compatibilityMode) if err != nil { logger.Errorf("Error validating origin URL: %s", err) return nil, errors.Wrap(err, "Error validating origin URL") @@ -254,38 +257,52 @@ func prepareTunnelConfig( return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") } + if namedTunnel != nil { + clientUUID, err := uuid.NewRandom() + if err != nil { + return nil, errors.Wrap(err, "can't generate clientUUID") + } + namedTunnel.Client = tunnelpogs.ClientInfo{ + ClientID: clientUUID[:], + Features: []string{origin.FeatureSerializedHeaders}, + Version: version, + Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch), + } + } + return &origin.TunnelConfig{ - BuildInfo: buildInfo, - ClientID: clientID, - ClientTlsConfig: httpTransport.TLSClientConfig, - CompressionQuality: c.Uint64("compression-quality"), - EdgeAddrs: c.StringSlice("edge"), - GracePeriod: c.Duration("grace-period"), - HAConnections: c.Int("ha-connections"), - HTTPTransport: httpTransport, - HeartbeatInterval: c.Duration("heartbeat-interval"), - Hostname: hostname, - HTTPHostHeader: c.String("http-host-header"), - IncidentLookup: origin.NewIncidentLookup(), - IsAutoupdated: c.Bool("is-autoupdated"), - IsFreeTunnel: isFreeTunnel, - LBPool: c.String("lb-pool"), - Logger: logger, - TransportLogger: transportLogger, - MaxHeartbeats: c.Uint64("heartbeat-count"), - Metrics: tunnelMetrics, - MetricsUpdateFreq: c.Duration("metrics-update-freq"), - NoChunkedEncoding: c.Bool("no-chunked-encoding"), - OriginCert: originCert, - OriginUrl: originURL, - ReportedVersion: version, - Retries: c.Uint("retries"), - RunFromTerminal: isRunningFromTerminal(), - Tags: tags, - TlsConfig: toEdgeTLSConfig, - UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"), - UseReconnectToken: c.Bool("use-reconnect-token"), - UseQuickReconnects: c.Bool("use-quick-reconnects"), + BuildInfo: buildInfo, + ClientID: clientID, + ClientTlsConfig: httpTransport.TLSClientConfig, + CompressionQuality: c.Uint64("compression-quality"), + EdgeAddrs: c.StringSlice("edge"), + GracePeriod: c.Duration("grace-period"), + HAConnections: c.Int("ha-connections"), + HTTPTransport: httpTransport, + HeartbeatInterval: c.Duration("heartbeat-interval"), + Hostname: hostname, + HTTPHostHeader: c.String("http-host-header"), + IncidentLookup: origin.NewIncidentLookup(), + IsAutoupdated: c.Bool("is-autoupdated"), + IsFreeTunnel: isFreeTunnel, + LBPool: c.String("lb-pool"), + Logger: logger, + TransportLogger: transportLogger, + MaxHeartbeats: c.Uint64("heartbeat-count"), + Metrics: tunnelMetrics, + MetricsUpdateFreq: c.Duration("metrics-update-freq"), + NoChunkedEncoding: c.Bool("no-chunked-encoding"), + OriginCert: originCert, + OriginUrl: originURL, + ReportedVersion: version, + Retries: c.Uint("retries"), + RunFromTerminal: isRunningFromTerminal(), + Tags: tags, + TlsConfig: toEdgeTLSConfig, + NamedTunnel: namedTunnel, + ReplaceExisting: c.Bool("force"), + // turn off use of reconnect token and auth refresh when using named tunnels + UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"), }, nil } diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index b49882bd..837f5039 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/pkg/errors" "gopkg.in/urfave/cli.v2" "gopkg.in/yaml.v2" @@ -34,7 +35,7 @@ var ( Aliases: []string{"o"}, Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'", } - forceFlag = &cli.StringFlag{ + forceFlag = &cli.BoolFlag{ Name: "force", Aliases: []string{"f"}, Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " + @@ -148,9 +149,12 @@ func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, e if err != nil { return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath) } - auth := pogs.TunnelAuth{} - err = json.Unmarshal(body, &auth) - return &auth, errors.Wrap(err, "couldn't parse tunnel credentials from JSON") + + var auth pogs.TunnelAuth + if err = json.Unmarshal(body, &auth); err != nil { + return nil, err + } + return &auth, nil } func buildListCommand() *cli.Command { @@ -325,6 +329,10 @@ func runTunnel(c *cli.Context) error { return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) } id := c.Args().First() + tunnelID, err := uuid.Parse(id) + if err != nil { + return errors.Wrap(err, "error parsing tunnel ID") + } logger, err := logger.New() if err != nil { @@ -340,5 +348,5 @@ func runTunnel(c *cli.Context) error { return err } logger.Debugf("Read credentials for %v", credentials.AccountTag) - return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id}) + return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID}) } diff --git a/origin/supervisor.go b/origin/supervisor.go index 027d3173..de980681 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -68,8 +68,6 @@ type Supervisor struct { connDigest map[uint8][]byte bufferPool *buffer.Pool - - namedTunnel *NamedTunnelConfig } type resolveResult struct { @@ -82,7 +80,7 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) { +func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) { var ( edgeIPs *edgediscovery.Edge err error @@ -95,6 +93,7 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel if err != nil { return nil, err } + return &Supervisor{ cloudflaredUUID: cloudflaredUUID, config: config, @@ -104,7 +103,6 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel logger: config.Logger, connDigest: make(map[uint8][]byte), bufferPool: buffer.NewPool(512 * 1024), - namedTunnel: namedTunnel, }, nil } @@ -229,17 +227,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign addr *net.TCPAddr err error ) - const thisConnID = 0 + const firstConnIndex = 0 defer func() { - s.tunnelErrors <- tunnelError{index: thisConnID, addr: addr, err: err} + s.tunnelErrors <- tunnelError{index: firstConnIndex, addr: addr, err: err} }() - addr, err = s.edgeIPs.GetAddr(thisConnID) + addr, err = s.edgeIPs.GetAddr(firstConnIndex) if err != nil { return } - err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) + err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) // If the first tunnel disconnects, keep restarting it. edgeErrors := 0 for s.unusedIPs() { @@ -257,12 +255,12 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } if edgeErrors >= 2 { - addr, err = s.edgeIPs.GetDifferentAddr(thisConnID) + addr, err = s.edgeIPs.GetDifferentAddr(firstConnIndex) if err != nil { return } } - err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) + err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) } } diff --git a/origin/supervisor_test.go b/origin/supervisor_test.go index 7b1ff701..21eeec60 100644 --- a/origin/supervisor_test.go +++ b/origin/supervisor_test.go @@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) { return time.After(d) } - s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) + s, err := NewSupervisor(testConfig(logger), uuid.New()) if !assert.NoError(t, err) { t.FailNow() } @@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) { return time.After(d) } - s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) + s, err := NewSupervisor(testConfig(logger), uuid.New()) if !assert.NoError(t, err) { t.FailNow() } @@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) { return time.After(d) } - s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) + s, err := NewSupervisor(testConfig(logger), uuid.New()) if !assert.NoError(t, err) { t.FailNow() } @@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) { func TestRefreshAuthFail(t *testing.T) { logger := logger.NewOutputWriter(logger.NewMockWriteManager()) - s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) + s, err := NewSupervisor(testConfig(logger), uuid.New()) if !assert.NoError(t, err) { t.FailNow() } diff --git a/origin/tunnel.go b/origin/tunnel.go index 20361e5d..ce463a4f 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -48,47 +48,46 @@ type registerRPCName string const ( register registerRPCName = "register" reconnect registerRPCName = "reconnect" - unknown registerRPCName = "unknown" ) type TunnelConfig struct { - BuildInfo *buildinfo.BuildInfo - ClientID string - ClientTlsConfig *tls.Config - CloseConnOnce *sync.Once // Used to close connectedSignal no more than once - CompressionQuality uint64 - EdgeAddrs []string - GracePeriod time.Duration - HAConnections int - HTTPTransport http.RoundTripper - HeartbeatInterval time.Duration - Hostname string - HTTPHostHeader string - IncidentLookup IncidentLookup - IsAutoupdated bool - IsFreeTunnel bool - LBPool string - Logger logger.Service - TransportLogger logger.Service - MaxHeartbeats uint64 - Metrics *TunnelMetrics - MetricsUpdateFreq time.Duration - NoChunkedEncoding bool - OriginCert []byte - ReportedVersion string - Retries uint - RunFromTerminal bool - Tags []tunnelpogs.Tag - TlsConfig *tls.Config - UseDeclarativeTunnel bool - WSGI bool + BuildInfo *buildinfo.BuildInfo + ClientID string + ClientTlsConfig *tls.Config + CloseConnOnce *sync.Once // Used to close connectedSignal no more than once + CompressionQuality uint64 + EdgeAddrs []string + GracePeriod time.Duration + HAConnections int + HTTPTransport http.RoundTripper + HeartbeatInterval time.Duration + Hostname string + HTTPHostHeader string + IncidentLookup IncidentLookup + IsAutoupdated bool + IsFreeTunnel bool + LBPool string + Logger logger.Service + TransportLogger logger.Service + MaxHeartbeats uint64 + Metrics *TunnelMetrics + MetricsUpdateFreq time.Duration + NoChunkedEncoding bool + OriginCert []byte + ReportedVersion string + Retries uint + RunFromTerminal bool + Tags []tunnelpogs.Tag + TlsConfig *tls.Config + WSGI bool // OriginUrl may not be used if a user specifies a unix socket. OriginUrl string // feature-flag to use new edge reconnect tokens UseReconnectToken bool - // feature-flag for using ConnectionDigest - UseQuickReconnects bool + + NamedTunnel *NamedTunnelConfig + ReplaceExisting bool } // ReconnectTunnelCredentialManager is invoked by functions in this file to @@ -103,6 +102,8 @@ type ReconnectTunnelCredentialManager interface { type dupConnRegisterTunnelError struct{} +var errDuplicationConnection = &dupConnRegisterTunnelError{} + func (e dupConnRegisterTunnelError) Error() string { return "already connected to this server" } @@ -171,21 +172,35 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } } -func (c *TunnelConfig) SupportedFeatures() []string { - basic := []string{FeatureSerializedHeaders} - if c.UseQuickReconnects { - basic = append(basic, FeatureQuickReconnects) +func (c *TunnelConfig) ConnectionOptions(originLocalAddr string) *tunnelpogs.ConnectionOptions { + // attempt to parse out origin IP, but don't fail since it's informational field + host, _, _ := net.SplitHostPort(originLocalAddr) + originIP := net.ParseIP(host) + + return &tunnelpogs.ConnectionOptions{ + Client: c.NamedTunnel.Client, + OriginLocalIP: originIP, + ReplaceExisting: c.ReplaceExisting, + CompressionQuality: uint8(c.CompressionQuality), } - return basic +} + +func (c *TunnelConfig) SupportedFeatures() []string { + features := []string{FeatureSerializedHeaders} + if c.NamedTunnel == nil { + features = append(features, FeatureQuickReconnects) + } + return features } type NamedTunnelConfig struct { - Auth pogs.TunnelAuth - ID string + Auth pogs.TunnelAuth + ID uuid.UUID + Client pogs.ClientInfo } -func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error { - s, err := NewSupervisor(config, cloudflaredID, namedTunnel) +func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { + s, err := NewSupervisor(config, cloudflaredID) if err != nil { return err } @@ -196,7 +211,7 @@ func ServeTunnelLoop(ctx context.Context, credentialManager ReconnectTunnelCredentialManager, config *TunnelConfig, addr *net.TCPAddr, - connectionID uint8, + connectionIndex uint8, connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, bufferPool *buffer.Pool, @@ -219,7 +234,7 @@ func ServeTunnelLoop(ctx context.Context, credentialManager, config, config.Logger, - addr, connectionID, + addr, connectionIndex, connectedFuse, &backoff, cloudflaredUUID, @@ -228,7 +243,7 @@ func ServeTunnelLoop(ctx context.Context, ) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { - config.Logger.Infof("Retrying in %s seconds: connectionID: %d", duration, connectionID) + config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration) backoff.Backoff(ctx) continue } @@ -243,7 +258,7 @@ func ServeTunnel( config *TunnelConfig, logger logger.Service, addr *net.TCPAddr, - connectionID uint8, + connectionIndex uint8, connectedFuse *h2mux.BooleanFuse, backoff *BackoffHandler, cloudflaredUUID uuid.UUID, @@ -262,22 +277,18 @@ func ServeTunnel( } }() - connectionTag := uint8ToString(connectionID) - - // additional tags to send other than hostname which is set in cloudflared main package - tags := make(map[string]string) - tags["ha"] = connectionTag + connectionTag := uint8ToString(connectionIndex) // Returns error from parsing the origin URL or handshake errors - handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool) + handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool) if err != nil { switch err.(type) { case connection.DialError: - logger.Errorf("Unable to dial edge: %s connectionID: %d", err, connectionID) + logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err) case h2mux.MuxerHandshakeError: - logger.Errorf("Handshake failed with edge server: %s connectionID: %d", err, connectionID) + logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err) default: - logger.Errorf("Tunnel creation failure: %s connectionID: %d", err, connectionID) + logger.Errorf("Connection %d failed: %s", connectionIndex, err) return err, false } return err, true @@ -293,20 +304,21 @@ func ServeTunnel( } }() + if config.NamedTunnel != nil { + return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr) + } + if config.UseReconnectToken && connectedFuse.Value() { token, tokenErr := credentialManager.ReconnectToken() eventDigest, eventDigestErr := credentialManager.EventDigest() // if we have both credentials, we can reconnect if tokenErr == nil && eventDigestErr == nil { var connDigest []byte - - // check if we can use Quick Reconnects - if config.UseQuickReconnects { - if digest, connDigestErr := credentialManager.ConnDigest(connectionID); connDigestErr == nil { - connDigest = digest - } + if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil { + connDigest = digest } - return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager) + + return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager) } // log errors and proceed to RegisterTunnel if tokenErr != nil { @@ -316,7 +328,7 @@ func ServeTunnel( logger.Errorf("Couldn't get event digest: %s", eventDigestErr) } } - return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID) + return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID) }) errGroup.Go(func() error { @@ -325,12 +337,15 @@ func ServeTunnel( select { case <-serveCtx.Done(): // UnregisterTunnel blocks until the RPC call returns - var err error if connectedFuse.Value() { - err = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger) + if config.NamedTunnel != nil { + _ = UnregisterConnection(ctx, handler.muxer, config) + } else { + _ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger) + } } handler.muxer.Shutdown() - return err + return nil case <-updateMetricsTickC: handler.UpdateMetrics(connectionTag) } @@ -361,8 +376,6 @@ func ServeTunnel( err = errGroup.Wait() if err != nil { - _ = newClientRegisterTunnelError(err, config.Metrics.regFail, unknown) - switch castedErr := err.(type) { case dupConnRegisterTunnelError: logger.Info("Already connected to this server, selecting a different one") @@ -382,7 +395,7 @@ func ServeTunnel( logger.Info("Muxer shutdown") return err, true case *ReconnectSignal: - logger.Infof("Restarting due to reconnect signal in %d seconds", castedErr.Delay) + logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, castedErr.Delay) castedErr.DelayBeforeReconnect() return err, true default: @@ -393,6 +406,74 @@ func ServeTunnel( return nil, true } +func RegisterConnection( + ctx context.Context, + muxer *h2mux.Muxer, + config *TunnelConfig, + connectionIndex uint8, + originLocalAddr string, +) error { + const registerConnection = "registerConnection" + + config.TransportLogger.Debug("initiating RPC stream for RegisterConnection") + rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) + if err != nil { + // RPC stream open error + return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection) + } + defer rpc.Close() + + conn, err := rpc.RegisterConnection( + ctx, + config.NamedTunnel.Auth, + config.NamedTunnel.ID, + connectionIndex, + config.ConnectionOptions(originLocalAddr), + ) + if err != nil { + if err.Error() == DuplicateConnectionError { + config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc() + return errDuplicationConnection + } + config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc() + return serverRegistrationErrorFromRPC(err) + } + + config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc() + config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID) + + return nil +} + +func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { + if retryable, ok := err.(*tunnelpogs.RetryableError); ok { + return &serverRegisterTunnelError{ + cause: retryable.Unwrap(), + permanent: false, + } + } + return &serverRegisterTunnelError{ + cause: err, + permanent: true, + } +} + +func UnregisterConnection( + ctx context.Context, + muxer *h2mux.Muxer, + config *TunnelConfig, +) error { + config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection") + rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) + if err != nil { + // RPC stream open error + return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register) + } + defer rpc.Close() + + return rpc.UnregisterConnection(ctx) +} + func RegisterTunnel( ctx context.Context, credentialManager ReconnectTunnelCredentialManager, @@ -437,7 +518,7 @@ func ReconnectTunnel( config *TunnelConfig, logger logger.Service, connectionID uint8, - originLocalIP string, + originLocalAddr string, uuid uuid.UUID, credentialManager ReconnectTunnelCredentialManager, ) error { @@ -459,7 +540,7 @@ func ReconnectTunnel( eventDigest, connDigest, config.Hostname, - config.RegistrationOptions(connectionID, originLocalIP, uuid), + config.RegistrationOptions(connectionID, originLocalAddr, uuid), ) if registrationErr := registration.DeserializeError(); registrationErr != nil { // ReconnectTunnel RPC failure @@ -508,11 +589,11 @@ func processRegistrationSuccess( func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error { if err.Error() == DuplicateConnectionError { metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc() - return dupConnRegisterTunnelError{} + return errDuplicationConnection } metrics.regFail.WithLabelValues("server_error", string(name)).Inc() return serverRegisterTunnelError{ - cause: fmt.Errorf("Server error: %s", err.Error()), + cause: err, permanent: err.IsPermanent(), } } diff --git a/tunnelrpc/pogs/connectionrpc.go b/tunnelrpc/pogs/connectionrpc.go index 9327eea3..79102b0a 100644 --- a/tunnelrpc/pogs/connectionrpc.go +++ b/tunnelrpc/pogs/connectionrpc.go @@ -223,7 +223,7 @@ func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth Tu return nil, newRPCError("unknown result which %d", result.Which()) } -func (c TunnelServer_PogsClient) Unregister(ctx context.Context) error { +func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error { return nil