TUN-8775: Make sure the session Close can only be called once
The previous capture of the sync.OnceValue was re-initialized for each call to `Close`. This needed to be initialized during the creation of the session to ensure that the sync.OnceValue reference was held for the session's lifetime. Closes TUN-8775
This commit is contained in:
parent
f07d04d129
commit
37010529bc
|
@ -73,6 +73,9 @@ type session struct {
|
||||||
contextChan chan context.Context
|
contextChan chan context.Context
|
||||||
metrics Metrics
|
metrics Metrics
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
|
|
||||||
|
// A special close function that we wrap with sync.Once to make sure it is only called once
|
||||||
|
closeFn func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(
|
func NewSession(
|
||||||
|
@ -86,6 +89,7 @@ func NewSession(
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
) Session {
|
) Session {
|
||||||
logger := log.With().Str(logFlowID, id.String()).Logger()
|
logger := log.With().Str(logFlowID, id.String()).Logger()
|
||||||
|
closeChan := make(chan error, 1)
|
||||||
session := &session{
|
session := &session{
|
||||||
id: id,
|
id: id,
|
||||||
closeAfterIdle: closeAfterIdle,
|
closeAfterIdle: closeAfterIdle,
|
||||||
|
@ -96,11 +100,19 @@ func NewSession(
|
||||||
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
||||||
// drop instead of blocking because last active time only needs to be an approximation
|
// drop instead of blocking because last active time only needs to be an approximation
|
||||||
activeAtChan: make(chan time.Time, 1),
|
activeAtChan: make(chan time.Time, 1),
|
||||||
closeChan: make(chan error, 1),
|
closeChan: closeChan,
|
||||||
// contextChan is an unbounded channel to help enforce one active migration of a session at a time.
|
// contextChan is an unbounded channel to help enforce one active migration of a session at a time.
|
||||||
contextChan: make(chan context.Context),
|
contextChan: make(chan context.Context),
|
||||||
metrics: metrics,
|
metrics: metrics,
|
||||||
log: &logger,
|
log: &logger,
|
||||||
|
closeFn: sync.OnceValue(func() error {
|
||||||
|
// We don't want to block on sending to the close channel if it is already full
|
||||||
|
select {
|
||||||
|
case closeChan <- SessionCloseErr:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return origin.Close()
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
session.eyeball.Store(&eyeball)
|
session.eyeball.Store(&eyeball)
|
||||||
return session
|
return session
|
||||||
|
@ -218,14 +230,7 @@ func (s *session) markActive() {
|
||||||
|
|
||||||
func (s *session) Close() error {
|
func (s *session) Close() error {
|
||||||
// Make sure that we only close the origin connection once
|
// Make sure that we only close the origin connection once
|
||||||
return sync.OnceValue(func() error {
|
return s.closeFn()
|
||||||
// We don't want to block on sending to the close channel if it is already full
|
|
||||||
select {
|
|
||||||
case s.closeChan <- SessionCloseErr:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return s.origin.Close()
|
|
||||||
})()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
|
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
|
||||||
|
|
|
@ -255,11 +255,16 @@ func TestSessionClose_Multiple(t *testing.T) {
|
||||||
if !origin.closed.Load() {
|
if !origin.closed.Load() {
|
||||||
t.Fatal("origin wasn't closed")
|
t.Fatal("origin wasn't closed")
|
||||||
}
|
}
|
||||||
|
// Reset the closed status to make sure it isn't closed again
|
||||||
|
origin.closed.Store(false)
|
||||||
// subsequent closes shouldn't call close again or cause any errors
|
// subsequent closes shouldn't call close again or cause any errors
|
||||||
err = session.Close()
|
err = session.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if origin.closed.Load() {
|
||||||
|
t.Fatal("origin was incorrectly closed twice")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_IdleTimeout(t *testing.T) {
|
func TestSessionServe_IdleTimeout(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue