TUN-6584: Define QUIC datagram v2 format to support proxying IP packets

This commit is contained in:
cthuang 2022-08-01 13:48:33 +01:00 committed by Chung-Ting Huang
parent d3fd581b7b
commit 278df5478a
12 changed files with 548 additions and 352 deletions

View File

@ -33,14 +33,19 @@ const (
HTTPHostKey = "HttpHost"
QUICMetadataFlowID = "FlowID"
// emperically this capacity has been working well
demuxChanCapacity = 16
)
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct {
session quic.Connection
logger *zerolog.Logger
orchestrator Orchestrator
sessionManager datagramsession.Manager
session quic.Connection
logger *zerolog.Logger
orchestrator Orchestrator
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager
// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *quicpogs.DatagramMuxer
controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions
}
@ -60,18 +65,16 @@ func NewQUICConnection(
return nil, &EdgeQuicDialError{Cause: err}
}
datagramMuxer, err := quicpogs.NewDatagramMuxer(session, logger)
if err != nil {
return nil, err
}
sessionManager := datagramsession.NewManager(datagramMuxer, logger)
demuxChan := make(chan *quicpogs.SessionDatagram, demuxChanCapacity)
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan)
sessionManager := datagramsession.NewManager(logger, datagramMuxer.MuxSession, demuxChan)
return &QUICConnection{
session: session,
orchestrator: orchestrator,
logger: logger,
sessionManager: sessionManager,
datagramMuxer: datagramMuxer,
controlStreamHandler: controlStreamHandler,
connOptions: connOptions,
}, nil
@ -108,6 +111,11 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
return q.sessionManager.Serve(ctx)
})
errGroup.Go(func() error {
defer cancel()
return q.datagramMuxer.ServeReceive(ctx)
})
return errGroup.Wait()
}

View File

@ -42,9 +42,3 @@ func (sc *errClosedSession) Error() string {
return fmt.Sprintf("session closed by local due to %s", sc.message)
}
}
// newDatagram is an event when transport receives new datagram
type newDatagram struct {
sessionID uuid.UUID
payload []byte
}

View File

@ -7,9 +7,9 @@ import (
"time"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
)
const (
@ -36,26 +36,25 @@ type Manager interface {
type manager struct {
registrationChan chan *registerSessionEvent
unregistrationChan chan *unregisterSessionEvent
datagramChan chan *newDatagram
closedChan chan struct{}
transport transport
sendFunc transportSender
receiveChan <-chan *quicpogs.SessionDatagram
closedChan <-chan struct{}
sessions map[uuid.UUID]*Session
log *zerolog.Logger
// timeout waiting for an API to finish. This can be overriden in test
timeout time.Duration
}
func NewManager(transport transport, log *zerolog.Logger) *manager {
func NewManager(log *zerolog.Logger, sendF transportSender, receiveChan <-chan *quicpogs.SessionDatagram) *manager {
return &manager{
registrationChan: make(chan *registerSessionEvent),
unregistrationChan: make(chan *unregisterSessionEvent),
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
datagramChan: make(chan *newDatagram, requestChanCapacity),
closedChan: make(chan struct{}),
transport: transport,
sessions: make(map[uuid.UUID]*Session),
log: log,
timeout: defaultReqTimeout,
sendFunc: sendF,
receiveChan: receiveChan,
closedChan: make(chan struct{}),
sessions: make(map[uuid.UUID]*Session),
log: log,
timeout: defaultReqTimeout,
}
}
@ -65,49 +64,21 @@ func (m *manager) UpdateLogger(log *zerolog.Logger) {
}
func (m *manager) Serve(ctx context.Context) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
for {
sessionID, payload, err := m.transport.ReceiveFrom()
if err != nil {
if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) {
return nil
} else {
return err
}
}
datagram := &newDatagram{
sessionID: sessionID,
payload: payload,
}
select {
case <-ctx.Done():
return ctx.Err()
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
// Send the datagram to the event loop. It will find the session to send to
case m.datagramChan <- datagram:
}
for {
select {
case <-ctx.Done():
m.shutdownSessions(ctx.Err())
return ctx.Err()
// receiveChan is buffered, so the transport can read more datagrams from transport while the event loop is
// processing other events
case datagram := <-m.receiveChan:
m.sendToSession(datagram)
case registration := <-m.registrationChan:
m.registerSession(ctx, registration)
case unregistration := <-m.unregistrationChan:
m.unregisterSession(unregistration)
}
})
errGroup.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case datagram := <-m.datagramChan:
m.sendToSession(datagram)
case registration := <-m.registrationChan:
m.registerSession(ctx, registration)
// TODO: TUN-5422: Unregister inactive session upon timeout
case unregistration := <-m.unregistrationChan:
m.unregisterSession(unregistration)
}
}
})
err := errGroup.Wait()
close(m.closedChan)
m.shutdownSessions(err)
return err
}
}
func (m *manager) shutdownSessions(err error) {
@ -149,16 +120,17 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
}
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
logger := m.log.With().Str("sessionID", id.String()).Logger()
return &Session{
ID: id,
transport: m.transport,
dstConn: dstConn,
ID: id,
sendFunc: m.sendFunc,
dstConn: dstConn,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2),
log: m.log,
log: &logger,
}
}
@ -191,16 +163,13 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
}
}
func (m *manager) sendToSession(datagram *newDatagram) {
session, ok := m.sessions[datagram.sessionID]
func (m *manager) sendToSession(datagram *quicpogs.SessionDatagram) {
session, ok := m.sessions[datagram.ID]
if !ok {
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found")
m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found")
return
}
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
// need to run in another go routine
_, err := session.transportToDst(datagram.payload)
if err != nil {
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
}
session.transportToDst(datagram.Payload)
}

View File

@ -14,22 +14,30 @@ import (
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
)
var (
nopLogger = zerolog.Nop()
)
func TestManagerServe(t *testing.T) {
const (
sessions = 20
msgs = 50
sessions = 2
msgs = 5
remoteUnregisterMsg = "eyeball closed connection"
)
mg, transport := newTestManager(1)
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
for i := 0; i < sessions; i++ {
sessionID := uuid.New()
eyeballTracker[sessionID] = newDatagramChannel(1)
requestChan := make(chan *quicpogs.SessionDatagram)
transport := mockQUICTransport{
sessions: make(map[uuid.UUID]chan []byte),
}
for i := 0; i < sessions; i++ {
transport.sessions[uuid.New()] = make(chan []byte)
}
mg := NewManager(&nopLogger, transport.MuxSession, requestChan)
ctx, cancel := context.WithCancel(context.Background())
serveDone := make(chan struct{})
@ -38,53 +46,42 @@ func TestManagerServe(t *testing.T) {
close(serveDone)
}(ctx)
go func(ctx context.Context) {
for {
sessionID, payload, err := transport.respChan.Receive(ctx)
if err != nil {
require.Equal(t, context.Canceled, err)
return
}
respChan := eyeballTracker[sessionID]
require.NoError(t, respChan.Send(ctx, sessionID, payload))
}
}(ctx)
errGroup, ctx := errgroup.WithContext(ctx)
for sID, receiver := range eyeballTracker {
for sessionID, eyeballRespChan := range transport.sessions {
// Assign loop variables to local variables
sID := sessionID
payload := testPayload(sID)
expectResp := testResponse(payload)
cfdConn, originConn := net.Pipe()
origin := mockOrigin{
expectMsgCount: msgs,
expectedMsg: payload,
expectedResp: expectResp,
conn: originConn,
}
eyeball := mockEyeballSession{
id: sID,
expectedMsgCount: msgs,
expectedMsg: payload,
expectedResponse: expectResp,
respReceiver: eyeballRespChan,
}
// Assign loop variables to local variables
sessionID := sID
eyeballRespReceiver := receiver
errGroup.Go(func() error {
payload := testPayload(sessionID)
expectResp := testResponse(payload)
cfdConn, originConn := net.Pipe()
origin := mockOrigin{
expectMsgCount: msgs,
expectedMsg: payload,
expectedResp: expectResp,
conn: originConn,
}
eyeball := mockEyeball{
expectMsgCount: msgs,
expectedMsg: expectResp,
expectSessionID: sessionID,
respReceiver: eyeballRespReceiver,
}
session, err := mg.RegisterSession(ctx, sID, cfdConn)
require.NoError(t, err)
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
reqErrGroup.Go(func() error {
return origin.serve()
})
reqErrGroup.Go(func() error {
return eyeball.serve(reqCtx)
return eyeball.serve(reqCtx, requestChan)
})
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
sessionDone := make(chan struct{})
go func() {
closedByRemote, err := session.Serve(ctx, time.Minute*2)
@ -97,23 +94,17 @@ func TestManagerServe(t *testing.T) {
close(sessionDone)
}()
for i := 0; i < msgs; i++ {
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
}
// Make sure eyeball and origin have received all messages before unregistering the session
require.NoError(t, reqErrGroup.Wait())
require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true))
require.NoError(t, mg.UnregisterSession(ctx, sID, remoteUnregisterMsg, true))
<-sessionDone
return nil
})
}
require.NoError(t, errGroup.Wait())
cancel()
transport.close()
<-serveDone
}
@ -122,7 +113,7 @@ func TestTimeout(t *testing.T) {
testTimeout = time.Millisecond * 50
)
mg, _ := newTestManager(1)
mg := NewManager(&nopLogger, nil, nil)
mg.timeout = testTimeout
ctx := context.Background()
sessionID := uuid.New()
@ -135,9 +126,51 @@ func TestTimeout(t *testing.T) {
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestCloseTransportCloseSessions(t *testing.T) {
mg, transport := newTestManager(1)
ctx := context.Background()
func TestUnregisterSessionCloseSession(t *testing.T) {
sessionID := uuid.New()
payload := []byte(t.Name())
sender := newMockTransportSender(sessionID, payload)
mg := NewManager(&nopLogger, sender.muxSession, nil)
ctx, cancel := context.WithCancel(context.Background())
managerDone := make(chan struct{})
go func() {
err := mg.Serve(ctx)
require.Error(t, err)
close(managerDone)
}()
cfdConn, originConn := net.Pipe()
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
require.NotNil(t, session)
unregisteredChan := make(chan struct{})
go func() {
_, err := originConn.Write(payload)
require.NoError(t, err)
err = mg.UnregisterSession(ctx, sessionID, "eyeball closed session", true)
require.NoError(t, err)
close(unregisteredChan)
}()
closedByRemote, err := session.Serve(ctx, time.Minute)
require.True(t, closedByRemote)
require.Error(t, err)
<-unregisteredChan
cancel()
<-managerDone
}
func TestManagerCtxDoneCloseSessions(t *testing.T) {
sessionID := uuid.New()
payload := []byte(t.Name())
sender := newMockTransportSender(sessionID, payload)
mg := NewManager(&nopLogger, sender.muxSession, nil)
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
@ -147,35 +180,26 @@ func TestCloseTransportCloseSessions(t *testing.T) {
require.Error(t, err)
}()
cfdConn, eyeballConn := net.Pipe()
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn)
cfdConn, originConn := net.Pipe()
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
require.NotNil(t, session)
wg.Add(1)
go func() {
defer wg.Done()
_, err := eyeballConn.Write([]byte(t.Name()))
_, err := originConn.Write(payload)
require.NoError(t, err)
transport.close()
cancel()
}()
closedByRemote, err := session.Serve(ctx, time.Minute)
require.True(t, closedByRemote)
require.False(t, closedByRemote)
require.Error(t, err)
wg.Wait()
}
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(capacity),
respChan: newDatagramChannel(capacity),
}
return NewManager(transport, &log), transport
}
type mockOrigin struct {
expectMsgCount int
expectedMsg []byte
@ -197,7 +221,6 @@ func (mo *mockOrigin) serve() error {
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
}
_, err = mo.conn.Write(mo.expectedResp)
if err != nil {
return err
@ -214,72 +237,35 @@ func testResponse(msg []byte) []byte {
return []byte(fmt.Sprintf("Response to %v", msg))
}
type mockEyeball struct {
expectMsgCount int
expectedMsg []byte
expectSessionID uuid.UUID
respReceiver *datagramChannel
type mockQUICTransport struct {
sessions map[uuid.UUID]chan []byte
}
func (me *mockEyeball) serve(ctx context.Context) error {
for i := 0; i < me.expectMsgCount; i++ {
sessionID, msg, err := me.respReceiver.Receive(ctx)
if err != nil {
return err
}
if sessionID != me.expectSessionID {
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID)
}
if !bytes.Equal(msg, me.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
}
}
func (me *mockQUICTransport) MuxSession(id uuid.UUID, payload []byte) error {
session := me.sessions[id]
session <- payload
return nil
}
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
type datagramChannel struct {
datagramChan chan *newDatagram
closedChan chan struct{}
type mockEyeballSession struct {
id uuid.UUID
expectedMsgCount int
expectedMsg []byte
expectedResponse []byte
respReceiver <-chan []byte
}
func newDatagramChannel(capacity uint) *datagramChannel {
return &datagramChannel{
datagramChan: make(chan *newDatagram, capacity),
closedChan: make(chan struct{}),
}
}
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-rc.closedChan:
return &errClosedSession{
message: fmt.Errorf("datagram channel closed").Error(),
byRemote: true,
func (me *mockEyeballSession) serve(ctx context.Context, requestChan chan *quicpogs.SessionDatagram) error {
for i := 0; i < me.expectedMsgCount; i++ {
requestChan <- &quicpogs.SessionDatagram{
ID: me.id,
Payload: me.expectedMsg,
}
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
return nil
}
}
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
select {
case <-ctx.Done():
return uuid.Nil, nil, ctx.Err()
case <-rc.closedChan:
err := &errClosedSession{
message: fmt.Errorf("datagram channel closed").Error(),
byRemote: true,
resp := <-me.respReceiver
if !bytes.Equal(resp, me.expectedResponse) {
return fmt.Errorf("Expect %v, read %v", me.expectedResponse, resp)
}
return uuid.Nil, nil, err
case msg := <-rc.datagramChan:
return msg.sessionID, msg.payload, nil
fmt.Println("Resp", resp)
}
}
func (rc *datagramChannel) Close() {
// No need to close msgChan, it will be garbage collect once there is no reference to it
close(rc.closedChan)
return nil
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"net"
"time"
"github.com/google/uuid"
@ -18,21 +19,20 @@ func SessionIdleErr(timeout time.Duration) error {
return fmt.Errorf("session idle for %v", timeout)
}
type transportSender func(sessionID uuid.UUID, payload []byte) error
// Session is a bidirectional pipe of datagrams between transport and dstConn
// Currently the only implementation of transport is quic DatagramMuxer
// Destination can be a connection with origin or with eyeball
// When the destination is origin:
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
// write method of the Session to send to origin
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to origin
// - Datagrams from origin are read from conn and Send to transport using the transportSender callback. Transport will return them to eyeball
// When the destination is eyeball:
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
// write method of the Session to send to eyeball
// - Datagrams from eyeball are read from conn and Send to transport. Transport will send them to cloudflared using the transportSender callback.
// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to the eyeball
type Session struct {
ID uuid.UUID
transport transport
dstConn io.ReadWriteCloser
ID uuid.UUID
sendFunc transportSender
dstConn io.ReadWriteCloser
// activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time
closeChan chan error
@ -46,9 +46,16 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos
const maxPacketSize = 1500
readBuffer := make([]byte, maxPacketSize)
for {
if err := s.dstToTransport(readBuffer); err != nil {
s.closeChan <- err
return
if closeSession, err := s.dstToTransport(readBuffer); err != nil {
if err != net.ErrClosed {
s.log.Error().Err(err).Msg("Failed to send session payload from destination to transport")
} else {
s.log.Debug().Msg("Session cannot read from destination because the connection is closed")
}
if closeSession {
s.closeChan <- err
return
}
}
}
}()
@ -89,32 +96,25 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
}
}
func (s *Session) dstToTransport(buffer []byte) error {
func (s *Session) dstToTransport(buffer []byte) (closeSession bool, err error) {
n, err := s.dstConn.Read(buffer)
s.markActive()
// https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes
if n > 0 {
if n <= int(s.transport.MTU()) {
err = s.transport.SendTo(s.ID, buffer[:n])
} else {
// drop packet for now, eventually reply with ICMP for PMTUD
s.log.Debug().
Str("session", s.ID.String()).
Int("len", n).
Int("mtu", s.transport.MTU()).
Msg("dropped packet exceeding MTU")
if n > 0 || err == nil {
if sendErr := s.sendFunc(s.ID, buffer[:n]); sendErr != nil {
return false, sendErr
}
}
// Some UDP application might send 0-size payload.
if err == nil && n == 0 {
err = s.transport.SendTo(s.ID, []byte{})
}
return err
return err != nil, err
}
func (s *Session) transportToDst(payload []byte) (int, error) {
s.markActive()
return s.dstConn.Write(payload)
n, err := s.dstConn.Write(payload)
if err != nil {
s.log.Err(err).Msg("Failed to write payload to session")
}
return n, err
}
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there

View File

@ -11,8 +11,11 @@ import (
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
)
// TestCloseSession makes sure a session will stop after context is done
@ -41,7 +44,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
mg, _ := newTestManager(1)
log := zerolog.Nop()
mg := NewManager(&log, nil, nil)
session := mg.newSession(sessionID, cfdConn)
ctx, cancel := context.WithCancel(context.Background())
@ -114,7 +118,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
mg, _ := newTestManager(100)
respChan := make(chan *quicpogs.SessionDatagram)
sender := newMockTransportSender(sessionID, payload)
mg := NewManager(&nopLogger, sender.muxSession, respChan)
session := mg.newSession(sessionID, cfdConn)
startTime := time.Now()
@ -177,7 +183,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50
mg, _ := newTestManager(1)
mg := NewManager(&nopLogger, nil, nil)
session := mg.newSession(uuid.New(), nil)
var wg sync.WaitGroup
wg.Add(concurrentCalls)
@ -190,11 +196,16 @@ func TestMarkActiveNotBlocking(t *testing.T) {
wg.Wait()
}
// Some UDP application might send 0-size payload.
func TestZeroBytePayload(t *testing.T) {
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
mg, transport := newTestManager(1)
sender := sendOnceTransportSender{
baseSender: newMockTransportSender(sessionID, make([]byte, 0)),
sentChan: make(chan struct{}),
}
mg := NewManager(&nopLogger, sender.muxSession, nil)
session := mg.newSession(sessionID, cfdConn)
ctx, cancel := context.WithCancel(context.Background())
@ -215,11 +226,39 @@ func TestZeroBytePayload(t *testing.T) {
return nil
})
receivedSessionID, payload, err := transport.respChan.Receive(ctx)
require.NoError(t, err)
require.Len(t, payload, 0)
require.Equal(t, sessionID, receivedSessionID)
<-sender.sentChan
cancel()
require.NoError(t, errGroup.Wait())
}
type mockTransportSender struct {
expectedSessionID uuid.UUID
expectedPayload []byte
}
func newMockTransportSender(expectedSessionID uuid.UUID, expectedPayload []byte) *mockTransportSender {
return &mockTransportSender{
expectedSessionID: expectedSessionID,
expectedPayload: expectedPayload,
}
}
func (mts *mockTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
if sessionID != mts.expectedSessionID {
return fmt.Errorf("Expect session %s, got %s", mts.expectedSessionID, sessionID)
}
if !bytes.Equal(payload, mts.expectedPayload) {
return fmt.Errorf("Expect %v, read %v", mts.expectedPayload, payload)
}
return nil
}
type sendOnceTransportSender struct {
baseSender *mockTransportSender
sentChan chan struct{}
}
func (sots *sendOnceTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
defer close(sots.sentChan)
return sots.baseSender.muxSession(sessionID, payload)
}

View File

@ -1,13 +0,0 @@
package datagramsession
import "github.com/google/uuid"
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
type transport interface {
// SendTo writes payload for a session to the transport
SendTo(sessionID uuid.UUID, payload []byte) error
// ReceiveFrom reads the next datagram from the transport
ReceiveFrom() (uuid.UUID, []byte, error)
// Max transmission unit to receive from the transport
MTU() int
}

View File

@ -1,36 +0,0 @@
package datagramsession
import (
"context"
"github.com/google/uuid"
)
type mockQUICTransport struct {
reqChan *datagramChannel
respChan *datagramChannel
}
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
buf := make([]byte, len(payload))
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
copy(buf, payload)
return mt.respChan.Send(context.Background(), sessionID, buf)
}
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
return mt.reqChan.Receive(context.Background())
}
func (mt *mockQUICTransport) MTU() int {
return 1280
}
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
return mt.reqChan.Send(ctx, sessionID, payload)
}
func (mt *mockQUICTransport) close() {
mt.reqChan.Close()
mt.respChan.Close()
}

View File

@ -1,6 +1,7 @@
package quic
import (
"context"
"fmt"
"github.com/google/uuid"
@ -13,53 +14,88 @@ const (
sessionIDLen = len(uuid.UUID{})
)
type SessionDatagram struct {
ID uuid.UUID
Payload []byte
}
type BaseDatagramMuxer interface {
// MuxSession suffix the session ID to the payload so the other end of the QUIC connection can demultiplex the
// payload from multiple datagram sessions
MuxSession(sessionID uuid.UUID, payload []byte) error
// ServeReceive starts a loop to receive datagrams from the QUIC connection
ServeReceive(ctx context.Context) error
}
type DatagramMuxer struct {
session quic.Connection
logger *zerolog.Logger
session quic.Connection
logger *zerolog.Logger
demuxChan chan<- *SessionDatagram
}
func NewDatagramMuxer(quicSession quic.Connection, logger *zerolog.Logger) (*DatagramMuxer, error) {
func NewDatagramMuxer(quicSession quic.Connection, log *zerolog.Logger, demuxChan chan<- *SessionDatagram) *DatagramMuxer {
logger := log.With().Uint8("datagramVersion", 1).Logger()
return &DatagramMuxer{
session: quicSession,
logger: logger,
}, nil
session: quicSession,
logger: &logger,
demuxChan: demuxChan,
}
}
// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex
// the payload from multiple datagram sessions
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error {
if len(payload) > maxDatagramPayloadSize {
// Maximum application payload to send to / receive from QUIC datagram frame
func (dm *DatagramMuxer) mtu() int {
return maxDatagramPayloadSize
}
func (dm *DatagramMuxer) MuxSession(sessionID uuid.UUID, payload []byte) error {
if len(payload) > dm.mtu() {
// TODO: TUN-5302 return ICMP packet too big message
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.MTU())
// drop packet for now, eventually reply with ICMP for PMTUD
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
}
msgWithID, err := suffixSessionID(sessionID, payload)
payloadWithMetadata, err := suffixSessionID(sessionID, payload)
if err != nil {
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
}
if err := dm.session.SendMessage(msgWithID); err != nil {
if err := dm.session.SendMessage(payloadWithMetadata); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
}
// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager
// which determines how to proxy to the origin. It assumes the datagram session has already been
// registered with session manager through other side channel
func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
msg, err := dm.session.ReceiveMessage()
if err != nil {
return uuid.Nil, nil, err
func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error {
for {
// Extracts datagram session ID, then sends the session ID and payload to receiver
// which determines how to proxy to the origin. It assumes the datagram session has already been
// registered with receiver through other side channel
msg, err := dm.session.ReceiveMessage()
if err != nil {
return err
}
if err := dm.demux(ctx, msg); err != nil {
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
if err == context.Canceled {
return err
}
}
}
sessionID, payload, err := extractSessionID(msg)
if err != nil {
return uuid.Nil, nil, err
}
return sessionID, payload, nil
}
// Maximum application payload to send to / receive from QUIC datagram frame
func (dm *DatagramMuxer) MTU() int {
return maxDatagramPayloadSize
func (dm *DatagramMuxer) demux(ctx context.Context, msg []byte) error {
sessionID, payload, err := extractSessionID(msg)
if err != nil {
return err
}
sessionDatagram := SessionDatagram{
ID: sessionID,
Payload: payload,
}
select {
case dm.demuxChan <- &sessionDatagram:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Each QUIC datagram should be suffixed with session ID.

View File

@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"testing"
"time"
@ -52,9 +53,29 @@ func TestSuffixSessionIDError(t *testing.T) {
require.Error(t, err)
}
func TestMaxDatagramPayload(t *testing.T) {
payload := make([]byte, maxDatagramPayloadSize)
func TestDatagram(t *testing.T) {
maxPayload := make([]byte, maxDatagramPayloadSize)
noPayloadSession := uuid.New()
maxPayloadSession := uuid.New()
sessionToPayload := []*SessionDatagram{
{
ID: noPayloadSession,
Payload: make([]byte, 0),
},
{
ID: maxPayloadSession,
Payload: maxPayload,
},
}
flowPayloads := [][]byte{
maxPayload,
}
testDatagram(t, 1, sessionToPayload, nil)
testDatagram(t, 2, sessionToPayload, flowPayloads)
}
func testDatagram(t *testing.T, version uint8, sessionToPayloads []*SessionDatagram, packetPayloads [][]byte) {
quicConfig := &quic.Config{
KeepAlivePeriod: 5 * time.Millisecond,
EnableDatagrams: true,
@ -63,6 +84,8 @@ func TestMaxDatagramPayload(t *testing.T) {
quicListener := newQUICListener(t, quicConfig)
defer quicListener.Close()
logger := zerolog.Nop()
errGroup, ctx := errgroup.WithContext(context.Background())
// Run edge side of datagram muxer
errGroup.Go(func() error {
@ -72,22 +95,32 @@ func TestMaxDatagramPayload(t *testing.T) {
return err
}
logger := zerolog.Nop()
muxer, err := NewDatagramMuxer(quicSession, &logger)
if err != nil {
return err
sessionDemuxChan := make(chan *SessionDatagram, 16)
switch version {
case 1:
muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan)
muxer.ServeReceive(ctx)
case 2:
packetDemuxChan := make(chan []byte, len(packetPayloads))
muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan)
muxer.ServeReceive(ctx)
for _, expectedPayload := range packetPayloads {
require.Equal(t, expectedPayload, <-packetDemuxChan)
}
default:
return fmt.Errorf("unknown datagram version %d", version)
}
sessionID, receivedPayload, err := muxer.ReceiveFrom()
if err != nil {
return err
for _, expectedPayload := range sessionToPayloads {
actualPayload := <-sessionDemuxChan
require.Equal(t, expectedPayload, actualPayload)
}
require.Equal(t, testSessionID, sessionID)
require.True(t, bytes.Equal(payload, receivedPayload))
return nil
})
largePayload := make([]byte, MaxDatagramFrameSize)
// Run cloudflared side of datagram muxer
errGroup.Go(func() error {
tlsClientConfig := &tls.Config{
@ -97,24 +130,35 @@ func TestMaxDatagramPayload(t *testing.T) {
// Establish quic connection
quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig)
require.NoError(t, err)
logger := zerolog.Nop()
muxer, err := NewDatagramMuxer(quicSession, &logger)
if err != nil {
return err
}
defer quicSession.CloseWithError(0, "")
// Wait a few milliseconds for MTU discovery to take place
time.Sleep(time.Millisecond * 100)
err = muxer.SendTo(testSessionID, payload)
if err != nil {
return err
var muxer BaseDatagramMuxer
switch version {
case 1:
muxer = NewDatagramMuxer(quicSession, &logger, nil)
case 2:
muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil)
for _, payload := range packetPayloads {
require.NoError(t, muxerV2.MuxPacket(payload))
}
// Payload larger than transport MTU, should not be sent
require.Error(t, muxerV2.MuxPacket(largePayload))
muxer = muxerV2
default:
return fmt.Errorf("unknown datagram version %d", version)
}
// Payload larger than transport MTU, should return an error
largePayload := make([]byte, MaxDatagramFrameSize)
err = muxer.SendTo(testSessionID, largePayload)
require.Error(t, err)
for _, sessionDatagram := range sessionToPayloads {
require.NoError(t, muxer.MuxSession(sessionDatagram.ID, sessionDatagram.Payload))
}
// Payload larger than transport MTU, should not be sent
require.Error(t, muxer.MuxSession(testSessionID, largePayload))
// Wait for edge to finish receiving the messages
time.Sleep(time.Millisecond * 100)
return nil
})
@ -154,3 +198,35 @@ func generateTLSConfig() *tls.Config {
NextProtos: []string{"argotunnel"},
}
}
type sessionMuxer interface {
SendToSession(sessionID uuid.UUID, payload []byte) error
}
type mockSessionReceiver struct {
expectedSessionToPayload map[uuid.UUID][]byte
receivedCount int
}
func (msr *mockSessionReceiver) ReceiveDatagram(sessionID uuid.UUID, payload []byte) error {
expectedPayload := msr.expectedSessionToPayload[sessionID]
if !bytes.Equal(expectedPayload, payload) {
return fmt.Errorf("expect %v to have payload %s, got %s", sessionID, string(expectedPayload), string(payload))
}
msr.receivedCount++
return nil
}
type mockFlowReceiver struct {
expectedPayloads [][]byte
receivedCount int
}
func (mfr *mockFlowReceiver) ReceiveFlow(payload []byte) error {
expectedPayload := mfr.expectedPayloads[mfr.receivedCount]
if !bytes.Equal(expectedPayload, payload) {
return fmt.Errorf("expect flow %d to have payload %s, got %s", mfr.receivedCount, string(expectedPayload), string(payload))
}
mfr.receivedCount++
return nil
}

136
quic/datagramv2.go Normal file
View File

@ -0,0 +1,136 @@
package quic
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
type datagramV2Type byte
const (
udp datagramV2Type = iota
ip
)
func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) {
if len(b)+1 > MaxDatagramFrameSize {
return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize)
}
b = append(b, byte(datagramType))
return b, nil
}
// Maximum application payload to send to / receive from QUIC datagram frame
func (dm *DatagramMuxerV2) mtu() int {
return maxDatagramPayloadSize
}
type DatagramMuxerV2 struct {
session quic.Connection
logger *zerolog.Logger
sessionDemuxChan chan<- *SessionDatagram
packetDemuxChan chan<- []byte
}
func NewDatagramMuxerV2(
quicSession quic.Connection,
log *zerolog.Logger,
sessionDemuxChan chan<- *SessionDatagram,
packetDemuxChan chan<- []byte) *DatagramMuxerV2 {
logger := log.With().Uint8("datagramVersion", 2).Logger()
return &DatagramMuxerV2{
session: quicSession,
logger: &logger,
sessionDemuxChan: sessionDemuxChan,
packetDemuxChan: packetDemuxChan,
}
}
// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
// demultiplex the payload from multiple datagram sessions
func (dm *DatagramMuxerV2) MuxSession(sessionID uuid.UUID, payload []byte) error {
if len(payload) > dm.mtu() {
// TODO: TUN-5302 return ICMP packet too big message
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
}
msgWithID, err := suffixSessionID(sessionID, payload)
if err != nil {
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
}
msgWithIDAndType, err := suffixType(msgWithID, udp)
if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
}
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
}
// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing
// the payload as IP and look at the source and destination.
func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error {
payloadWithVersion, err := suffixType(packet, ip)
if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
}
if err := dm.session.SendMessage(payloadWithVersion); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
}
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
for {
msg, err := dm.session.ReceiveMessage()
if err != nil {
return err
}
if err := dm.demux(ctx, msg); err != nil {
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
if err == context.Canceled {
return err
}
}
}
}
func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
if len(msgWithType) < 1 {
return fmt.Errorf("QUIC datagram should have at least 1 byte")
}
msgType := datagramV2Type(msgWithType[len(msgWithType)-1])
msg := msgWithType[0 : len(msgWithType)-1]
switch msgType {
case udp:
sessionID, payload, err := extractSessionID(msg)
if err != nil {
return err
}
sessionDatagram := SessionDatagram{
ID: sessionID,
Payload: payload,
}
select {
case dm.sessionDemuxChan <- &sessionDatagram:
return nil
case <-ctx.Done():
return ctx.Err()
}
case ip:
select {
case dm.packetDemuxChan <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
}
default:
return fmt.Errorf("Unexpected datagram type %d", msgType)
}
}

View File

@ -33,6 +33,7 @@ const (
FeatureSerializedHeaders = "serialized_headers"
FeatureQuickReconnects = "quick_reconnects"
FeatureAllowRemoteConfig = "allow_remote_config"
FeatureDatagramV2 = "support_datagram_v2"
)
type TunnelConfig struct {