diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 2b6defde..70da0447 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -321,24 +321,51 @@ func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time. func (m *Muxer) Serve(ctx context.Context) error { errGroup, _ := errgroup.WithContext(ctx) errGroup.Go(func() error { - err := m.muxReader.run(m.config.Logger) - m.explicitShutdown.Fuse(false) - m.r.Close() - m.abort() - return err + ch := make(chan error) + go func() { + err := m.muxReader.run(m.config.Logger) + m.explicitShutdown.Fuse(false) + m.r.Close() + m.abort() + ch <- err + }() + select { + case err := <-ch: + return err + case <-ctx.Done(): + return ctx.Err() + } }) errGroup.Go(func() error { - err := m.muxWriter.run(m.config.Logger) - m.explicitShutdown.Fuse(false) - m.w.Close() - m.abort() - return err + ch := make(chan error) + go func() { + err := m.muxWriter.run(m.config.Logger) + m.explicitShutdown.Fuse(false) + m.w.Close() + m.abort() + ch <- err + }() + select { + case err := <-ch: + return err + case <-ctx.Done(): + return ctx.Err() + } }) errGroup.Go(func() error { - err := m.muxMetricsUpdater.run(m.config.Logger) - return err + ch := make(chan error) + go func() { + err := m.muxMetricsUpdater.run(m.config.Logger) + ch <- err + }() + select { + case err := <-ch: + return err + case <-ctx.Done(): + return ctx.Err() + } }) err := errGroup.Wait()