CUSTESC-53681: Correct QUIC connection management for datagram handlers
Corrects the pattern of using errgroup's and context cancellation to simplify the logic for canceling extra routines for the QUIC connection. This is because the extra context cancellation is redundant with the fact that the errgroup also cancels it's own provided context when a routine returns (error or not). For the datagram handler specifically, since it can respond faster to a context cancellation from the QUIC connection, we wrap the error before surfacing it outside of the QUIC connection scope to the supervisor. Additionally, the supervisor will look for this error type to check if it should retry the QUIC connection. These two operations are required because the supervisor does not look for a context canceled error when deciding to retry a connection. If a context canceled from the datagram handler were to be returned up to the supervisor on the initial connection, the cloudflared application would exit. We want to ensure that cloudflared maintains connection attempts even if any of the services on-top of a QUIC connection fail (datagram handler in this case). Additional logging is also introduced along these paths to help with understanding the error conditions from the specific handlers on-top of a QUIC connection. Related CUSTESC-53681 Closes TUN-9610
This commit is contained in:
parent
8825ceecb5
commit
41dffd7f3c
|
|
@ -82,7 +82,7 @@ func (c *controlStream) ServeControlStream(
|
||||||
tunnelConfigGetter TunnelConfigJSONGetter,
|
tunnelConfigGetter TunnelConfigJSONGetter,
|
||||||
) error {
|
) error {
|
||||||
registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)
|
registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)
|
||||||
|
c.observer.logConnecting(c.connIndex, c.edgeAddress, c.protocol)
|
||||||
registrationDetails, err := registrationClient.RegisterConnection(
|
registrationDetails, err := registrationClient.RegisterConnection(
|
||||||
ctx,
|
ctx,
|
||||||
c.tunnelProperties.Credentials.Auth(),
|
c.tunnelProperties.Credentials.Auth(),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -53,26 +52,26 @@ func serverRegistrationErrorFromRPC(err error) ServerRegisterTunnelError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type muxerShutdownError struct{}
|
type ControlStreamError struct{}
|
||||||
|
|
||||||
func (e muxerShutdownError) Error() string {
|
var _ error = &ControlStreamError{}
|
||||||
return "muxer shutdown"
|
|
||||||
|
func (e *ControlStreamError) Error() string {
|
||||||
|
return "control stream encountered a failure while serving"
|
||||||
}
|
}
|
||||||
|
|
||||||
var errMuxerStopped = muxerShutdownError{}
|
type StreamListenerError struct{}
|
||||||
|
|
||||||
func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) bool {
|
var _ error = &StreamListenerError{}
|
||||||
log := observer.log.With().
|
|
||||||
Uint8(LogFieldConnIndex, connIndex).
|
|
||||||
Err(err).
|
|
||||||
Logger()
|
|
||||||
|
|
||||||
switch err.(type) {
|
func (e *StreamListenerError) Error() string {
|
||||||
case edgediscovery.DialError:
|
return "accept stream listener encountered a failure while serving"
|
||||||
log.Error().Msg("Connection unable to dial edge")
|
|
||||||
default:
|
|
||||||
log.Error().Msg("Connection failed")
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
|
type DatagramManagerError struct{}
|
||||||
|
|
||||||
|
var _ error = &DatagramManagerError{}
|
||||||
|
|
||||||
|
func (e *DatagramManagerError) Error() string {
|
||||||
|
return "datagram manager encountered a failure while serving"
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,15 @@ func (o *Observer) RegisterSink(sink EventSink) {
|
||||||
o.addSinkChan <- sink
|
o.addSinkChan <- sink
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Observer) logConnecting(connIndex uint8, address net.IP, protocol Protocol) {
|
||||||
|
o.log.Debug().
|
||||||
|
Int(management.EventTypeKey, int(management.Cloudflared)).
|
||||||
|
Uint8(LogFieldConnIndex, connIndex).
|
||||||
|
IPAddr(LogFieldIPAddress, address).
|
||||||
|
Str(LogFieldProtocol, protocol.String()).
|
||||||
|
Msg("Registering tunnel connection")
|
||||||
|
}
|
||||||
|
|
||||||
func (o *Observer) logConnected(connectionID uuid.UUID, connIndex uint8, location string, address net.IP, protocol Protocol) {
|
func (o *Observer) logConnected(connectionID uuid.UUID, connIndex uint8, location string, address net.IP, protocol Protocol) {
|
||||||
o.log.Info().
|
o.log.Info().
|
||||||
Int(management.EventTypeKey, int(management.Cloudflared)).
|
Int(management.EventTypeKey, int(management.Cloudflared)).
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package connection
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -12,7 +13,6 @@ import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
@ -65,7 +65,7 @@ func NewTunnelConnection(
|
||||||
streamWriteTimeout time.Duration,
|
streamWriteTimeout time.Duration,
|
||||||
gracePeriod time.Duration,
|
gracePeriod time.Duration,
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
) (TunnelConnection, error) {
|
) TunnelConnection {
|
||||||
return &quicConnection{
|
return &quicConnection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
|
@ -77,10 +77,11 @@ func NewTunnelConnection(
|
||||||
rpcTimeout: rpcTimeout,
|
rpcTimeout: rpcTimeout,
|
||||||
streamWriteTimeout: streamWriteTimeout,
|
streamWriteTimeout: streamWriteTimeout,
|
||||||
gracePeriod: gracePeriod,
|
gracePeriod: gracePeriod,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve starts a QUIC connection that begins accepting streams.
|
// Serve starts a QUIC connection that begins accepting streams.
|
||||||
|
// Returning a nil error means cloudflared will exit for good and will not attempt to reconnect.
|
||||||
func (q *quicConnection) Serve(ctx context.Context) error {
|
func (q *quicConnection) Serve(ctx context.Context) error {
|
||||||
// The edge assumes the first stream is used for the control plane
|
// The edge assumes the first stream is used for the control plane
|
||||||
controlStream, err := q.conn.OpenStream()
|
controlStream, err := q.conn.OpenStream()
|
||||||
|
|
@ -88,16 +89,16 @@ func (q *quicConnection) Serve(ctx context.Context) error {
|
||||||
return fmt.Errorf("failed to open a registration control stream: %w", err)
|
return fmt.Errorf("failed to open a registration control stream: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
|
|
||||||
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
|
|
||||||
// connection).
|
|
||||||
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
|
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
|
||||||
// other goroutine as fast as possible.
|
// other goroutines. We enforce returning a not-nil error for each function started in the errgroup by logging
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
// the error returned and returning a custom error type instead.
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
// In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
|
// Close the quic connection if any of the following routines return from the errgroup (regardless of their error)
|
||||||
// stream is already fully registered before the other goroutines can proceed.
|
// because they are no longer processing requests for the connection.
|
||||||
|
defer q.Close()
|
||||||
|
|
||||||
|
// Start the control stream routine
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
// err is equal to nil if we exit due to unregistration. If that happens we want to wait the full
|
// err is equal to nil if we exit due to unregistration. If that happens we want to wait the full
|
||||||
// amount of the grace period, allowing requests to finish before we cancel the context, which will
|
// amount of the grace period, allowing requests to finish before we cancel the context, which will
|
||||||
|
|
@ -114,16 +115,26 @@ func (q *quicConnection) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cancel()
|
if err != nil {
|
||||||
return err
|
q.logger.Error().Err(err).Msg("failed to serve the control stream")
|
||||||
|
}
|
||||||
|
return &ControlStreamError{}
|
||||||
})
|
})
|
||||||
|
// Start the accept stream loop routine
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
defer cancel()
|
err := q.acceptStream(ctx)
|
||||||
return q.acceptStream(ctx)
|
if err != nil {
|
||||||
|
q.logger.Error().Err(err).Msg("failed to accept incoming stream requests")
|
||||||
|
}
|
||||||
|
return &StreamListenerError{}
|
||||||
})
|
})
|
||||||
|
// Start the datagram handler routine
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
defer cancel()
|
err := q.datagramHandler.Serve(ctx)
|
||||||
return q.datagramHandler.Serve(ctx)
|
if err != nil {
|
||||||
|
q.logger.Error().Err(err).Msg("failed to run the datagram handler")
|
||||||
|
}
|
||||||
|
return &DatagramManagerError{}
|
||||||
})
|
})
|
||||||
|
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
|
|
@ -140,7 +151,6 @@ func (q *quicConnection) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *quicConnection) acceptStream(ctx context.Context) error {
|
func (q *quicConnection) acceptStream(ctx context.Context) error {
|
||||||
defer q.Close()
|
|
||||||
for {
|
for {
|
||||||
quicStream, err := q.conn.AcceptStream(ctx)
|
quicStream, err := q.conn.AcceptStream(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -230,7 +240,7 @@ func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.Re
|
||||||
ConnIndex: q.connIndex,
|
ConnIndex: q.connIndex,
|
||||||
}), rwa.connectResponseSent
|
}), rwa.connectResponseSent
|
||||||
default:
|
default:
|
||||||
return errors.Errorf("unsupported error type: %s", request.Type), false
|
return fmt.Errorf("unsupported error type: %s", request.Type), false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -847,7 +847,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
||||||
&log,
|
&log,
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelConn, err := NewTunnelConnection(
|
tunnelConn := NewTunnelConnection(
|
||||||
ctx,
|
ctx,
|
||||||
conn,
|
conn,
|
||||||
index,
|
index,
|
||||||
|
|
@ -860,7 +860,6 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
||||||
0*time.Second,
|
0*time.Second,
|
||||||
&log,
|
&log,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
|
||||||
return tunnelConn, datagramConn
|
return tunnelConn, datagramConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,24 +98,17 @@ func NewDatagramV2Connection(ctx context.Context,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *datagramV2Connection) Serve(ctx context.Context) error {
|
func (d *datagramV2Connection) Serve(ctx context.Context) error {
|
||||||
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
|
// If either goroutine from the errgroup returns at all (error or nil), we rely on its cancellation to make sure
|
||||||
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
|
// the other goroutines as well.
|
||||||
// connection).
|
|
||||||
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
|
|
||||||
// other goroutine as fast as possible.
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
defer cancel()
|
|
||||||
return d.sessionManager.Serve(ctx)
|
return d.sessionManager.Serve(ctx)
|
||||||
})
|
})
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
defer cancel()
|
|
||||||
return d.datagramMuxer.ServeReceive(ctx)
|
return d.datagramMuxer.ServeReceive(ctx)
|
||||||
})
|
})
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
defer cancel()
|
|
||||||
return d.packetRouter.Serve(ctx)
|
return d.packetRouter.Serve(ctx)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,7 @@ func (c *datagramConn) Serve(ctx context.Context) error {
|
||||||
// Monitor the context of cloudflared
|
// Monitor the context of cloudflared
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
// Monitor the context of the underlying connection
|
// Monitor the context of the underlying quic connection
|
||||||
case <-connCtx.Done():
|
case <-connCtx.Done():
|
||||||
return connCtx.Err()
|
return connCtx.Err()
|
||||||
// Monitor for any hard errors from reading the connection
|
// Monitor for any hard errors from reading the connection
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,7 @@ func (s *Supervisor) Run(
|
||||||
if err == errEarlyShutdown {
|
if err == errEarlyShutdown {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
s.log.Logger().Error().Err(err).Msg("initial tunnel connection failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
|
|
@ -154,6 +155,7 @@ func (s *Supervisor) Run(
|
||||||
// (note that this may also be caused by context cancellation)
|
// (note that this may also be caused by context cancellation)
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
tunnelsActive--
|
tunnelsActive--
|
||||||
|
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
||||||
if tunnelError.err != nil && !shuttingDown {
|
if tunnelError.err != nil && !shuttingDown {
|
||||||
switch tunnelError.err.(type) {
|
switch tunnelError.err.(type) {
|
||||||
case ReconnectSignal:
|
case ReconnectSignal:
|
||||||
|
|
@ -166,7 +168,6 @@ func (s *Supervisor) Run(
|
||||||
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
|
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
|
||||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
s.waitForNextTunnel(tunnelError.index)
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
|
|
||||||
|
|
@ -285,7 +286,10 @@ func (s *Supervisor) startFirstTunnel(
|
||||||
*quic.IdleTimeoutError,
|
*quic.IdleTimeoutError,
|
||||||
*quic.ApplicationError,
|
*quic.ApplicationError,
|
||||||
edgediscovery.DialError,
|
edgediscovery.DialError,
|
||||||
*connection.EdgeQuicDialError:
|
*connection.EdgeQuicDialError,
|
||||||
|
*connection.ControlStreamError,
|
||||||
|
*connection.StreamListenerError,
|
||||||
|
*connection.DatagramManagerError:
|
||||||
// Try again for these types of errors
|
// Try again for these types of errors
|
||||||
default:
|
default:
|
||||||
// Uncaught errors should bail startup
|
// Uncaught errors should bail startup
|
||||||
|
|
@ -301,13 +305,9 @@ func (s *Supervisor) startTunnel(
|
||||||
index int,
|
index int,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
) {
|
) {
|
||||||
var err error
|
|
||||||
defer func() {
|
|
||||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// nolint: gosec
|
// nolint: gosec
|
||||||
err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
|
err := s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
|
||||||
|
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||||
|
|
|
||||||
|
|
@ -556,6 +556,7 @@ func (e *EdgeTunnelServer) serveQUIC(
|
||||||
pqMode := connOptions.FeatureSnapshot.PostQuantum
|
pqMode := connOptions.FeatureSnapshot.PostQuantum
|
||||||
curvePref, err := curvePreference(pqMode, fips.IsFipsEnabled(), tlsConfig.CurvePreferences)
|
curvePref, err := curvePreference(pqMode, fips.IsFipsEnabled(), tlsConfig.CurvePreferences)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
connLogger.ConnAwareLogger().Err(err).Msgf("failed to get curve preferences")
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -627,7 +628,7 @@ func (e *EdgeTunnelServer) serveQUIC(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the [quic.Connection] as a TunnelConnection
|
// Wrap the [quic.Connection] as a TunnelConnection
|
||||||
tunnelConn, err := connection.NewTunnelConnection(
|
tunnelConn := connection.NewTunnelConnection(
|
||||||
ctx,
|
ctx,
|
||||||
conn,
|
conn,
|
||||||
connIndex,
|
connIndex,
|
||||||
|
|
@ -640,17 +641,13 @@ func (e *EdgeTunnelServer) serveQUIC(
|
||||||
e.config.GracePeriod,
|
e.config.GracePeriod,
|
||||||
connLogger.Logger(),
|
connLogger.Logger(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new tunnel connection")
|
|
||||||
return err, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serve the TunnelConnection
|
// Serve the TunnelConnection
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
err := tunnelConn.Serve(serveCtx)
|
err := tunnelConn.Serve(serveCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve tunnel connection")
|
connLogger.ConnAwareLogger().Err(err).Msg("failed to serve tunnel connection")
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue