diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index dfd326d6..9a43fb46 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -7,10 +7,12 @@ import ( "net" "net/url" "os" + ossig "os/signal" "reflect" "runtime" "runtime/trace" "sync" + "syscall" "time" "github.com/cloudflare/cloudflared/awsuploader" @@ -399,10 +401,14 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan return err } + // When the user sends SIGUSR1, disconnect all connections. + reconnectCh := make(chan os.Signal, 1) + ossig.Notify(reconnectCh, syscall.SIGUSR1) + wg.Add(1) go func() { defer wg.Done() - errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID) + errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh) }() return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period")) diff --git a/origin/supervisor.go b/origin/supervisor.go index 1cb7c336..8c3a672e 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "os" "sync" "time" @@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { }, nil } -func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error { +func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error { logger := s.config.Logger - if err := s.initialize(ctx, connectedSignal); err != nil { + if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { return err } var tunnelsWaiting []int @@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er case <-backoffTimer: backoffTimer = nil for _, index := range tunnelsWaiting { - go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) + go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh) } tunnelsActive += len(tunnelsWaiting) tunnelsWaiting = nil @@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er } // Returns nil if initialization succeeded, else the initialization error. -func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error { +func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error { logger := s.logger s.lastResolve = time.Now() @@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig s.config.HAConnections = availableAddrs } - go s.startFirstTunnel(ctx, connectedSignal) + go s.startFirstTunnel(ctx, connectedSignal, reconnectCh) select { case <-ctx.Done(): <-s.tunnelErrors @@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { ch := signal.New(make(chan struct{})) - go s.startTunnel(ctx, i, ch) + go s.startTunnel(ctx, i, ch, reconnectCh) time.Sleep(registrationInterval) } return nil @@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed -func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) { +func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) { var ( addr *net.TCPAddr err error @@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } - err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool) + err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) // If the first tunnel disconnects, keep restarting it. edgeErrors := 0 for s.unusedIPs() { @@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } } - err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool) + err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) } } // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. -func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) { +func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) { var ( addr *net.TCPAddr err error @@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal if err != nil { return } - err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool) + err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) } func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { diff --git a/origin/tunnel.go b/origin/tunnel.go index 26e4e4ae..6f0bd2b4 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "os" "strconv" "strings" "sync" @@ -169,12 +170,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } } -func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error { +func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan os.Signal) error { s, err := NewSupervisor(config, cloudflaredID) if err != nil { return err } - return s.Run(ctx, connectedSignal) + return s.Run(ctx, connectedSignal, reconnectCh) } func ServeTunnelLoop(ctx context.Context, @@ -185,6 +186,7 @@ func ServeTunnelLoop(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID, bufferPool *buffer.Pool, + reconnectCh chan os.Signal, ) error { connectionLogger := config.Logger.WithField("connectionID", connectionID) config.Metrics.incrementHaConnections() @@ -209,6 +211,7 @@ func ServeTunnelLoop(ctx context.Context, &backoff, u, bufferPool, + reconnectCh, ) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { @@ -232,6 +235,7 @@ func ServeTunnel( backoff *BackoffHandler, u uuid.UUID, bufferPool *buffer.Pool, + reconnectCh chan os.Signal, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -318,6 +322,11 @@ func ServeTunnel( } }) + errGroup.Go(func() error { + <-reconnectCh + return fmt.Errorf("received disconnect signal") + }) + errGroup.Go(func() error { // All routines should stop when muxer finish serving. When muxer is shutdown // gracefully, it doesn't return an error, so we need to return errMuxerShutdown