TUN-3007: Implement named tunnel connection registration and unregistration.
Removed flag for using quick reconnect, this logic is now always enabled.
This commit is contained in:
		
							parent
							
								
									932e383051
								
							
						
					
					
						commit
						2a3d486126
					
				| 
						 | 
					@ -106,7 +106,7 @@ func ssh(c *cli.Context) error {
 | 
				
			||||||
	wsConn := carrier.NewWSConnection(logger, false)
 | 
						wsConn := carrier.NewWSConnection(logger, false)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c.NArg() > 0 || c.IsSet(sshURLFlag) {
 | 
						if c.NArg() > 0 || c.IsSet(sshURLFlag) {
 | 
				
			||||||
		localForwarder, err := config.ValidateUrl(c)
 | 
							localForwarder, err := config.ValidateUrl(c, true)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			logger.Errorf("Error validating origin URL: %s", err)
 | 
								logger.Errorf("Error validating origin URL: %s", err)
 | 
				
			||||||
			return errors.Wrap(err, "error validating origin URL")
 | 
								return errors.Wrap(err, "error validating origin URL")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,11 +6,12 @@ import (
 | 
				
			||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/cloudflare/cloudflared/validation"
 | 
					 | 
				
			||||||
	homedir "github.com/mitchellh/go-homedir"
 | 
						homedir "github.com/mitchellh/go-homedir"
 | 
				
			||||||
	"gopkg.in/urfave/cli.v2"
 | 
						"gopkg.in/urfave/cli.v2"
 | 
				
			||||||
	"gopkg.in/urfave/cli.v2/altsrc"
 | 
						"gopkg.in/urfave/cli.v2/altsrc"
 | 
				
			||||||
	"gopkg.in/yaml.v2"
 | 
						"gopkg.in/yaml.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/validation"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
| 
						 | 
					@ -176,9 +177,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
 | 
					// ValidateUrl will validate url flag correctness. It can be either from --url or argument
 | 
				
			||||||
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
 | 
					// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
 | 
				
			||||||
func ValidateUrl(c *cli.Context) (string, error) {
 | 
					func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) {
 | 
				
			||||||
	var url = c.String("url")
 | 
						var url = c.String("url")
 | 
				
			||||||
	if c.NArg() > 0 {
 | 
						if allowFromArgs && c.NArg() > 0 {
 | 
				
			||||||
		if c.IsSet("url") {
 | 
							if c.IsSet("url") {
 | 
				
			||||||
			return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
 | 
								return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -359,7 +359,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
 | 
				
			||||||
			defer wg.Done()
 | 
								defer wg.Done()
 | 
				
			||||||
			hello.StartHelloWorldServer(logger, helloListener, shutdownC)
 | 
								hello.StartHelloWorldServer(logger, helloListener, shutdownC)
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
		c.Set("url", "https://"+helloListener.Addr().String())
 | 
							forceSetFlag(c, "url", "https://"+helloListener.Addr().String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c.IsSet(sshServerFlag) {
 | 
						if c.IsSet(sshServerFlag) {
 | 
				
			||||||
| 
						 | 
					@ -409,7 +409,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
 | 
				
			||||||
				close(shutdownC)
 | 
									close(shutdownC)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
		c.Set("url", "ssh://"+localServerAddress)
 | 
							forceSetFlag(c, "url", "ssh://"+localServerAddress)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	url := c.String("url")
 | 
						url := c.String("url")
 | 
				
			||||||
| 
						 | 
					@ -453,7 +453,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler)
 | 
								errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler)
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
		c.Set("url", "http://"+listener.Addr().String())
 | 
							forceSetFlag(c, "url", "http://"+listener.Addr().String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	transportLogger, err := createLogger(c, true)
 | 
						transportLogger, err := createLogger(c, true)
 | 
				
			||||||
| 
						 | 
					@ -461,7 +461,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
 | 
				
			||||||
		return errors.Wrap(err, "error setting up transport logger")
 | 
							return errors.Wrap(err, "error setting up transport logger")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger)
 | 
						tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, namedTunnel)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -475,12 +475,21 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
 | 
				
			||||||
	wg.Add(1)
 | 
						wg.Add(1)
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		defer wg.Done()
 | 
							defer wg.Done()
 | 
				
			||||||
		errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel)
 | 
							errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger)
 | 
						return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// forceSetFlag attempts to set the given flag value in the closest context that has it defined
 | 
				
			||||||
 | 
					func forceSetFlag(c *cli.Context, name, value string) {
 | 
				
			||||||
 | 
						for _, ctx := range c.Lineage() {
 | 
				
			||||||
 | 
							if err := ctx.Set(name, value); err == nil {
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Before(c *cli.Context) error {
 | 
					func Before(c *cli.Context) error {
 | 
				
			||||||
	logger, err := createLogger(c, false)
 | 
						logger, err := createLogger(c, false)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -969,13 +978,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
 | 
				
			||||||
			EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
 | 
								EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
 | 
				
			||||||
			Hidden:  true,
 | 
								Hidden:  true,
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
					 | 
				
			||||||
			Name:    "use-quick-reconnects",
 | 
					 | 
				
			||||||
			Usage:   "Test reestablishing connections with the new 'connection digest' flow.",
 | 
					 | 
				
			||||||
			Value:   true,
 | 
					 | 
				
			||||||
			EnvVars: []string{"TUNNEL_USE_QUICK_RECONNECTS"},
 | 
					 | 
				
			||||||
			Hidden:  true,
 | 
					 | 
				
			||||||
		}),
 | 
					 | 
				
			||||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
			Name:    "dial-edge-timeout",
 | 
								Name:    "dial-edge-timeout",
 | 
				
			||||||
			Usage:   "Maximum wait time to set up a connection with the edge",
 | 
								Usage:   "Maximum wait time to set up a connection with the edge",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -158,7 +158,10 @@ func prepareTunnelConfig(
 | 
				
			||||||
	version string,
 | 
						version string,
 | 
				
			||||||
	logger logger.Service,
 | 
						logger logger.Service,
 | 
				
			||||||
	transportLogger logger.Service,
 | 
						transportLogger logger.Service,
 | 
				
			||||||
 | 
						namedTunnel *origin.NamedTunnelConfig,
 | 
				
			||||||
) (*origin.TunnelConfig, error) {
 | 
					) (*origin.TunnelConfig, error) {
 | 
				
			||||||
 | 
						compatibilityMode := namedTunnel == nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hostname, err := validation.ValidateHostname(c.String("hostname"))
 | 
						hostname, err := validation.ValidateHostname(c.String("hostname"))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("Invalid hostname: %s", err)
 | 
							logger.Errorf("Invalid hostname: %s", err)
 | 
				
			||||||
| 
						 | 
					@ -181,7 +184,7 @@ func prepareTunnelConfig(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
 | 
						tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	originURL, err := config.ValidateUrl(c)
 | 
						originURL, err := config.ValidateUrl(c, compatibilityMode)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("Error validating origin URL: %s", err)
 | 
							logger.Errorf("Error validating origin URL: %s", err)
 | 
				
			||||||
		return nil, errors.Wrap(err, "Error validating origin URL")
 | 
							return nil, errors.Wrap(err, "Error validating origin URL")
 | 
				
			||||||
| 
						 | 
					@ -254,6 +257,19 @@ func prepareTunnelConfig(
 | 
				
			||||||
		return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
 | 
							return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if namedTunnel != nil {
 | 
				
			||||||
 | 
							clientUUID, err := uuid.NewRandom()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, errors.Wrap(err, "can't generate clientUUID")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							namedTunnel.Client = tunnelpogs.ClientInfo{
 | 
				
			||||||
 | 
								ClientID: clientUUID[:],
 | 
				
			||||||
 | 
								Features: []string{origin.FeatureSerializedHeaders},
 | 
				
			||||||
 | 
								Version:  version,
 | 
				
			||||||
 | 
								Arch:     fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &origin.TunnelConfig{
 | 
						return &origin.TunnelConfig{
 | 
				
			||||||
		BuildInfo:          buildInfo,
 | 
							BuildInfo:          buildInfo,
 | 
				
			||||||
		ClientID:           clientID,
 | 
							ClientID:           clientID,
 | 
				
			||||||
| 
						 | 
					@ -283,9 +299,10 @@ func prepareTunnelConfig(
 | 
				
			||||||
		RunFromTerminal:    isRunningFromTerminal(),
 | 
							RunFromTerminal:    isRunningFromTerminal(),
 | 
				
			||||||
		Tags:               tags,
 | 
							Tags:               tags,
 | 
				
			||||||
		TlsConfig:          toEdgeTLSConfig,
 | 
							TlsConfig:          toEdgeTLSConfig,
 | 
				
			||||||
		UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
 | 
							NamedTunnel:        namedTunnel,
 | 
				
			||||||
		UseReconnectToken:    c.Bool("use-reconnect-token"),
 | 
							ReplaceExisting:    c.Bool("force"),
 | 
				
			||||||
		UseQuickReconnects:   c.Bool("use-quick-reconnects"),
 | 
							// turn off use of reconnect token and auth refresh when using named tunnels
 | 
				
			||||||
 | 
							UseReconnectToken:  compatibilityMode && c.Bool("use-reconnect-token"),
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,6 +11,7 @@ import (
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/google/uuid"
 | 
				
			||||||
	"github.com/pkg/errors"
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
	"gopkg.in/urfave/cli.v2"
 | 
						"gopkg.in/urfave/cli.v2"
 | 
				
			||||||
	"gopkg.in/yaml.v2"
 | 
						"gopkg.in/yaml.v2"
 | 
				
			||||||
| 
						 | 
					@ -34,7 +35,7 @@ var (
 | 
				
			||||||
		Aliases: []string{"o"},
 | 
							Aliases: []string{"o"},
 | 
				
			||||||
		Usage:   "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'",
 | 
							Usage:   "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	forceFlag = &cli.StringFlag{
 | 
						forceFlag = &cli.BoolFlag{
 | 
				
			||||||
		Name:    "force",
 | 
							Name:    "force",
 | 
				
			||||||
		Aliases: []string{"f"},
 | 
							Aliases: []string{"f"},
 | 
				
			||||||
		Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " +
 | 
							Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " +
 | 
				
			||||||
| 
						 | 
					@ -148,9 +149,12 @@ func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, e
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
 | 
							return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	auth := pogs.TunnelAuth{}
 | 
					
 | 
				
			||||||
	err = json.Unmarshal(body, &auth)
 | 
						var auth pogs.TunnelAuth
 | 
				
			||||||
	return &auth, errors.Wrap(err, "couldn't parse tunnel credentials from JSON")
 | 
						if err = json.Unmarshal(body, &auth); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &auth, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func buildListCommand() *cli.Command {
 | 
					func buildListCommand() *cli.Command {
 | 
				
			||||||
| 
						 | 
					@ -325,6 +329,10 @@ func runTunnel(c *cli.Context) error {
 | 
				
			||||||
		return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
 | 
							return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	id := c.Args().First()
 | 
						id := c.Args().First()
 | 
				
			||||||
 | 
						tunnelID, err := uuid.Parse(id)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errors.Wrap(err, "error parsing tunnel ID")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logger, err := logger.New()
 | 
						logger, err := logger.New()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -340,5 +348,5 @@ func runTunnel(c *cli.Context) error {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	logger.Debugf("Read credentials for %v", credentials.AccountTag)
 | 
						logger.Debugf("Read credentials for %v", credentials.AccountTag)
 | 
				
			||||||
	return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id})
 | 
						return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -68,8 +68,6 @@ type Supervisor struct {
 | 
				
			||||||
	connDigest     map[uint8][]byte
 | 
						connDigest     map[uint8][]byte
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	bufferPool *buffer.Pool
 | 
						bufferPool *buffer.Pool
 | 
				
			||||||
 | 
					 | 
				
			||||||
	namedTunnel *NamedTunnelConfig
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type resolveResult struct {
 | 
					type resolveResult struct {
 | 
				
			||||||
| 
						 | 
					@ -82,7 +80,7 @@ type tunnelError struct {
 | 
				
			||||||
	err   error
 | 
						err   error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) {
 | 
					func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
 | 
				
			||||||
	var (
 | 
						var (
 | 
				
			||||||
		edgeIPs *edgediscovery.Edge
 | 
							edgeIPs *edgediscovery.Edge
 | 
				
			||||||
		err     error
 | 
							err     error
 | 
				
			||||||
| 
						 | 
					@ -95,6 +93,7 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &Supervisor{
 | 
						return &Supervisor{
 | 
				
			||||||
		cloudflaredUUID:   cloudflaredUUID,
 | 
							cloudflaredUUID:   cloudflaredUUID,
 | 
				
			||||||
		config:            config,
 | 
							config:            config,
 | 
				
			||||||
| 
						 | 
					@ -104,7 +103,6 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
 | 
				
			||||||
		logger:            config.Logger,
 | 
							logger:            config.Logger,
 | 
				
			||||||
		connDigest:        make(map[uint8][]byte),
 | 
							connDigest:        make(map[uint8][]byte),
 | 
				
			||||||
		bufferPool:        buffer.NewPool(512 * 1024),
 | 
							bufferPool:        buffer.NewPool(512 * 1024),
 | 
				
			||||||
		namedTunnel:       namedTunnel,
 | 
					 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -229,17 +227,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
 | 
				
			||||||
		addr *net.TCPAddr
 | 
							addr *net.TCPAddr
 | 
				
			||||||
		err  error
 | 
							err  error
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	const thisConnID = 0
 | 
						const firstConnIndex = 0
 | 
				
			||||||
	defer func() {
 | 
						defer func() {
 | 
				
			||||||
		s.tunnelErrors <- tunnelError{index: thisConnID, addr: addr, err: err}
 | 
							s.tunnelErrors <- tunnelError{index: firstConnIndex, addr: addr, err: err}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	addr, err = s.edgeIPs.GetAddr(thisConnID)
 | 
						addr, err = s.edgeIPs.GetAddr(firstConnIndex)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
 | 
						err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
 | 
				
			||||||
	// If the first tunnel disconnects, keep restarting it.
 | 
						// If the first tunnel disconnects, keep restarting it.
 | 
				
			||||||
	edgeErrors := 0
 | 
						edgeErrors := 0
 | 
				
			||||||
	for s.unusedIPs() {
 | 
						for s.unusedIPs() {
 | 
				
			||||||
| 
						 | 
					@ -257,12 +255,12 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if edgeErrors >= 2 {
 | 
							if edgeErrors >= 2 {
 | 
				
			||||||
			addr, err = s.edgeIPs.GetDifferentAddr(thisConnID)
 | 
								addr, err = s.edgeIPs.GetDifferentAddr(firstConnIndex)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
 | 
							err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
 | 
				
			||||||
		return time.After(d)
 | 
							return time.After(d)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
 | 
						s, err := NewSupervisor(testConfig(logger), uuid.New())
 | 
				
			||||||
	if !assert.NoError(t, err) {
 | 
						if !assert.NoError(t, err) {
 | 
				
			||||||
		t.FailNow()
 | 
							t.FailNow()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) {
 | 
				
			||||||
		return time.After(d)
 | 
							return time.After(d)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
 | 
						s, err := NewSupervisor(testConfig(logger), uuid.New())
 | 
				
			||||||
	if !assert.NoError(t, err) {
 | 
						if !assert.NoError(t, err) {
 | 
				
			||||||
		t.FailNow()
 | 
							t.FailNow()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
 | 
				
			||||||
		return time.After(d)
 | 
							return time.After(d)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
 | 
						s, err := NewSupervisor(testConfig(logger), uuid.New())
 | 
				
			||||||
	if !assert.NoError(t, err) {
 | 
						if !assert.NoError(t, err) {
 | 
				
			||||||
		t.FailNow()
 | 
							t.FailNow()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
 | 
				
			||||||
func TestRefreshAuthFail(t *testing.T) {
 | 
					func TestRefreshAuthFail(t *testing.T) {
 | 
				
			||||||
	logger := logger.NewOutputWriter(logger.NewMockWriteManager())
 | 
						logger := logger.NewOutputWriter(logger.NewMockWriteManager())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
 | 
						s, err := NewSupervisor(testConfig(logger), uuid.New())
 | 
				
			||||||
	if !assert.NoError(t, err) {
 | 
						if !assert.NoError(t, err) {
 | 
				
			||||||
		t.FailNow()
 | 
							t.FailNow()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										165
									
								
								origin/tunnel.go
								
								
								
								
							
							
						
						
									
										165
									
								
								origin/tunnel.go
								
								
								
								
							| 
						 | 
					@ -48,7 +48,6 @@ type registerRPCName string
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	register  registerRPCName = "register"
 | 
						register  registerRPCName = "register"
 | 
				
			||||||
	reconnect registerRPCName = "reconnect"
 | 
						reconnect registerRPCName = "reconnect"
 | 
				
			||||||
	unknown   registerRPCName = "unknown"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TunnelConfig struct {
 | 
					type TunnelConfig struct {
 | 
				
			||||||
| 
						 | 
					@ -80,15 +79,15 @@ type TunnelConfig struct {
 | 
				
			||||||
	RunFromTerminal    bool
 | 
						RunFromTerminal    bool
 | 
				
			||||||
	Tags               []tunnelpogs.Tag
 | 
						Tags               []tunnelpogs.Tag
 | 
				
			||||||
	TlsConfig          *tls.Config
 | 
						TlsConfig          *tls.Config
 | 
				
			||||||
	UseDeclarativeTunnel bool
 | 
					 | 
				
			||||||
	WSGI               bool
 | 
						WSGI               bool
 | 
				
			||||||
	// OriginUrl may not be used if a user specifies a unix socket.
 | 
						// OriginUrl may not be used if a user specifies a unix socket.
 | 
				
			||||||
	OriginUrl string
 | 
						OriginUrl string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// feature-flag to use new edge reconnect tokens
 | 
						// feature-flag to use new edge reconnect tokens
 | 
				
			||||||
	UseReconnectToken bool
 | 
						UseReconnectToken bool
 | 
				
			||||||
	// feature-flag for using ConnectionDigest
 | 
					
 | 
				
			||||||
	UseQuickReconnects bool
 | 
						NamedTunnel     *NamedTunnelConfig
 | 
				
			||||||
 | 
						ReplaceExisting bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ReconnectTunnelCredentialManager is invoked by functions in this file to
 | 
					// ReconnectTunnelCredentialManager is invoked by functions in this file to
 | 
				
			||||||
| 
						 | 
					@ -103,6 +102,8 @@ type ReconnectTunnelCredentialManager interface {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type dupConnRegisterTunnelError struct{}
 | 
					type dupConnRegisterTunnelError struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var errDuplicationConnection = &dupConnRegisterTunnelError{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (e dupConnRegisterTunnelError) Error() string {
 | 
					func (e dupConnRegisterTunnelError) Error() string {
 | 
				
			||||||
	return "already connected to this server"
 | 
						return "already connected to this server"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -171,21 +172,35 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *TunnelConfig) SupportedFeatures() []string {
 | 
					func (c *TunnelConfig) ConnectionOptions(originLocalAddr string) *tunnelpogs.ConnectionOptions {
 | 
				
			||||||
	basic := []string{FeatureSerializedHeaders}
 | 
						// attempt to parse out origin IP, but don't fail since it's informational field
 | 
				
			||||||
	if c.UseQuickReconnects {
 | 
						host, _, _ := net.SplitHostPort(originLocalAddr)
 | 
				
			||||||
		basic = append(basic, FeatureQuickReconnects)
 | 
						originIP := net.ParseIP(host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &tunnelpogs.ConnectionOptions{
 | 
				
			||||||
 | 
							Client:             c.NamedTunnel.Client,
 | 
				
			||||||
 | 
							OriginLocalIP:      originIP,
 | 
				
			||||||
 | 
							ReplaceExisting:    c.ReplaceExisting,
 | 
				
			||||||
 | 
							CompressionQuality: uint8(c.CompressionQuality),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return basic
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *TunnelConfig) SupportedFeatures() []string {
 | 
				
			||||||
 | 
						features := []string{FeatureSerializedHeaders}
 | 
				
			||||||
 | 
						if c.NamedTunnel == nil {
 | 
				
			||||||
 | 
							features = append(features, FeatureQuickReconnects)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return features
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type NamedTunnelConfig struct {
 | 
					type NamedTunnelConfig struct {
 | 
				
			||||||
	Auth   pogs.TunnelAuth
 | 
						Auth   pogs.TunnelAuth
 | 
				
			||||||
	ID   string
 | 
						ID     uuid.UUID
 | 
				
			||||||
 | 
						Client pogs.ClientInfo
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error {
 | 
					func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
 | 
				
			||||||
	s, err := NewSupervisor(config, cloudflaredID, namedTunnel)
 | 
						s, err := NewSupervisor(config, cloudflaredID)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -196,7 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
 | 
				
			||||||
	credentialManager ReconnectTunnelCredentialManager,
 | 
						credentialManager ReconnectTunnelCredentialManager,
 | 
				
			||||||
	config *TunnelConfig,
 | 
						config *TunnelConfig,
 | 
				
			||||||
	addr *net.TCPAddr,
 | 
						addr *net.TCPAddr,
 | 
				
			||||||
	connectionID uint8,
 | 
						connectionIndex uint8,
 | 
				
			||||||
	connectedSignal *signal.Signal,
 | 
						connectedSignal *signal.Signal,
 | 
				
			||||||
	cloudflaredUUID uuid.UUID,
 | 
						cloudflaredUUID uuid.UUID,
 | 
				
			||||||
	bufferPool *buffer.Pool,
 | 
						bufferPool *buffer.Pool,
 | 
				
			||||||
| 
						 | 
					@ -219,7 +234,7 @@ func ServeTunnelLoop(ctx context.Context,
 | 
				
			||||||
			credentialManager,
 | 
								credentialManager,
 | 
				
			||||||
			config,
 | 
								config,
 | 
				
			||||||
			config.Logger,
 | 
								config.Logger,
 | 
				
			||||||
			addr, connectionID,
 | 
								addr, connectionIndex,
 | 
				
			||||||
			connectedFuse,
 | 
								connectedFuse,
 | 
				
			||||||
			&backoff,
 | 
								&backoff,
 | 
				
			||||||
			cloudflaredUUID,
 | 
								cloudflaredUUID,
 | 
				
			||||||
| 
						 | 
					@ -228,7 +243,7 @@ func ServeTunnelLoop(ctx context.Context,
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
		if recoverable {
 | 
							if recoverable {
 | 
				
			||||||
			if duration, ok := backoff.GetBackoffDuration(ctx); ok {
 | 
								if duration, ok := backoff.GetBackoffDuration(ctx); ok {
 | 
				
			||||||
				config.Logger.Infof("Retrying in %s seconds: connectionID: %d", duration, connectionID)
 | 
									config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration)
 | 
				
			||||||
				backoff.Backoff(ctx)
 | 
									backoff.Backoff(ctx)
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -243,7 +258,7 @@ func ServeTunnel(
 | 
				
			||||||
	config *TunnelConfig,
 | 
						config *TunnelConfig,
 | 
				
			||||||
	logger logger.Service,
 | 
						logger logger.Service,
 | 
				
			||||||
	addr *net.TCPAddr,
 | 
						addr *net.TCPAddr,
 | 
				
			||||||
	connectionID uint8,
 | 
						connectionIndex uint8,
 | 
				
			||||||
	connectedFuse *h2mux.BooleanFuse,
 | 
						connectedFuse *h2mux.BooleanFuse,
 | 
				
			||||||
	backoff *BackoffHandler,
 | 
						backoff *BackoffHandler,
 | 
				
			||||||
	cloudflaredUUID uuid.UUID,
 | 
						cloudflaredUUID uuid.UUID,
 | 
				
			||||||
| 
						 | 
					@ -262,22 +277,18 @@ func ServeTunnel(
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	connectionTag := uint8ToString(connectionID)
 | 
						connectionTag := uint8ToString(connectionIndex)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// additional tags to send other than hostname which is set in cloudflared main package
 | 
					 | 
				
			||||||
	tags := make(map[string]string)
 | 
					 | 
				
			||||||
	tags["ha"] = connectionTag
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Returns error from parsing the origin URL or handshake errors
 | 
						// Returns error from parsing the origin URL or handshake errors
 | 
				
			||||||
	handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool)
 | 
						handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		switch err.(type) {
 | 
							switch err.(type) {
 | 
				
			||||||
		case connection.DialError:
 | 
							case connection.DialError:
 | 
				
			||||||
			logger.Errorf("Unable to dial edge: %s connectionID: %d", err, connectionID)
 | 
								logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
 | 
				
			||||||
		case h2mux.MuxerHandshakeError:
 | 
							case h2mux.MuxerHandshakeError:
 | 
				
			||||||
			logger.Errorf("Handshake failed with edge server: %s connectionID: %d", err, connectionID)
 | 
								logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			logger.Errorf("Tunnel creation failure: %s connectionID: %d", err, connectionID)
 | 
								logger.Errorf("Connection %d failed: %s", connectionIndex, err)
 | 
				
			||||||
			return err, false
 | 
								return err, false
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return err, true
 | 
							return err, true
 | 
				
			||||||
| 
						 | 
					@ -293,20 +304,21 @@ func ServeTunnel(
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if config.NamedTunnel != nil {
 | 
				
			||||||
 | 
								return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if config.UseReconnectToken && connectedFuse.Value() {
 | 
							if config.UseReconnectToken && connectedFuse.Value() {
 | 
				
			||||||
			token, tokenErr := credentialManager.ReconnectToken()
 | 
								token, tokenErr := credentialManager.ReconnectToken()
 | 
				
			||||||
			eventDigest, eventDigestErr := credentialManager.EventDigest()
 | 
								eventDigest, eventDigestErr := credentialManager.EventDigest()
 | 
				
			||||||
			// if we have both credentials, we can reconnect
 | 
								// if we have both credentials, we can reconnect
 | 
				
			||||||
			if tokenErr == nil && eventDigestErr == nil {
 | 
								if tokenErr == nil && eventDigestErr == nil {
 | 
				
			||||||
				var connDigest []byte
 | 
									var connDigest []byte
 | 
				
			||||||
 | 
									if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil {
 | 
				
			||||||
				// check if we can use Quick Reconnects
 | 
					 | 
				
			||||||
				if config.UseQuickReconnects {
 | 
					 | 
				
			||||||
					if digest, connDigestErr := credentialManager.ConnDigest(connectionID); connDigestErr == nil {
 | 
					 | 
				
			||||||
					connDigest = digest
 | 
										connDigest = digest
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				}
 | 
					
 | 
				
			||||||
				return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager)
 | 
									return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			// log errors and proceed to RegisterTunnel
 | 
								// log errors and proceed to RegisterTunnel
 | 
				
			||||||
			if tokenErr != nil {
 | 
								if tokenErr != nil {
 | 
				
			||||||
| 
						 | 
					@ -316,7 +328,7 @@ func ServeTunnel(
 | 
				
			||||||
				logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
 | 
									logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID)
 | 
							return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	errGroup.Go(func() error {
 | 
						errGroup.Go(func() error {
 | 
				
			||||||
| 
						 | 
					@ -325,12 +337,15 @@ func ServeTunnel(
 | 
				
			||||||
			select {
 | 
								select {
 | 
				
			||||||
			case <-serveCtx.Done():
 | 
								case <-serveCtx.Done():
 | 
				
			||||||
				// UnregisterTunnel blocks until the RPC call returns
 | 
									// UnregisterTunnel blocks until the RPC call returns
 | 
				
			||||||
				var err error
 | 
					 | 
				
			||||||
				if connectedFuse.Value() {
 | 
									if connectedFuse.Value() {
 | 
				
			||||||
					err = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
 | 
										if config.NamedTunnel != nil {
 | 
				
			||||||
 | 
											_ = UnregisterConnection(ctx, handler.muxer, config)
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				handler.muxer.Shutdown()
 | 
									handler.muxer.Shutdown()
 | 
				
			||||||
				return err
 | 
									return nil
 | 
				
			||||||
			case <-updateMetricsTickC:
 | 
								case <-updateMetricsTickC:
 | 
				
			||||||
				handler.UpdateMetrics(connectionTag)
 | 
									handler.UpdateMetrics(connectionTag)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -361,8 +376,6 @@ func ServeTunnel(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = errGroup.Wait()
 | 
						err = errGroup.Wait()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		_ = newClientRegisterTunnelError(err, config.Metrics.regFail, unknown)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		switch castedErr := err.(type) {
 | 
							switch castedErr := err.(type) {
 | 
				
			||||||
		case dupConnRegisterTunnelError:
 | 
							case dupConnRegisterTunnelError:
 | 
				
			||||||
			logger.Info("Already connected to this server, selecting a different one")
 | 
								logger.Info("Already connected to this server, selecting a different one")
 | 
				
			||||||
| 
						 | 
					@ -382,7 +395,7 @@ func ServeTunnel(
 | 
				
			||||||
			logger.Info("Muxer shutdown")
 | 
								logger.Info("Muxer shutdown")
 | 
				
			||||||
			return err, true
 | 
								return err, true
 | 
				
			||||||
		case *ReconnectSignal:
 | 
							case *ReconnectSignal:
 | 
				
			||||||
			logger.Infof("Restarting due to reconnect signal in %d seconds", castedErr.Delay)
 | 
								logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, castedErr.Delay)
 | 
				
			||||||
			castedErr.DelayBeforeReconnect()
 | 
								castedErr.DelayBeforeReconnect()
 | 
				
			||||||
			return err, true
 | 
								return err, true
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
| 
						 | 
					@ -393,6 +406,74 @@ func ServeTunnel(
 | 
				
			||||||
	return nil, true
 | 
						return nil, true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func RegisterConnection(
 | 
				
			||||||
 | 
						ctx context.Context,
 | 
				
			||||||
 | 
						muxer *h2mux.Muxer,
 | 
				
			||||||
 | 
						config *TunnelConfig,
 | 
				
			||||||
 | 
						connectionIndex uint8,
 | 
				
			||||||
 | 
						originLocalAddr string,
 | 
				
			||||||
 | 
					) error {
 | 
				
			||||||
 | 
						const registerConnection = "registerConnection"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
 | 
				
			||||||
 | 
						rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							// RPC stream open error
 | 
				
			||||||
 | 
							return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rpc.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := rpc.RegisterConnection(
 | 
				
			||||||
 | 
							ctx,
 | 
				
			||||||
 | 
							config.NamedTunnel.Auth,
 | 
				
			||||||
 | 
							config.NamedTunnel.ID,
 | 
				
			||||||
 | 
							connectionIndex,
 | 
				
			||||||
 | 
							config.ConnectionOptions(originLocalAddr),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if err.Error() == DuplicateConnectionError {
 | 
				
			||||||
 | 
								config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
 | 
				
			||||||
 | 
								return errDuplicationConnection
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
 | 
				
			||||||
 | 
							return serverRegistrationErrorFromRPC(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
 | 
				
			||||||
 | 
						config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
 | 
				
			||||||
 | 
						if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
 | 
				
			||||||
 | 
							return &serverRegisterTunnelError{
 | 
				
			||||||
 | 
								cause:     retryable.Unwrap(),
 | 
				
			||||||
 | 
								permanent: false,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &serverRegisterTunnelError{
 | 
				
			||||||
 | 
							cause:     err,
 | 
				
			||||||
 | 
							permanent: true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func UnregisterConnection(
 | 
				
			||||||
 | 
						ctx context.Context,
 | 
				
			||||||
 | 
						muxer *h2mux.Muxer,
 | 
				
			||||||
 | 
						config *TunnelConfig,
 | 
				
			||||||
 | 
					) error {
 | 
				
			||||||
 | 
						config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
 | 
				
			||||||
 | 
						rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							// RPC stream open error
 | 
				
			||||||
 | 
							return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rpc.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return rpc.UnregisterConnection(ctx)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RegisterTunnel(
 | 
					func RegisterTunnel(
 | 
				
			||||||
	ctx context.Context,
 | 
						ctx context.Context,
 | 
				
			||||||
	credentialManager ReconnectTunnelCredentialManager,
 | 
						credentialManager ReconnectTunnelCredentialManager,
 | 
				
			||||||
| 
						 | 
					@ -437,7 +518,7 @@ func ReconnectTunnel(
 | 
				
			||||||
	config *TunnelConfig,
 | 
						config *TunnelConfig,
 | 
				
			||||||
	logger logger.Service,
 | 
						logger logger.Service,
 | 
				
			||||||
	connectionID uint8,
 | 
						connectionID uint8,
 | 
				
			||||||
	originLocalIP string,
 | 
						originLocalAddr string,
 | 
				
			||||||
	uuid uuid.UUID,
 | 
						uuid uuid.UUID,
 | 
				
			||||||
	credentialManager ReconnectTunnelCredentialManager,
 | 
						credentialManager ReconnectTunnelCredentialManager,
 | 
				
			||||||
) error {
 | 
					) error {
 | 
				
			||||||
| 
						 | 
					@ -459,7 +540,7 @@ func ReconnectTunnel(
 | 
				
			||||||
		eventDigest,
 | 
							eventDigest,
 | 
				
			||||||
		connDigest,
 | 
							connDigest,
 | 
				
			||||||
		config.Hostname,
 | 
							config.Hostname,
 | 
				
			||||||
		config.RegistrationOptions(connectionID, originLocalIP, uuid),
 | 
							config.RegistrationOptions(connectionID, originLocalAddr, uuid),
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	if registrationErr := registration.DeserializeError(); registrationErr != nil {
 | 
						if registrationErr := registration.DeserializeError(); registrationErr != nil {
 | 
				
			||||||
		// ReconnectTunnel RPC failure
 | 
							// ReconnectTunnel RPC failure
 | 
				
			||||||
| 
						 | 
					@ -508,11 +589,11 @@ func processRegistrationSuccess(
 | 
				
			||||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
 | 
					func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
 | 
				
			||||||
	if err.Error() == DuplicateConnectionError {
 | 
						if err.Error() == DuplicateConnectionError {
 | 
				
			||||||
		metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
 | 
							metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
 | 
				
			||||||
		return dupConnRegisterTunnelError{}
 | 
							return errDuplicationConnection
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
 | 
						metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
 | 
				
			||||||
	return serverRegisterTunnelError{
 | 
						return serverRegisterTunnelError{
 | 
				
			||||||
		cause:     fmt.Errorf("Server error: %s", err.Error()),
 | 
							cause:     err,
 | 
				
			||||||
		permanent: err.IsPermanent(),
 | 
							permanent: err.IsPermanent(),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -223,7 +223,7 @@ func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth Tu
 | 
				
			||||||
	return nil, newRPCError("unknown result which %d", result.Which())
 | 
						return nil, newRPCError("unknown result which %d", result.Which())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c TunnelServer_PogsClient) Unregister(ctx context.Context) error {
 | 
					func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
 | 
				
			||||||
	client := tunnelrpc.TunnelServer{Client: c.Client}
 | 
						client := tunnelrpc.TunnelServer{Client: c.Client}
 | 
				
			||||||
	promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
 | 
						promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue