TUN-6123: For a given connection with edge, close all datagram sessions through this connection when it's closed

This commit is contained in:
cthuang 2022-04-21 10:52:19 +01:00 committed by Chung Ting Huang
parent a97233bb3e
commit 8f0498f66a
4 changed files with 103 additions and 48 deletions

View File

@ -2,6 +2,7 @@ package datagramsession
import ( import (
"context" "context"
"fmt"
"io" "io"
"time" "time"
@ -16,6 +17,10 @@ const (
defaultReqTimeout = time.Second * 5 defaultReqTimeout = time.Second * 5
) )
var (
errSessionManagerClosed = fmt.Errorf("session manager closed")
)
// Manager defines the APIs to manage sessions from the same transport. // Manager defines the APIs to manage sessions from the same transport.
type Manager interface { type Manager interface {
// Serve starts the event loop // Serve starts the event loop
@ -30,6 +35,7 @@ type manager struct {
registrationChan chan *registerSessionEvent registrationChan chan *registerSessionEvent
unregistrationChan chan *unregisterSessionEvent unregistrationChan chan *unregisterSessionEvent
datagramChan chan *newDatagram datagramChan chan *newDatagram
closedChan chan struct{}
transport transport transport transport
sessions map[uuid.UUID]*Session sessions map[uuid.UUID]*Session
log *zerolog.Logger log *zerolog.Logger
@ -43,6 +49,7 @@ func NewManager(transport transport, log *zerolog.Logger) *manager {
unregistrationChan: make(chan *unregisterSessionEvent), unregistrationChan: make(chan *unregisterSessionEvent),
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events // datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
datagramChan: make(chan *newDatagram, requestChanCapacity), datagramChan: make(chan *newDatagram, requestChanCapacity),
closedChan: make(chan struct{}),
transport: transport, transport: transport,
sessions: make(map[uuid.UUID]*Session), sessions: make(map[uuid.UUID]*Session),
log: log, log: log,
@ -90,7 +97,24 @@ func (m *manager) Serve(ctx context.Context) error {
} }
} }
}) })
return errGroup.Wait() err := errGroup.Wait()
close(m.closedChan)
m.shutdownSessions(err)
return err
}
func (m *manager) shutdownSessions(err error) {
if err == nil {
err = errSessionManagerClosed
}
closeSessionErr := &errClosedSession{
message: err.Error(),
// Usually connection with remote has been closed, so set this to true to skip unregistering from remote
byRemote: true,
}
for _, s := range m.sessions {
s.close(closeSessionErr)
}
} }
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
@ -104,15 +128,33 @@ func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, orig
case m.registrationChan <- event: case m.registrationChan <- event:
session := <-event.resultChan session := <-event.resultChan
return session, nil return session, nil
// Once closedChan is closed, manager won't accept more registration because nothing is
// reading from registrationChan and it's an unbuffered channel
case <-m.closedChan:
return nil, errSessionManagerClosed
} }
} }
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log) session := m.newSession(registration.sessionID, registration.originProxy)
m.sessions[registration.sessionID] = session m.sessions[registration.sessionID] = session
registration.resultChan <- session registration.resultChan <- session
} }
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
return &Session{
ID: id,
transport: m.transport,
dstConn: dstConn,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2),
log: m.log,
}
}
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error { func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
ctx, cancel := context.WithTimeout(ctx, m.timeout) ctx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel() defer cancel()
@ -129,6 +171,8 @@ func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, me
return ctx.Err() return ctx.Err()
case m.unregistrationChan <- event: case m.unregistrationChan <- event:
return nil return nil
case <-m.closedChan:
return errSessionManagerClosed
} }
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@ -21,12 +22,8 @@ func TestManagerServe(t *testing.T) {
msgs = 50 msgs = 50
remoteUnregisterMsg = "eyeball closed connection" remoteUnregisterMsg = "eyeball closed connection"
) )
log := zerolog.Nop()
transport := &mockQUICTransport{ mg, transport := newTestManager(1)
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
mg := NewManager(transport, &log)
eyeballTracker := make(map[uuid.UUID]*datagramChannel) eyeballTracker := make(map[uuid.UUID]*datagramChannel)
for i := 0; i < sessions; i++ { for i := 0; i < sessions; i++ {
@ -124,12 +121,8 @@ func TestTimeout(t *testing.T) {
const ( const (
testTimeout = time.Millisecond * 50 testTimeout = time.Millisecond * 50
) )
log := zerolog.Nop()
transport := &mockQUICTransport{ mg, _ := newTestManager(1)
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
mg := NewManager(transport, &log)
mg.timeout = testTimeout mg.timeout = testTimeout
ctx := context.Background() ctx := context.Background()
sessionID := uuid.New() sessionID := uuid.New()
@ -142,6 +135,47 @@ func TestTimeout(t *testing.T) {
require.ErrorIs(t, err, context.DeadlineExceeded) require.ErrorIs(t, err, context.DeadlineExceeded)
} }
func TestCloseTransportCloseSessions(t *testing.T) {
mg, transport := newTestManager(1)
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err := mg.Serve(ctx)
require.Error(t, err)
}()
cfdConn, eyeballConn := net.Pipe()
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn)
require.NoError(t, err)
require.NotNil(t, session)
wg.Add(1)
go func() {
defer wg.Done()
_, err := eyeballConn.Write([]byte(t.Name()))
require.NoError(t, err)
transport.close()
}()
closedByRemote, err := session.Serve(ctx, time.Minute)
require.True(t, closedByRemote)
require.Error(t, err)
wg.Wait()
}
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(capacity),
respChan: newDatagramChannel(capacity),
}
return NewManager(transport, &log), transport
}
type mockOrigin struct { type mockOrigin struct {
expectMsgCount int expectMsgCount int
expectedMsg []byte expectedMsg []byte

View File

@ -39,20 +39,6 @@ type Session struct {
log *zerolog.Logger log *zerolog.Logger
} }
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session {
return &Session{
ID: id,
transport: transport,
dstConn: dstConn,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2),
log: log,
}
}
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) { func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) {
go func() { go func() {
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -41,12 +40,9 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
sessionID := uuid.New() sessionID := uuid.New()
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID) payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(1), mg, _ := newTestManager(1)
respChan: newDatagramChannel(1), session := mg.newSession(sessionID, cfdConn)
}
log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sessionDone := make(chan struct{}) sessionDone := make(chan struct{})
@ -117,12 +113,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
sessionID := uuid.New() sessionID := uuid.New()
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID) payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(100), mg, _ := newTestManager(100)
respChan: newDatagramChannel(100), session := mg.newSession(sessionID, cfdConn)
}
log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
startTime := time.Now() startTime := time.Now()
activeUntil := startTime.Add(activeTime) activeUntil := startTime.Add(activeTime)
@ -184,7 +177,8 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
func TestMarkActiveNotBlocking(t *testing.T) { func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50 const concurrentCalls = 50
session := newSession(uuid.New(), nil, nil, nil) mg, _ := newTestManager(1)
session := mg.newSession(uuid.New(), nil)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(concurrentCalls) wg.Add(concurrentCalls)
for i := 0; i < concurrentCalls; i++ { for i := 0; i < concurrentCalls; i++ {
@ -199,12 +193,9 @@ func TestMarkActiveNotBlocking(t *testing.T) {
func TestZeroBytePayload(t *testing.T) { func TestZeroBytePayload(t *testing.T) {
sessionID := uuid.New() sessionID := uuid.New()
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(1), mg, transport := newTestManager(1)
respChan: newDatagramChannel(1), session := mg.newSession(sessionID, cfdConn)
}
log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)