TUN-6584: Define QUIC datagram v2 format to support proxying IP packets

This commit is contained in:
cthuang 2022-08-01 13:48:33 +01:00 committed by Chung-Ting Huang
parent d3fd581b7b
commit 278df5478a
12 changed files with 548 additions and 352 deletions

View File

@ -33,6 +33,8 @@ const (
HTTPHostKey = "HttpHost" HTTPHostKey = "HttpHost"
QUICMetadataFlowID = "FlowID" QUICMetadataFlowID = "FlowID"
// emperically this capacity has been working well
demuxChanCapacity = 16
) )
// QUICConnection represents the type that facilitates Proxying via QUIC streams. // QUICConnection represents the type that facilitates Proxying via QUIC streams.
@ -40,7 +42,10 @@ type QUICConnection struct {
session quic.Connection session quic.Connection
logger *zerolog.Logger logger *zerolog.Logger
orchestrator Orchestrator orchestrator Orchestrator
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *quicpogs.DatagramMuxer
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
} }
@ -60,18 +65,16 @@ func NewQUICConnection(
return nil, &EdgeQuicDialError{Cause: err} return nil, &EdgeQuicDialError{Cause: err}
} }
datagramMuxer, err := quicpogs.NewDatagramMuxer(session, logger) demuxChan := make(chan *quicpogs.SessionDatagram, demuxChanCapacity)
if err != nil { datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan)
return nil, err sessionManager := datagramsession.NewManager(logger, datagramMuxer.MuxSession, demuxChan)
}
sessionManager := datagramsession.NewManager(datagramMuxer, logger)
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
orchestrator: orchestrator, orchestrator: orchestrator,
logger: logger, logger: logger,
sessionManager: sessionManager, sessionManager: sessionManager,
datagramMuxer: datagramMuxer,
controlStreamHandler: controlStreamHandler, controlStreamHandler: controlStreamHandler,
connOptions: connOptions, connOptions: connOptions,
}, nil }, nil
@ -108,6 +111,11 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
return q.sessionManager.Serve(ctx) return q.sessionManager.Serve(ctx)
}) })
errGroup.Go(func() error {
defer cancel()
return q.datagramMuxer.ServeReceive(ctx)
})
return errGroup.Wait() return errGroup.Wait()
} }

View File

@ -42,9 +42,3 @@ func (sc *errClosedSession) Error() string {
return fmt.Sprintf("session closed by local due to %s", sc.message) return fmt.Sprintf("session closed by local due to %s", sc.message)
} }
} }
// newDatagram is an event when transport receives new datagram
type newDatagram struct {
sessionID uuid.UUID
payload []byte
}

View File

@ -7,9 +7,9 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
) )
const ( const (
@ -36,23 +36,22 @@ type Manager interface {
type manager struct { type manager struct {
registrationChan chan *registerSessionEvent registrationChan chan *registerSessionEvent
unregistrationChan chan *unregisterSessionEvent unregistrationChan chan *unregisterSessionEvent
datagramChan chan *newDatagram sendFunc transportSender
closedChan chan struct{} receiveChan <-chan *quicpogs.SessionDatagram
transport transport closedChan <-chan struct{}
sessions map[uuid.UUID]*Session sessions map[uuid.UUID]*Session
log *zerolog.Logger log *zerolog.Logger
// timeout waiting for an API to finish. This can be overriden in test // timeout waiting for an API to finish. This can be overriden in test
timeout time.Duration timeout time.Duration
} }
func NewManager(transport transport, log *zerolog.Logger) *manager { func NewManager(log *zerolog.Logger, sendF transportSender, receiveChan <-chan *quicpogs.SessionDatagram) *manager {
return &manager{ return &manager{
registrationChan: make(chan *registerSessionEvent), registrationChan: make(chan *registerSessionEvent),
unregistrationChan: make(chan *unregisterSessionEvent), unregistrationChan: make(chan *unregisterSessionEvent),
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events sendFunc: sendF,
datagramChan: make(chan *newDatagram, requestChanCapacity), receiveChan: receiveChan,
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
transport: transport,
sessions: make(map[uuid.UUID]*Session), sessions: make(map[uuid.UUID]*Session),
log: log, log: log,
timeout: defaultReqTimeout, timeout: defaultReqTimeout,
@ -65,49 +64,21 @@ func (m *manager) UpdateLogger(log *zerolog.Logger) {
} }
func (m *manager) Serve(ctx context.Context) error { func (m *manager) Serve(ctx context.Context) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
for { for {
sessionID, payload, err := m.transport.ReceiveFrom()
if err != nil {
if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) {
return nil
} else {
return err
}
}
datagram := &newDatagram{
sessionID: sessionID,
payload: payload,
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
m.shutdownSessions(ctx.Err())
return ctx.Err() return ctx.Err()
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access // receiveChan is buffered, so the transport can read more datagrams from transport while the event loop is
// Send the datagram to the event loop. It will find the session to send to // processing other events
case m.datagramChan <- datagram: case datagram := <-m.receiveChan:
}
}
})
errGroup.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case datagram := <-m.datagramChan:
m.sendToSession(datagram) m.sendToSession(datagram)
case registration := <-m.registrationChan: case registration := <-m.registrationChan:
m.registerSession(ctx, registration) m.registerSession(ctx, registration)
// TODO: TUN-5422: Unregister inactive session upon timeout
case unregistration := <-m.unregistrationChan: case unregistration := <-m.unregistrationChan:
m.unregisterSession(unregistration) m.unregisterSession(unregistration)
} }
} }
})
err := errGroup.Wait()
close(m.closedChan)
m.shutdownSessions(err)
return err
} }
func (m *manager) shutdownSessions(err error) { func (m *manager) shutdownSessions(err error) {
@ -149,16 +120,17 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
} }
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session { func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
logger := m.log.With().Str("sessionID", id.String()).Logger()
return &Session{ return &Session{
ID: id, ID: id,
transport: m.transport, sendFunc: m.sendFunc,
dstConn: dstConn, dstConn: dstConn,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation // drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2), activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel // capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2), closeChan: make(chan error, 2),
log: m.log, log: &logger,
} }
} }
@ -191,16 +163,13 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
} }
} }
func (m *manager) sendToSession(datagram *newDatagram) { func (m *manager) sendToSession(datagram *quicpogs.SessionDatagram) {
session, ok := m.sessions[datagram.sessionID] session, ok := m.sessions[datagram.ID]
if !ok { if !ok {
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found") m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found")
return return
} }
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't // session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
// need to run in another go routine // need to run in another go routine
_, err := session.transportToDst(datagram.payload) session.transportToDst(datagram.Payload)
if err != nil {
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
}
} }

View File

@ -14,22 +14,30 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
)
var (
nopLogger = zerolog.Nop()
) )
func TestManagerServe(t *testing.T) { func TestManagerServe(t *testing.T) {
const ( const (
sessions = 20 sessions = 2
msgs = 50 msgs = 5
remoteUnregisterMsg = "eyeball closed connection" remoteUnregisterMsg = "eyeball closed connection"
) )
mg, transport := newTestManager(1) requestChan := make(chan *quicpogs.SessionDatagram)
transport := mockQUICTransport{
eyeballTracker := make(map[uuid.UUID]*datagramChannel) sessions: make(map[uuid.UUID]chan []byte),
for i := 0; i < sessions; i++ {
sessionID := uuid.New()
eyeballTracker[sessionID] = newDatagramChannel(1)
} }
for i := 0; i < sessions; i++ {
transport.sessions[uuid.New()] = make(chan []byte)
}
mg := NewManager(&nopLogger, transport.MuxSession, requestChan)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
serveDone := make(chan struct{}) serveDone := make(chan struct{})
@ -38,25 +46,11 @@ func TestManagerServe(t *testing.T) {
close(serveDone) close(serveDone)
}(ctx) }(ctx)
go func(ctx context.Context) {
for {
sessionID, payload, err := transport.respChan.Receive(ctx)
if err != nil {
require.Equal(t, context.Canceled, err)
return
}
respChan := eyeballTracker[sessionID]
require.NoError(t, respChan.Send(ctx, sessionID, payload))
}
}(ctx)
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
for sID, receiver := range eyeballTracker { for sessionID, eyeballRespChan := range transport.sessions {
// Assign loop variables to local variables // Assign loop variables to local variables
sessionID := sID sID := sessionID
eyeballRespReceiver := receiver payload := testPayload(sID)
errGroup.Go(func() error {
payload := testPayload(sessionID)
expectResp := testResponse(payload) expectResp := testResponse(payload)
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
@ -67,24 +61,27 @@ func TestManagerServe(t *testing.T) {
expectedResp: expectResp, expectedResp: expectResp,
conn: originConn, conn: originConn,
} }
eyeball := mockEyeball{
expectMsgCount: msgs, eyeball := mockEyeballSession{
expectedMsg: expectResp, id: sID,
expectSessionID: sessionID, expectedMsgCount: msgs,
respReceiver: eyeballRespReceiver, expectedMsg: payload,
expectedResponse: expectResp,
respReceiver: eyeballRespChan,
} }
// Assign loop variables to local variables
errGroup.Go(func() error {
session, err := mg.RegisterSession(ctx, sID, cfdConn)
require.NoError(t, err)
reqErrGroup, reqCtx := errgroup.WithContext(ctx) reqErrGroup, reqCtx := errgroup.WithContext(ctx)
reqErrGroup.Go(func() error { reqErrGroup.Go(func() error {
return origin.serve() return origin.serve()
}) })
reqErrGroup.Go(func() error { reqErrGroup.Go(func() error {
return eyeball.serve(reqCtx) return eyeball.serve(reqCtx, requestChan)
}) })
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
sessionDone := make(chan struct{}) sessionDone := make(chan struct{})
go func() { go func() {
closedByRemote, err := session.Serve(ctx, time.Minute*2) closedByRemote, err := session.Serve(ctx, time.Minute*2)
@ -97,23 +94,17 @@ func TestManagerServe(t *testing.T) {
close(sessionDone) close(sessionDone)
}() }()
for i := 0; i < msgs; i++ {
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
}
// Make sure eyeball and origin have received all messages before unregistering the session // Make sure eyeball and origin have received all messages before unregistering the session
require.NoError(t, reqErrGroup.Wait()) require.NoError(t, reqErrGroup.Wait())
require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true)) require.NoError(t, mg.UnregisterSession(ctx, sID, remoteUnregisterMsg, true))
<-sessionDone <-sessionDone
return nil return nil
}) })
} }
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
cancel() cancel()
transport.close()
<-serveDone <-serveDone
} }
@ -122,7 +113,7 @@ func TestTimeout(t *testing.T) {
testTimeout = time.Millisecond * 50 testTimeout = time.Millisecond * 50
) )
mg, _ := newTestManager(1) mg := NewManager(&nopLogger, nil, nil)
mg.timeout = testTimeout mg.timeout = testTimeout
ctx := context.Background() ctx := context.Background()
sessionID := uuid.New() sessionID := uuid.New()
@ -135,9 +126,51 @@ func TestTimeout(t *testing.T) {
require.ErrorIs(t, err, context.DeadlineExceeded) require.ErrorIs(t, err, context.DeadlineExceeded)
} }
func TestCloseTransportCloseSessions(t *testing.T) { func TestUnregisterSessionCloseSession(t *testing.T) {
mg, transport := newTestManager(1) sessionID := uuid.New()
ctx := context.Background() payload := []byte(t.Name())
sender := newMockTransportSender(sessionID, payload)
mg := NewManager(&nopLogger, sender.muxSession, nil)
ctx, cancel := context.WithCancel(context.Background())
managerDone := make(chan struct{})
go func() {
err := mg.Serve(ctx)
require.Error(t, err)
close(managerDone)
}()
cfdConn, originConn := net.Pipe()
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
require.NotNil(t, session)
unregisteredChan := make(chan struct{})
go func() {
_, err := originConn.Write(payload)
require.NoError(t, err)
err = mg.UnregisterSession(ctx, sessionID, "eyeball closed session", true)
require.NoError(t, err)
close(unregisteredChan)
}()
closedByRemote, err := session.Serve(ctx, time.Minute)
require.True(t, closedByRemote)
require.Error(t, err)
<-unregisteredChan
cancel()
<-managerDone
}
func TestManagerCtxDoneCloseSessions(t *testing.T) {
sessionID := uuid.New()
payload := []byte(t.Name())
sender := newMockTransportSender(sessionID, payload)
mg := NewManager(&nopLogger, sender.muxSession, nil)
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -147,35 +180,26 @@ func TestCloseTransportCloseSessions(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}() }()
cfdConn, eyeballConn := net.Pipe() cfdConn, originConn := net.Pipe()
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn) session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, session) require.NotNil(t, session)
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := eyeballConn.Write([]byte(t.Name())) _, err := originConn.Write(payload)
require.NoError(t, err) require.NoError(t, err)
transport.close() cancel()
}() }()
closedByRemote, err := session.Serve(ctx, time.Minute) closedByRemote, err := session.Serve(ctx, time.Minute)
require.True(t, closedByRemote) require.False(t, closedByRemote)
require.Error(t, err) require.Error(t, err)
wg.Wait() wg.Wait()
} }
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(capacity),
respChan: newDatagramChannel(capacity),
}
return NewManager(transport, &log), transport
}
type mockOrigin struct { type mockOrigin struct {
expectMsgCount int expectMsgCount int
expectedMsg []byte expectedMsg []byte
@ -197,7 +221,6 @@ func (mo *mockOrigin) serve() error {
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) { if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n]) return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
} }
_, err = mo.conn.Write(mo.expectedResp) _, err = mo.conn.Write(mo.expectedResp)
if err != nil { if err != nil {
return err return err
@ -214,72 +237,35 @@ func testResponse(msg []byte) []byte {
return []byte(fmt.Sprintf("Response to %v", msg)) return []byte(fmt.Sprintf("Response to %v", msg))
} }
type mockEyeball struct { type mockQUICTransport struct {
expectMsgCount int sessions map[uuid.UUID]chan []byte
}
func (me *mockQUICTransport) MuxSession(id uuid.UUID, payload []byte) error {
session := me.sessions[id]
session <- payload
return nil
}
type mockEyeballSession struct {
id uuid.UUID
expectedMsgCount int
expectedMsg []byte expectedMsg []byte
expectSessionID uuid.UUID expectedResponse []byte
respReceiver *datagramChannel respReceiver <-chan []byte
} }
func (me *mockEyeball) serve(ctx context.Context) error { func (me *mockEyeballSession) serve(ctx context.Context, requestChan chan *quicpogs.SessionDatagram) error {
for i := 0; i < me.expectMsgCount; i++ { for i := 0; i < me.expectedMsgCount; i++ {
sessionID, msg, err := me.respReceiver.Receive(ctx) requestChan <- &quicpogs.SessionDatagram{
if err != nil { ID: me.id,
return err Payload: me.expectedMsg,
} }
if sessionID != me.expectSessionID { resp := <-me.respReceiver
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID) if !bytes.Equal(resp, me.expectedResponse) {
} return fmt.Errorf("Expect %v, read %v", me.expectedResponse, resp)
if !bytes.Equal(msg, me.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
} }
fmt.Println("Resp", resp)
} }
return nil return nil
} }
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
type datagramChannel struct {
datagramChan chan *newDatagram
closedChan chan struct{}
}
func newDatagramChannel(capacity uint) *datagramChannel {
return &datagramChannel{
datagramChan: make(chan *newDatagram, capacity),
closedChan: make(chan struct{}),
}
}
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-rc.closedChan:
return &errClosedSession{
message: fmt.Errorf("datagram channel closed").Error(),
byRemote: true,
}
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
return nil
}
}
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
select {
case <-ctx.Done():
return uuid.Nil, nil, ctx.Err()
case <-rc.closedChan:
err := &errClosedSession{
message: fmt.Errorf("datagram channel closed").Error(),
byRemote: true,
}
return uuid.Nil, nil, err
case msg := <-rc.datagramChan:
return msg.sessionID, msg.payload, nil
}
}
func (rc *datagramChannel) Close() {
// No need to close msgChan, it will be garbage collect once there is no reference to it
close(rc.closedChan)
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -18,20 +19,19 @@ func SessionIdleErr(timeout time.Duration) error {
return fmt.Errorf("session idle for %v", timeout) return fmt.Errorf("session idle for %v", timeout)
} }
type transportSender func(sessionID uuid.UUID, payload []byte) error
// Session is a bidirectional pipe of datagrams between transport and dstConn // Session is a bidirectional pipe of datagrams between transport and dstConn
// Currently the only implementation of transport is quic DatagramMuxer
// Destination can be a connection with origin or with eyeball // Destination can be a connection with origin or with eyeball
// When the destination is origin: // When the destination is origin:
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the // - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to origin
// write method of the Session to send to origin // - Datagrams from origin are read from conn and Send to transport using the transportSender callback. Transport will return them to eyeball
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
// When the destination is eyeball: // When the destination is eyeball:
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared // - Datagrams from eyeball are read from conn and Send to transport. Transport will send them to cloudflared using the transportSender callback.
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the // - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to the eyeball
// write method of the Session to send to eyeball
type Session struct { type Session struct {
ID uuid.UUID ID uuid.UUID
transport transport sendFunc transportSender
dstConn io.ReadWriteCloser dstConn io.ReadWriteCloser
// activeAtChan is used to communicate the last read/write time // activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time activeAtChan chan time.Time
@ -46,11 +46,18 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos
const maxPacketSize = 1500 const maxPacketSize = 1500
readBuffer := make([]byte, maxPacketSize) readBuffer := make([]byte, maxPacketSize)
for { for {
if err := s.dstToTransport(readBuffer); err != nil { if closeSession, err := s.dstToTransport(readBuffer); err != nil {
if err != net.ErrClosed {
s.log.Error().Err(err).Msg("Failed to send session payload from destination to transport")
} else {
s.log.Debug().Msg("Session cannot read from destination because the connection is closed")
}
if closeSession {
s.closeChan <- err s.closeChan <- err
return return
} }
} }
}
}() }()
err = s.waitForCloseCondition(ctx, closeAfterIdle) err = s.waitForCloseCondition(ctx, closeAfterIdle)
if closeSession, ok := err.(*errClosedSession); ok { if closeSession, ok := err.(*errClosedSession); ok {
@ -89,32 +96,25 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
} }
} }
func (s *Session) dstToTransport(buffer []byte) error { func (s *Session) dstToTransport(buffer []byte) (closeSession bool, err error) {
n, err := s.dstConn.Read(buffer) n, err := s.dstConn.Read(buffer)
s.markActive() s.markActive()
// https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes // https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes
if n > 0 { if n > 0 || err == nil {
if n <= int(s.transport.MTU()) { if sendErr := s.sendFunc(s.ID, buffer[:n]); sendErr != nil {
err = s.transport.SendTo(s.ID, buffer[:n]) return false, sendErr
} else {
// drop packet for now, eventually reply with ICMP for PMTUD
s.log.Debug().
Str("session", s.ID.String()).
Int("len", n).
Int("mtu", s.transport.MTU()).
Msg("dropped packet exceeding MTU")
} }
} }
// Some UDP application might send 0-size payload. return err != nil, err
if err == nil && n == 0 {
err = s.transport.SendTo(s.ID, []byte{})
}
return err
} }
func (s *Session) transportToDst(payload []byte) (int, error) { func (s *Session) transportToDst(payload []byte) (int, error) {
s.markActive() s.markActive()
return s.dstConn.Write(payload) n, err := s.dstConn.Write(payload)
if err != nil {
s.log.Err(err).Msg("Failed to write payload to session")
}
return n, err
} }
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there // Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there

View File

@ -11,8 +11,11 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
quicpogs "github.com/cloudflare/cloudflared/quic"
) )
// TestCloseSession makes sure a session will stop after context is done // TestCloseSession makes sure a session will stop after context is done
@ -41,7 +44,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID) payload := testPayload(sessionID)
mg, _ := newTestManager(1) log := zerolog.Nop()
mg := NewManager(&log, nil, nil)
session := mg.newSession(sessionID, cfdConn) session := mg.newSession(sessionID, cfdConn)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -114,7 +118,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID) payload := testPayload(sessionID)
mg, _ := newTestManager(100) respChan := make(chan *quicpogs.SessionDatagram)
sender := newMockTransportSender(sessionID, payload)
mg := NewManager(&nopLogger, sender.muxSession, respChan)
session := mg.newSession(sessionID, cfdConn) session := mg.newSession(sessionID, cfdConn)
startTime := time.Now() startTime := time.Now()
@ -177,7 +183,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
func TestMarkActiveNotBlocking(t *testing.T) { func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50 const concurrentCalls = 50
mg, _ := newTestManager(1) mg := NewManager(&nopLogger, nil, nil)
session := mg.newSession(uuid.New(), nil) session := mg.newSession(uuid.New(), nil)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(concurrentCalls) wg.Add(concurrentCalls)
@ -190,11 +196,16 @@ func TestMarkActiveNotBlocking(t *testing.T) {
wg.Wait() wg.Wait()
} }
// Some UDP application might send 0-size payload.
func TestZeroBytePayload(t *testing.T) { func TestZeroBytePayload(t *testing.T) {
sessionID := uuid.New() sessionID := uuid.New()
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
mg, transport := newTestManager(1) sender := sendOnceTransportSender{
baseSender: newMockTransportSender(sessionID, make([]byte, 0)),
sentChan: make(chan struct{}),
}
mg := NewManager(&nopLogger, sender.muxSession, nil)
session := mg.newSession(sessionID, cfdConn) session := mg.newSession(sessionID, cfdConn)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -215,11 +226,39 @@ func TestZeroBytePayload(t *testing.T) {
return nil return nil
}) })
receivedSessionID, payload, err := transport.respChan.Receive(ctx) <-sender.sentChan
require.NoError(t, err)
require.Len(t, payload, 0)
require.Equal(t, sessionID, receivedSessionID)
cancel() cancel()
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
} }
type mockTransportSender struct {
expectedSessionID uuid.UUID
expectedPayload []byte
}
func newMockTransportSender(expectedSessionID uuid.UUID, expectedPayload []byte) *mockTransportSender {
return &mockTransportSender{
expectedSessionID: expectedSessionID,
expectedPayload: expectedPayload,
}
}
func (mts *mockTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
if sessionID != mts.expectedSessionID {
return fmt.Errorf("Expect session %s, got %s", mts.expectedSessionID, sessionID)
}
if !bytes.Equal(payload, mts.expectedPayload) {
return fmt.Errorf("Expect %v, read %v", mts.expectedPayload, payload)
}
return nil
}
type sendOnceTransportSender struct {
baseSender *mockTransportSender
sentChan chan struct{}
}
func (sots *sendOnceTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
defer close(sots.sentChan)
return sots.baseSender.muxSession(sessionID, payload)
}

View File

@ -1,13 +0,0 @@
package datagramsession
import "github.com/google/uuid"
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
type transport interface {
// SendTo writes payload for a session to the transport
SendTo(sessionID uuid.UUID, payload []byte) error
// ReceiveFrom reads the next datagram from the transport
ReceiveFrom() (uuid.UUID, []byte, error)
// Max transmission unit to receive from the transport
MTU() int
}

View File

@ -1,36 +0,0 @@
package datagramsession
import (
"context"
"github.com/google/uuid"
)
type mockQUICTransport struct {
reqChan *datagramChannel
respChan *datagramChannel
}
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
buf := make([]byte, len(payload))
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
copy(buf, payload)
return mt.respChan.Send(context.Background(), sessionID, buf)
}
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
return mt.reqChan.Receive(context.Background())
}
func (mt *mockQUICTransport) MTU() int {
return 1280
}
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
return mt.reqChan.Send(ctx, sessionID, payload)
}
func (mt *mockQUICTransport) close() {
mt.reqChan.Close()
mt.respChan.Close()
}

View File

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"context"
"fmt" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
@ -13,53 +14,88 @@ const (
sessionIDLen = len(uuid.UUID{}) sessionIDLen = len(uuid.UUID{})
) )
type SessionDatagram struct {
ID uuid.UUID
Payload []byte
}
type BaseDatagramMuxer interface {
// MuxSession suffix the session ID to the payload so the other end of the QUIC connection can demultiplex the
// payload from multiple datagram sessions
MuxSession(sessionID uuid.UUID, payload []byte) error
// ServeReceive starts a loop to receive datagrams from the QUIC connection
ServeReceive(ctx context.Context) error
}
type DatagramMuxer struct { type DatagramMuxer struct {
session quic.Connection session quic.Connection
logger *zerolog.Logger logger *zerolog.Logger
demuxChan chan<- *SessionDatagram
} }
func NewDatagramMuxer(quicSession quic.Connection, logger *zerolog.Logger) (*DatagramMuxer, error) { func NewDatagramMuxer(quicSession quic.Connection, log *zerolog.Logger, demuxChan chan<- *SessionDatagram) *DatagramMuxer {
logger := log.With().Uint8("datagramVersion", 1).Logger()
return &DatagramMuxer{ return &DatagramMuxer{
session: quicSession, session: quicSession,
logger: logger, logger: &logger,
}, nil demuxChan: demuxChan,
}
} }
// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex // Maximum application payload to send to / receive from QUIC datagram frame
// the payload from multiple datagram sessions func (dm *DatagramMuxer) mtu() int {
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error { return maxDatagramPayloadSize
if len(payload) > maxDatagramPayloadSize { }
func (dm *DatagramMuxer) MuxSession(sessionID uuid.UUID, payload []byte) error {
if len(payload) > dm.mtu() {
// TODO: TUN-5302 return ICMP packet too big message // TODO: TUN-5302 return ICMP packet too big message
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.MTU()) // drop packet for now, eventually reply with ICMP for PMTUD
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
} }
msgWithID, err := suffixSessionID(sessionID, payload) payloadWithMetadata, err := suffixSessionID(sessionID, payload)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped") return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
} }
if err := dm.session.SendMessage(msgWithID); err != nil { if err := dm.session.SendMessage(payloadWithMetadata); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge") return errors.Wrap(err, "Failed to send datagram back to edge")
} }
return nil return nil
} }
// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error {
// which determines how to proxy to the origin. It assumes the datagram session has already been for {
// registered with session manager through other side channel // Extracts datagram session ID, then sends the session ID and payload to receiver
func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) { // which determines how to proxy to the origin. It assumes the datagram session has already been
// registered with receiver through other side channel
msg, err := dm.session.ReceiveMessage() msg, err := dm.session.ReceiveMessage()
if err != nil { if err != nil {
return uuid.Nil, nil, err return err
}
if err := dm.demux(ctx, msg); err != nil {
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
if err == context.Canceled {
return err
}
} }
sessionID, payload, err := extractSessionID(msg)
if err != nil {
return uuid.Nil, nil, err
} }
return sessionID, payload, nil
} }
// Maximum application payload to send to / receive from QUIC datagram frame func (dm *DatagramMuxer) demux(ctx context.Context, msg []byte) error {
func (dm *DatagramMuxer) MTU() int { sessionID, payload, err := extractSessionID(msg)
return maxDatagramPayloadSize if err != nil {
return err
}
sessionDatagram := SessionDatagram{
ID: sessionID,
Payload: payload,
}
select {
case dm.demuxChan <- &sessionDatagram:
return nil
case <-ctx.Done():
return ctx.Err()
}
} }
// Each QUIC datagram should be suffixed with session ID. // Each QUIC datagram should be suffixed with session ID.

View File

@ -8,6 +8,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt"
"math/big" "math/big"
"testing" "testing"
"time" "time"
@ -52,9 +53,29 @@ func TestSuffixSessionIDError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestMaxDatagramPayload(t *testing.T) { func TestDatagram(t *testing.T) {
payload := make([]byte, maxDatagramPayloadSize) maxPayload := make([]byte, maxDatagramPayloadSize)
noPayloadSession := uuid.New()
maxPayloadSession := uuid.New()
sessionToPayload := []*SessionDatagram{
{
ID: noPayloadSession,
Payload: make([]byte, 0),
},
{
ID: maxPayloadSession,
Payload: maxPayload,
},
}
flowPayloads := [][]byte{
maxPayload,
}
testDatagram(t, 1, sessionToPayload, nil)
testDatagram(t, 2, sessionToPayload, flowPayloads)
}
func testDatagram(t *testing.T, version uint8, sessionToPayloads []*SessionDatagram, packetPayloads [][]byte) {
quicConfig := &quic.Config{ quicConfig := &quic.Config{
KeepAlivePeriod: 5 * time.Millisecond, KeepAlivePeriod: 5 * time.Millisecond,
EnableDatagrams: true, EnableDatagrams: true,
@ -63,6 +84,8 @@ func TestMaxDatagramPayload(t *testing.T) {
quicListener := newQUICListener(t, quicConfig) quicListener := newQUICListener(t, quicConfig)
defer quicListener.Close() defer quicListener.Close()
logger := zerolog.Nop()
errGroup, ctx := errgroup.WithContext(context.Background()) errGroup, ctx := errgroup.WithContext(context.Background())
// Run edge side of datagram muxer // Run edge side of datagram muxer
errGroup.Go(func() error { errGroup.Go(func() error {
@ -72,22 +95,32 @@ func TestMaxDatagramPayload(t *testing.T) {
return err return err
} }
logger := zerolog.Nop() sessionDemuxChan := make(chan *SessionDatagram, 16)
muxer, err := NewDatagramMuxer(quicSession, &logger)
if err != nil { switch version {
return err case 1:
muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan)
muxer.ServeReceive(ctx)
case 2:
packetDemuxChan := make(chan []byte, len(packetPayloads))
muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan)
muxer.ServeReceive(ctx)
for _, expectedPayload := range packetPayloads {
require.Equal(t, expectedPayload, <-packetDemuxChan)
}
default:
return fmt.Errorf("unknown datagram version %d", version)
} }
sessionID, receivedPayload, err := muxer.ReceiveFrom() for _, expectedPayload := range sessionToPayloads {
if err != nil { actualPayload := <-sessionDemuxChan
return err require.Equal(t, expectedPayload, actualPayload)
} }
require.Equal(t, testSessionID, sessionID)
require.True(t, bytes.Equal(payload, receivedPayload))
return nil return nil
}) })
largePayload := make([]byte, MaxDatagramFrameSize)
// Run cloudflared side of datagram muxer // Run cloudflared side of datagram muxer
errGroup.Go(func() error { errGroup.Go(func() error {
tlsClientConfig := &tls.Config{ tlsClientConfig := &tls.Config{
@ -97,24 +130,35 @@ func TestMaxDatagramPayload(t *testing.T) {
// Establish quic connection // Establish quic connection
quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig) quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig)
require.NoError(t, err) require.NoError(t, err)
defer quicSession.CloseWithError(0, "")
logger := zerolog.Nop()
muxer, err := NewDatagramMuxer(quicSession, &logger)
if err != nil {
return err
}
// Wait a few milliseconds for MTU discovery to take place // Wait a few milliseconds for MTU discovery to take place
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
err = muxer.SendTo(testSessionID, payload)
if err != nil { var muxer BaseDatagramMuxer
return err switch version {
case 1:
muxer = NewDatagramMuxer(quicSession, &logger, nil)
case 2:
muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil)
for _, payload := range packetPayloads {
require.NoError(t, muxerV2.MuxPacket(payload))
}
// Payload larger than transport MTU, should not be sent
require.Error(t, muxerV2.MuxPacket(largePayload))
muxer = muxerV2
default:
return fmt.Errorf("unknown datagram version %d", version)
} }
// Payload larger than transport MTU, should return an error for _, sessionDatagram := range sessionToPayloads {
largePayload := make([]byte, MaxDatagramFrameSize) require.NoError(t, muxer.MuxSession(sessionDatagram.ID, sessionDatagram.Payload))
err = muxer.SendTo(testSessionID, largePayload) }
require.Error(t, err) // Payload larger than transport MTU, should not be sent
require.Error(t, muxer.MuxSession(testSessionID, largePayload))
// Wait for edge to finish receiving the messages
time.Sleep(time.Millisecond * 100)
return nil return nil
}) })
@ -154,3 +198,35 @@ func generateTLSConfig() *tls.Config {
NextProtos: []string{"argotunnel"}, NextProtos: []string{"argotunnel"},
} }
} }
type sessionMuxer interface {
SendToSession(sessionID uuid.UUID, payload []byte) error
}
type mockSessionReceiver struct {
expectedSessionToPayload map[uuid.UUID][]byte
receivedCount int
}
func (msr *mockSessionReceiver) ReceiveDatagram(sessionID uuid.UUID, payload []byte) error {
expectedPayload := msr.expectedSessionToPayload[sessionID]
if !bytes.Equal(expectedPayload, payload) {
return fmt.Errorf("expect %v to have payload %s, got %s", sessionID, string(expectedPayload), string(payload))
}
msr.receivedCount++
return nil
}
type mockFlowReceiver struct {
expectedPayloads [][]byte
receivedCount int
}
func (mfr *mockFlowReceiver) ReceiveFlow(payload []byte) error {
expectedPayload := mfr.expectedPayloads[mfr.receivedCount]
if !bytes.Equal(expectedPayload, payload) {
return fmt.Errorf("expect flow %d to have payload %s, got %s", mfr.receivedCount, string(expectedPayload), string(payload))
}
mfr.receivedCount++
return nil
}

