TUN-5301: Separate datagram multiplex and session management logic from quic connection logic

This commit is contained in:
cthuang 2021-11-23 12:45:59 +00:00 committed by Arég Harutyunyan
parent dd32dc1364
commit eea3d11e40
10 changed files with 675 additions and 163 deletions

View File

@ -16,6 +16,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/datagramsession"
quicpogs "github.com/cloudflare/cloudflared/quic" quicpogs "github.com/cloudflare/cloudflared/quic"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
@ -32,10 +33,11 @@ const (
// QUICConnection represents the type that facilitates Proxying via QUIC streams. // QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct { type QUICConnection struct {
session quic.Session session quic.Session
logger *zerolog.Logger logger *zerolog.Logger
httpProxy OriginProxy httpProxy OriginProxy
udpSessions *udpSessions sessionManager datagramsession.Manager
localIP net.IP
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -49,12 +51,6 @@ func NewQUICConnection(
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,
observer *Observer, observer *Observer,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
localIP, err := GetLocalIP()
if err != nil {
return nil, err
}
observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP)
udpSessions := newUDPSessions(localIP)
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to dial to edge: %w", err) return nil, fmt.Errorf("failed to dial to edge: %w", err)
@ -71,24 +67,36 @@ func NewQUICConnection(
return nil, err return nil, err
} }
datagramMuxer, err := quicpogs.NewDatagramMuxer(session)
if err != nil {
return nil, err
}
sessionManager := datagramsession.NewManager(datagramMuxer, observer.log)
localIP, err := getLocalIP()
if err != nil {
return nil, err
}
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
httpProxy: httpProxy, httpProxy: httpProxy,
logger: observer.log, logger: observer.log,
udpSessions: udpSessions, sessionManager: sessionManager,
localIP: localIP,
}, nil }, nil
} }
// Serve starts a QUIC session that begins accepting streams. // Serve starts a QUIC session that begins accepting streams.
func (q *QUICConnection) Serve(ctx context.Context) error { func (q *QUICConnection) Serve(ctx context.Context) error {
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return q.listenEdgeDatagram()
})
errGroup.Go(func() error { errGroup.Go(func() error {
return q.acceptStream(ctx) return q.acceptStream(ctx)
}) })
errGroup.Go(func() error {
return q.sessionManager.Serve(ctx)
})
return errGroup.Wait() return errGroup.Wait()
} }
@ -111,26 +119,6 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
} }
} }
// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload
func (q *QUICConnection) listenEdgeDatagram() error {
for {
msg, err := q.session.ReceiveMessage()
if err != nil {
return err
}
go func(msg []byte) {
sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg)
if err != nil {
q.logger.Err(err).Msg("Failed to parse session ID from datagram")
return
}
if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil {
q.logger.Err(err).Msg("Failed to send UDP to origin")
}
}(msg)
}
}
// Close closes the session with no errors specified. // Close closes the session with no errors specified.
func (q *QUICConnection) Close() { func (q *QUICConnection) Close() {
q.session.CloseWithError(0, "") q.session.CloseWithError(0, "")
@ -186,46 +174,29 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
} }
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error { func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort) // Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := q.newUDPProxy(dstIP, dstPort)
if err != nil { if err != nil {
q.logger.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
return err return err
} }
q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort) session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
go q.listenOriginUDP(sessionID, udpConn) if err != nil {
q.logger.Err(err).Msgf("Failed to register udp session %s", sessionID)
return err
}
go func() {
defer q.sessionManager.UnregisterSession(q.session.Context(), sessionID)
if err := session.Serve(q.session.Context()); err != nil {
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).Msg("session terminated")
}
}()
q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort)
return nil return nil
} }
// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin // TODO: TUN-5422 Implement UnregisterUdpSession RPC
func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) {
defer func() {
q.udpSessions.unregister(sessionID)
conn.Close()
}()
readBuffer := make([]byte, MaxDatagramFrameSize)
for {
n, err := conn.Read(readBuffer)
if n > 0 {
if n > MaxDatagramFrameSize-sessionIDLen {
// TODO: TUN-5302 return ICMP packet too big message
q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen)
continue
}
msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n])
if err != nil {
q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped")
continue
}
if err := q.session.SendMessage(msgWithID); err != nil {
q.logger.Err(err).Msg("Failed to send datagram back to edge")
return
}
}
if err != nil {
q.logger.Err(err).Msg("Failed to read UDP from origin")
return
}
}
}
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client. // the client.
@ -320,3 +291,35 @@ func isTransferEncodingChunked(req *http.Request) bool {
// separated value as well. // separated value as well.
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
} }
// TODO: TUN-5303: Define an UDPProxy in ingress package
func (q *QUICConnection) newUDPProxy(dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
dstAddr := &net.UDPAddr{
IP: dstIP,
Port: int(dstPort),
}
return net.DialUDP("udp", nil, dstAddr)
}
// TODO: TUN-5303: Find the local IP once in ingress package
// TODO: TUN-5421 allow user to specify which IP to bind to
func getLocalIP() (net.IP, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
// Find the IP that is not loop back
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if !ip.IsLoopback() {
return ip, nil
}
}
return nil, fmt.Errorf("cannot determine IP to bind to")
}

View File

@ -1,90 +0,0 @@
package connection
import (
"fmt"
"net"
"sync"
"github.com/google/uuid"
)
// TODO: TUN-5422 Unregister session
const (
sessionIDLen = len(uuid.UUID{})
)
type udpSessions struct {
lock sync.RWMutex
sessions map[uuid.UUID]*net.UDPConn
localIP net.IP
}
func newUDPSessions(localIP net.IP) *udpSessions {
return &udpSessions{
sessions: make(map[uuid.UUID]*net.UDPConn),
localIP: localIP,
}
}
func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) {
us.lock.Lock()
defer us.lock.Unlock()
dstAddr := &net.UDPAddr{
IP: dstIP,
Port: int(dstPort),
}
conn, err := net.DialUDP("udp", us.localAddr(), dstAddr)
if err != nil {
return nil, err
}
us.sessions[id] = conn
return conn, nil
}
func (us *udpSessions) unregister(id uuid.UUID) {
us.lock.Lock()
defer us.lock.Unlock()
delete(us.sessions, id)
}
func (us *udpSessions) send(id uuid.UUID, payload []byte) error {
us.lock.RLock()
defer us.lock.RUnlock()
conn, ok := us.sessions[id]
if !ok {
return fmt.Errorf("session %s not found", id)
}
_, err := conn.Write(payload)
return err
}
func (ud *udpSessions) localAddr() *net.UDPAddr {
// TODO: Determine the IP to bind to
return &net.UDPAddr{
IP: ud.localIP,
Port: 0,
}
}
// TODO: TUN-5421 allow user to specify which IP to bind to
func GetLocalIP() (net.IP, error) {
addrs, err := net.InterfaceAddrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
// Find the IP that is not loop back
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if !ip.IsLoopback() {
return ip, nil
}
}
return nil, fmt.Errorf("cannot determine IP to bind to")
}

33
datagramsession/event.go Normal file
View File

@ -0,0 +1,33 @@
package datagramsession
import (
"io"
"github.com/google/uuid"
)
// registerSessionEvent is an event to start tracking a new session
type registerSessionEvent struct {
sessionID uuid.UUID
originProxy io.ReadWriteCloser
resultChan chan *Session
}
func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser) *registerSessionEvent {
return &registerSessionEvent{
sessionID: sessionID,
originProxy: originProxy,
resultChan: make(chan *Session, 1),
}
}
// unregisterSessionEvent is an event to stop tracking and terminate the session.
type unregisterSessionEvent struct {
sessionID uuid.UUID
}
// newDatagram is an event when transport receives new datagram
type newDatagram struct {
sessionID uuid.UUID
payload []byte
}

134
datagramsession/manager.go Normal file
View File

@ -0,0 +1,134 @@
package datagramsession
import (
"context"
"io"
"github.com/google/uuid"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
)
const (
requestChanCapacity = 16
)
// Manager defines the APIs to manage sessions from the same transport.
type Manager interface {
// Serve starts the event loop
Serve(ctx context.Context) error
// RegisterSession starts tracking a session. Caller is responsible for starting the session
RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error)
// UnregisterSession stops tracking the session and terminates it
UnregisterSession(ctx context.Context, sessionID uuid.UUID) error
}
type manager struct {
registrationChan chan *registerSessionEvent
unregistrationChan chan *unregisterSessionEvent
datagramChan chan *newDatagram
transport transport
sessions map[uuid.UUID]*Session
log *zerolog.Logger
}
func NewManager(transport transport, log *zerolog.Logger) Manager {
return &manager{
registrationChan: make(chan *registerSessionEvent),
unregistrationChan: make(chan *unregisterSessionEvent),
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
datagramChan: make(chan *newDatagram, requestChanCapacity),
transport: transport,
sessions: make(map[uuid.UUID]*Session),
log: log,
}
}
func (m *manager) Serve(ctx context.Context) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
for {
sessionID, payload, err := m.transport.ReceiveFrom()
if err != nil {
m.log.Err(err).Msg("Failed to receive datagram from transport, closing session manager")
return err
}
datagram := &newDatagram{
sessionID: sessionID,
payload: payload,
}
select {
case <-ctx.Done():
return ctx.Err()
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
// Send the datagram to the event loop. It will find the session to send to
case m.datagramChan <- datagram:
}
}
})
errGroup.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case datagram := <-m.datagramChan:
m.sendToSession(datagram)
case registration := <-m.registrationChan:
m.registerSession(ctx, registration)
// TODO: TUN-5422: Unregister inactive session upon timeout
case unregistration := <-m.unregistrationChan:
m.unregisterSession(unregistration)
}
}
})
return errGroup.Wait()
}
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
event := newRegisterSessionEvent(sessionID, originProxy)
select {
case <-ctx.Done():
return nil, ctx.Err()
case m.registrationChan <- event:
session := <-event.resultChan
return session, nil
}
}
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
session := newSession(registration.sessionID, m.transport, registration.originProxy)
m.sessions[registration.sessionID] = session
registration.resultChan <- session
}
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID) error {
event := &unregisterSessionEvent{sessionID: sessionID}
select {
case <-ctx.Done():
return ctx.Err()
case m.unregistrationChan <- event:
return nil
}
}
func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
session, ok := m.sessions[unregistration.sessionID]
if ok {
delete(m.sessions, unregistration.sessionID)
session.close()
}
}
func (m *manager) sendToSession(datagram *newDatagram) {
session, ok := m.sessions[datagram.sessionID]
if !ok {
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found")
return
}
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
// need to run in another go routine
_, err := session.writeToDst(datagram.payload)
if err != nil {
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
}
}

View File

