TUN-5956: Add timeout to session manager APIs

This commit is contained in:
cthuang 2022-03-28 10:06:28 +01:00 committed by Chung Ting Huang
parent c5d1662244
commit c0f85ab85b
2 changed files with 34 additions and 1 deletions

View File

@ -3,6 +3,7 @@ package datagramsession
import ( import (
"context" "context"
"io" "io"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@ -12,6 +13,7 @@ import (
const ( const (
requestChanCapacity = 16 requestChanCapacity = 16
defaultReqTimeout = time.Second * 5
) )
// Manager defines the APIs to manage sessions from the same transport. // Manager defines the APIs to manage sessions from the same transport.
@ -31,9 +33,11 @@ type manager struct {
transport transport transport transport
sessions map[uuid.UUID]*Session sessions map[uuid.UUID]*Session
log *zerolog.Logger log *zerolog.Logger
// timeout waiting for an API to finish. This can be overriden in test
timeout time.Duration
} }
func NewManager(transport transport, log *zerolog.Logger) Manager { func NewManager(transport transport, log *zerolog.Logger) *manager {
return &manager{ return &manager{
registrationChan: make(chan *registerSessionEvent), registrationChan: make(chan *registerSessionEvent),
unregistrationChan: make(chan *unregisterSessionEvent), unregistrationChan: make(chan *unregisterSessionEvent),
@ -42,6 +46,7 @@ func NewManager(transport transport, log *zerolog.Logger) Manager {
transport: transport, transport: transport,
sessions: make(map[uuid.UUID]*Session), sessions: make(map[uuid.UUID]*Session),
log: log, log: log,
timeout: defaultReqTimeout,
} }
} }
@ -89,9 +94,12 @@ func (m *manager) Serve(ctx context.Context) error {
} }
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
event := newRegisterSessionEvent(sessionID, originProxy) event := newRegisterSessionEvent(sessionID, originProxy)
select { select {
case <-ctx.Done(): case <-ctx.Done():
m.log.Error().Msg("Datagram session registration timeout")
return nil, ctx.Err() return nil, ctx.Err()
case m.registrationChan <- event: case m.registrationChan <- event:
session := <-event.resultChan session := <-event.resultChan
@ -106,6 +114,8 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
} }
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error { func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
event := &unregisterSessionEvent{ event := &unregisterSessionEvent{
sessionID: sessionID, sessionID: sessionID,
err: &errClosedSession{ err: &errClosedSession{
@ -115,6 +125,7 @@ func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, me
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
m.log.Error().Msg("Datagram session unregistration timeout")
return ctx.Err() return ctx.Err()
case m.unregistrationChan <- event: case m.unregistrationChan <- event:
return nil return nil

View File

@ -120,6 +120,28 @@ func TestManagerServe(t *testing.T) {
<-serveDone <-serveDone
} }
func TestTimeout(t *testing.T) {
const (
testTimeout = time.Millisecond * 50
)
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
mg := NewManager(transport, &log)
mg.timeout = testTimeout
ctx := context.Background()
sessionID := uuid.New()
// session manager is not running, so event loop is not running and therefore calling the APIs should timeout
session, err := mg.RegisterSession(ctx, sessionID, nil)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Nil(t, session)
err = mg.UnregisterSession(ctx, sessionID, "session gone", true)
require.ErrorIs(t, err, context.DeadlineExceeded)
}
type mockOrigin struct { type mockOrigin struct {
expectMsgCount int expectMsgCount int
expectedMsg []byte expectedMsg []byte