diff --git a/quic/v3/session.go b/quic/v3/session.go index 57641c90..fa1b1f6e 100644 --- a/quic/v3/session.go +++ b/quic/v3/session.go @@ -73,6 +73,9 @@ type session struct { contextChan chan context.Context metrics Metrics log *zerolog.Logger + + // A special close function that we wrap with sync.Once to make sure it is only called once + closeFn func() error } func NewSession( @@ -86,6 +89,7 @@ func NewSession( log *zerolog.Logger, ) Session { logger := log.With().Str(logFlowID, id.String()).Logger() + closeChan := make(chan error, 1) session := &session{ id: id, closeAfterIdle: closeAfterIdle, @@ -96,11 +100,19 @@ func NewSession( // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // drop instead of blocking because last active time only needs to be an approximation activeAtChan: make(chan time.Time, 1), - closeChan: make(chan error, 1), + closeChan: closeChan, // contextChan is an unbounded channel to help enforce one active migration of a session at a time. contextChan: make(chan context.Context), metrics: metrics, log: &logger, + closeFn: sync.OnceValue(func() error { + // We don't want to block on sending to the close channel if it is already full + select { + case closeChan <- SessionCloseErr: + default: + } + return origin.Close() + }), } session.eyeball.Store(&eyeball) return session @@ -218,14 +230,7 @@ func (s *session) markActive() { func (s *session) Close() error { // Make sure that we only close the origin connection once - return sync.OnceValue(func() error { - // We don't want to block on sending to the close channel if it is already full - select { - case s.closeChan <- SessionCloseErr: - default: - } - return s.origin.Close() - })() + return s.closeFn() } func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go index f47ceb14..8c570074 100644 --- a/quic/v3/session_test.go +++ b/quic/v3/session_test.go @@ -255,11 +255,16 @@ func TestSessionClose_Multiple(t *testing.T) { if !origin.closed.Load() { t.Fatal("origin wasn't closed") } + // Reset the closed status to make sure it isn't closed again + origin.closed.Store(false) // subsequent closes shouldn't call close again or cause any errors err = session.Close() if err != nil { t.Fatal(err) } + if origin.closed.Load() { + t.Fatal("origin was incorrectly closed twice") + } } func TestSessionServe_IdleTimeout(t *testing.T) {