TUN-2819: cloudflared should close its connections when a signal is sent

This commit is contained in:
Adam Chalmers 2020-03-19 10:38:28 -05:00
parent 96f11de7ab
commit 6dcf3a4cbc
3 changed files with 30 additions and 14 deletions

View File

@ -7,10 +7,12 @@ import (
"net"
"net/url"
"os"
ossig "os/signal"
"reflect"
"runtime"
"runtime/trace"
"sync"
"syscall"
"time"
"github.com/cloudflare/cloudflared/awsuploader"
@ -399,10 +401,14 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return err
}
// When the user sends SIGUSR1, disconnect all connections.
reconnectCh := make(chan os.Signal, 1)
ossig.Notify(reconnectCh, syscall.SIGUSR1)
wg.Add(1)
go func() {
defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID)
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
}()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"os"
"sync"
"time"
@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
}, nil
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal); err != nil {
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err
}
var tunnelsWaiting []int
@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
}
// Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
logger := s.logger
s.lastResolve = time.Now()
@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
s.config.HAConnections = availableAddrs
}
go s.startFirstTunnel(ctx, connectedSignal)
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
select {
case <-ctx.Done():
<-s.tunnelErrors
@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch)
go s.startTunnel(ctx, i, ch, reconnectCh)
time.Sleep(registrationInterval)
}
return nil
@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
var (
addr *net.TCPAddr
err error
@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it.
edgeErrors := 0
for s.unusedIPs() {
@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
}
}
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
var (
addr *net.TCPAddr
err error
@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil {
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool)
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
}
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {

View File

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
@ -169,12 +170,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan os.Signal) error {
s, err := NewSupervisor(config, cloudflaredID)
if err != nil {
return err
}
return s.Run(ctx, connectedSignal)
return s.Run(ctx, connectedSignal, reconnectCh)
}
func ServeTunnelLoop(ctx context.Context,
@ -185,6 +186,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedSignal *signal.Signal,
u uuid.UUID,
bufferPool *buffer.Pool,
reconnectCh chan os.Signal,
) error {
connectionLogger := config.Logger.WithField("connectionID", connectionID)
config.Metrics.incrementHaConnections()
@ -209,6 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
&backoff,
u,
bufferPool,
reconnectCh,
)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
@ -232,6 +235,7 @@ func ServeTunnel(
backoff *BackoffHandler,
u uuid.UUID,
bufferPool *buffer.Pool,
reconnectCh chan os.Signal,
) (err error, recoverable bool) {
// Treat panics as recoverable errors
defer func() {
@ -318,6 +322,11 @@ func ServeTunnel(
}
})
errGroup.Go(func() error {
<-reconnectCh
return fmt.Errorf("received disconnect signal")
})
errGroup.Go(func() error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown