TUN-3118: Changed graceful shutdown to immediately unregister tunnel from the edge, keep the connection open until the edge drops it or grace period expires
This commit is contained in:
parent
db0562c7b8
commit
d503aeaf77
|
@ -32,7 +32,6 @@ import (
|
||||||
"github.com/coreos/go-systemd/daemon"
|
"github.com/coreos/go-systemd/daemon"
|
||||||
"github.com/facebookgo/grace/gracenet"
|
"github.com/facebookgo/grace/gracenet"
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/mitchellh/go-homedir"
|
"github.com/mitchellh/go-homedir"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -199,7 +198,7 @@ func runAdhocNamedTunnel(sc *subcommandContext, name string) error {
|
||||||
|
|
||||||
// runClassicTunnel creates a "classic" non-named tunnel
|
// runClassicTunnel creates a "classic" non-named tunnel
|
||||||
func runClassicTunnel(sc *subcommandContext) error {
|
func runClassicTunnel(sc *subcommandContext) error {
|
||||||
return StartServer(sc.c, version, shutdownC, graceShutdownC, nil, sc.log, sc.isUIEnabled)
|
return StartServer(sc.c, version, nil, sc.log, sc.isUIEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) {
|
func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) {
|
||||||
|
@ -215,8 +214,6 @@ func routeFromFlag(c *cli.Context) (tunnelstore.Route, bool) {
|
||||||
func StartServer(
|
func StartServer(
|
||||||
c *cli.Context,
|
c *cli.Context,
|
||||||
version string,
|
version string,
|
||||||
shutdownC,
|
|
||||||
graceShutdownC chan struct{},
|
|
||||||
namedTunnel *connection.NamedTunnelConfig,
|
namedTunnel *connection.NamedTunnelConfig,
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
isUIEnabled bool,
|
isUIEnabled bool,
|
||||||
|
@ -287,12 +284,6 @@ func StartServer(
|
||||||
go writePidFile(connectedSignal, c.String("pidfile"), log)
|
go writePidFile(connectedSignal, c.String("pidfile"), log)
|
||||||
}
|
}
|
||||||
|
|
||||||
cloudflaredID, err := uuid.NewRandom()
|
|
||||||
if err != nil {
|
|
||||||
log.Err(err).Msg("Cannot generate cloudflared ID")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
go func() {
|
go func() {
|
||||||
<-shutdownC
|
<-shutdownC
|
||||||
|
@ -363,7 +354,7 @@ func StartServer(
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
|
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if isUIEnabled {
|
if isUIEnabled {
|
||||||
|
@ -1040,7 +1031,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Info().Msgf("Sending reconnect signal %+v", reconnect)
|
log.Info().Msgf("Sending %+v", reconnect)
|
||||||
reconnectCh <- reconnect
|
reconnectCh <- reconnect
|
||||||
default:
|
default:
|
||||||
log.Info().Str(LogFieldCommand, command).Msg("Unknown command")
|
log.Info().Str(LogFieldCommand, command).Msg("Unknown command")
|
||||||
|
|
|
@ -51,7 +51,7 @@ func waitForSignalWithGraceShutdown(errC chan error,
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errC:
|
case err := <-errC:
|
||||||
logger.Info().Msgf("Initiating graceful shutdown due to %v ...", err)
|
logger.Info().Msgf("Initiating shutdown due to %v ...", err)
|
||||||
close(graceShutdownC)
|
close(graceShutdownC)
|
||||||
close(shutdownC)
|
close(shutdownC)
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -274,8 +274,6 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
||||||
return StartServer(
|
return StartServer(
|
||||||
sc.c,
|
sc.c,
|
||||||
version,
|
version,
|
||||||
shutdownC,
|
|
||||||
graceShutdownC,
|
|
||||||
&connection.NamedTunnelConfig{Credentials: credentials},
|
&connection.NamedTunnelConfig{Credentials: credentials},
|
||||||
sc.log,
|
sc.log,
|
||||||
sc.isUIEnabled,
|
sc.isUIEnabled,
|
||||||
|
|
|
@ -2,6 +2,7 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -29,6 +30,11 @@ type h2muxConnection struct {
|
||||||
connIndex uint8
|
connIndex uint8
|
||||||
|
|
||||||
observer *Observer
|
observer *Observer
|
||||||
|
gracefulShutdownC chan struct{}
|
||||||
|
stoppedGracefully bool
|
||||||
|
|
||||||
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
||||||
}
|
}
|
||||||
|
|
||||||
type MuxerConfig struct {
|
type MuxerConfig struct {
|
||||||
|
@ -57,6 +63,7 @@ func NewH2muxConnection(
|
||||||
edgeConn net.Conn,
|
edgeConn net.Conn,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
|
gracefulShutdownC chan struct{},
|
||||||
) (*h2muxConnection, error, bool) {
|
) (*h2muxConnection, error, bool) {
|
||||||
h := &h2muxConnection{
|
h := &h2muxConnection{
|
||||||
config: config,
|
config: config,
|
||||||
|
@ -64,6 +71,8 @@ func NewH2muxConnection(
|
||||||
connIndexStr: uint8ToString(connIndex),
|
connIndexStr: uint8ToString(connIndex),
|
||||||
connIndex: connIndex,
|
connIndex: connIndex,
|
||||||
observer: observer,
|
observer: observer,
|
||||||
|
gracefulShutdownC: gracefulShutdownC,
|
||||||
|
newRPCClientFunc: newRegistrationRPCClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Establish a muxed connection with the edge
|
// Establish a muxed connection with the edge
|
||||||
|
@ -77,21 +86,14 @@ func NewH2muxConnection(
|
||||||
return h, nil, false
|
return h, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
|
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
return h.serveMuxer(serveCtx)
|
return h.serveMuxer(serveCtx)
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
stream, err := h.newRPCStream(serveCtx, register)
|
if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer.log)
|
|
||||||
defer rpcClient.Close()
|
|
||||||
|
|
||||||
if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
connectedFuse.Connected()
|
connectedFuse.Connected()
|
||||||
|
@ -137,6 +139,10 @@ func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) StoppedGracefully() bool {
|
||||||
|
return h.stoppedGracefully
|
||||||
|
}
|
||||||
|
|
||||||
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
|
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
|
||||||
// All routines should stop when muxer finish serving. When muxer is shutdown
|
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||||
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||||
|
@ -152,13 +158,21 @@ func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse Connect
|
||||||
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
|
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-h.gracefulShutdownC:
|
||||||
|
if connectedFuse.IsConnected() {
|
||||||
|
h.unregister(isNamedTunnel)
|
||||||
|
}
|
||||||
|
h.stoppedGracefully = true
|
||||||
|
h.gracefulShutdownC = nil
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// UnregisterTunnel blocks until the RPC call returns
|
// UnregisterTunnel blocks until the RPC call returns
|
||||||
if connectedFuse.IsConnected() {
|
if !h.stoppedGracefully && connectedFuse.IsConnected() {
|
||||||
h.unregister(isNamedTunnel)
|
h.unregister(isNamedTunnel)
|
||||||
}
|
}
|
||||||
h.muxer.Shutdown()
|
h.muxer.Shutdown()
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-updateMetricsTickC:
|
case <-updateMetricsTickC:
|
||||||
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
|
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,10 +11,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -32,13 +34,20 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
|
||||||
go func() {
|
go func() {
|
||||||
edgeMuxConfig := h2mux.MuxerConfig{
|
edgeMuxConfig := h2mux.MuxerConfig{
|
||||||
Log: testObserver.log,
|
Log: testObserver.log,
|
||||||
|
Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error {
|
||||||
|
// we only expect RPC traffic in client->edge direction, provide minimal support for mocking
|
||||||
|
require.True(t, stream.IsRPCStream())
|
||||||
|
return stream.WriteHeaders([]h2mux.Header{
|
||||||
|
{Name: ":status", Value: "200"},
|
||||||
|
})
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
|
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
edgeMuxChan <- edgeMux
|
edgeMuxChan <- edgeMux
|
||||||
}()
|
}()
|
||||||
var connIndex = uint8(0)
|
var connIndex = uint8(0)
|
||||||
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver)
|
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return h2muxConn, <-edgeMuxChan
|
return h2muxConn, <-edgeMuxChan
|
||||||
}
|
}
|
||||||
|
@ -168,6 +177,55 @@ func TestServeStreamWS(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdownH2Mux(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
h2muxConn, edgeMux := newH2MuxConnection(t)
|
||||||
|
|
||||||
|
shutdownC := make(chan struct{})
|
||||||
|
unregisteredC := make(chan struct{})
|
||||||
|
h2muxConn.gracefulShutdownC = shutdownC
|
||||||
|
h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient {
|
||||||
|
return &mockNamedTunnelRPCClient{
|
||||||
|
registered: nil,
|
||||||
|
unregistered: unregisteredC,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(3)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = edgeMux.Serve(ctx)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = h2muxConn.serveMuxer(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
close(shutdownC)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-unregisteredC:
|
||||||
|
break // ok
|
||||||
|
case <-time.Tick(time.Second):
|
||||||
|
assert.Fail(t, "timed out waiting for control loop to unregister")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.True(t, h2muxConn.stoppedGracefully)
|
||||||
|
assert.Nil(t, h2muxConn.gracefulShutdownC)
|
||||||
|
}
|
||||||
|
|
||||||
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
|
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
|
||||||
for _, header := range stream.Headers {
|
for _, header := range stream.Headers {
|
||||||
if header.Name == name && header.Value == val {
|
if header.Name == name && header.Value == val {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
@ -22,6 +23,8 @@ const (
|
||||||
controlStreamUpgrade = "control-stream"
|
controlStreamUpgrade = "control-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
|
||||||
|
|
||||||
type http2Connection struct {
|
type http2Connection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
server *http2.Server
|
server *http2.Server
|
||||||
|
@ -35,6 +38,8 @@ type http2Connection struct {
|
||||||
// newRPCClientFunc allows us to mock RPCs during testing
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
||||||
connectedFuse ConnectedFuse
|
connectedFuse ConnectedFuse
|
||||||
|
gracefulShutdownC chan struct{}
|
||||||
|
stoppedGracefully bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP2Connection(
|
func NewHTTP2Connection(
|
||||||
|
@ -45,6 +50,7 @@ func NewHTTP2Connection(
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse ConnectedFuse,
|
connectedFuse ConnectedFuse,
|
||||||
|
gracefulShutdownC chan struct{},
|
||||||
) *http2Connection {
|
) *http2Connection {
|
||||||
return &http2Connection{
|
return &http2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
@ -60,10 +66,11 @@ func NewHTTP2Connection(
|
||||||
wg: &sync.WaitGroup{},
|
wg: &sync.WaitGroup{},
|
||||||
newRPCClientFunc: newRegistrationRPCClient,
|
newRPCClientFunc: newRegistrationRPCClient,
|
||||||
connectedFuse: connectedFuse,
|
connectedFuse: connectedFuse,
|
||||||
|
gracefulShutdownC: gracefulShutdownC,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *http2Connection) Serve(ctx context.Context) {
|
func (c *http2Connection) Serve(ctx context.Context) error {
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
c.close()
|
c.close()
|
||||||
|
@ -72,6 +79,11 @@ func (c *http2Connection) Serve(ctx context.Context) {
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
Handler: c,
|
Handler: c,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if !c.stoppedGracefully {
|
||||||
|
return errEdgeConnectionClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -106,6 +118,10 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *http2Connection) StoppedGracefully() bool {
|
||||||
|
return c.stoppedGracefully
|
||||||
|
}
|
||||||
|
|
||||||
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
||||||
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
|
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
|
||||||
defer rpcClient.Close()
|
defer rpcClient.Close()
|
||||||
|
@ -115,8 +131,16 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht
|
||||||
}
|
}
|
||||||
c.connectedFuse.Connected()
|
c.connectedFuse.Connected()
|
||||||
|
|
||||||
<-ctx.Done()
|
// wait for connection termination or start of graceful shutdown
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
break
|
||||||
|
case <-c.gracefulShutdownC:
|
||||||
|
c.stoppedGracefully = true
|
||||||
|
}
|
||||||
|
|
||||||
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
|
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
|
||||||
|
c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
|
||||||
|
@ -36,6 +38,7 @@ func newTestHTTP2Connection() (*http2Connection, net.Conn) {
|
||||||
testObserver,
|
testObserver,
|
||||||
connIndex,
|
connIndex,
|
||||||
mockConnectedFuse{},
|
mockConnectedFuse{},
|
||||||
|
nil,
|
||||||
), edgeConn
|
), edgeConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,10 +244,64 @@ func TestServeControlStream(t *testing.T) {
|
||||||
<-rpcClientFactory.registered
|
<-rpcClientFactory.registered
|
||||||
cancel()
|
cancel()
|
||||||
<-rpcClientFactory.unregistered
|
<-rpcClientFactory.unregistered
|
||||||
|
assert.False(t, http2Conn.stoppedGracefully)
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
rpcClientFactory := mockRPCClientFactory{
|
||||||
|
registered: make(chan struct{}),
|
||||||
|
unregistered: make(chan struct{}),
|
||||||
|
}
|
||||||
|
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
||||||
|
http2Conn.gracefulShutdownC = make(chan struct{})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
http2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
||||||
|
|
||||||
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-rpcClientFactory.registered:
|
||||||
|
break //ok
|
||||||
|
case <-time.Tick(time.Second):
|
||||||
|
t.Fatal("timeout out waiting for registration")
|
||||||
|
}
|
||||||
|
|
||||||
|
// signal graceful shutdown
|
||||||
|
close(http2Conn.gracefulShutdownC)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-rpcClientFactory.unregistered:
|
||||||
|
break //ok
|
||||||
|
case <-time.Tick(time.Second):
|
||||||
|
t.Fatal("timeout out waiting for unregistered signal")
|
||||||
|
}
|
||||||
|
assert.True(t, http2Conn.stoppedGracefully)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
@ -281,6 +338,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkServeHTTPSimple(b *testing.B) {
|
func BenchmarkServeHTTPSimple(b *testing.B) {
|
||||||
test := testRequest{
|
test := testRequest{
|
||||||
name: "ok",
|
name: "ok",
|
||||||
|
|
|
@ -272,17 +272,36 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) registerNamedTunnel(
|
||||||
|
ctx context.Context,
|
||||||
|
namedTunnel *NamedTunnelConfig,
|
||||||
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
|
) error {
|
||||||
|
stream, err := h.newRPCStream(ctx, register)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rpcClient := h.newRPCClientFunc(ctx, stream, h.observer.log)
|
||||||
|
defer rpcClient.Close()
|
||||||
|
|
||||||
|
if err = rpcClient.RegisterConnection(ctx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
|
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
|
||||||
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
|
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
stream, err := h.newRPCStream(unregisterCtx, register)
|
stream, err := h.newRPCStream(unregisterCtx, unregister)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
if isNamedTunnel {
|
if isNamedTunnel {
|
||||||
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer.log)
|
rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
|
||||||
defer rpcClient.Close()
|
defer rpcClient.Close()
|
||||||
|
|
||||||
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
|
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
|
||||||
|
@ -293,4 +312,6 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
|
||||||
// gracePeriod is encoded in int64 using capnproto
|
// gracePeriod is encoded in int64 using capnproto
|
||||||
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
|
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package origin
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -54,6 +55,9 @@ type Supervisor struct {
|
||||||
|
|
||||||
reconnectCredentialManager *reconnectCredentialManager
|
reconnectCredentialManager *reconnectCredentialManager
|
||||||
useReconnectToken bool
|
useReconnectToken bool
|
||||||
|
|
||||||
|
reconnectCh chan ReconnectSignal
|
||||||
|
gracefulShutdownC chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type tunnelError struct {
|
type tunnelError struct {
|
||||||
|
@ -62,11 +66,13 @@ type tunnelError struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
|
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) {
|
||||||
var (
|
cloudflaredUUID, err := uuid.NewRandom()
|
||||||
edgeIPs *edgediscovery.Edge
|
if err != nil {
|
||||||
err error
|
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
|
||||||
)
|
}
|
||||||
|
|
||||||
|
var edgeIPs *edgediscovery.Edge
|
||||||
if len(config.EdgeAddrs) > 0 {
|
if len(config.EdgeAddrs) > 0 {
|
||||||
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
|
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
|
||||||
} else {
|
} else {
|
||||||
|
@ -90,11 +96,16 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||||
log: config.Log,
|
log: config.Log,
|
||||||
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
||||||
useReconnectToken: useReconnectToken,
|
useReconnectToken: useReconnectToken,
|
||||||
|
reconnectCh: reconnectCh,
|
||||||
|
gracefulShutdownC: gracefulShutdownC,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
func (s *Supervisor) Run(
|
||||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
ctx context.Context,
|
||||||
|
connectedSignal *signal.Signal,
|
||||||
|
) error {
|
||||||
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
|
@ -131,7 +142,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
tunnelsActive--
|
tunnelsActive--
|
||||||
if tunnelError.err != nil {
|
if tunnelError.err != nil {
|
||||||
s.log.Err(tunnelError.err).Msg("supervisor: Tunnel disconnected")
|
s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
||||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
s.waitForNextTunnel(tunnelError.index)
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
|
|
||||||
|
@ -139,14 +150,16 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
backoffTimer = backoff.BackoffTimer()
|
backoffTimer = backoff.BackoffTimer()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Previously we'd mark the edge address as bad here, but now we'll just silently use
|
// Previously we'd mark the edge address as bad here, but now we'll just silently use another.
|
||||||
// another.
|
} else if tunnelsActive == 0 {
|
||||||
|
// all connected tunnels exited gracefully, no more work to do
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
// Backoff was set and its timer expired
|
// Backoff was set and its timer expired
|
||||||
case <-backoffTimer:
|
case <-backoffTimer:
|
||||||
backoffTimer = nil
|
backoffTimer = nil
|
||||||
for _, index := range tunnelsWaiting {
|
for _, index := range tunnelsWaiting {
|
||||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||||
}
|
}
|
||||||
tunnelsActive += len(tunnelsWaiting)
|
tunnelsActive += len(tunnelsWaiting)
|
||||||
tunnelsWaiting = nil
|
tunnelsWaiting = nil
|
||||||
|
@ -171,14 +184,17 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns nil if initialization succeeded, else the initialization error.
|
// Returns nil if initialization succeeded, else the initialization error.
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
func (s *Supervisor) initialize(
|
||||||
availableAddrs := int(s.edgeIPs.AvailableAddrs())
|
ctx context.Context,
|
||||||
|
connectedSignal *signal.Signal,
|
||||||
|
) error {
|
||||||
|
availableAddrs := s.edgeIPs.AvailableAddrs()
|
||||||
if s.config.HAConnections > availableAddrs {
|
if s.config.HAConnections > availableAddrs {
|
||||||
s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
||||||
s.config.HAConnections = availableAddrs
|
s.config.HAConnections = availableAddrs
|
||||||
}
|
}
|
||||||
|
|
||||||
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
|
go s.startFirstTunnel(ctx, connectedSignal)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
<-s.tunnelErrors
|
<-s.tunnelErrors
|
||||||
|
@ -190,7 +206,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
// At least one successful connection, so start the rest
|
// At least one successful connection, so start the rest
|
||||||
for i := 1; i < s.config.HAConnections; i++ {
|
for i := 1; i < s.config.HAConnections; i++ {
|
||||||
ch := signal.New(make(chan struct{}))
|
ch := signal.New(make(chan struct{}))
|
||||||
go s.startTunnel(ctx, i, ch, reconnectCh)
|
go s.startTunnel(ctx, i, ch)
|
||||||
time.Sleep(registrationInterval)
|
time.Sleep(registrationInterval)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -198,7 +214,10 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
|
|
||||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
|
func (s *Supervisor) startFirstTunnel(
|
||||||
|
ctx context.Context,
|
||||||
|
connectedSignal *signal.Signal,
|
||||||
|
) {
|
||||||
var (
|
var (
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
|
@ -221,7 +240,8 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
firstConnIndex,
|
firstConnIndex,
|
||||||
connectedSignal,
|
connectedSignal,
|
||||||
s.cloudflaredUUID,
|
s.cloudflaredUUID,
|
||||||
reconnectCh,
|
s.reconnectCh,
|
||||||
|
s.gracefulShutdownC,
|
||||||
)
|
)
|
||||||
// If the first tunnel disconnects, keep restarting it.
|
// If the first tunnel disconnects, keep restarting it.
|
||||||
edgeErrors := 0
|
edgeErrors := 0
|
||||||
|
@ -253,14 +273,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
firstConnIndex,
|
firstConnIndex,
|
||||||
connectedSignal,
|
connectedSignal,
|
||||||
s.cloudflaredUUID,
|
s.cloudflaredUUID,
|
||||||
reconnectCh,
|
s.reconnectCh,
|
||||||
|
s.gracefulShutdownC,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors.
|
// s.tunnelErrors.
|
||||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
|
func (s *Supervisor) startTunnel(
|
||||||
|
ctx context.Context,
|
||||||
|
index int,
|
||||||
|
connectedSignal *signal.Signal,
|
||||||
|
) {
|
||||||
var (
|
var (
|
||||||
addr *net.TCPAddr
|
addr *net.TCPAddr
|
||||||
err error
|
err error
|
||||||
|
@ -281,7 +306,8 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||||
uint8(index),
|
uint8(index),
|
||||||
connectedSignal,
|
connectedSignal,
|
||||||
s.cloudflaredUUID,
|
s.cloudflaredUUID,
|
||||||
reconnectCh,
|
s.reconnectCh,
|
||||||
|
s.gracefulShutdownC,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -107,12 +107,18 @@ func (c *TunnelConfig) SupportedFeatures() []string {
|
||||||
return features
|
return features
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
func StartTunnelDaemon(
|
||||||
s, err := NewSupervisor(config, cloudflaredID)
|
ctx context.Context,
|
||||||
|
config *TunnelConfig,
|
||||||
|
connectedSignal *signal.Signal,
|
||||||
|
reconnectCh chan ReconnectSignal,
|
||||||
|
graceShutdownC chan struct{},
|
||||||
|
) error {
|
||||||
|
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.Run(ctx, connectedSignal, reconnectCh)
|
return s.Run(ctx, connectedSignal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnelLoop(
|
func ServeTunnelLoop(
|
||||||
|
@ -124,6 +130,7 @@ func ServeTunnelLoop(
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
|
gracefulShutdownC chan struct{},
|
||||||
) error {
|
) error {
|
||||||
haConnections.Inc()
|
haConnections.Inc()
|
||||||
defer haConnections.Dec()
|
defer haConnections.Dec()
|
||||||
|
@ -158,6 +165,7 @@ func ServeTunnelLoop(
|
||||||
cloudflaredUUID,
|
cloudflaredUUID,
|
||||||
reconnectCh,
|
reconnectCh,
|
||||||
protocallFallback.protocol,
|
protocallFallback.protocol,
|
||||||
|
gracefulShutdownC,
|
||||||
)
|
)
|
||||||
if !recoverable {
|
if !recoverable {
|
||||||
return err
|
return err
|
||||||
|
@ -242,6 +250,7 @@ func ServeTunnel(
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
protocol connection.Protocol,
|
protocol connection.Protocol,
|
||||||
|
gracefulShutdownC chan struct{},
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -268,7 +277,17 @@ func ServeTunnel(
|
||||||
}
|
}
|
||||||
if protocol == connection.HTTP2 {
|
if protocol == connection.HTTP2 {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
||||||
return ServeHTTP2(ctx, log, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
|
return ServeHTTP2(
|
||||||
|
ctx,
|
||||||
|
log,
|
||||||
|
config,
|
||||||
|
edgeConn,
|
||||||
|
connOptions,
|
||||||
|
connIndex,
|
||||||
|
connectedFuse,
|
||||||
|
reconnectCh,
|
||||||
|
gracefulShutdownC,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return ServeH2mux(
|
return ServeH2mux(
|
||||||
ctx,
|
ctx,
|
||||||
|
@ -280,6 +299,7 @@ func ServeTunnel(
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
cloudflaredUUID,
|
cloudflaredUUID,
|
||||||
reconnectCh,
|
reconnectCh,
|
||||||
|
gracefulShutdownC,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,6 +313,7 @@ func ServeH2mux(
|
||||||
connectedFuse *connectedFuse,
|
connectedFuse *connectedFuse,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
|
gracefulShutdownC chan struct{},
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
config.Log.Debug().Msgf("Connecting via h2mux")
|
config.Log.Debug().Msgf("Connecting via h2mux")
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
|
@ -302,6 +323,7 @@ func ServeH2mux(
|
||||||
edgeConn,
|
edgeConn,
|
||||||
connIndex,
|
connIndex,
|
||||||
config.Observer,
|
config.Observer,
|
||||||
|
gracefulShutdownC,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, recoverable
|
return err, recoverable
|
||||||
|
@ -312,13 +334,13 @@ func ServeH2mux(
|
||||||
errGroup.Go(func() (err error) {
|
errGroup.Go(func() (err error) {
|
||||||
if config.NamedTunnel != nil {
|
if config.NamedTunnel != nil {
|
||||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
||||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
|
||||||
}
|
}
|
||||||
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||||
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC))
|
||||||
|
|
||||||
err = errGroup.Wait()
|
err = errGroup.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -367,9 +389,10 @@ func ServeHTTP2(
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse connection.ConnectedFuse,
|
connectedFuse connection.ConnectedFuse,
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
|
gracefulShutdownC chan struct{},
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
log.Debug().Msgf("Connecting via http2")
|
log.Debug().Msgf("Connecting via http2")
|
||||||
server := connection.NewHTTP2Connection(
|
h2conn := connection.NewHTTP2Connection(
|
||||||
tlsServerConn,
|
tlsServerConn,
|
||||||
config.ConnectionConfig,
|
config.ConnectionConfig,
|
||||||
config.NamedTunnel,
|
config.NamedTunnel,
|
||||||
|
@ -377,15 +400,15 @@ func ServeHTTP2(
|
||||||
config.Observer,
|
config.Observer,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
|
gracefulShutdownC,
|
||||||
)
|
)
|
||||||
|
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
server.Serve(serveCtx)
|
return h2conn.Serve(serveCtx)
|
||||||
return fmt.Errorf("connection with edge closed")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC))
|
||||||
|
|
||||||
err = errGroup.Wait()
|
err = errGroup.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -394,11 +417,13 @@ func ServeHTTP2(
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error {
|
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) func() error {
|
||||||
return func() error {
|
return func() error {
|
||||||
select {
|
select {
|
||||||
case reconnect := <-reconnectCh:
|
case reconnect := <-reconnectCh:
|
||||||
return &reconnect
|
return reconnect
|
||||||
|
case <-gracefulShutdownCh:
|
||||||
|
return nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue