TUN-3106: Pass NamedTunnel config to StartServer

This commit is contained in:
Adam Chalmers 2020-06-17 13:33:55 -05:00
parent 9131e842a5
commit 4d3ebaf984
6 changed files with 28 additions and 18 deletions

View File

@ -129,7 +129,7 @@ func action(version string, shutdownC, graceShutdownC chan struct{}) cli.ActionF
tags := make(map[string]string) tags := make(map[string]string)
tags["hostname"] = c.String("hostname") tags["hostname"] = c.String("hostname")
raven.SetTagsContext(tags) raven.SetTagsContext(tags)
raven.CapturePanic(func() { err = tunnel.StartServer(c, version, shutdownC, graceShutdownC) }, nil) raven.CapturePanic(func() { err = tunnel.StartServer(c, version, shutdownC, graceShutdownC, nil) }, nil)
exitCode := 0 exitCode := 0
if err != nil { if err != nil {
handleError(err) handleError(err)

View File

@ -207,7 +207,7 @@ func Commands() []*cli.Command {
} }
func tunnel(c *cli.Context) error { func tunnel(c *cli.Context) error {
return StartServer(c, version, shutdownC, graceShutdownC) return StartServer(c, version, shutdownC, graceShutdownC, nil)
} }
func Init(v string, s, g chan struct{}) { func Init(v string, s, g chan struct{}) {
@ -238,7 +238,7 @@ func createLogger(c *cli.Context, isTransport bool) (logger.Service, error) {
return logger.New(loggerOpts...) return logger.New(loggerOpts...)
} }
func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan struct{}) error { func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan struct{}, namedTunnel *origin.NamedTunnelConfig) error {
logger, err := createLogger(c, false) logger, err := createLogger(c, false)
if err != nil { if err != nil {
return cliutil.PrintLoggerSetupError("error setting up logger", err) return cliutil.PrintLoggerSetupError("error setting up logger", err)
@ -475,7 +475,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh) errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel)
}() }()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger)

View File

@ -18,6 +18,7 @@ import (
"github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstore" "github.com/cloudflare/cloudflared/tunnelstore"
) )
@ -339,5 +340,5 @@ func runTunnel(c *cli.Context) error {
return err return err
} }
logger.Debugf("Read credentials for %v", credentials.AccountTag) logger.Debugf("Read credentials for %v", credentials.AccountTag)
panic("TODO: start tunnel supervisor") return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id})
} }

View File

@ -68,6 +68,8 @@ type Supervisor struct {
connDigest map[uint8][]byte connDigest map[uint8][]byte
bufferPool *buffer.Pool bufferPool *buffer.Pool
namedTunnel *NamedTunnelConfig
} }
type resolveResult struct { type resolveResult struct {
@ -80,7 +82,7 @@ type tunnelError struct {
err error err error
} }
func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) {
var ( var (
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
err error err error
@ -94,7 +96,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
return nil, err return nil, err
} }
return &Supervisor{ return &Supervisor{
cloudflaredUUID: u, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
@ -102,6 +104,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
logger: config.Logger, logger: config.Logger,
connDigest: make(map[uint8][]byte), connDigest: make(map[uint8][]byte),
bufferPool: buffer.NewPool(512 * 1024), bufferPool: buffer.NewPool(512 * 1024),
namedTunnel: namedTunnel,
}, nil }, nil
} }

View File

@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New()) s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New()) s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New()) s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }
@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
func TestRefreshAuthFail(t *testing.T) { func TestRefreshAuthFail(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) logger := logger.NewOutputWriter(logger.NewMockWriteManager())
s, err := NewSupervisor(testConfig(logger), uuid.New()) s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
t.FailNow() t.FailNow()
} }

View File

@ -26,6 +26,7 @@ import (
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
@ -178,8 +179,13 @@ func (c *TunnelConfig) SupportedFeatures() []string {
return basic return basic
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { type NamedTunnelConfig struct {
s, err := NewSupervisor(config, cloudflaredID) Auth pogs.TunnelAuth
ID string
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error {
s, err := NewSupervisor(config, cloudflaredID, namedTunnel)
if err != nil { if err != nil {
return err return err
} }
@ -192,7 +198,7 @@ func ServeTunnelLoop(ctx context.Context,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionID uint8, connectionID uint8,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
u uuid.UUID, cloudflaredUUID uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) error { ) error {
@ -216,7 +222,7 @@ func ServeTunnelLoop(ctx context.Context,
addr, connectionID, addr, connectionID,
connectedFuse, connectedFuse,
&backoff, &backoff,
u, cloudflaredUUID,
bufferPool, bufferPool,
reconnectCh, reconnectCh,
) )
@ -240,7 +246,7 @@ func ServeTunnel(
connectionID uint8, connectionID uint8,
connectedFuse *h2mux.BooleanFuse, connectedFuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *BackoffHandler,
u uuid.UUID, cloudflaredUUID uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
@ -300,7 +306,7 @@ func ServeTunnel(
connDigest = digest connDigest = digest
} }
} }
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, u, credentialManager) return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager)
} }
// log errors and proceed to RegisterTunnel // log errors and proceed to RegisterTunnel
if tokenErr != nil { if tokenErr != nil {
@ -310,7 +316,7 @@ func ServeTunnel(
logger.Errorf("Couldn't get event digest: %s", eventDigestErr) logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
} }
} }
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, u) return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {