TUN-2819: cloudflared should close its connections when a signal is sent
This commit is contained in:
parent
96f11de7ab
commit
6dcf3a4cbc
|
@ -7,10 +7,12 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
ossig "os/signal"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/trace"
|
"runtime/trace"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/awsuploader"
|
"github.com/cloudflare/cloudflared/awsuploader"
|
||||||
|
@ -399,10 +401,14 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When the user sends SIGUSR1, disconnect all connections.
|
||||||
|
reconnectCh := make(chan os.Signal, 1)
|
||||||
|
ossig.Notify(reconnectCh, syscall.SIGUSR1)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID)
|
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))
|
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -105,9 +106,9 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
|
||||||
logger := s.config.Logger
|
logger := s.config.Logger
|
||||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
|
@ -157,7 +158,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||||
case <-backoffTimer:
|
case <-backoffTimer:
|
||||||
backoffTimer = nil
|
backoffTimer = nil
|
||||||
for _, index := range tunnelsWaiting {
|
for _, index := range tunnelsWaiting {
|
||||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
|
||||||
}
|
}
|
||||||
tunnelsActive += len(tunnelsWaiting)
|
tunnelsActive += len(tunnelsWaiting)
|
||||||
tunnelsWaiting = nil
|
tunnelsWaiting = nil
|
||||||
|
@ -191,7 +192,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns nil if initialization succeeded, else the initialization error.
|
// Returns nil if initialization succeeded, else the initialization error.
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) error {
|
||||||
logger := s.logger
|
logger := s.logger
|
||||||
|
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
|
@ -201,7 +202,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
s.config.HAConnections = availableAddrs
|
s.config.HAConnections = availableAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.startFirstTunnel(ctx, connectedSignal)
|
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
<-s.tunnelErrors
|
<-s.tunnelErrors
|
||||||
|
@ -213,7 +214,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
// At least one successful connection, so start the rest
|
// At least one successful connection, so start the rest
|
||||||
for i := 1; i < s.config.HAConnections; i++ {
|
for i := 1; i < s.config.HAConnections; i++ {
|
||||||
ch := signal.New(make(chan struct{}))
|
ch := signal.New(make(chan struct{}))
|
||||||
go s.startTunnel(ctx, i, ch)
|
go s.startTunnel(ctx, i, ch, reconnectCh)
|
||||||
time.Sleep(registrationInterval)
|
time.Sleep(registrationInterval)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -221,7 +222,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
|
|
||||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
|
||||||
var (
|
var (
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
|
@ -236,7 +237,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||||
// If the first tunnel disconnects, keep restarting it.
|
// If the first tunnel disconnects, keep restarting it.
|
||||||
edgeErrors := 0
|
edgeErrors := 0
|
||||||
for s.unusedIPs() {
|
for s.unusedIPs() {
|
||||||
|
@ -259,13 +260,13 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors.
|
// s.tunnelErrors.
|
||||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan os.Signal) {
|
||||||
var (
|
var (
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
|
@ -278,7 +279,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool)
|
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -169,12 +170,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan os.Signal) error {
|
||||||
s, err := NewSupervisor(config, cloudflaredID)
|
s, err := NewSupervisor(config, cloudflaredID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.Run(ctx, connectedSignal)
|
return s.Run(ctx, connectedSignal, reconnectCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnelLoop(ctx context.Context,
|
func ServeTunnelLoop(ctx context.Context,
|
||||||
|
@ -185,6 +186,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
u uuid.UUID,
|
u uuid.UUID,
|
||||||
bufferPool *buffer.Pool,
|
bufferPool *buffer.Pool,
|
||||||
|
reconnectCh chan os.Signal,
|
||||||
) error {
|
) error {
|
||||||
connectionLogger := config.Logger.WithField("connectionID", connectionID)
|
connectionLogger := config.Logger.WithField("connectionID", connectionID)
|
||||||
config.Metrics.incrementHaConnections()
|
config.Metrics.incrementHaConnections()
|
||||||
|
@ -209,6 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
&backoff,
|
&backoff,
|
||||||
u,
|
u,
|
||||||
bufferPool,
|
bufferPool,
|
||||||
|
reconnectCh,
|
||||||
)
|
)
|
||||||
if recoverable {
|
if recoverable {
|
||||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
@ -232,6 +235,7 @@ func ServeTunnel(
|
||||||
backoff *BackoffHandler,
|
backoff *BackoffHandler,
|
||||||
u uuid.UUID,
|
u uuid.UUID,
|
||||||
bufferPool *buffer.Pool,
|
bufferPool *buffer.Pool,
|
||||||
|
reconnectCh chan os.Signal,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -318,6 +322,11 @@ func ServeTunnel(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
<-reconnectCh
|
||||||
|
return fmt.Errorf("received disconnect signal")
|
||||||
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
// All routines should stop when muxer finish serving. When muxer is shutdown
|
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||||
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||||
|
|
Loading…
Reference in New Issue