TUN-2819: cloudflared should close its connections when a signal is sent
This commit is contained in:
parent
96f11de7ab
commit
6dcf3a4cbc
|
@ -7,10 +7,12 @@ import (
|
|||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
ossig "os/signal"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"runtime/trace"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/awsuploader"
|
||||
|
@ -399,10 +401,14 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
|||
return err
|
||||
}
|
||||
|
||||
// When the user sends SIGUSR1, disconnect all connections.
|
||||
reconnectCh := make(chan os.Signal, 1)
|
||||
ossig.Notify(reconnectCh, syscall.SIGUSR1)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID)
|
||||
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
|
||||
}()
|
||||
|
||||
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
|
||||
logger := s.config.Logger
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||
return err
|
||||
}
|
||||
var tunnelsWaiting []int
|
||||
|
@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
|||
case <-backoffTimer:
|
||||
backoffTimer = nil
|
||||
for _, index := range tunnelsWaiting {
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
|
||||
}
|
||||
tunnelsActive += len(tunnelsWaiting)
|
||||
tunnelsWaiting = nil
|
||||
|
@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
|||
}
|
||||
|
||||
// Returns nil if initialization succeeded, else the initialization error.
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
|
||||
logger := s.logger
|
||||
|
||||
s.lastResolve = time.Now()
|
||||
|
@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
|||
s.config.HAConnections = availableAddrs
|
||||
}
|
||||
|
||||
go s.startFirstTunnel(ctx, connectedSignal)
|
||||
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
<-s.tunnelErrors
|
||||
|
@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
|||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
ch := signal.New(make(chan struct{}))
|
||||
go s.startTunnel(ctx, i, ch)
|
||||
go s.startTunnel(ctx, i, ch, reconnectCh)
|
||||
time.Sleep(registrationInterval)
|
||||
}
|
||||
return nil
|
||||
|
@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
|||
|
||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
|
@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
|||
return
|
||||
}
|
||||
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
// If the first tunnel disconnects, keep restarting it.
|
||||
edgeErrors := 0
|
||||
for s.unusedIPs() {
|
||||
|
@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
|||
return
|
||||
}
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
|
@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -169,12 +170,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
|||
}
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan os.Signal) error {
|
||||
s, err := NewSupervisor(config, cloudflaredID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Run(ctx, connectedSignal)
|
||||
return s.Run(ctx, connectedSignal, reconnectCh)
|
||||
}
|
||||
|
||||
func ServeTunnelLoop(ctx context.Context,
|
||||
|
@ -185,6 +186,7 @@ func ServeTunnelLoop(ctx context.Context,
|
|||
connectedSignal *signal.Signal,
|
||||
u uuid.UUID,
|
||||
bufferPool *buffer.Pool,
|
||||
reconnectCh chan os.Signal,
|
||||
) error {
|
||||
connectionLogger := config.Logger.WithField("connectionID", connectionID)
|
||||
config.Metrics.incrementHaConnections()
|
||||
|
@ -209,6 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
|
|||
&backoff,
|
||||
u,
|
||||
bufferPool,
|
||||
reconnectCh,
|
||||
)
|
||||
if recoverable {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
|
@ -232,6 +235,7 @@ func ServeTunnel(
|
|||
backoff *BackoffHandler,
|
||||
u uuid.UUID,
|
||||
bufferPool *buffer.Pool,
|
||||
reconnectCh chan os.Signal,
|
||||
) (err error, recoverable bool) {
|
||||
// Treat panics as recoverable errors
|
||||
defer func() {
|
||||
|
@ -318,6 +322,11 @@ func ServeTunnel(
|
|||
}
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
<-reconnectCh
|
||||
return fmt.Errorf("received disconnect signal")
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||
|
|
Loading…
Reference in New Issue