From d503aeaf7702e4794c2766d6e6eaf2e3a91cc8f8 Mon Sep 17 00:00:00 2001 From: Igor Postelnik Date: Wed, 20 Jan 2021 13:41:09 -0600 Subject: [PATCH] TUN-3118: Changed graceful shutdown to immediately unregister tunnel from the edge, keep the connection open until the edge drops it or grace period expires --- cmd/cloudflared/tunnel/cmd.go | 15 +---- cmd/cloudflared/tunnel/signal.go | 2 +- cmd/cloudflared/tunnel/subcommand_context.go | 2 - connection/h2mux.go | 46 +++++++++----- connection/h2mux_test.go | 62 +++++++++++++++++- connection/http2.go | 50 +++++++++++---- connection/http2_test.go | 58 +++++++++++++++++ connection/rpc.go | 25 +++++++- origin/supervisor.go | 66 ++++++++++++++------ origin/tunnel.go | 49 +++++++++++---- 10 files changed, 295 insertions(+), 80 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 66f576a3..8e97a60b 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -32,7 +32,6 @@ import ( "github.com/coreos/go-systemd/daemon" "github.com/facebookgo/grace/gracenet" "github.com/getsentry/raven-go" - "github.com/google/uuid" "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -199,7 +198,7 @@ func runAdhocNamedTunnel(sc *subcommandContext, name string) error { // runClassicTunnel creates a "classic" non-named tunnel func runClassicTunnel(sc *subcommandContext) error { - return StartServer(sc.c, version, shutdownC, graceShutdownC, nil, sc.log, sc.isUIEnabled) + return StartServer(sc.c, version, nil, sc.log, sc.isUIEnabled) } func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) { @@ -215,8 +214,6 @@ func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) { func StartServer( c *cli.Context, version string, - shutdownC, - graceShutdownC chan struct{}, namedTunnel *connection.NamedTunnelConfig, log *zerolog.Logger, isUIEnabled bool, @@ -287,12 +284,6 @@ func StartServer( go writePidFile(connectedSignal, c.String("pidfile"), log) } - cloudflaredID, err := uuid.NewRandom() - if err != nil { - log.Err(err).Msg("Cannot generate cloudflared ID") - return err - } - ctx, cancel := context.WithCancel(context.Background()) go func() { <-shutdownC @@ -363,7 +354,7 @@ func StartServer( wg.Add(1) go func() { defer wg.Done() - errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh) + errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC) }() if isUIEnabled { @@ -1040,7 +1031,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) continue } } - log.Info().Msgf("Sending reconnect signal %+v", reconnect) + log.Info().Msgf("Sending %+v", reconnect) reconnectCh <- reconnect default: log.Info().Str(LogFieldCommand, command).Msg("Unknown command") diff --git a/cmd/cloudflared/tunnel/signal.go b/cmd/cloudflared/tunnel/signal.go index b9d49f3d..cb71eec5 100644 --- a/cmd/cloudflared/tunnel/signal.go +++ b/cmd/cloudflared/tunnel/signal.go @@ -51,7 +51,7 @@ func waitForSignalWithGraceShutdown(errC chan error, select { case err := <-errC: - logger.Info().Msgf("Initiating graceful shutdown due to %v ...", err) + logger.Info().Msgf("Initiating shutdown due to %v ...", err) close(graceShutdownC) close(shutdownC) return err diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 5439470a..056d5644 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -274,8 +274,6 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { return StartServer( sc.c, version, - shutdownC, - graceShutdownC, &connection.NamedTunnelConfig{Credentials: credentials}, sc.log, sc.isUIEnabled, diff --git a/connection/h2mux.go b/connection/h2mux.go index b627c305..aaecf6a2 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -2,6 +2,7 @@ package connection import ( "context" + "io" "net" "net/http" "time" @@ -28,7 +29,12 @@ type h2muxConnection struct { connIndexStr string connIndex uint8 - observer *Observer + observer *Observer + gracefulShutdownC chan struct{} + stoppedGracefully bool + + // newRPCClientFunc allows us to mock RPCs during testing + newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient } type MuxerConfig struct { @@ -57,13 +63,16 @@ func NewH2muxConnection( edgeConn net.Conn, connIndex uint8, observer *Observer, + gracefulShutdownC chan struct{}, ) (*h2muxConnection, error, bool) { h := &h2muxConnection{ - config: config, - muxerConfig: muxerConfig, - connIndexStr: uint8ToString(connIndex), - connIndex: connIndex, - observer: observer, + config: config, + muxerConfig: muxerConfig, + connIndexStr: uint8ToString(connIndex), + connIndex: connIndex, + observer: observer, + gracefulShutdownC: gracefulShutdownC, + newRPCClientFunc: newRegistrationRPCClient, } // Establish a muxed connection with the edge @@ -77,21 +86,14 @@ func NewH2muxConnection( return h, nil, false } -func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { +func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { return h.serveMuxer(serveCtx) }) errGroup.Go(func() error { - stream, err := h.newRPCStream(serveCtx, register) - if err != nil { - return err - } - rpcClient := newRegistrationRPCClient(ctx, stream, h.observer.log) - defer rpcClient.Close() - - if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { + if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil { return err } connectedFuse.Connected() @@ -137,6 +139,10 @@ func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel return errGroup.Wait() } +func (h *h2muxConnection) StoppedGracefully() bool { + return h.stoppedGracefully +} + func (h *h2muxConnection) serveMuxer(ctx context.Context) error { // All routines should stop when muxer finish serving. When muxer is shutdown // gracefully, it doesn't return an error, so we need to return errMuxerShutdown @@ -152,13 +158,21 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq) for { select { + case <-h.gracefulShutdownC: + if connectedFuse.IsConnected() { + h.unregister(isNamedTunnel) + } + h.stoppedGracefully = true + h.gracefulShutdownC = nil + case <-ctx.Done(): // UnregisterTunnel blocks until the RPC call returns - if connectedFuse.IsConnected() { + if !h.stoppedGracefully && connectedFuse.IsConnected() { h.unregister(isNamedTunnel) } h.muxer.Shutdown() return + case <-updateMetricsTickC: h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics()) } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index 36f8cb63..13314ba9 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -11,10 +11,12 @@ import ( "testing" "time" - "github.com/cloudflare/cloudflared/h2mux" "github.com/gobwas/ws/wsutil" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/h2mux" ) var ( @@ -32,13 +34,20 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) { go func() { edgeMuxConfig := h2mux.MuxerConfig{ Log: testObserver.log, + Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error { + // we only expect RPC traffic in client->edge direction, provide minimal support for mocking + require.True(t, stream.IsRPCStream()) + return stream.WriteHeaders([]h2mux.Header{ + {Name: ":status", Value: "200"}, + }) + }), } edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams) require.NoError(t, err) edgeMuxChan <- edgeMux }() var connIndex = uint8(0) - h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver) + h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil) require.NoError(t, err) return h2muxConn, <-edgeMuxChan } @@ -168,6 +177,55 @@ func TestServeStreamWS(t *testing.T) { wg.Wait() } +func TestGracefulShutdownH2Mux(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h2muxConn, edgeMux := newH2MuxConnection(t) + + shutdownC := make(chan struct{}) + unregisteredC := make(chan struct{}) + h2muxConn.gracefulShutdownC = shutdownC + h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient { + return &mockNamedTunnelRPCClient{ + registered: nil, + unregistered: unregisteredC, + } + } + + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + _ = edgeMux.Serve(ctx) + }() + go func() { + defer wg.Done() + _ = h2muxConn.serveMuxer(ctx) + }() + + go func() { + defer wg.Done() + h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true) + }() + + time.Sleep(100 * time.Millisecond) + close(shutdownC) + + select { + case <-unregisteredC: + break // ok + case <-time.Tick(time.Second): + assert.Fail(t, "timed out waiting for control loop to unregister") + } + + cancel() + wg.Wait() + + assert.True(t, h2muxConn.stoppedGracefully) + assert.Nil(t, h2muxConn.gracefulShutdownC) +} + func hasHeader(stream *h2mux.MuxedStream, name, val string) bool { for _, header := range stream.Headers { if header.Name == name && header.Value == val { diff --git a/connection/http2.go b/connection/http2.go index f8eb0762..bb54a2b8 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -2,6 +2,7 @@ package connection import ( "context" + "fmt" "io" "math" "net" @@ -22,6 +23,8 @@ const ( controlStreamUpgrade = "control-stream" ) +var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed") + type http2Connection struct { conn net.Conn server *http2.Server @@ -33,8 +36,10 @@ type http2Connection struct { connIndex uint8 wg *sync.WaitGroup // newRPCClientFunc allows us to mock RPCs during testing - newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient - connectedFuse ConnectedFuse + newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient + connectedFuse ConnectedFuse + gracefulShutdownC chan struct{} + stoppedGracefully bool } func NewHTTP2Connection( @@ -45,25 +50,27 @@ func NewHTTP2Connection( observer *Observer, connIndex uint8, connectedFuse ConnectedFuse, + gracefulShutdownC chan struct{}, ) *http2Connection { return &http2Connection{ conn: conn, server: &http2.Server{ MaxConcurrentStreams: math.MaxUint32, }, - config: config, - namedTunnel: namedTunnelConfig, - connOptions: connOptions, - observer: observer, - connIndexStr: uint8ToString(connIndex), - connIndex: connIndex, - wg: &sync.WaitGroup{}, - newRPCClientFunc: newRegistrationRPCClient, - connectedFuse: connectedFuse, + config: config, + namedTunnel: namedTunnelConfig, + connOptions: connOptions, + observer: observer, + connIndexStr: uint8ToString(connIndex), + connIndex: connIndex, + wg: &sync.WaitGroup{}, + newRPCClientFunc: newRegistrationRPCClient, + connectedFuse: connectedFuse, + gracefulShutdownC: gracefulShutdownC, } } -func (c *http2Connection) Serve(ctx context.Context) { +func (c *http2Connection) Serve(ctx context.Context) error { go func() { <-ctx.Done() c.close() @@ -72,6 +79,11 @@ func (c *http2Connection) Serve(ctx context.Context) { Context: ctx, Handler: c, }) + + if !c.stoppedGracefully { + return errEdgeConnectionClosed + } + return nil } func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -106,6 +118,10 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func (c *http2Connection) StoppedGracefully() bool { + return c.stoppedGracefully +} + func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log) defer rpcClient.Close() @@ -115,8 +131,16 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht } c.connectedFuse.Connected() - <-ctx.Done() + // wait for connection termination or start of graceful shutdown + select { + case <-ctx.Done(): + break + case <-c.gracefulShutdownC: + c.stoppedGracefully = true + } + rpcClient.GracefulShutdown(ctx, c.config.GracePeriod) + c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection") return nil } diff --git a/connection/http2_test.go b/connection/http2_test.go index 2d3c30cb..96b705cf 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -36,6 +38,7 @@ func newTestHTTP2Connection() (*http2Connection, net.Conn) { testObserver, connIndex, mockConnectedFuse{}, + nil, ), edgeConn } @@ -241,10 +244,64 @@ func TestServeControlStream(t *testing.T) { <-rpcClientFactory.registered cancel() <-rpcClientFactory.unregistered + assert.False(t, http2Conn.stoppedGracefully) wg.Wait() } +func TestGracefulShutdownHTTP2(t *testing.T) { + http2Conn, edgeConn := newTestHTTP2Connection() + + rpcClientFactory := mockRPCClientFactory{ + registered: make(chan struct{}), + unregistered: make(chan struct{}), + } + http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient + http2Conn.gracefulShutdownC = make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.Serve(ctx) + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) + require.NoError(t, err) + req.Header.Set(internalUpgradeHeader, controlStreamUpgrade) + + edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) + require.NoError(t, err) + + wg.Add(1) + go func() { + defer wg.Done() + _, _ = edgeHTTP2Conn.RoundTrip(req) + }() + + select { + case <-rpcClientFactory.registered: + break //ok + case <-time.Tick(time.Second): + t.Fatal("timeout out waiting for registration") + } + + // signal graceful shutdown + close(http2Conn.gracefulShutdownC) + + select { + case <-rpcClientFactory.unregistered: + break //ok + case <-time.Tick(time.Second): + t.Fatal("timeout out waiting for unregistered signal") + } + assert.True(t, http2Conn.stoppedGracefully) + + cancel() + wg.Wait() +} + func benchmarkServeHTTP(b *testing.B, test testRequest) { http2Conn, edgeConn := newTestHTTP2Connection() @@ -281,6 +338,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) { cancel() wg.Wait() } + func BenchmarkServeHTTPSimple(b *testing.B) { test := testRequest{ name: "ok", diff --git a/connection/rpc.go b/connection/rpc.go index 514d69db..74a78339 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -272,17 +272,36 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe return nil } +func (h *h2muxConnection) registerNamedTunnel( + ctx context.Context, + namedTunnel *NamedTunnelConfig, + connOptions *tunnelpogs.ConnectionOptions, +) error { + stream, err := h.newRPCStream(ctx, register) + if err != nil { + return err + } + rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log) + defer rpcClient.Close() + + if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { + return err + } + return nil +} + func (h *h2muxConnection) unregister(isNamedTunnel bool) { unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod) defer cancel() - stream, err := h.newRPCStream(unregisterCtx, register) + stream, err := h.newRPCStream(unregisterCtx, unregister) if err != nil { return } + defer stream.Close() if isNamedTunnel { - rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer.log) + rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log) defer rpcClient.Close() rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod) @@ -293,4 +312,6 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) { // gracePeriod is encoded in int64 using capnproto _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds()) } + + h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection") } diff --git a/origin/supervisor.go b/origin/supervisor.go index 5fa431c8..c00ba342 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -3,6 +3,7 @@ package origin import ( "context" "errors" + "fmt" "net" "time" @@ -54,6 +55,9 @@ type Supervisor struct { reconnectCredentialManager *reconnectCredentialManager useReconnectToken bool + + reconnectCh chan ReconnectSignal + gracefulShutdownC chan struct{} } type tunnelError struct { @@ -62,11 +66,13 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) { - var ( - edgeIPs *edgediscovery.Edge - err error - ) +func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) { + cloudflaredUUID, err := uuid.NewRandom() + if err != nil { + return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) + } + + var edgeIPs *edgediscovery.Edge if len(config.EdgeAddrs) > 0 { edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs) } else { @@ -90,11 +96,16 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor log: config.Log, reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections), useReconnectToken: useReconnectToken, + reconnectCh: reconnectCh, + gracefulShutdownC: gracefulShutdownC, }, nil } -func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { - if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { +func (s *Supervisor) Run( + ctx context.Context, + connectedSignal *signal.Signal, +) error { + if err := s.initialize(ctx, connectedSignal); err != nil { return err } var tunnelsWaiting []int @@ -131,7 +142,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re case tunnelError := <-s.tunnelErrors: tunnelsActive-- if tunnelError.err != nil { - s.log.Err(tunnelError.err).Msg("supervisor: Tunnel disconnected") + s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated") tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) s.waitForNextTunnel(tunnelError.index) @@ -139,14 +150,16 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re backoffTimer = backoff.BackoffTimer() } - // Previously we'd mark the edge address as bad here, but now we'll just silently use - // another. + // Previously we'd mark the edge address as bad here, but now we'll just silently use another. + } else if tunnelsActive == 0 { + // all connected tunnels exited gracefully, no more work to do + return nil } // Backoff was set and its timer expired case <-backoffTimer: backoffTimer = nil for _, index := range tunnelsWaiting { - go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh) + go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) } tunnelsActive += len(tunnelsWaiting) tunnelsWaiting = nil @@ -171,14 +184,17 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re } // Returns nil if initialization succeeded, else the initialization error. -func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { - availableAddrs := int(s.edgeIPs.AvailableAddrs()) +func (s *Supervisor) initialize( + ctx context.Context, + connectedSignal *signal.Signal, +) error { + availableAddrs := s.edgeIPs.AvailableAddrs() if s.config.HAConnections > availableAddrs { s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) s.config.HAConnections = availableAddrs } - go s.startFirstTunnel(ctx, connectedSignal, reconnectCh) + go s.startFirstTunnel(ctx, connectedSignal) select { case <-ctx.Done(): <-s.tunnelErrors @@ -190,7 +206,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { ch := signal.New(make(chan struct{})) - go s.startTunnel(ctx, i, ch, reconnectCh) + go s.startTunnel(ctx, i, ch) time.Sleep(registrationInterval) } return nil @@ -198,7 +214,10 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed -func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) { +func (s *Supervisor) startFirstTunnel( + ctx context.Context, + connectedSignal *signal.Signal, +) { var ( addr *net.TCPAddr err error @@ -221,7 +240,8 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign firstConnIndex, connectedSignal, s.cloudflaredUUID, - reconnectCh, + s.reconnectCh, + s.gracefulShutdownC, ) // If the first tunnel disconnects, keep restarting it. edgeErrors := 0 @@ -253,14 +273,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign firstConnIndex, connectedSignal, s.cloudflaredUUID, - reconnectCh, + s.reconnectCh, + s.gracefulShutdownC, ) } } // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. -func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) { +func (s *Supervisor) startTunnel( + ctx context.Context, + index int, + connectedSignal *signal.Signal, +) { var ( addr *net.TCPAddr err error @@ -281,7 +306,8 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal uint8(index), connectedSignal, s.cloudflaredUUID, - reconnectCh, + s.reconnectCh, + s.gracefulShutdownC, ) } diff --git a/origin/tunnel.go b/origin/tunnel.go index f681dac2..f972d967 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -107,12 +107,18 @@ func (c *TunnelConfig) SupportedFeatures() []string { return features } -func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { - s, err := NewSupervisor(config, cloudflaredID) +func StartTunnelDaemon( + ctx context.Context, + config *TunnelConfig, + connectedSignal *signal.Signal, + reconnectCh chan ReconnectSignal, + graceShutdownC chan struct{}, +) error { + s, err := NewSupervisor(config, reconnectCh, graceShutdownC) if err != nil { return err } - return s.Run(ctx, connectedSignal, reconnectCh) + return s.Run(ctx, connectedSignal) } func ServeTunnelLoop( @@ -124,6 +130,7 @@ func ServeTunnelLoop( connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, + gracefulShutdownC chan struct{}, ) error { haConnections.Inc() defer haConnections.Dec() @@ -158,6 +165,7 @@ func ServeTunnelLoop( cloudflaredUUID, reconnectCh, protocallFallback.protocol, + gracefulShutdownC, ) if !recoverable { return err @@ -242,6 +250,7 @@ func ServeTunnel( cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, protocol connection.Protocol, + gracefulShutdownC chan struct{}, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -268,7 +277,17 @@ func ServeTunnel( } if protocol == connection.HTTP2 { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) - return ServeHTTP2(ctx, log, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh) + return ServeHTTP2( + ctx, + log, + config, + edgeConn, + connOptions, + connIndex, + connectedFuse, + reconnectCh, + gracefulShutdownC, + ) } return ServeH2mux( ctx, @@ -280,6 +299,7 @@ func ServeTunnel( connectedFuse, cloudflaredUUID, reconnectCh, + gracefulShutdownC, ) } @@ -293,6 +313,7 @@ func ServeH2mux( connectedFuse *connectedFuse, cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, + gracefulShutdownC chan struct{}, ) (err error, recoverable bool) { config.Log.Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors @@ -302,6 +323,7 @@ func ServeH2mux( edgeConn, connIndex, config.Observer, + gracefulShutdownC, ) if err != nil { return err, recoverable @@ -312,13 +334,13 @@ func ServeH2mux( errGroup.Go(func() (err error) { if config.NamedTunnel != nil { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) - return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) + return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) } registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) - errGroup.Go(listenReconnect(serveCtx, reconnectCh)) + errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)) err = errGroup.Wait() if err != nil { @@ -367,9 +389,10 @@ func ServeHTTP2( connIndex uint8, connectedFuse connection.ConnectedFuse, reconnectCh chan ReconnectSignal, + gracefulShutdownC chan struct{}, ) (err error, recoverable bool) { log.Debug().Msgf("Connecting via http2") - server := connection.NewHTTP2Connection( + h2conn := connection.NewHTTP2Connection( tlsServerConn, config.ConnectionConfig, config.NamedTunnel, @@ -377,15 +400,15 @@ func ServeHTTP2( config.Observer, connIndex, connectedFuse, + gracefulShutdownC, ) errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - server.Serve(serveCtx) - return fmt.Errorf("connection with edge closed") + return h2conn.Serve(serveCtx) }) - errGroup.Go(listenReconnect(serveCtx, reconnectCh)) + errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)) err = errGroup.Wait() if err != nil { @@ -394,11 +417,13 @@ func ServeHTTP2( return nil, false } -func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error { +func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) func() error { return func() error { select { case reconnect := <-reconnectCh: - return &reconnect + return reconnect + case <-gracefulShutdownCh: + return nil case <-ctx.Done(): return nil }