TUN-5600: Close QUIC transports as soon as possible while respecting graceful shutdown

This does a few fixes to make sure that the QUICConnection returns from
Serve when the context is cancelled.

QUIC transport now behaves like other transports: closes as soon as there
is no traffic, or at most by grace-period. Note that we do not wait for
UDP traffic since that's connectionless by design.
This commit is contained in:
Nuno Diegues 2022-01-04 19:00:44 +00:00
parent ead93e9f26
commit 628545d229
4 changed files with 33 additions and 17 deletions

View File

@ -34,10 +34,11 @@ const (
// QUICConnection represents the type that facilitates Proxying via QUIC streams. // QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct { type QUICConnection struct {
session quic.Session session quic.Session
logger *zerolog.Logger logger *zerolog.Logger
httpProxy OriginProxy httpProxy OriginProxy
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
controlStreamHandler ControlStreamHandler
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -49,7 +50,7 @@ func NewQUICConnection(
httpProxy OriginProxy, httpProxy OriginProxy,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,
observer *Observer, logger *zerolog.Logger,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil { if err != nil {
@ -72,34 +73,44 @@ func NewQUICConnection(
return nil, err return nil, err
} }
sessionManager := datagramsession.NewManager(datagramMuxer, observer.log) sessionManager := datagramsession.NewManager(datagramMuxer, logger)
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
httpProxy: httpProxy, httpProxy: httpProxy,
logger: observer.log, logger: logger,
sessionManager: sessionManager, sessionManager: sessionManager,
controlStreamHandler: controlStreamHandler,
}, nil }, nil
} }
// Serve starts a QUIC session that begins accepting streams. // Serve starts a QUIC session that begins accepting streams.
func (q *QUICConnection) Serve(ctx context.Context) error { func (q *QUICConnection) Serve(ctx context.Context) error {
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
// connection).
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
// other goroutine as fast as possible.
ctx, cancel := context.WithCancel(ctx)
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
defer cancel()
return q.acceptStream(ctx) return q.acceptStream(ctx)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
defer cancel()
return q.sessionManager.Serve(ctx) return q.sessionManager.Serve(ctx)
}) })
return errGroup.Wait() return errGroup.Wait()
} }
func (q *QUICConnection) acceptStream(ctx context.Context) error { func (q *QUICConnection) acceptStream(ctx context.Context) error {
defer q.Close()
for { for {
stream, err := q.session.AcceptStream(ctx) stream, err := q.session.AcceptStream(ctx)
if err != nil { if err != nil {
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
return nil return nil
} }
return fmt.Errorf("failed to accept QUIC stream: %w", err) return fmt.Errorf("failed to accept QUIC stream: %w", err)

View File

@ -661,7 +661,7 @@ func testQUICConnection(ctx context.Context, udpListenerAddr net.Addr, t *testin
originProxy, originProxy,
&tunnelpogs.ConnectionOptions{}, &tunnelpogs.ConnectionOptions{},
fakeControlStream{}, fakeControlStream{},
NewObserver(&log, &log, false), &log,
) )
require.NoError(t, err) require.NoError(t, err)
return qc return qc

View File

@ -5,6 +5,7 @@ import (
"io" "io"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -50,7 +51,11 @@ func (m *manager) Serve(ctx context.Context) error {
for { for {
sessionID, payload, err := m.transport.ReceiveFrom() sessionID, payload, err := m.transport.ReceiveFrom()
if err != nil { if err != nil {
return err if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) {
return nil
} else {
return err
}
} }
datagram := &newDatagram{ datagram := &newDatagram{
sessionID: sessionID, sessionID: sessionID,
@ -69,7 +74,7 @@ func (m *manager) Serve(ctx context.Context) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return nil
case datagram := <-m.datagramChan: case datagram := <-m.datagramChan:
m.sendToSession(datagram) m.sendToSession(datagram)
case registration := <-m.registrationChan: case registration := <-m.registrationChan:

View File

@ -548,7 +548,7 @@ func ServeQUIC(
config.ConnectionConfig.OriginProxy, config.ConnectionConfig.OriginProxy,
connOptions, connOptions,
controlStreamHandler, controlStreamHandler,
config.Observer) connLogger.Logger())
if err != nil { if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
return err, true return err, true
@ -556,11 +556,11 @@ func ServeQUIC(
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
err := quicConn.Serve(ctx) err := quicConn.Serve(serveCtx)
if err != nil { if err != nil {
connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection") connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection")
} }
return fmt.Errorf("Connection with edge closed") return err
}) })
errGroup.Go(func() error { errGroup.Go(func() error {