TUN-2819: cloudflared should close its connections when a signal is sent

This commit is contained in:
Adam Chalmers 2020-03-19 10:38:28 -05:00
parent 96f11de7ab
commit 6dcf3a4cbc
3 changed files with 30 additions and 14 deletions

View File

@ -7,10 +7,12 @@ import (
"net" "net"
"net/url" "net/url"
"os" "os"
ossig "os/signal"
"reflect" "reflect"
"runtime" "runtime"
"runtime/trace" "runtime/trace"
"sync" "sync"
"syscall"
"time" "time"
"github.com/cloudflare/cloudflared/awsuploader" "github.com/cloudflare/cloudflared/awsuploader"
@ -399,10 +401,14 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return err return err
} }
// When the user sends SIGUSR1, disconnect all connections.
reconnectCh := make(chan os.Signal, 1)
ossig.Notify(reconnectCh, syscall.SIGUSR1)
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID) errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
}() }()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period")) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"time" "time"
@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
}, nil }, nil
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error { func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
logger := s.config.Logger logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal); err != nil { if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err return err
} }
var tunnelsWaiting []int var tunnelsWaiting []int
@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
case <-backoffTimer: case <-backoffTimer:
backoffTimer = nil backoffTimer = nil
for _, index := range tunnelsWaiting { for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
} }
tunnelsActive += len(tunnelsWaiting) tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil tunnelsWaiting = nil
@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
} }
// Returns nil if initialization succeeded, else the initialization error. // Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error { func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
logger := s.logger logger := s.logger
s.lastResolve = time.Now() s.lastResolve = time.Now()
@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
s.config.HAConnections = availableAddrs s.config.HAConnections = availableAddrs
} }
go s.startFirstTunnel(ctx, connectedSignal) go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
select { select {
case <-ctx.Done(): case <-ctx.Done():
<-s.tunnelErrors <-s.tunnelErrors
@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest // At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ { for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{})) ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch) go s.startTunnel(ctx, i, ch, reconnectCh)
time.Sleep(registrationInterval) time.Sleep(registrationInterval)
} }
return nil return nil
@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on // startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) { func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error
@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool) err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
for s.unusedIPs() { for s.unusedIPs() {
@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool) err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
} }
} }
// startTunnel starts a new tunnel connection. The resulting error will be sent on // startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors. // s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) { func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error
@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil { if err != nil {
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool) err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
} }
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {

View File

@ -9,6 +9,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -169,12 +170,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
} }
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error { func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan os.Signal) error {
s, err := NewSupervisor(config, cloudflaredID) s, err := NewSupervisor(config, cloudflaredID)
if err != nil { if err != nil {
return err return err
} }
return s.Run(ctx, connectedSignal) return s.Run(ctx, connectedSignal, reconnectCh)
} }
func ServeTunnelLoop(ctx context.Context, func ServeTunnelLoop(ctx context.Context,
@ -185,6 +186,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan os.Signal,
) error { ) error {
connectionLogger := config.Logger.WithField("connectionID", connectionID) connectionLogger := config.Logger.WithField("connectionID", connectionID)
config.Metrics.incrementHaConnections() config.Metrics.incrementHaConnections()
@ -209,6 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
&backoff, &backoff,
u, u,
bufferPool, bufferPool,
reconnectCh,
) )
if recoverable { if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok { if duration, ok := backoff.GetBackoffDuration(ctx); ok {
@ -232,6 +235,7 @@ func ServeTunnel(
backoff *BackoffHandler, backoff *BackoffHandler,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
reconnectCh chan os.Signal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -318,6 +322,11 @@ func ServeTunnel(
} }
}) })
errGroup.Go(func() error {
<-reconnectCh
return fmt.Errorf("received disconnect signal")
})
errGroup.Go(func() error { errGroup.Go(func() error {
// All routines should stop when muxer finish serving. When muxer is shutdown // All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown // gracefully, it doesn't return an error, so we need to return errMuxerShutdown