@ -0,0 +1,214 @@
package datagramsession
import (
"bytes"
"context"
"fmt"
"io"
"net"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestManagerServe(t *testing.T) {
const (
sessions = 20
msgs = 50
)
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(),
respChan: newDatagramChannel(),
}
mg := NewManager(transport, &log)
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
for i := 0; i < sessions; i++ {
sessionID := uuid.New()
eyeballTracker[sessionID] = newDatagramChannel()
}
ctx, cancel := context.WithCancel(context.Background())
serveDone := make(chan struct{})
go func(ctx context.Context) {
mg.Serve(ctx)
close(serveDone)
}(ctx)
go func(ctx context.Context) {
for {
sessionID, payload, err := transport.respChan.Receive(ctx)
if err != nil {
require.Equal(t, context.Canceled, err)
return
}
respChan := eyeballTracker[sessionID]
require.NoError(t, respChan.Send(ctx, sessionID, payload))
}
}(ctx)
errGroup, ctx := errgroup.WithContext(ctx)
for sID, receiver := range eyeballTracker {
// Assign loop variables to local variables
sessionID := sID
eyeballRespReceiver := receiver
errGroup.Go(func() error {
payload := testPayload(sessionID)
expectResp := testResponse(payload)
cfdConn, originConn := net.Pipe()
origin := mockOrigin{
expectMsgCount: msgs,
expectedMsg: payload,
expectedResp: expectResp,
conn: originConn,
}
eyeball := mockEyeball{
expectMsgCount: msgs,
expectedMsg: expectResp,
expectSessionID: sessionID,
respReceiver: eyeballRespReceiver,
}
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
reqErrGroup.Go(func() error {
return origin.serve()
})
reqErrGroup.Go(func() error {
return eyeball.serve(reqCtx)
})
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
sessionDone := make(chan struct{})
go func() {
session.Serve(ctx)
close(sessionDone)
}()
for i := 0; i < msgs; i++ {
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
}
// Make sure eyeball and origin have received all messages before unregistering the session
require.NoError(t, reqErrGroup.Wait())
require.NoError(t, mg.UnregisterSession(ctx, sessionID))
<-sessionDone
return nil
})
}
require.NoError(t, errGroup.Wait())
cancel()
transport.close()
<-serveDone
}
type mockOrigin struct {
expectMsgCount int
expectedMsg []byte
expectedResp []byte
conn io.ReadWriteCloser
}
func (mo *mockOrigin) serve() error {
expectedMsgLen := len(mo.expectedMsg)
readBuffer := make([]byte, expectedMsgLen+1)
for i := 0; i < mo.expectMsgCount; i++ {
n, err := mo.conn.Read(readBuffer)
if err != nil {
return err
}
if n != expectedMsgLen {
return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n)
}
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
}
_, err = mo.conn.Write(mo.expectedResp)
if err != nil {
return err
}
}
return nil
}
func testPayload(sessionID uuid.UUID) []byte {
return []byte(fmt.Sprintf("Message from %s", sessionID))
}
func testResponse(msg []byte) []byte {
return []byte(fmt.Sprintf("Response to %v", msg))
}
type mockEyeball struct {
expectMsgCount int
expectedMsg []byte
expectSessionID uuid.UUID
respReceiver *datagramChannel
}
func (me *mockEyeball) serve(ctx context.Context) error {
for i := 0; i < me.expectMsgCount; i++ {
sessionID, msg, err := me.respReceiver.Receive(ctx)
if err != nil {
return err
}
if sessionID != me.expectSessionID {
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID)
}
if !bytes.Equal(msg, me.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
}
}
return nil
}
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
type datagramChannel struct {
datagramChan chan *newDatagram
closedChan chan struct{}
}
func newDatagramChannel() *datagramChannel {
return &datagramChannel{
datagramChan: make(chan *newDatagram, 1),
closedChan: make(chan struct{}),
}
}
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-rc.closedChan:
return fmt.Errorf("datagram channel closed")
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
return nil
}
}
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
select {
case <-ctx.Done():
return uuid.Nil, nil, ctx.Err()
case <-rc.closedChan:
return uuid.Nil, nil, fmt.Errorf("datagram channel closed")
case msg := <-rc.datagramChan:
return msg.sessionID, msg.payload, nil
}
}
func (rc *datagramChannel) Close() {
// No need to close msgChan, it will be garbage collect once there is no reference to it
close(rc.closedChan)
}

View File

@ -0,0 +1,70 @@
package datagramsession
import (
"context"
"io"
"github.com/google/uuid"
)
// Each Session is a bidirectional pipe of datagrams between transport and dstConn
// Currently the only implementation of transport is quic DatagramMuxer
// Destination can be a connection with origin or with eyeball
// When the destination is origin:
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
// write method of the Session to send to origin
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
// When the destination is eyeball:
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
// write method of the Session to send to eyeball
type Session struct {
id uuid.UUID
transport transport
dstConn io.ReadWriteCloser
doneChan chan struct{}
}
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session {
return &Session{
id: id,
transport: transport,
dstConn: dstConn,
doneChan: make(chan struct{}),
}
}
func (s *Session) Serve(ctx context.Context) error {
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-serveCtx.Done():
case <-s.doneChan:
}
s.dstConn.Close()
}()
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations
readBuffer := make([]byte, 1280)
for {
// TODO: TUN-5303: origin proxy should determine the buffer size
n, err := s.dstConn.Read(readBuffer)
if n > 0 {
if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil {
return err
}
}
if err != nil {
return err
}
}
}
func (s *Session) writeToDst(payload []byte) (int, error) {
return s.dstConn.Write(payload)
}
func (s *Session) close() {
close(s.doneChan)
}

View File

@ -0,0 +1,59 @@
package datagramsession
import (
"context"
"net"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
// TestCloseSession makes sure a session will stop after context is done
func TestSessionCtxDone(t *testing.T) {
testSessionReturns(t, true)
}
// TestCloseSession makes sure a session will stop after close method is called
func TestCloseSession(t *testing.T) {
testSessionReturns(t, false)
}
func testSessionReturns(t *testing.T, closeByContext bool) {
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(),
respChan: newDatagramChannel(),
}
session := newSession(sessionID, transport, cfdConn)
ctx, cancel := context.WithCancel(context.Background())
sessionDone := make(chan struct{})
go func() {
session.Serve(ctx)
close(sessionDone)
}()
go func() {
n, err := session.writeToDst(payload)
require.NoError(t, err)
require.Equal(t, len(payload), n)
}()
readBuffer := make([]byte, len(payload)+1)
n, err := originConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, len(payload), n)
if closeByContext {
cancel()
} else {
session.close()
}
<-sessionDone
// call cancelled again otherwise the linter will warn about possible context leak
cancel()
}

View File

@ -0,0 +1,11 @@
package datagramsession
import "github.com/google/uuid"
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
type transport interface {
// SendTo writes payload for a session to the transport
SendTo(sessionID uuid.UUID, payload []byte) error
// ReceiveFrom reads the next datagram from the transport
ReceiveFrom() (uuid.UUID, []byte, error)
}

View File

@ -0,0 +1,32 @@
package datagramsession
import (
"context"
"github.com/google/uuid"
)
type mockQUICTransport struct {
reqChan *datagramChannel
respChan *datagramChannel
}
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
buf := make([]byte, len(payload))
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
copy(buf, payload)
return mt.respChan.Send(context.Background(), sessionID, buf)
}
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
return mt.reqChan.Receive(context.Background())
}
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
return mt.reqChan.Send(ctx, sessionID, payload)
}
func (mt *mockQUICTransport) close() {
mt.reqChan.Close()
mt.respChan.Close()
}

View File

@ -4,13 +4,59 @@ import (
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
) )
const ( const (
sessionIDLen = len(uuid.UUID{})
MaxDatagramFrameSize = 1220 MaxDatagramFrameSize = 1220
sessionIDLen = len(uuid.UUID{})
) )
type DatagramMuxer struct {
ID uuid.UUID
session quic.Session
}
func NewDatagramMuxer(quicSession quic.Session) (*DatagramMuxer, error) {
muxerID, err := uuid.NewRandom()
if err != nil {
return nil, err
}
return &DatagramMuxer{
ID: muxerID,
session: quicSession,
}, nil
}
// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex
// the payload from multiple datagram sessions
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error {
if len(payload) > MaxDatagramFrameSize-sessionIDLen {
// TODO: TUN-5302 return ICMP packet too big message
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), MaxDatagramFrameSize-sessionIDLen)
}
msgWithID, err := SuffixSessionID(sessionID, payload)
if err != nil {
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
}
if err := dm.session.SendMessage(msgWithID); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
}
// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager
// which determines how to proxy to the origin. It assumes the datagram session has already been
// registered with session manager through other side channel
func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
msg, err := dm.session.ReceiveMessage()
if err != nil {
return uuid.Nil, nil, err
}
return ExtractSessionID(msg)
}
// Each QUIC datagram should be suffixed with session ID. // Each QUIC datagram should be suffixed with session ID.
// ExtractSessionID extracts the session ID and a slice with only the payload // ExtractSessionID extracts the session ID and a slice with only the payload
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) { func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {