TUN-8775: Make sure the session Close can only be called once

The previous capture of the sync.OnceValue was re-initialized for each
call to `Close`. This needed to be initialized during the creation of
the session to ensure that the sync.OnceValue reference was held for
the session's lifetime.

Closes TUN-8775
This commit is contained in:
Devin Carr 2024-12-05 14:12:53 -08:00
parent f07d04d129
commit 37010529bc
2 changed files with 19 additions and 9 deletions

View File

@ -73,6 +73,9 @@ type session struct {
contextChan chan context.Context contextChan chan context.Context
metrics Metrics metrics Metrics
log *zerolog.Logger log *zerolog.Logger
// A special close function that we wrap with sync.Once to make sure it is only called once
closeFn func() error
} }
func NewSession( func NewSession(
@ -86,6 +89,7 @@ func NewSession(
log *zerolog.Logger, log *zerolog.Logger,
) Session { ) Session {
logger := log.With().Str(logFlowID, id.String()).Logger() logger := log.With().Str(logFlowID, id.String()).Logger()
closeChan := make(chan error, 1)
session := &session{ session := &session{
id: id, id: id,
closeAfterIdle: closeAfterIdle, closeAfterIdle: closeAfterIdle,
@ -96,11 +100,19 @@ func NewSession(
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation // drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1), activeAtChan: make(chan time.Time, 1),
closeChan: make(chan error, 1), closeChan: closeChan,
// contextChan is an unbounded channel to help enforce one active migration of a session at a time. // contextChan is an unbounded channel to help enforce one active migration of a session at a time.
contextChan: make(chan context.Context), contextChan: make(chan context.Context),
metrics: metrics, metrics: metrics,
log: &logger, log: &logger,
closeFn: sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full
select {
case closeChan <- SessionCloseErr:
default:
}
return origin.Close()
}),
} }
session.eyeball.Store(&eyeball) session.eyeball.Store(&eyeball)
return session return session
@ -218,14 +230,7 @@ func (s *session) markActive() {
func (s *session) Close() error { func (s *session) Close() error {
// Make sure that we only close the origin connection once // Make sure that we only close the origin connection once
return sync.OnceValue(func() error { return s.closeFn()
// We don't want to block on sending to the close channel if it is already full
select {
case s.closeChan <- SessionCloseErr:
default:
}
return s.origin.Close()
})()
} }
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {

View File

@ -255,11 +255,16 @@ func TestSessionClose_Multiple(t *testing.T) {
if !origin.closed.Load() { if !origin.closed.Load() {
t.Fatal("origin wasn't closed") t.Fatal("origin wasn't closed")
} }
// Reset the closed status to make sure it isn't closed again
origin.closed.Store(false)
// subsequent closes shouldn't call close again or cause any errors // subsequent closes shouldn't call close again or cause any errors
err = session.Close() err = session.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if origin.closed.Load() {
t.Fatal("origin was incorrectly closed twice")
}
} }
func TestSessionServe_IdleTimeout(t *testing.T) { func TestSessionServe_IdleTimeout(t *testing.T) {