

10 changed files with 674 additions and 162 deletions
@ -1,90 +0,0 @@
|
||||
package connection |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
"sync" |
||||
|
||||
"github.com/google/uuid" |
||||
) |
||||
|
||||
// TODO: TUN-5422 Unregister session
|
||||
const ( |
||||
sessionIDLen = len(uuid.UUID{}) |
||||
) |
||||
|
||||
type udpSessions struct { |
||||
lock sync.RWMutex |
||||
sessions map[uuid.UUID]*net.UDPConn |
||||
localIP net.IP |
||||
} |
||||
|
||||
func newUDPSessions(localIP net.IP) *udpSessions { |
||||
return &udpSessions{ |
||||
sessions: make(map[uuid.UUID]*net.UDPConn), |
||||
localIP: localIP, |
||||
} |
||||
} |
||||
|
||||
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) { |
||||
us.lock.Lock() |
||||
defer us.lock.Unlock() |
||||
dstAddr := &net.UDPAddr{ |
||||
IP: dstIP, |
||||
Port: int(dstPort), |
||||
} |
||||
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
us.sessions[id] = conn |
||||
return conn, nil |
||||
} |
||||
|
||||
func (us *udpSessions) unregister(id uuid.UUID) { |
||||
us.lock.Lock() |
||||
defer us.lock.Unlock() |
||||
delete(us.sessions, id) |
||||
} |
||||
|
||||
func (us *udpSessions) send(id uuid.UUID, payload []byte) error { |
||||
us.lock.RLock() |
||||
defer us.lock.RUnlock() |
||||
conn, ok := us.sessions[id] |
||||
if !ok { |
||||
return fmt.Errorf("session %s not found", id) |
||||
} |
||||
_, err := conn.Write(payload) |
||||
return err |
||||
} |
||||
|
||||
func (ud *udpSessions) localAddr() *net.UDPAddr { |
||||
// TODO: Determine the IP to bind to
|
||||
|
||||
return &net.UDPAddr{ |
||||
IP: ud.localIP, |
||||
Port: 0, |
||||
} |
||||
} |
||||
|
||||
// TODO: TUN-5421 allow user to specify which IP to bind to
|
||||
func GetLocalIP() (net.IP, error) { |
||||
addrs, err := net.InterfaceAddrs() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
for _, addr := range addrs { |
||||
// Find the IP that is not loop back
|
||||
var ip net.IP |
||||
switch v := addr.(type) { |
||||
case *net.IPNet: |
||||
ip = v.IP |
||||
case *net.IPAddr: |
||||
ip = v.IP |
||||
} |
||||
if !ip.IsLoopback() { |
||||
return ip, nil |
||||
} |
||||
} |
||||
return nil, fmt.Errorf("cannot determine IP to bind to") |
||||
} |
@ -0,0 +1,33 @@
|
||||
package datagramsession |
||||
|
||||
import ( |
||||
"io" |
||||
|
||||
"github.com/google/uuid" |
||||
) |
||||
|
||||
// registerSessionEvent is an event to start tracking a new session
|
||||
type registerSessionEvent struct { |
||||
sessionID uuid.UUID |
||||
originProxy io.ReadWriteCloser |
||||
resultChan chan *Session |
||||
} |
||||
|
||||
func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser) *registerSessionEvent { |
||||
return ®isterSessionEvent{ |
||||
sessionID: sessionID, |
||||
originProxy: originProxy, |
||||
resultChan: make(chan *Session, 1), |
||||
} |
||||
} |
||||
|
||||
// unregisterSessionEvent is an event to stop tracking and terminate the session.
|
||||
type unregisterSessionEvent struct { |
||||
sessionID uuid.UUID |
||||
} |
||||
|
||||
// newDatagram is an event when transport receives new datagram
|
||||
type newDatagram struct { |
||||
sessionID uuid.UUID |
||||
payload []byte |
||||
} |
@ -0,0 +1,134 @@
|
||||
package datagramsession |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
|
||||
"github.com/google/uuid" |
||||
"github.com/rs/zerolog" |
||||
"golang.org/x/sync/errgroup" |
||||
) |
||||
|
||||
const ( |
||||
requestChanCapacity = 16 |
||||
) |
||||
|
||||
// Manager defines the APIs to manage sessions from the same transport.
|
||||
type Manager interface { |
||||
// Serve starts the event loop
|
||||
Serve(ctx context.Context) error |
||||
// RegisterSession starts tracking a session. Caller is responsible for starting the session
|
||||
RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error) |
||||
// UnregisterSession stops tracking the session and terminates it
|
||||
UnregisterSession(ctx context.Context, sessionID uuid.UUID) error |
||||
} |
||||
|
||||
type manager struct { |
||||
registrationChan chan *registerSessionEvent |
||||
unregistrationChan chan *unregisterSessionEvent |
||||
datagramChan chan *newDatagram |
||||
transport transport |
||||
sessions map[uuid.UUID]*Session |
||||
log *zerolog.Logger |
||||
} |
||||
|
||||
func NewManager(transport transport, log *zerolog.Logger) Manager { |
||||
return &manager{ |
||||
registrationChan: make(chan *registerSessionEvent), |
||||
unregistrationChan: make(chan *unregisterSessionEvent), |
||||
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
|
||||
datagramChan: make(chan *newDatagram, requestChanCapacity), |
||||
transport: transport, |
||||
sessions: make(map[uuid.UUID]*Session), |
||||
log: log, |
||||
} |
||||
} |
||||
|
||||
func (m *manager) Serve(ctx context.Context) error { |
||||
errGroup, ctx := errgroup.WithContext(ctx) |
||||
errGroup.Go(func() error { |
||||
for { |
||||
sessionID, payload, err := m.transport.ReceiveFrom() |
||||
if err != nil { |
||||
m.log.Err(err).Msg("Failed to receive datagram from transport, closing session manager") |
||||
return err |
||||
} |
||||
datagram := &newDatagram{ |
||||
sessionID: sessionID, |
||||
payload: payload, |
||||
} |
||||
select { |
||||
case <-ctx.Done(): |
||||
return ctx.Err() |
||||
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
|
||||
// Send the datagram to the event loop. It will find the session to send to
|
||||
case m.datagramChan <- datagram: |
||||
} |
||||
} |
||||
}) |
||||
errGroup.Go(func() error { |
||||
for { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return ctx.Err() |
||||
case datagram := <-m.datagramChan: |
||||
m.sendToSession(datagram) |
||||
case registration := <-m.registrationChan: |
||||
m.registerSession(ctx, registration) |
||||
// TODO: TUN-5422: Unregister inactive session upon timeout
|
||||
case unregistration := <-m.unregistrationChan: |
||||
m.unregisterSession(unregistration) |
||||
} |
||||
} |
||||
}) |
||||
return errGroup.Wait() |
||||
} |
||||
|
||||
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { |
||||
event := newRegisterSessionEvent(sessionID, originProxy) |
||||
select { |
||||
case <-ctx.Done(): |
||||
return nil, ctx.Err() |
||||
case m.registrationChan <- event: |
||||
session := <-event.resultChan |
||||
return session, nil |
||||
} |
||||
} |
||||
|
||||
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { |
||||
session := newSession(registration.sessionID, m.transport, registration.originProxy) |
||||
m.sessions[registration.sessionID] = session |
||||
registration.resultChan <- session |
||||
} |
||||
|
||||
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID) error { |
||||
event := &unregisterSessionEvent{sessionID: sessionID} |
||||
select { |
||||
case <-ctx.Done(): |
||||
return ctx.Err() |
||||
case m.unregistrationChan <- event: |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) { |
||||
session, ok := m.sessions[unregistration.sessionID] |
||||
if ok { |
||||
delete(m.sessions, unregistration.sessionID) |
||||
session.close() |
||||
} |
||||
} |
||||
|
||||
func (m *manager) sendToSession(datagram *newDatagram) { |
||||
session, ok := m.sessions[datagram.sessionID] |
||||
if !ok { |
||||
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found") |
||||
return |
||||
} |
||||
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
||||
// need to run in another go routine
|
||||
_, err := session.writeToDst(datagram.payload) |
||||
if err != nil { |
||||
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session") |
||||
} |
||||
} |
@ -0,0 +1,214 @@
|
||||
package datagramsession |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"testing" |
||||
|
||||
"github.com/google/uuid" |
||||
"github.com/rs/zerolog" |
||||
"github.com/stretchr/testify/require" |
||||
"golang.org/x/sync/errgroup" |
||||
) |
||||
|
||||
func TestManagerServe(t *testing.T) { |
||||
const ( |
||||
sessions = 20 |
||||
msgs = 50 |
||||
) |
||||
log := zerolog.Nop() |
||||
transport := &mockQUICTransport{ |
||||
reqChan: newDatagramChannel(), |
||||
respChan: newDatagramChannel(), |
||||
} |
||||
mg := NewManager(transport, &log) |
||||
|
||||
eyeballTracker := make(map[uuid.UUID]*datagramChannel) |
||||
for i := 0; i < sessions; i++ { |
||||
sessionID := uuid.New() |
||||
eyeballTracker[sessionID] = newDatagramChannel() |
||||
} |
||||
|
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
serveDone := make(chan struct{}) |
||||
go func(ctx context.Context) { |
||||
mg.Serve(ctx) |
||||
close(serveDone) |
||||
}(ctx) |
||||
|
||||
go func(ctx context.Context) { |
||||
for { |
||||
sessionID, payload, err := transport.respChan.Receive(ctx) |
||||
if err != nil { |
||||
require.Equal(t, context.Canceled, err) |
||||
return |
||||
} |
||||
respChan := eyeballTracker[sessionID] |
||||
require.NoError(t, respChan.Send(ctx, sessionID, payload)) |
||||
} |
||||
}(ctx) |
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx) |
||||
for sID, receiver := range eyeballTracker { |
||||
// Assign loop variables to local variables
|
||||
sessionID := sID |
||||
eyeballRespReceiver := receiver |
||||
errGroup.Go(func() error { |
||||
payload := testPayload(sessionID) |
||||
expectResp := testResponse(payload) |
||||
|
||||
cfdConn, originConn := net.Pipe() |
||||
|
||||
origin := mockOrigin{ |
||||
expectMsgCount: msgs, |
||||
expectedMsg: payload, |
||||
expectedResp: expectResp, |
||||
conn: originConn, |
||||
} |
||||
eyeball := mockEyeball{ |
||||
expectMsgCount: msgs, |
||||
expectedMsg: expectResp, |
||||
expectSessionID: sessionID, |
||||
respReceiver: eyeballRespReceiver, |
||||
} |
||||
|
||||
reqErrGroup, reqCtx := errgroup.WithContext(ctx) |
||||
reqErrGroup.Go(func() error { |
||||
return origin.serve() |
||||
}) |
||||
reqErrGroup.Go(func() error { |
||||
return eyeball.serve(reqCtx) |
||||
}) |
||||
|
||||
session, err := mg.RegisterSession(ctx, sessionID, cfdConn) |
||||
require.NoError(t, err) |
||||
|
||||
sessionDone := make(chan struct{}) |
||||
go func() { |
||||
session.Serve(ctx) |
||||
close(sessionDone) |
||||
}() |
||||
|
||||
for i := 0; i < msgs; i++ { |
||||
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID))) |
||||
} |
||||
|
||||
// Make sure eyeball and origin have received all messages before unregistering the session
|
||||
require.NoError(t, reqErrGroup.Wait()) |
||||
|
||||
require.NoError(t, mg.UnregisterSession(ctx, sessionID)) |
||||
<-sessionDone |
||||
|
||||
return nil |
||||
}) |
||||
} |
||||
|
||||
require.NoError(t, errGroup.Wait()) |
||||
cancel() |
||||
transport.close() |
||||
<-serveDone |
||||
} |
||||
|
||||
type mockOrigin struct { |
||||
expectMsgCount int |
||||
expectedMsg []byte |
||||
expectedResp []byte |
||||
conn io.ReadWriteCloser |
||||
} |
||||
|
||||
func (mo *mockOrigin) serve() error { |
||||
expectedMsgLen := len(mo.expectedMsg) |
||||
readBuffer := make([]byte, expectedMsgLen+1) |
||||
for i := 0; i < mo.expectMsgCount; i++ { |
||||
n, err := mo.conn.Read(readBuffer) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if n != expectedMsgLen { |
||||
return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n) |
||||
} |
||||
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) { |
||||
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n]) |
||||
} |
||||
|
||||
_, err = mo.conn.Write(mo.expectedResp) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func testPayload(sessionID uuid.UUID) []byte { |
||||
return []byte(fmt.Sprintf("Message from %s", sessionID)) |
||||
} |
||||
|
||||
func testResponse(msg []byte) []byte { |
||||
return []byte(fmt.Sprintf("Response to %v", msg)) |
||||
} |
||||
|
||||
type mockEyeball struct { |
||||
expectMsgCount int |
||||
expectedMsg []byte |
||||
expectSessionID uuid.UUID |
||||
respReceiver *datagramChannel |
||||
} |
||||
|
||||
func (me *mockEyeball) serve(ctx context.Context) error { |
||||
for i := 0; i < me.expectMsgCount; i++ { |
||||
sessionID, msg, err := me.respReceiver.Receive(ctx) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if sessionID != me.expectSessionID { |
||||
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID) |
||||
} |
||||
if !bytes.Equal(msg, me.expectedMsg) { |
||||
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
|
||||
type datagramChannel struct { |
||||
datagramChan chan *newDatagram |
||||
closedChan chan struct{} |
||||
} |
||||
|
||||
func newDatagramChannel() *datagramChannel { |
||||
return &datagramChannel{ |
||||
datagramChan: make(chan *newDatagram, 1), |
||||
closedChan: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return ctx.Err() |
||||
case <-rc.closedChan: |
||||
return fmt.Errorf("datagram channel closed") |
||||
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}: |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return uuid.Nil, nil, ctx.Err() |
||||
case <-rc.closedChan: |
||||
return uuid.Nil, nil, fmt.Errorf("datagram channel closed") |
||||
case msg := <-rc.datagramChan: |
||||
return msg.sessionID, msg.payload, nil |
||||
} |
||||
} |
||||
|
||||
func (rc *datagramChannel) Close() { |
||||
// No need to close msgChan, it will be garbage collect once there is no reference to it
|
||||
close(rc.closedChan) |
||||
} |
@ -0,0 +1,70 @@
|
||||
package datagramsession |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
|
||||
"github.com/google/uuid" |
||||
) |
||||
|
||||
// Each Session is a bidirectional pipe of datagrams between transport and dstConn
|
||||
// Currently the only implementation of transport is quic DatagramMuxer
|
||||
// Destination can be a connection with origin or with eyeball
|
||||
// When the destination is origin:
|
||||
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
||||
// write method of the Session to send to origin
|
||||
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
|
||||
// When the destination is eyeball:
|
||||
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
|
||||
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
||||
// write method of the Session to send to eyeball
|
||||
type Session struct { |
||||
id uuid.UUID |
||||
transport transport |
||||
dstConn io.ReadWriteCloser |
||||
doneChan chan struct{} |
||||
} |
||||
|
||||
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session { |
||||
return &Session{ |
||||
id: id, |
||||
transport: transport, |
||||
dstConn: dstConn, |
||||
doneChan: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
func (s *Session) Serve(ctx context.Context) error { |
||||
serveCtx, cancel := context.WithCancel(ctx) |
||||
defer cancel() |
||||
go func() { |
||||
select { |
||||
case <-serveCtx.Done(): |
||||
case <-s.doneChan: |
||||
} |
||||
s.dstConn.Close() |
||||
}() |
||||
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||
// This makes it safe to share readBuffer between iterations
|
||||
readBuffer := make([]byte, 1280) |
||||
for { |
||||
// TODO: TUN-5303: origin proxy should determine the buffer size
|
||||
n, err := s.dstConn.Read(readBuffer) |
||||
if n > 0 { |
||||
if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (s *Session) writeToDst(payload []byte) (int, error) { |
||||
return s.dstConn.Write(payload) |
||||
} |
||||
|
||||
func (s *Session) close() { |
||||
close(s.doneChan) |
||||
} |
@ -0,0 +1,59 @@
|
||||
package datagramsession |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
"testing" |
||||
|
||||
"github.com/google/uuid" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
// TestCloseSession makes sure a session will stop after context is done
|
||||
func TestSessionCtxDone(t *testing.T) { |
||||
testSessionReturns(t, true) |
||||
} |
||||
|
||||
// TestCloseSession makes sure a session will stop after close method is called
|
||||
func TestCloseSession(t *testing.T) { |
||||
testSessionReturns(t, false) |
||||
} |
||||
|
||||
func testSessionReturns(t *testing.T, closeByContext bool) { |
||||
sessionID := uuid.New() |
||||
cfdConn, originConn := net.Pipe() |
||||
payload := testPayload(sessionID) |
||||
transport := &mockQUICTransport{ |
||||
reqChan: newDatagramChannel(), |
||||
respChan: newDatagramChannel(), |
||||
} |
||||
session := newSession(sessionID, transport, cfdConn) |
||||
|
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
sessionDone := make(chan struct{}) |
||||
go func() { |
||||
session.Serve(ctx) |
||||
close(sessionDone) |
||||
}() |
||||
|
||||
go func() { |
||||
n, err := session.writeToDst(payload) |
||||
require.NoError(t, err) |
||||
require.Equal(t, len(payload), n) |
||||
}() |
||||
|
||||
readBuffer := make([]byte, len(payload)+1) |
||||
n, err := originConn.Read(readBuffer) |
||||
require.NoError(t, err) |
||||
require.Equal(t, len(payload), n) |
||||
|
||||
if closeByContext { |
||||
cancel() |
||||
} else { |
||||
session.close() |
||||
} |
||||
|
||||
<-sessionDone |
||||
// call cancelled again otherwise the linter will warn about possible context leak
|
||||
cancel() |
||||
} |
@ -0,0 +1,11 @@
|
||||
package datagramsession |
||||
|
||||
import "github.com/google/uuid" |
||||
|
||||
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
|
||||
type transport interface { |
||||
// SendTo writes payload for a session to the transport
|
||||
SendTo(sessionID uuid.UUID, payload []byte) error |
||||
// ReceiveFrom reads the next datagram from the transport
|
||||
ReceiveFrom() (uuid.UUID, []byte, error) |
||||
} |
@ -0,0 +1,32 @@
|
||||
package datagramsession |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/google/uuid" |
||||
) |
||||
|
||||
type mockQUICTransport struct { |
||||
reqChan *datagramChannel |
||||
respChan *datagramChannel |
||||
} |
||||
|
||||
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error { |
||||
buf := make([]byte, len(payload)) |
||||
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||
copy(buf, payload) |
||||
return mt.respChan.Send(context.Background(), sessionID, buf) |
||||
} |
||||
|
||||
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) { |
||||
return mt.reqChan.Receive(context.Background()) |
||||
} |
||||
|
||||
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error { |
||||
return mt.reqChan.Send(ctx, sessionID, payload) |
||||
} |
||||
|
||||
func (mt *mockQUICTransport) close() { |
||||
mt.reqChan.Close() |
||||
mt.respChan.Close() |
||||
} |
Loading…
Reference in new issue