diff --git a/datagramsession/manager.go b/datagramsession/manager.go index c4144279..a691e3f9 100644 --- a/datagramsession/manager.go +++ b/datagramsession/manager.go @@ -2,6 +2,7 @@ package datagramsession import ( "context" + "fmt" "io" "time" @@ -16,6 +17,10 @@ const ( defaultReqTimeout = time.Second * 5 ) +var ( + errSessionManagerClosed = fmt.Errorf("session manager closed") +) + // Manager defines the APIs to manage sessions from the same transport. type Manager interface { // Serve starts the event loop @@ -30,6 +35,7 @@ type manager struct { registrationChan chan *registerSessionEvent unregistrationChan chan *unregisterSessionEvent datagramChan chan *newDatagram + closedChan chan struct{} transport transport sessions map[uuid.UUID]*Session log *zerolog.Logger @@ -43,6 +49,7 @@ func NewManager(transport transport, log *zerolog.Logger) *manager { unregistrationChan: make(chan *unregisterSessionEvent), // datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events datagramChan: make(chan *newDatagram, requestChanCapacity), + closedChan: make(chan struct{}), transport: transport, sessions: make(map[uuid.UUID]*Session), log: log, @@ -90,7 +97,24 @@ func (m *manager) Serve(ctx context.Context) error { } } }) - return errGroup.Wait() + err := errGroup.Wait() + close(m.closedChan) + m.shutdownSessions(err) + return err +} + +func (m *manager) shutdownSessions(err error) { + if err == nil { + err = errSessionManagerClosed + } + closeSessionErr := &errClosedSession{ + message: err.Error(), + // Usually connection with remote has been closed, so set this to true to skip unregistering from remote + byRemote: true, + } + for _, s := range m.sessions { + s.close(closeSessionErr) + } } func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { @@ -104,15 +128,33 @@ func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, orig case m.registrationChan <- event: session := <-event.resultChan return session, nil + // Once closedChan is closed, manager won't accept more registration because nothing is + // reading from registrationChan and it's an unbuffered channel + case <-m.closedChan: + return nil, errSessionManagerClosed } } func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { - session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log) + session := m.newSession(registration.sessionID, registration.originProxy) m.sessions[registration.sessionID] = session registration.resultChan <- session } +func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session { + return &Session{ + ID: id, + transport: m.transport, + dstConn: dstConn, + // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will + // drop instead of blocking because last active time only needs to be an approximation + activeAtChan: make(chan time.Time, 2), + // capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel + closeChan: make(chan error, 2), + log: m.log, + } +} + func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error { ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() @@ -129,6 +171,8 @@ func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, me return ctx.Err() case m.unregistrationChan <- event: return nil + case <-m.closedChan: + return errSessionManagerClosed } } diff --git a/datagramsession/manager_test.go b/datagramsession/manager_test.go index 823c55bb..a9361624 100644 --- a/datagramsession/manager_test.go +++ b/datagramsession/manager_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "sync" "testing" "time" @@ -21,12 +22,8 @@ func TestManagerServe(t *testing.T) { msgs = 50 remoteUnregisterMsg = "eyeball closed connection" ) - log := zerolog.Nop() - transport := &mockQUICTransport{ - reqChan: newDatagramChannel(1), - respChan: newDatagramChannel(1), - } - mg := NewManager(transport, &log) + + mg, transport := newTestManager(1) eyeballTracker := make(map[uuid.UUID]*datagramChannel) for i := 0; i < sessions; i++ { @@ -124,12 +121,8 @@ func TestTimeout(t *testing.T) { const ( testTimeout = time.Millisecond * 50 ) - log := zerolog.Nop() - transport := &mockQUICTransport{ - reqChan: newDatagramChannel(1), - respChan: newDatagramChannel(1), - } - mg := NewManager(transport, &log) + + mg, _ := newTestManager(1) mg.timeout = testTimeout ctx := context.Background() sessionID := uuid.New() @@ -142,6 +135,47 @@ func TestTimeout(t *testing.T) { require.ErrorIs(t, err, context.DeadlineExceeded) } +func TestCloseTransportCloseSessions(t *testing.T) { + mg, transport := newTestManager(1) + ctx := context.Background() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := mg.Serve(ctx) + require.Error(t, err) + }() + + cfdConn, eyeballConn := net.Pipe() + session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn) + require.NoError(t, err) + require.NotNil(t, session) + + wg.Add(1) + go func() { + defer wg.Done() + _, err := eyeballConn.Write([]byte(t.Name())) + require.NoError(t, err) + transport.close() + }() + + closedByRemote, err := session.Serve(ctx, time.Minute) + require.True(t, closedByRemote) + require.Error(t, err) + + wg.Wait() +} + +func newTestManager(capacity uint) (*manager, *mockQUICTransport) { + log := zerolog.Nop() + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(capacity), + respChan: newDatagramChannel(capacity), + } + return NewManager(transport, &log), transport +} + type mockOrigin struct { expectMsgCount int expectedMsg []byte diff --git a/datagramsession/session.go b/datagramsession/session.go index 0a1199e8..351ae298 100644 --- a/datagramsession/session.go +++ b/datagramsession/session.go @@ -39,20 +39,6 @@ type Session struct { log *zerolog.Logger } -func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session { - return &Session{ - ID: id, - transport: transport, - dstConn: dstConn, - // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will - // drop instead of blocking because last active time only needs to be an approximation - activeAtChan: make(chan time.Time, 2), - // capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel - closeChan: make(chan error, 2), - log: log, - } -} - func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) { go func() { // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 diff --git a/datagramsession/session_test.go b/datagramsession/session_test.go index bf1b1e20..db4201f6 100644 --- a/datagramsession/session_test.go +++ b/datagramsession/session_test.go @@ -11,7 +11,6 @@ import ( "time" "github.com/google/uuid" - "github.com/rs/zerolog" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) @@ -41,12 +40,9 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D sessionID := uuid.New() cfdConn, originConn := net.Pipe() payload := testPayload(sessionID) - transport := &mockQUICTransport{ - reqChan: newDatagramChannel(1), - respChan: newDatagramChannel(1), - } - log := zerolog.Nop() - session := newSession(sessionID, transport, cfdConn, &log) + + mg, _ := newTestManager(1) + session := mg.newSession(sessionID, cfdConn) ctx, cancel := context.WithCancel(context.Background()) sessionDone := make(chan struct{}) @@ -117,12 +113,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) sessionID := uuid.New() cfdConn, originConn := net.Pipe() payload := testPayload(sessionID) - transport := &mockQUICTransport{ - reqChan: newDatagramChannel(100), - respChan: newDatagramChannel(100), - } - log := zerolog.Nop() - session := newSession(sessionID, transport, cfdConn, &log) + + mg, _ := newTestManager(100) + session := mg.newSession(sessionID, cfdConn) startTime := time.Now() activeUntil := startTime.Add(activeTime) @@ -184,7 +177,8 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) func TestMarkActiveNotBlocking(t *testing.T) { const concurrentCalls = 50 - session := newSession(uuid.New(), nil, nil, nil) + mg, _ := newTestManager(1) + session := mg.newSession(uuid.New(), nil) var wg sync.WaitGroup wg.Add(concurrentCalls) for i := 0; i < concurrentCalls; i++ { @@ -199,12 +193,9 @@ func TestMarkActiveNotBlocking(t *testing.T) { func TestZeroBytePayload(t *testing.T) { sessionID := uuid.New() cfdConn, originConn := net.Pipe() - transport := &mockQUICTransport{ - reqChan: newDatagramChannel(1), - respChan: newDatagramChannel(1), - } - log := zerolog.Nop() - session := newSession(sessionID, transport, cfdConn, &log) + + mg, transport := newTestManager(1) + session := mg.newSession(sessionID, cfdConn) ctx, cancel := context.WithCancel(context.Background()) errGroup, ctx := errgroup.WithContext(ctx)