TUN-6123: For a given connection with edge, close all datagram sessions through this connection when it's closed
This commit is contained in:
parent
a97233bb3e
commit
8f0498f66a
|
@ -2,6 +2,7 @@ package datagramsession
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -16,6 +17,10 @@ const (
|
||||||
defaultReqTimeout = time.Second * 5
|
defaultReqTimeout = time.Second * 5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errSessionManagerClosed = fmt.Errorf("session manager closed")
|
||||||
|
)
|
||||||
|
|
||||||
// Manager defines the APIs to manage sessions from the same transport.
|
// Manager defines the APIs to manage sessions from the same transport.
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
// Serve starts the event loop
|
// Serve starts the event loop
|
||||||
|
@ -30,6 +35,7 @@ type manager struct {
|
||||||
registrationChan chan *registerSessionEvent
|
registrationChan chan *registerSessionEvent
|
||||||
unregistrationChan chan *unregisterSessionEvent
|
unregistrationChan chan *unregisterSessionEvent
|
||||||
datagramChan chan *newDatagram
|
datagramChan chan *newDatagram
|
||||||
|
closedChan chan struct{}
|
||||||
transport transport
|
transport transport
|
||||||
sessions map[uuid.UUID]*Session
|
sessions map[uuid.UUID]*Session
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
|
@ -43,6 +49,7 @@ func NewManager(transport transport, log *zerolog.Logger) *manager {
|
||||||
unregistrationChan: make(chan *unregisterSessionEvent),
|
unregistrationChan: make(chan *unregisterSessionEvent),
|
||||||
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
|
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
|
||||||
datagramChan: make(chan *newDatagram, requestChanCapacity),
|
datagramChan: make(chan *newDatagram, requestChanCapacity),
|
||||||
|
closedChan: make(chan struct{}),
|
||||||
transport: transport,
|
transport: transport,
|
||||||
sessions: make(map[uuid.UUID]*Session),
|
sessions: make(map[uuid.UUID]*Session),
|
||||||
log: log,
|
log: log,
|
||||||
|
@ -90,7 +97,24 @@ func (m *manager) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return errGroup.Wait()
|
err := errGroup.Wait()
|
||||||
|
close(m.closedChan)
|
||||||
|
m.shutdownSessions(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *manager) shutdownSessions(err error) {
|
||||||
|
if err == nil {
|
||||||
|
err = errSessionManagerClosed
|
||||||
|
}
|
||||||
|
closeSessionErr := &errClosedSession{
|
||||||
|
message: err.Error(),
|
||||||
|
// Usually connection with remote has been closed, so set this to true to skip unregistering from remote
|
||||||
|
byRemote: true,
|
||||||
|
}
|
||||||
|
for _, s := range m.sessions {
|
||||||
|
s.close(closeSessionErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
|
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
|
||||||
|
@ -104,15 +128,33 @@ func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, orig
|
||||||
case m.registrationChan <- event:
|
case m.registrationChan <- event:
|
||||||
session := <-event.resultChan
|
session := <-event.resultChan
|
||||||
return session, nil
|
return session, nil
|
||||||
|
// Once closedChan is closed, manager won't accept more registration because nothing is
|
||||||
|
// reading from registrationChan and it's an unbuffered channel
|
||||||
|
case <-m.closedChan:
|
||||||
|
return nil, errSessionManagerClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
|
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
|
||||||
session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log)
|
session := m.newSession(registration.sessionID, registration.originProxy)
|
||||||
m.sessions[registration.sessionID] = session
|
m.sessions[registration.sessionID] = session
|
||||||
registration.resultChan <- session
|
registration.resultChan <- session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
|
||||||
|
return &Session{
|
||||||
|
ID: id,
|
||||||
|
transport: m.transport,
|
||||||
|
dstConn: dstConn,
|
||||||
|
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
||||||
|
// drop instead of blocking because last active time only needs to be an approximation
|
||||||
|
activeAtChan: make(chan time.Time, 2),
|
||||||
|
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
|
||||||
|
closeChan: make(chan error, 2),
|
||||||
|
log: m.log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
|
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
ctx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -129,6 +171,8 @@ func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, me
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case m.unregistrationChan <- event:
|
case m.unregistrationChan <- event:
|
||||||
return nil
|
return nil
|
||||||
|
case <-m.closedChan:
|
||||||
|
return errSessionManagerClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -21,12 +22,8 @@ func TestManagerServe(t *testing.T) {
|
||||||
msgs = 50
|
msgs = 50
|
||||||
remoteUnregisterMsg = "eyeball closed connection"
|
remoteUnregisterMsg = "eyeball closed connection"
|
||||||
)
|
)
|
||||||
log := zerolog.Nop()
|
|
||||||
transport := &mockQUICTransport{
|
mg, transport := newTestManager(1)
|
||||||
reqChan: newDatagramChannel(1),
|
|
||||||
respChan: newDatagramChannel(1),
|
|
||||||
}
|
|
||||||
mg := NewManager(transport, &log)
|
|
||||||
|
|
||||||
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
|
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
|
||||||
for i := 0; i < sessions; i++ {
|
for i := 0; i < sessions; i++ {
|
||||||
|
@ -124,12 +121,8 @@ func TestTimeout(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
testTimeout = time.Millisecond * 50
|
testTimeout = time.Millisecond * 50
|
||||||
)
|
)
|
||||||
log := zerolog.Nop()
|
|
||||||
transport := &mockQUICTransport{
|
mg, _ := newTestManager(1)
|
||||||
reqChan: newDatagramChannel(1),
|
|
||||||
respChan: newDatagramChannel(1),
|
|
||||||
}
|
|
||||||
mg := NewManager(transport, &log)
|
|
||||||
mg.timeout = testTimeout
|
mg.timeout = testTimeout
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
sessionID := uuid.New()
|
sessionID := uuid.New()
|
||||||
|
@ -142,6 +135,47 @@ func TestTimeout(t *testing.T) {
|
||||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseTransportCloseSessions(t *testing.T) {
|
||||||
|
mg, transport := newTestManager(1)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := mg.Serve(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfdConn, eyeballConn := net.Pipe()
|
||||||
|
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, session)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, err := eyeballConn.Write([]byte(t.Name()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
transport.close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
closedByRemote, err := session.Serve(ctx, time.Minute)
|
||||||
|
require.True(t, closedByRemote)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
|
||||||
|
log := zerolog.Nop()
|
||||||
|
transport := &mockQUICTransport{
|
||||||
|
reqChan: newDatagramChannel(capacity),
|
||||||
|
respChan: newDatagramChannel(capacity),
|
||||||
|
}
|
||||||
|
return NewManager(transport, &log), transport
|
||||||
|
}
|
||||||
|
|
||||||
type mockOrigin struct {
|
type mockOrigin struct {
|
||||||
expectMsgCount int
|
expectMsgCount int
|
||||||
expectedMsg []byte
|
expectedMsg []byte
|
||||||
|
|
|
@ -39,20 +39,6 @@ type Session struct {
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session {
|
|
||||||
return &Session{
|
|
||||||
ID: id,
|
|
||||||
transport: transport,
|
|
||||||
dstConn: dstConn,
|
|
||||||
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
|
||||||
// drop instead of blocking because last active time only needs to be an approximation
|
|
||||||
activeAtChan: make(chan time.Time, 2),
|
|
||||||
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
|
|
||||||
closeChan: make(chan error, 2),
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) {
|
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) {
|
||||||
go func() {
|
go func() {
|
||||||
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
@ -41,12 +40,9 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
|
||||||
sessionID := uuid.New()
|
sessionID := uuid.New()
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
payload := testPayload(sessionID)
|
payload := testPayload(sessionID)
|
||||||
transport := &mockQUICTransport{
|
|
||||||
reqChan: newDatagramChannel(1),
|
mg, _ := newTestManager(1)
|
||||||
respChan: newDatagramChannel(1),
|
session := mg.newSession(sessionID, cfdConn)
|
||||||
}
|
|
||||||
log := zerolog.Nop()
|
|
||||||
session := newSession(sessionID, transport, cfdConn, &log)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
sessionDone := make(chan struct{})
|
sessionDone := make(chan struct{})
|
||||||
|
@ -117,12 +113,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
|
||||||
sessionID := uuid.New()
|
sessionID := uuid.New()
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
payload := testPayload(sessionID)
|
payload := testPayload(sessionID)
|
||||||
transport := &mockQUICTransport{
|
|
||||||
reqChan: newDatagramChannel(100),
|
mg, _ := newTestManager(100)
|
||||||
respChan: newDatagramChannel(100),
|
session := mg.newSession(sessionID, cfdConn)
|
||||||
}
|
|
||||||
log := zerolog.Nop()
|
|
||||||
session := newSession(sessionID, transport, cfdConn, &log)
|
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
activeUntil := startTime.Add(activeTime)
|
activeUntil := startTime.Add(activeTime)
|
||||||
|
@ -184,7 +177,8 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
|
||||||
|
|
||||||
func TestMarkActiveNotBlocking(t *testing.T) {
|
func TestMarkActiveNotBlocking(t *testing.T) {
|
||||||
const concurrentCalls = 50
|
const concurrentCalls = 50
|
||||||
session := newSession(uuid.New(), nil, nil, nil)
|
mg, _ := newTestManager(1)
|
||||||
|
session := mg.newSession(uuid.New(), nil)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(concurrentCalls)
|
wg.Add(concurrentCalls)
|
||||||
for i := 0; i < concurrentCalls; i++ {
|
for i := 0; i < concurrentCalls; i++ {
|
||||||
|
@ -199,12 +193,9 @@ func TestMarkActiveNotBlocking(t *testing.T) {
|
||||||
func TestZeroBytePayload(t *testing.T) {
|
func TestZeroBytePayload(t *testing.T) {
|
||||||
sessionID := uuid.New()
|
sessionID := uuid.New()
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
transport := &mockQUICTransport{
|
|
||||||
reqChan: newDatagramChannel(1),
|
mg, transport := newTestManager(1)
|
||||||
respChan: newDatagramChannel(1),
|
session := mg.newSession(sessionID, cfdConn)
|
||||||
}
|
|
||||||
log := zerolog.Nop()
|
|
||||||
session := newSession(sessionID, transport, cfdConn, &log)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
Loading…
Reference in New Issue