136
quic/datagramv2.go Normal file
View File

@ -0,0 +1,136 @@
package quic
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
type datagramV2Type byte
const (
udp datagramV2Type = iota
ip
)
func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) {
if len(b)+1 > MaxDatagramFrameSize {
return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize)
}
b = append(b, byte(datagramType))
return b, nil
}
// Maximum application payload to send to / receive from QUIC datagram frame
func (dm *DatagramMuxerV2) mtu() int {
return maxDatagramPayloadSize
}
type DatagramMuxerV2 struct {
session quic.Connection
logger *zerolog.Logger
sessionDemuxChan chan<- *SessionDatagram
packetDemuxChan chan<- []byte
}
func NewDatagramMuxerV2(
quicSession quic.Connection,
log *zerolog.Logger,
sessionDemuxChan chan<- *SessionDatagram,
packetDemuxChan chan<- []byte) *DatagramMuxerV2 {
logger := log.With().Uint8("datagramVersion", 2).Logger()
return &DatagramMuxerV2{
session: quicSession,
logger: &logger,
sessionDemuxChan: sessionDemuxChan,
packetDemuxChan: packetDemuxChan,
}
}
// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
// demultiplex the payload from multiple datagram sessions
func (dm *DatagramMuxerV2) MuxSession(sessionID uuid.UUID, payload []byte) error {
if len(payload) > dm.mtu() {
// TODO: TUN-5302 return ICMP packet too big message
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
}
msgWithID, err := suffixSessionID(sessionID, payload)
if err != nil {
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
}
msgWithIDAndType, err := suffixType(msgWithID, udp)
if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
}
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
}
// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing
// the payload as IP and look at the source and destination.
func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error {
payloadWithVersion, err := suffixType(packet, ip)
if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
}
if err := dm.session.SendMessage(payloadWithVersion); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
}
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
for {
msg, err := dm.session.ReceiveMessage()
if err != nil {
return err
}
if err := dm.demux(ctx, msg); err != nil {
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
if err == context.Canceled {
return err
}
}
}
}
func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
if len(msgWithType) < 1 {
return fmt.Errorf("QUIC datagram should have at least 1 byte")
}
msgType := datagramV2Type(msgWithType[len(msgWithType)-1])
msg := msgWithType[0 : len(msgWithType)-1]
switch msgType {
case udp:
sessionID, payload, err := extractSessionID(msg)
if err != nil {
return err
}
sessionDatagram := SessionDatagram{
ID: sessionID,
Payload: payload,
}
select {
case dm.sessionDemuxChan <- &sessionDatagram:
return nil
case <-ctx.Done():
return ctx.Err()
}
case ip:
select {
case dm.packetDemuxChan <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
}
default:
return fmt.Errorf("Unexpected datagram type %d", msgType)
}
}

View File

@ -33,6 +33,7 @@ const (
FeatureSerializedHeaders = "serialized_headers" FeatureSerializedHeaders = "serialized_headers"
FeatureQuickReconnects = "quick_reconnects" FeatureQuickReconnects = "quick_reconnects"
FeatureAllowRemoteConfig = "allow_remote_config" FeatureAllowRemoteConfig = "allow_remote_config"
FeatureDatagramV2 = "support_datagram_v2"
) )
type TunnelConfig struct { type TunnelConfig struct {