diff --git a/connection/control.go b/connection/control.go index f586f842..c0c6a1d7 100644 --- a/connection/control.go +++ b/connection/control.go @@ -29,7 +29,9 @@ type controlStream struct { // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown. type ControlStreamHandler interface { - ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error + // ServeControlStream handles the control plane of the transport in the current goroutine calling this + ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error + // IsStopped tells whether the method above has finished IsStopped() bool } @@ -61,7 +63,6 @@ func (c *controlStream) ServeControlStream( ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, - shouldWaitForUnregister bool, ) error { rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) @@ -71,12 +72,7 @@ func (c *controlStream) ServeControlStream( } c.connectedFuse.Connected() - if shouldWaitForUnregister { - c.waitForUnregister(ctx, rpcClient) - } else { - go c.waitForUnregister(ctx, rpcClient) - } - + c.waitForUnregister(ctx, rpcClient) return nil } diff --git a/connection/http2.go b/connection/http2.go index 794af9d5..c0ab8f23 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -108,7 +108,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch connType { case TypeControlStream: - if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, true); err != nil { + if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil { c.controlStreamErr = err c.log.Error().Err(err) respWriter.WriteErrorResponse() diff --git a/connection/quic.go b/connection/quic.go index b586be26..c1b4ff9d 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -39,11 +39,11 @@ type QUICConnection struct { httpProxy OriginProxy sessionManager datagramsession.Manager controlStreamHandler ControlStreamHandler + connOptions *tunnelpogs.ConnectionOptions } // NewQUICConnection returns a new instance of QUICConnection. func NewQUICConnection( - ctx context.Context, quicConfig *quic.Config, edgeAddr net.Addr, tlsConfig *tls.Config, @@ -57,17 +57,6 @@ func NewQUICConnection( return nil, fmt.Errorf("failed to dial to edge: %w", err) } - registrationStream, err := session.OpenStream() - if err != nil { - return nil, fmt.Errorf("failed to open a registration stream: %w", err) - } - - err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false) - if err != nil { - // Not wrapping error here to be consistent with the http2 message. - return nil, err - } - datagramMuxer, err := quicpogs.NewDatagramMuxer(session) if err != nil { return nil, err @@ -81,11 +70,18 @@ func NewQUICConnection( logger: logger, sessionManager: sessionManager, controlStreamHandler: controlStreamHandler, + connOptions: connOptions, }, nil } // Serve starts a QUIC session that begins accepting streams. func (q *QUICConnection) Serve(ctx context.Context) error { + // origintunneld assumes the first stream is used for the control plane + controlStream, err := q.session.OpenStream() + if err != nil { + return fmt.Errorf("failed to open a registration control stream: %w", err) + } + // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this // connection). @@ -93,6 +89,13 @@ func (q *QUICConnection) Serve(ctx context.Context) error { // other goroutine as fast as possible. ctx, cancel := context.WithCancel(ctx) errGroup, ctx := errgroup.WithContext(ctx) + + // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control + // stream is already fully registered before the other goroutines can proceed. + errGroup.Go(func() error { + defer cancel() + return q.serveControlStream(ctx, controlStream) + }) errGroup.Go(func() error { defer cancel() return q.acceptStream(ctx) @@ -101,9 +104,21 @@ func (q *QUICConnection) Serve(ctx context.Context) error { defer cancel() return q.sessionManager.Serve(ctx) }) + return errGroup.Wait() } +func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { + // This blocks until the control plane is done. + err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions) + if err != nil { + // Not wrapping error here to be consistent with the http2 message. + return err + } + + return nil +} + func (q *QUICConnection) acceptStream(ctx context.Context) error { defer q.Close() for { diff --git a/connection/quic_test.go b/connection/quic_test.go index 4c2ab20e..5bc0827e 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -154,7 +154,7 @@ func TestQUICServer(t *testing.T) { ) }() - qc := testQUICConnection(ctx, udpListener.LocalAddr(), t) + qc := testQUICConnection(udpListener.LocalAddr(), t) go qc.Serve(ctx) wg.Wait() @@ -167,7 +167,8 @@ type fakeControlStream struct { ControlStreamHandler } -func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error { +func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error { + <-ctx.Done() return nil } func (fakeControlStream) IsStopped() bool { @@ -532,7 +533,7 @@ func TestServeUDPSession(t *testing.T) { edgeQUICSessionChan <- edgeQUICSession }() - qc := testQUICConnection(ctx, udpListener.LocalAddr(), t) + qc := testQUICConnection(udpListener.LocalAddr(), t) go qc.Serve(ctx) edgeQUICSession := <-edgeQUICSessionChan @@ -645,7 +646,7 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI return nil } -func testQUICConnection(ctx context.Context, udpListenerAddr net.Addr, t *testing.T) *QUICConnection { +func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection { tlsClientConfig := &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"argotunnel"}, @@ -654,7 +655,6 @@ func testQUICConnection(ctx context.Context, udpListenerAddr net.Addr, t *testin originProxy := &mockOriginProxyWithRequest{} log := zerolog.New(os.Stdout) qc, err := NewQUICConnection( - ctx, testQUICConfig, udpListenerAddr, tlsClientConfig, diff --git a/origin/tunnel.go b/origin/tunnel.go index d13eec8b..d3452d40 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -35,13 +35,6 @@ const ( quicMaxIdleTimeout = 15 * time.Second ) -type rpcName string - -const ( - reconnect rpcName = "reconnect" - authenticate rpcName = " authenticate" -) - type TunnelConfig struct { ConnectionConfig *connection.Config OSArch string @@ -535,44 +528,39 @@ func ServeQUIC( EnableDatagrams: true, Tracer: quicpogs.NewClientTracer(connLogger.Logger(), connIndex), } - for { - select { - case <-ctx.Done(): - return - default: - quicConn, err := connection.NewQUICConnection( - ctx, - quicConfig, - edgeAddr, - tlsConfig, - config.ConnectionConfig.OriginProxy, - connOptions, - controlStreamHandler, - connLogger.Logger()) - if err != nil { - connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") - return err, true - } - errGroup, serveCtx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - err := quicConn.Serve(serveCtx) - if err != nil { - connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection") - } - return err - }) - - errGroup.Go(func() error { - return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) - }) - - err = errGroup.Wait() - if err == nil { - return nil, false - } - } + quicConn, err := connection.NewQUICConnection( + quicConfig, + edgeAddr, + tlsConfig, + config.ConnectionConfig.OriginProxy, + connOptions, + controlStreamHandler, + connLogger.Logger()) + if err != nil { + connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") + return err, true } + + errGroup, serveCtx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + err := quicConn.Serve(serveCtx) + if err != nil { + connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection") + } + return err + }) + + errGroup.Go(func() error { + err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) + if err != nil { + // forcefully break the connection (this is only used for testing) + quicConn.Close() + } + return err + }) + + return errGroup.Wait(), false } func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {