diff --git a/connection/control.go b/connection/control.go index 4d5cd437..17664a94 100644 --- a/connection/control.go +++ b/connection/control.go @@ -29,7 +29,7 @@ type controlStream struct { // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown. type ControlStreamHandler interface { - ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error + ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error IsStopped() bool } @@ -61,6 +61,7 @@ func (c *controlStream) ServeControlStream( ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, + shouldWaitForUnregister bool, ) error { rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) defer rpcClient.Close() @@ -70,6 +71,16 @@ func (c *controlStream) ServeControlStream( } c.connectedFuse.Connected() + if shouldWaitForUnregister { + c.waitForUnregister(ctx, rpcClient) + } else { + go c.waitForUnregister(ctx, rpcClient) + } + + return nil +} + +func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTunnelRPCClient) { // wait for connection termination or start of graceful shutdown select { case <-ctx.Done(): @@ -81,7 +92,6 @@ func (c *controlStream) ServeControlStream( c.observer.sendUnregisteringEvent(c.connIndex) rpcClient.GracefulShutdown(ctx, c.gracePeriod) c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection") - return nil } func (c *controlStream) IsStopped() bool { diff --git a/connection/http2.go b/connection/http2.go index 44346371..2e869890 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -109,7 +109,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch connType { case TypeControlStream: - if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil { + if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, true); err != nil { c.controlStreamErr = err c.log.Error().Err(err) respWriter.WriteErrorResponse() diff --git a/connection/quic.go b/connection/quic.go index a3e349c5..c4f7e0ae 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -57,7 +57,7 @@ func NewQUICConnection( return nil, errors.Wrap(err, "failed to open a registration stream") } - err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions) + err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false) if err != nil { // Not wrapping error here to be consistent with the http2 message. return nil, err diff --git a/connection/quic_test.go b/connection/quic_test.go index d29f7cdc..332eb987 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -183,7 +183,7 @@ type fakeControlStream struct { ControlStreamHandler } -func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error { +func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error { return nil } func (fakeControlStream) IsStopped() bool {