TUN-3118: Changed graceful shutdown to immediately unregister tunnel from the edge, keep the connection open until the edge drops it or grace period expires

This commit is contained in:
Igor Postelnik 2021-01-20 13:41:09 -06:00
parent db0562c7b8
commit d503aeaf77
10 changed files with 295 additions and 80 deletions

View File

@ -32,7 +32,6 @@ import (
"github.com/coreos/go-systemd/daemon" "github.com/coreos/go-systemd/daemon"
"github.com/facebookgo/grace/gracenet" "github.com/facebookgo/grace/gracenet"
"github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
"github.com/google/uuid"
"github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -199,7 +198,7 @@ func runAdhocNamedTunnel(sc *subcommandContext, name string) error {
// runClassicTunnel creates a "classic" non-named tunnel // runClassicTunnel creates a "classic" non-named tunnel
func runClassicTunnel(sc *subcommandContext) error { func runClassicTunnel(sc *subcommandContext) error {
return StartServer(sc.c, version, shutdownC, graceShutdownC, nil, sc.log, sc.isUIEnabled) return StartServer(sc.c, version, nil, sc.log, sc.isUIEnabled)
} }
func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) { func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) {
@ -215,8 +214,6 @@ func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) {
func StartServer( func StartServer(
c *cli.Context, c *cli.Context,
version string, version string,
shutdownC,
graceShutdownC chan struct{},
namedTunnel *connection.NamedTunnelConfig, namedTunnel *connection.NamedTunnelConfig,
log *zerolog.Logger, log *zerolog.Logger,
isUIEnabled bool, isUIEnabled bool,
@ -287,12 +284,6 @@ func StartServer(
go writePidFile(connectedSignal, c.String("pidfile"), log) go writePidFile(connectedSignal, c.String("pidfile"), log)
} }
cloudflaredID, err := uuid.NewRandom()
if err != nil {
log.Err(err).Msg("Cannot generate cloudflared ID")
return err
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { go func() {
<-shutdownC <-shutdownC
@ -363,7 +354,7 @@ func StartServer(
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh) errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC)
}() }()
if isUIEnabled { if isUIEnabled {
@ -1040,7 +1031,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger)
continue continue
} }
} }
log.Info().Msgf("Sending reconnect signal %+v", reconnect) log.Info().Msgf("Sending %+v", reconnect)
reconnectCh <- reconnect reconnectCh <- reconnect
default: default:
log.Info().Str(LogFieldCommand, command).Msg("Unknown command") log.Info().Str(LogFieldCommand, command).Msg("Unknown command")

View File

@ -51,7 +51,7 @@ func waitForSignalWithGraceShutdown(errC chan error,
select { select {
case err := <-errC: case err := <-errC:
logger.Info().Msgf("Initiating graceful shutdown due to %v ...", err) logger.Info().Msgf("Initiating shutdown due to %v ...", err)
close(graceShutdownC) close(graceShutdownC)
close(shutdownC) close(shutdownC)
return err return err

View File

@ -274,8 +274,6 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return StartServer( return StartServer(
sc.c, sc.c,
version, version,
shutdownC,
graceShutdownC,
&connection.NamedTunnelConfig{Credentials: credentials}, &connection.NamedTunnelConfig{Credentials: credentials},
sc.log, sc.log,
sc.isUIEnabled, sc.isUIEnabled,

View File

@ -2,6 +2,7 @@ package connection
import ( import (
"context" "context"
"io"
"net" "net"
"net/http" "net/http"
"time" "time"
@ -29,6 +30,11 @@ type h2muxConnection struct {
connIndex uint8 connIndex uint8
observer *Observer observer *Observer
gracefulShutdownC chan struct{}
stoppedGracefully bool
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
} }
type MuxerConfig struct { type MuxerConfig struct {
@ -57,6 +63,7 @@ func NewH2muxConnection(
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
gracefulShutdownC chan struct{},
) (*h2muxConnection, error, bool) { ) (*h2muxConnection, error, bool) {
h := &h2muxConnection{ h := &h2muxConnection{
config: config, config: config,
@ -64,6 +71,8 @@ func NewH2muxConnection(
connIndexStr: uint8ToString(connIndex), connIndexStr: uint8ToString(connIndex),
connIndex: connIndex, connIndex: connIndex,
observer: observer, observer: observer,
gracefulShutdownC: gracefulShutdownC,
newRPCClientFunc: newRegistrationRPCClient,
} }
// Establish a muxed connection with the edge // Establish a muxed connection with the edge
@ -77,21 +86,14 @@ func NewH2muxConnection(
return h, nil, false return h, nil, false
} }
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
return h.serveMuxer(serveCtx) return h.serveMuxer(serveCtx)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
stream, err := h.newRPCStream(serveCtx, register) if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer.log)
defer rpcClient.Close()
if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err return err
} }
connectedFuse.Connected() connectedFuse.Connected()
@ -137,6 +139,10 @@ func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel
return errGroup.Wait() return errGroup.Wait()
} }
func (h *h2muxConnection) StoppedGracefully() bool {
return h.stoppedGracefully
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error { func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown // All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown // gracefully, it doesn't return an error, so we need to return errMuxerShutdown
@ -152,13 +158,21 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq) updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
for { for {
select { select {
case <-h.gracefulShutdownC:
if connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.stoppedGracefully = true
h.gracefulShutdownC = nil
case <-ctx.Done(): case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns // UnregisterTunnel blocks until the RPC call returns
if connectedFuse.IsConnected() { if !h.stoppedGracefully && connectedFuse.IsConnected() {
h.unregister(isNamedTunnel) h.unregister(isNamedTunnel)
} }
h.muxer.Shutdown() h.muxer.Shutdown()
return return
case <-updateMetricsTickC: case <-updateMetricsTickC:
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics()) h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
} }

View File

@ -11,10 +11,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/h2mux"
) )
var ( var (
@ -32,13 +34,20 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
go func() { go func() {
edgeMuxConfig := h2mux.MuxerConfig{ edgeMuxConfig := h2mux.MuxerConfig{
Log: testObserver.log, Log: testObserver.log,
Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error {
// we only expect RPC traffic in client->edge direction, provide minimal support for mocking
require.True(t, stream.IsRPCStream())
return stream.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "200"},
})
}),
} }
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams) edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
require.NoError(t, err) require.NoError(t, err)
edgeMuxChan <- edgeMux edgeMuxChan <- edgeMux
}() }()
var connIndex = uint8(0) var connIndex = uint8(0)
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver) h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil)
require.NoError(t, err) require.NoError(t, err)
return h2muxConn, <-edgeMuxChan return h2muxConn, <-edgeMuxChan
} }
@ -168,6 +177,55 @@ func TestServeStreamWS(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestGracefulShutdownH2Mux(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h2muxConn, edgeMux := newH2MuxConnection(t)
shutdownC := make(chan struct{})
unregisteredC := make(chan struct{})
h2muxConn.gracefulShutdownC = shutdownC
h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient {
return &mockNamedTunnelRPCClient{
registered: nil,
unregistered: unregisteredC,
}
}
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
_ = edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
_ = h2muxConn.serveMuxer(ctx)
}()
go func() {
defer wg.Done()
h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true)
}()
time.Sleep(100 * time.Millisecond)
close(shutdownC)
select {
case <-unregisteredC:
break // ok
case <-time.Tick(time.Second):
assert.Fail(t, "timed out waiting for control loop to unregister")
}
cancel()
wg.Wait()
assert.True(t, h2muxConn.stoppedGracefully)
assert.Nil(t, h2muxConn.gracefulShutdownC)
}
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool { func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
for _, header := range stream.Headers { for _, header := range stream.Headers {
if header.Name == name && header.Value == val { if header.Name == name && header.Value == val {

View File

@ -2,6 +2,7 @@ package connection
import ( import (
"context" "context"
"fmt"
"io" "io"
"math" "math"
"net" "net"
@ -22,6 +23,8 @@ const (
controlStreamUpgrade = "control-stream" controlStreamUpgrade = "control-stream"
) )
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
type http2Connection struct { type http2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
@ -35,6 +38,8 @@ type http2Connection struct {
// newRPCClientFunc allows us to mock RPCs during testing // newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
gracefulShutdownC chan struct{}
stoppedGracefully bool
} }
func NewHTTP2Connection( func NewHTTP2Connection(
@ -45,6 +50,7 @@ func NewHTTP2Connection(
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
connectedFuse ConnectedFuse, connectedFuse ConnectedFuse,
gracefulShutdownC chan struct{},
) *http2Connection { ) *http2Connection {
return &http2Connection{ return &http2Connection{
conn: conn, conn: conn,
@ -60,10 +66,11 @@ func NewHTTP2Connection(
wg: &sync.WaitGroup{}, wg: &sync.WaitGroup{},
newRPCClientFunc: newRegistrationRPCClient, newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse, connectedFuse: connectedFuse,
gracefulShutdownC: gracefulShutdownC,
} }
} }
func (c *http2Connection) Serve(ctx context.Context) { func (c *http2Connection) Serve(ctx context.Context) error {
go func() { go func() {
<-ctx.Done() <-ctx.Done()
c.close() c.close()
@ -72,6 +79,11 @@ func (c *http2Connection) Serve(ctx context.Context) {
Context: ctx, Context: ctx,
Handler: c, Handler: c,
}) })
if !c.stoppedGracefully {
return errEdgeConnectionClosed
}
return nil
} }
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -106,6 +118,10 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (c *http2Connection) StoppedGracefully() bool {
return c.stoppedGracefully
}
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log) rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
defer rpcClient.Close() defer rpcClient.Close()
@ -115,8 +131,16 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht
} }
c.connectedFuse.Connected() c.connectedFuse.Connected()
<-ctx.Done() // wait for connection termination or start of graceful shutdown
select {
case <-ctx.Done():
break
case <-c.gracefulShutdownC:
c.stoppedGracefully = true
}
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod) rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
return nil return nil
} }

View File

@ -12,6 +12,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -36,6 +38,7 @@ func newTestHTTP2Connection() (*http2Connection, net.Conn) {
testObserver, testObserver,
connIndex, connIndex,
mockConnectedFuse{}, mockConnectedFuse{},
nil,
), edgeConn ), edgeConn
} }
@ -241,10 +244,64 @@ func TestServeControlStream(t *testing.T) {
<-rpcClientFactory.registered <-rpcClientFactory.registered
cancel() cancel()
<-rpcClientFactory.unregistered <-rpcClientFactory.unregistered
assert.False(t, http2Conn.stoppedGracefully)
wg.Wait() wg.Wait()
} }
func TestGracefulShutdownHTTP2(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
http2Conn.gracefulShutdownC = make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
require.NoError(t, err)
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
select {
case <-rpcClientFactory.registered:
break //ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for registration")
}
// signal graceful shutdown
close(http2Conn.gracefulShutdownC)
select {
case <-rpcClientFactory.unregistered:
break //ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for unregistered signal")
}
assert.True(t, http2Conn.stoppedGracefully)
cancel()
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) { func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection() http2Conn, edgeConn := newTestHTTP2Connection()
@ -281,6 +338,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
cancel() cancel()
wg.Wait() wg.Wait()
} }
func BenchmarkServeHTTPSimple(b *testing.B) { func BenchmarkServeHTTPSimple(b *testing.B) {
test := testRequest{ test := testRequest{
name: "ok", name: "ok",

View File

@ -272,17 +272,36 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe
return nil return nil
} }
func (h *h2muxConnection) registerNamedTunnel(
ctx context.Context,
namedTunnel *NamedTunnelConfig,
connOptions *tunnelpogs.ConnectionOptions,
) error {
stream, err := h.newRPCStream(ctx, register)
if err != nil {
return err
}
rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
defer rpcClient.Close()
if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err
}
return nil
}
func (h *h2muxConnection) unregister(isNamedTunnel bool) { func (h *h2muxConnection) unregister(isNamedTunnel bool) {
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod) unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
defer cancel() defer cancel()
stream, err := h.newRPCStream(unregisterCtx, register) stream, err := h.newRPCStream(unregisterCtx, unregister)
if err != nil { if err != nil {
return return
} }
defer stream.Close()
if isNamedTunnel { if isNamedTunnel {
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer.log) rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
defer rpcClient.Close() defer rpcClient.Close()
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod) rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
@ -293,4 +312,6 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
// gracePeriod is encoded in int64 using capnproto // gracePeriod is encoded in int64 using capnproto
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds()) _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
} }
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")
} }

View File

@ -3,6 +3,7 @@ package origin
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"time" "time"
@ -54,6 +55,9 @@ type Supervisor struct {
reconnectCredentialManager *reconnectCredentialManager reconnectCredentialManager *reconnectCredentialManager
useReconnectToken bool useReconnectToken bool
reconnectCh chan ReconnectSignal
gracefulShutdownC chan struct{}
} }
type tunnelError struct { type tunnelError struct {
@ -62,11 +66,13 @@ type tunnelError struct {
err error err error
} }
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) {
var ( cloudflaredUUID, err := uuid.NewRandom()
edgeIPs *edgediscovery.Edge if err != nil {
err error return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
) }
var edgeIPs *edgediscovery.Edge
if len(config.EdgeAddrs) > 0 { if len(config.EdgeAddrs) > 0 {
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs) edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
} else { } else {
@ -90,11 +96,16 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
log: config.Log, log: config.Log,
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections), reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
useReconnectToken: useReconnectToken, useReconnectToken: useReconnectToken,
reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC,
}, nil }, nil
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { func (s *Supervisor) Run(
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { ctx context.Context,
connectedSignal *signal.Signal,
) error {
if err := s.initialize(ctx, connectedSignal); err != nil {
return err return err
} }
var tunnelsWaiting []int var tunnelsWaiting []int
@ -131,7 +142,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case tunnelError := <-s.tunnelErrors: case tunnelError := <-s.tunnelErrors:
tunnelsActive-- tunnelsActive--
if tunnelError.err != nil { if tunnelError.err != nil {
s.log.Err(tunnelError.err).Msg("supervisor: Tunnel disconnected") s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index) s.waitForNextTunnel(tunnelError.index)
@ -139,14 +150,16 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
backoffTimer = backoff.BackoffTimer() backoffTimer = backoff.BackoffTimer()
} }
// Previously we'd mark the edge address as bad here, but now we'll just silently use // Previously we'd mark the edge address as bad here, but now we'll just silently use another.
// another. } else if tunnelsActive == 0 {
// all connected tunnels exited gracefully, no more work to do
return nil
} }
// Backoff was set and its timer expired // Backoff was set and its timer expired
case <-backoffTimer: case <-backoffTimer:
backoffTimer = nil backoffTimer = nil
for _, index := range tunnelsWaiting { for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh) go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
} }
tunnelsActive += len(tunnelsWaiting) tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil tunnelsWaiting = nil
@ -171,14 +184,17 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
} }
// Returns nil if initialization succeeded, else the initialization error. // Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { func (s *Supervisor) initialize(
availableAddrs := int(s.edgeIPs.AvailableAddrs()) ctx context.Context,
connectedSignal *signal.Signal,
) error {
availableAddrs := s.edgeIPs.AvailableAddrs()
if s.config.HAConnections > availableAddrs { if s.config.HAConnections > availableAddrs {
s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs) s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs s.config.HAConnections = availableAddrs
} }
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh) go s.startFirstTunnel(ctx, connectedSignal)
select { select {
case <-ctx.Done(): case <-ctx.Done():
<-s.tunnelErrors <-s.tunnelErrors
@ -190,7 +206,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest // At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ { for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{})) ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch, reconnectCh) go s.startTunnel(ctx, i, ch)
time.Sleep(registrationInterval) time.Sleep(registrationInterval)
} }
return nil return nil
@ -198,7 +214,10 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on // startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) { func (s *Supervisor) startFirstTunnel(
ctx context.Context,
connectedSignal *signal.Signal,
) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error
@ -221,7 +240,8 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
firstConnIndex, firstConnIndex,
connectedSignal, connectedSignal,
s.cloudflaredUUID, s.cloudflaredUUID,
reconnectCh, s.reconnectCh,
s.gracefulShutdownC,
) )
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
@ -253,14 +273,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
firstConnIndex, firstConnIndex,
connectedSignal, connectedSignal,
s.cloudflaredUUID, s.cloudflaredUUID,
reconnectCh, s.reconnectCh,
s.gracefulShutdownC,
) )
} }
} }
// startTunnel starts a new tunnel connection. The resulting error will be sent on // startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors. // s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) { func (s *Supervisor) startTunnel(
ctx context.Context,
index int,
connectedSignal *signal.Signal,
) {
var ( var (
addr *net.TCPAddr addr *net.TCPAddr
err error err error
@ -281,7 +306,8 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
uint8(index), uint8(index),
connectedSignal, connectedSignal,
s.cloudflaredUUID, s.cloudflaredUUID,
reconnectCh, s.reconnectCh,
s.gracefulShutdownC,
) )
} }

View File

@ -107,12 +107,18 @@ func (c *TunnelConfig) SupportedFeatures() []string {
return features return features
} }
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { func StartTunnelDaemon(
s, err := NewSupervisor(config, cloudflaredID) ctx context.Context,
config *TunnelConfig,
connectedSignal *signal.Signal,
reconnectCh chan ReconnectSignal,
graceShutdownC chan struct{},
) error {
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
if err != nil { if err != nil {
return err return err
} }
return s.Run(ctx, connectedSignal, reconnectCh) return s.Run(ctx, connectedSignal)
} }
func ServeTunnelLoop( func ServeTunnelLoop(
@ -124,6 +130,7 @@ func ServeTunnelLoop(
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{},
) error { ) error {
haConnections.Inc() haConnections.Inc()
defer haConnections.Dec() defer haConnections.Dec()
@ -158,6 +165,7 @@ func ServeTunnelLoop(
cloudflaredUUID, cloudflaredUUID,
reconnectCh, reconnectCh,
protocallFallback.protocol, protocallFallback.protocol,
gracefulShutdownC,
) )
if !recoverable { if !recoverable {
return err return err
@ -242,6 +250,7 @@ func ServeTunnel(
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
protocol connection.Protocol, protocol connection.Protocol,
gracefulShutdownC chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -268,7 +277,17 @@ func ServeTunnel(
} }
if protocol == connection.HTTP2 { if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
return ServeHTTP2(ctx, log, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh) return ServeHTTP2(
ctx,
log,
config,
edgeConn,
connOptions,
connIndex,
connectedFuse,
reconnectCh,
gracefulShutdownC,
)
} }
return ServeH2mux( return ServeH2mux(
ctx, ctx,
@ -280,6 +299,7 @@ func ServeTunnel(
connectedFuse, connectedFuse,
cloudflaredUUID, cloudflaredUUID,
reconnectCh, reconnectCh,
gracefulShutdownC,
) )
} }
@ -293,6 +313,7 @@ func ServeH2mux(
connectedFuse *connectedFuse, connectedFuse *connectedFuse,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
config.Log.Debug().Msgf("Connecting via h2mux") config.Log.Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
@ -302,6 +323,7 @@ func ServeH2mux(
edgeConn, edgeConn,
connIndex, connIndex,
config.Observer, config.Observer,
gracefulShutdownC,
) )
if err != nil { if err != nil {
return err, recoverable return err, recoverable
@ -312,13 +334,13 @@ func ServeH2mux(
errGroup.Go(func() (err error) { errGroup.Go(func() (err error) {
if config.NamedTunnel != nil { if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
} }
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
}) })
errGroup.Go(listenReconnect(serveCtx, reconnectCh)) errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC))
err = errGroup.Wait() err = errGroup.Wait()
if err != nil { if err != nil {
@ -367,9 +389,10 @@ func ServeHTTP2(
connIndex uint8, connIndex uint8,
connectedFuse connection.ConnectedFuse, connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
log.Debug().Msgf("Connecting via http2") log.Debug().Msgf("Connecting via http2")
server := connection.NewHTTP2Connection( h2conn := connection.NewHTTP2Connection(
tlsServerConn, tlsServerConn,
config.ConnectionConfig, config.ConnectionConfig,
config.NamedTunnel, config.NamedTunnel,
@ -377,15 +400,15 @@ func ServeHTTP2(
config.Observer, config.Observer,
connIndex, connIndex,
connectedFuse, connectedFuse,
gracefulShutdownC,
) )
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
server.Serve(serveCtx) return h2conn.Serve(serveCtx)
return fmt.Errorf("connection with edge closed")
}) })
errGroup.Go(listenReconnect(serveCtx, reconnectCh)) errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC))
err = errGroup.Wait() err = errGroup.Wait()
if err != nil { if err != nil {
@ -394,11 +417,13 @@ func ServeHTTP2(
return nil, false return nil, false
} }
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error { func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) func() error {
return func() error { return func() error {
select { select {
case reconnect := <-reconnectCh: case reconnect := <-reconnectCh:
return &reconnect return reconnect
case <-gracefulShutdownCh:
return nil
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
} }