TUN-8775: Make sure the session Close can only be called once

The previous capture of the sync.OnceValue was re-initialized for each
call to `Close`. This needed to be initialized during the creation of
the session to ensure that the sync.OnceValue reference was held for
the session's lifetime.

Closes TUN-8775
This commit is contained in:
Devin Carr 2024-12-05 14:12:53 -08:00
parent f07d04d129
commit 37010529bc
2 changed files with 19 additions and 9 deletions

View File

@ -73,6 +73,9 @@ type session struct {
contextChan chan context.Context
metrics Metrics
log *zerolog.Logger
// A special close function that we wrap with sync.Once to make sure it is only called once
closeFn func() error
}
func NewSession(
@ -86,6 +89,7 @@ func NewSession(
log *zerolog.Logger,
) Session {
logger := log.With().Str(logFlowID, id.String()).Logger()
closeChan := make(chan error, 1)
session := &session{
id: id,
closeAfterIdle: closeAfterIdle,
@ -96,11 +100,19 @@ func NewSession(
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1),
closeChan: make(chan error, 1),
closeChan: closeChan,
// contextChan is an unbounded channel to help enforce one active migration of a session at a time.
contextChan: make(chan context.Context),
metrics: metrics,
log: &logger,
closeFn: sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full
select {
case closeChan <- SessionCloseErr:
default:
}
return origin.Close()
}),
}
session.eyeball.Store(&eyeball)
return session
@ -218,14 +230,7 @@ func (s *session) markActive() {
func (s *session) Close() error {
// Make sure that we only close the origin connection once
return sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full
select {
case s.closeChan <- SessionCloseErr:
default:
}
return s.origin.Close()
})()
return s.closeFn()
}
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {

View File

@ -255,11 +255,16 @@ func TestSessionClose_Multiple(t *testing.T) {
if !origin.closed.Load() {
t.Fatal("origin wasn't closed")
}
// Reset the closed status to make sure it isn't closed again
origin.closed.Store(false)
// subsequent closes shouldn't call close again or cause any errors
err = session.Close()
if err != nil {
t.Fatal(err)
}
if origin.closed.Load() {
t.Fatal("origin was incorrectly closed twice")
}
}
func TestSessionServe_IdleTimeout(t *testing.T) {