From 4d3ebaf9844daf183ee7dc30c61dec8d57e16591 Mon Sep 17 00:00:00 2001 From: Adam Chalmers Date: Wed, 17 Jun 2020 13:33:55 -0500 Subject: [PATCH] TUN-3106: Pass NamedTunnel config to StartServer --- cmd/cloudflared/main.go | 2 +- cmd/cloudflared/tunnel/cmd.go | 6 +++--- cmd/cloudflared/tunnel/subcommands.go | 3 ++- origin/supervisor.go | 7 +++++-- origin/supervisor_test.go | 8 ++++---- origin/tunnel.go | 20 +++++++++++++------- 6 files changed, 28 insertions(+), 18 deletions(-) diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index 3a318e2d..71263059 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -129,7 +129,7 @@ func action(version string, shutdownC, graceShutdownC chan struct{}) cli.ActionF tags := make(map[string]string) tags["hostname"] = c.String("hostname") raven.SetTagsContext(tags) - raven.CapturePanic(func() { err = tunnel.StartServer(c, version, shutdownC, graceShutdownC) }, nil) + raven.CapturePanic(func() { err = tunnel.StartServer(c, version, shutdownC, graceShutdownC, nil) }, nil) exitCode := 0 if err != nil { handleError(err) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index d5d3b439..f3b8476d 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -207,7 +207,7 @@ func Commands() []*cli.Command { } func tunnel(c *cli.Context) error { - return StartServer(c, version, shutdownC, graceShutdownC) + return StartServer(c, version, shutdownC, graceShutdownC, nil) } func Init(v string, s, g chan struct{}) { @@ -238,7 +238,7 @@ func createLogger(c *cli.Context, isTransport bool) (logger.Service, error) { return logger.New(loggerOpts...) } -func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan struct{}) error { +func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan struct{}, namedTunnel *origin.NamedTunnelConfig) error { logger, err := createLogger(c, false) if err != nil { return cliutil.PrintLoggerSetupError("error setting up logger", err) @@ -475,7 +475,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan wg.Add(1) go func() { defer wg.Done() - errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh) + errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel) }() return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger) diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index a8dc7f98..b49882bd 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -18,6 +18,7 @@ import ( "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelstore" ) @@ -339,5 +340,5 @@ func runTunnel(c *cli.Context) error { return err } logger.Debugf("Read credentials for %v", credentials.AccountTag) - panic("TODO: start tunnel supervisor") + return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id}) } diff --git a/origin/supervisor.go b/origin/supervisor.go index b3c680f9..027d3173 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -68,6 +68,8 @@ type Supervisor struct { connDigest map[uint8][]byte bufferPool *buffer.Pool + + namedTunnel *NamedTunnelConfig } type resolveResult struct { @@ -80,7 +82,7 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { +func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) { var ( edgeIPs *edgediscovery.Edge err error @@ -94,7 +96,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { return nil, err } return &Supervisor{ - cloudflaredUUID: u, + cloudflaredUUID: cloudflaredUUID, config: config, edgeIPs: edgeIPs, tunnelErrors: make(chan tunnelError), @@ -102,6 +104,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { logger: config.Logger, connDigest: make(map[uint8][]byte), bufferPool: buffer.NewPool(512 * 1024), + namedTunnel: namedTunnel, }, nil } diff --git a/origin/supervisor_test.go b/origin/supervisor_test.go index 21eeec60..7b1ff701 100644 --- a/origin/supervisor_test.go +++ b/origin/supervisor_test.go @@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) { return time.After(d) } - s, err := NewSupervisor(testConfig(logger), uuid.New()) + s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) if !assert.NoError(t, err) { t.FailNow() } @@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) { return time.After(d) } - s, err := NewSupervisor(testConfig(logger), uuid.New()) + s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) if !assert.NoError(t, err) { t.FailNow() } @@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) { return time.After(d) } - s, err := NewSupervisor(testConfig(logger), uuid.New()) + s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) if !assert.NoError(t, err) { t.FailNow() } @@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) { func TestRefreshAuthFail(t *testing.T) { logger := logger.NewOutputWriter(logger.NewMockWriteManager()) - s, err := NewSupervisor(testConfig(logger), uuid.New()) + s, err := NewSupervisor(testConfig(logger), uuid.New(), nil) if !assert.NoError(t, err) { t.FailNow() } diff --git a/origin/tunnel.go b/origin/tunnel.go index dc5bbb78..20361e5d 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -26,6 +26,7 @@ import ( "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/websocket" @@ -178,8 +179,13 @@ func (c *TunnelConfig) SupportedFeatures() []string { return basic } -func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { - s, err := NewSupervisor(config, cloudflaredID) +type NamedTunnelConfig struct { + Auth pogs.TunnelAuth + ID string +} + +func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error { + s, err := NewSupervisor(config, cloudflaredID, namedTunnel) if err != nil { return err } @@ -192,7 +198,7 @@ func ServeTunnelLoop(ctx context.Context, addr *net.TCPAddr, connectionID uint8, connectedSignal *signal.Signal, - u uuid.UUID, + cloudflaredUUID uuid.UUID, bufferPool *buffer.Pool, reconnectCh chan ReconnectSignal, ) error { @@ -216,7 +222,7 @@ func ServeTunnelLoop(ctx context.Context, addr, connectionID, connectedFuse, &backoff, - u, + cloudflaredUUID, bufferPool, reconnectCh, ) @@ -240,7 +246,7 @@ func ServeTunnel( connectionID uint8, connectedFuse *h2mux.BooleanFuse, backoff *BackoffHandler, - u uuid.UUID, + cloudflaredUUID uuid.UUID, bufferPool *buffer.Pool, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { @@ -300,7 +306,7 @@ func ServeTunnel( connDigest = digest } } - return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, u, credentialManager) + return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager) } // log errors and proceed to RegisterTunnel if tokenErr != nil { @@ -310,7 +316,7 @@ func ServeTunnel( logger.Errorf("Couldn't get event digest: %s", eventDigestErr) } } - return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, u) + return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID) }) errGroup.Go(func() error {