TUN-6584: Define QUIC datagram v2 format to support proxying IP packets
This commit is contained in:
parent
d3fd581b7b
commit
278df5478a
|
@ -33,14 +33,19 @@ const (
|
|||
HTTPHostKey = "HttpHost"
|
||||
|
||||
QUICMetadataFlowID = "FlowID"
|
||||
// emperically this capacity has been working well
|
||||
demuxChanCapacity = 16
|
||||
)
|
||||
|
||||
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
||||
type QUICConnection struct {
|
||||
session quic.Connection
|
||||
logger *zerolog.Logger
|
||||
orchestrator Orchestrator
|
||||
sessionManager datagramsession.Manager
|
||||
session quic.Connection
|
||||
logger *zerolog.Logger
|
||||
orchestrator Orchestrator
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
sessionManager datagramsession.Manager
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer *quicpogs.DatagramMuxer
|
||||
controlStreamHandler ControlStreamHandler
|
||||
connOptions *tunnelpogs.ConnectionOptions
|
||||
}
|
||||
|
@ -60,18 +65,16 @@ func NewQUICConnection(
|
|||
return nil, &EdgeQuicDialError{Cause: err}
|
||||
}
|
||||
|
||||
datagramMuxer, err := quicpogs.NewDatagramMuxer(session, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionManager := datagramsession.NewManager(datagramMuxer, logger)
|
||||
demuxChan := make(chan *quicpogs.SessionDatagram, demuxChanCapacity)
|
||||
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan)
|
||||
sessionManager := datagramsession.NewManager(logger, datagramMuxer.MuxSession, demuxChan)
|
||||
|
||||
return &QUICConnection{
|
||||
session: session,
|
||||
orchestrator: orchestrator,
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
datagramMuxer: datagramMuxer,
|
||||
controlStreamHandler: controlStreamHandler,
|
||||
connOptions: connOptions,
|
||||
}, nil
|
||||
|
@ -108,6 +111,11 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
|
|||
return q.sessionManager.Serve(ctx)
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.datagramMuxer.ServeReceive(ctx)
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
|
|
|
@ -42,9 +42,3 @@ func (sc *errClosedSession) Error() string {
|
|||
return fmt.Sprintf("session closed by local due to %s", sc.message)
|
||||
}
|
||||
}
|
||||
|
||||
// newDatagram is an event when transport receives new datagram
|
||||
type newDatagram struct {
|
||||
sessionID uuid.UUID
|
||||
payload []byte
|
||||
}
|
||||
|
|
|
@ -7,9 +7,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -36,26 +36,25 @@ type Manager interface {
|
|||
type manager struct {
|
||||
registrationChan chan *registerSessionEvent
|
||||
unregistrationChan chan *unregisterSessionEvent
|
||||
datagramChan chan *newDatagram
|
||||
closedChan chan struct{}
|
||||
transport transport
|
||||
sendFunc transportSender
|
||||
receiveChan <-chan *quicpogs.SessionDatagram
|
||||
closedChan <-chan struct{}
|
||||
sessions map[uuid.UUID]*Session
|
||||
log *zerolog.Logger
|
||||
// timeout waiting for an API to finish. This can be overriden in test
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewManager(transport transport, log *zerolog.Logger) *manager {
|
||||
func NewManager(log *zerolog.Logger, sendF transportSender, receiveChan <-chan *quicpogs.SessionDatagram) *manager {
|
||||
return &manager{
|
||||
registrationChan: make(chan *registerSessionEvent),
|
||||
unregistrationChan: make(chan *unregisterSessionEvent),
|
||||
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
|
||||
datagramChan: make(chan *newDatagram, requestChanCapacity),
|
||||
closedChan: make(chan struct{}),
|
||||
transport: transport,
|
||||
sessions: make(map[uuid.UUID]*Session),
|
||||
log: log,
|
||||
timeout: defaultReqTimeout,
|
||||
sendFunc: sendF,
|
||||
receiveChan: receiveChan,
|
||||
closedChan: make(chan struct{}),
|
||||
sessions: make(map[uuid.UUID]*Session),
|
||||
log: log,
|
||||
timeout: defaultReqTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,49 +64,21 @@ func (m *manager) UpdateLogger(log *zerolog.Logger) {
|
|||
}
|
||||
|
||||
func (m *manager) Serve(ctx context.Context) error {
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
sessionID, payload, err := m.transport.ReceiveFrom()
|
||||
if err != nil {
|
||||
if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
datagram := &newDatagram{
|
||||
sessionID: sessionID,
|
||||
payload: payload,
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
|
||||
// Send the datagram to the event loop. It will find the session to send to
|
||||
case m.datagramChan <- datagram:
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.shutdownSessions(ctx.Err())
|
||||
return ctx.Err()
|
||||
// receiveChan is buffered, so the transport can read more datagrams from transport while the event loop is
|
||||
// processing other events
|
||||
case datagram := <-m.receiveChan:
|
||||
m.sendToSession(datagram)
|
||||
case registration := <-m.registrationChan:
|
||||
m.registerSession(ctx, registration)
|
||||
case unregistration := <-m.unregistrationChan:
|
||||
m.unregisterSession(unregistration)
|
||||
}
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case datagram := <-m.datagramChan:
|
||||
m.sendToSession(datagram)
|
||||
case registration := <-m.registrationChan:
|
||||
m.registerSession(ctx, registration)
|
||||
// TODO: TUN-5422: Unregister inactive session upon timeout
|
||||
case unregistration := <-m.unregistrationChan:
|
||||
m.unregisterSession(unregistration)
|
||||
}
|
||||
}
|
||||
})
|
||||
err := errGroup.Wait()
|
||||
close(m.closedChan)
|
||||
m.shutdownSessions(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *manager) shutdownSessions(err error) {
|
||||
|
@ -149,16 +120,17 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
|
|||
}
|
||||
|
||||
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
|
||||
logger := m.log.With().Str("sessionID", id.String()).Logger()
|
||||
return &Session{
|
||||
ID: id,
|
||||
transport: m.transport,
|
||||
dstConn: dstConn,
|
||||
ID: id,
|
||||
sendFunc: m.sendFunc,
|
||||
dstConn: dstConn,
|
||||
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
||||
// drop instead of blocking because last active time only needs to be an approximation
|
||||
activeAtChan: make(chan time.Time, 2),
|
||||
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
|
||||
closeChan: make(chan error, 2),
|
||||
log: m.log,
|
||||
log: &logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,16 +163,13 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
|
|||
}
|
||||
}
|
||||
|
||||
func (m *manager) sendToSession(datagram *newDatagram) {
|
||||
session, ok := m.sessions[datagram.sessionID]
|
||||
func (m *manager) sendToSession(datagram *quicpogs.SessionDatagram) {
|
||||
session, ok := m.sessions[datagram.ID]
|
||||
if !ok {
|
||||
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found")
|
||||
m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found")
|
||||
return
|
||||
}
|
||||
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
||||
// need to run in another go routine
|
||||
_, err := session.transportToDst(datagram.payload)
|
||||
if err != nil {
|
||||
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
|
||||
}
|
||||
session.transportToDst(datagram.Payload)
|
||||
}
|
||||
|
|
|
@ -14,22 +14,30 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||
)
|
||||
|
||||
var (
|
||||
nopLogger = zerolog.Nop()
|
||||
)
|
||||
|
||||
func TestManagerServe(t *testing.T) {
|
||||
const (
|
||||
sessions = 20
|
||||
msgs = 50
|
||||
sessions = 2
|
||||
msgs = 5
|
||||
remoteUnregisterMsg = "eyeball closed connection"
|
||||
)
|
||||
|
||||
mg, transport := newTestManager(1)
|
||||
|
||||
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
|
||||
for i := 0; i < sessions; i++ {
|
||||
sessionID := uuid.New()
|
||||
eyeballTracker[sessionID] = newDatagramChannel(1)
|
||||
requestChan := make(chan *quicpogs.SessionDatagram)
|
||||
transport := mockQUICTransport{
|
||||
sessions: make(map[uuid.UUID]chan []byte),
|
||||
}
|
||||
for i := 0; i < sessions; i++ {
|
||||
transport.sessions[uuid.New()] = make(chan []byte)
|
||||
}
|
||||
|
||||
mg := NewManager(&nopLogger, transport.MuxSession, requestChan)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
serveDone := make(chan struct{})
|
||||
|
@ -38,53 +46,42 @@ func TestManagerServe(t *testing.T) {
|
|||
close(serveDone)
|
||||
}(ctx)
|
||||
|
||||
go func(ctx context.Context) {
|
||||
for {
|
||||
sessionID, payload, err := transport.respChan.Receive(ctx)
|
||||
if err != nil {
|
||||
require.Equal(t, context.Canceled, err)
|
||||
return
|
||||
}
|
||||
respChan := eyeballTracker[sessionID]
|
||||
require.NoError(t, respChan.Send(ctx, sessionID, payload))
|
||||
}
|
||||
}(ctx)
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
for sID, receiver := range eyeballTracker {
|
||||
for sessionID, eyeballRespChan := range transport.sessions {
|
||||
// Assign loop variables to local variables
|
||||
sID := sessionID
|
||||
payload := testPayload(sID)
|
||||
expectResp := testResponse(payload)
|
||||
|
||||
cfdConn, originConn := net.Pipe()
|
||||
|
||||
origin := mockOrigin{
|
||||
expectMsgCount: msgs,
|
||||
expectedMsg: payload,
|
||||
expectedResp: expectResp,
|
||||
conn: originConn,
|
||||
}
|
||||
|
||||
eyeball := mockEyeballSession{
|
||||
id: sID,
|
||||
expectedMsgCount: msgs,
|
||||
expectedMsg: payload,
|
||||
expectedResponse: expectResp,
|
||||
respReceiver: eyeballRespChan,
|
||||
}
|
||||
|
||||
// Assign loop variables to local variables
|
||||
sessionID := sID
|
||||
eyeballRespReceiver := receiver
|
||||
errGroup.Go(func() error {
|
||||
payload := testPayload(sessionID)
|
||||
expectResp := testResponse(payload)
|
||||
|
||||
cfdConn, originConn := net.Pipe()
|
||||
|
||||
origin := mockOrigin{
|
||||
expectMsgCount: msgs,
|
||||
expectedMsg: payload,
|
||||
expectedResp: expectResp,
|
||||
conn: originConn,
|
||||
}
|
||||
eyeball := mockEyeball{
|
||||
expectMsgCount: msgs,
|
||||
expectedMsg: expectResp,
|
||||
expectSessionID: sessionID,
|
||||
respReceiver: eyeballRespReceiver,
|
||||
}
|
||||
|
||||
session, err := mg.RegisterSession(ctx, sID, cfdConn)
|
||||
require.NoError(t, err)
|
||||
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
|
||||
reqErrGroup.Go(func() error {
|
||||
return origin.serve()
|
||||
})
|
||||
reqErrGroup.Go(func() error {
|
||||
return eyeball.serve(reqCtx)
|
||||
return eyeball.serve(reqCtx, requestChan)
|
||||
})
|
||||
|
||||
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionDone := make(chan struct{})
|
||||
go func() {
|
||||
closedByRemote, err := session.Serve(ctx, time.Minute*2)
|
||||
|
@ -97,23 +94,17 @@ func TestManagerServe(t *testing.T) {
|
|||
close(sessionDone)
|
||||
}()
|
||||
|
||||
for i := 0; i < msgs; i++ {
|
||||
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
|
||||
}
|
||||
|
||||
// Make sure eyeball and origin have received all messages before unregistering the session
|
||||
require.NoError(t, reqErrGroup.Wait())
|
||||
|
||||
require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true))
|
||||
require.NoError(t, mg.UnregisterSession(ctx, sID, remoteUnregisterMsg, true))
|
||||
<-sessionDone
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
require.NoError(t, errGroup.Wait())
|
||||
cancel()
|
||||
transport.close()
|
||||
<-serveDone
|
||||
}
|
||||
|
||||
|
@ -122,7 +113,7 @@ func TestTimeout(t *testing.T) {
|
|||
testTimeout = time.Millisecond * 50
|
||||
)
|
||||
|
||||
mg, _ := newTestManager(1)
|
||||
mg := NewManager(&nopLogger, nil, nil)
|
||||
mg.timeout = testTimeout
|
||||
ctx := context.Background()
|
||||
sessionID := uuid.New()
|
||||
|
@ -135,9 +126,51 @@ func TestTimeout(t *testing.T) {
|
|||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestCloseTransportCloseSessions(t *testing.T) {
|
||||
mg, transport := newTestManager(1)
|
||||
ctx := context.Background()
|
||||
func TestUnregisterSessionCloseSession(t *testing.T) {
|
||||
sessionID := uuid.New()
|
||||
payload := []byte(t.Name())
|
||||
sender := newMockTransportSender(sessionID, payload)
|
||||
mg := NewManager(&nopLogger, sender.muxSession, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
managerDone := make(chan struct{})
|
||||
go func() {
|
||||
err := mg.Serve(ctx)
|
||||
require.Error(t, err)
|
||||
close(managerDone)
|
||||
}()
|
||||
|
||||
cfdConn, originConn := net.Pipe()
|
||||
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
|
||||
unregisteredChan := make(chan struct{})
|
||||
go func() {
|
||||
_, err := originConn.Write(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mg.UnregisterSession(ctx, sessionID, "eyeball closed session", true)
|
||||
require.NoError(t, err)
|
||||
|
||||
close(unregisteredChan)
|
||||
}()
|
||||
|
||||
closedByRemote, err := session.Serve(ctx, time.Minute)
|
||||
require.True(t, closedByRemote)
|
||||
require.Error(t, err)
|
||||
|
||||
<-unregisteredChan
|
||||
cancel()
|
||||
<-managerDone
|
||||
}
|
||||
|
||||
func TestManagerCtxDoneCloseSessions(t *testing.T) {
|
||||
sessionID := uuid.New()
|
||||
payload := []byte(t.Name())
|
||||
sender := newMockTransportSender(sessionID, payload)
|
||||
mg := NewManager(&nopLogger, sender.muxSession, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
@ -147,35 +180,26 @@ func TestCloseTransportCloseSessions(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}()
|
||||
|
||||
cfdConn, eyeballConn := net.Pipe()
|
||||
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn)
|
||||
cfdConn, originConn := net.Pipe()
|
||||
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, session)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := eyeballConn.Write([]byte(t.Name()))
|
||||
_, err := originConn.Write(payload)
|
||||
require.NoError(t, err)
|
||||
transport.close()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
closedByRemote, err := session.Serve(ctx, time.Minute)
|
||||
require.True(t, closedByRemote)
|
||||
require.False(t, closedByRemote)
|
||||
require.Error(t, err)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
|
||||
log := zerolog.Nop()
|
||||
transport := &mockQUICTransport{
|
||||
reqChan: newDatagramChannel(capacity),
|
||||
respChan: newDatagramChannel(capacity),
|
||||
}
|
||||
return NewManager(transport, &log), transport
|
||||
}
|
||||
|
||||
type mockOrigin struct {
|
||||
expectMsgCount int
|
||||
expectedMsg []byte
|
||||
|
@ -197,7 +221,6 @@ func (mo *mockOrigin) serve() error {
|
|||
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
|
||||
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
|
||||
}
|
||||
|
||||
_, err = mo.conn.Write(mo.expectedResp)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -214,72 +237,35 @@ func testResponse(msg []byte) []byte {
|
|||
return []byte(fmt.Sprintf("Response to %v", msg))
|
||||
}
|
||||
|
||||
type mockEyeball struct {
|
||||
expectMsgCount int
|
||||
expectedMsg []byte
|
||||
expectSessionID uuid.UUID
|
||||
respReceiver *datagramChannel
|
||||
type mockQUICTransport struct {
|
||||
sessions map[uuid.UUID]chan []byte
|
||||
}
|
||||
|
||||
func (me *mockEyeball) serve(ctx context.Context) error {
|
||||
for i := 0; i < me.expectMsgCount; i++ {
|
||||
sessionID, msg, err := me.respReceiver.Receive(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sessionID != me.expectSessionID {
|
||||
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID)
|
||||
}
|
||||
if !bytes.Equal(msg, me.expectedMsg) {
|
||||
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
|
||||
}
|
||||
}
|
||||
func (me *mockQUICTransport) MuxSession(id uuid.UUID, payload []byte) error {
|
||||
session := me.sessions[id]
|
||||
session <- payload
|
||||
return nil
|
||||
}
|
||||
|
||||
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
|
||||
type datagramChannel struct {
|
||||
datagramChan chan *newDatagram
|
||||
closedChan chan struct{}
|
||||
type mockEyeballSession struct {
|
||||
id uuid.UUID
|
||||
expectedMsgCount int
|
||||
expectedMsg []byte
|
||||
expectedResponse []byte
|
||||
respReceiver <-chan []byte
|
||||
}
|
||||
|
||||
func newDatagramChannel(capacity uint) *datagramChannel {
|
||||
return &datagramChannel{
|
||||
datagramChan: make(chan *newDatagram, capacity),
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-rc.closedChan:
|
||||
return &errClosedSession{
|
||||
message: fmt.Errorf("datagram channel closed").Error(),
|
||||
byRemote: true,
|
||||
func (me *mockEyeballSession) serve(ctx context.Context, requestChan chan *quicpogs.SessionDatagram) error {
|
||||
for i := 0; i < me.expectedMsgCount; i++ {
|
||||
requestChan <- &quicpogs.SessionDatagram{
|
||||
ID: me.id,
|
||||
Payload: me.expectedMsg,
|
||||
}
|
||||
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return uuid.Nil, nil, ctx.Err()
|
||||
case <-rc.closedChan:
|
||||
err := &errClosedSession{
|
||||
message: fmt.Errorf("datagram channel closed").Error(),
|
||||
byRemote: true,
|
||||
resp := <-me.respReceiver
|
||||
if !bytes.Equal(resp, me.expectedResponse) {
|
||||
return fmt.Errorf("Expect %v, read %v", me.expectedResponse, resp)
|
||||
}
|
||||
return uuid.Nil, nil, err
|
||||
case msg := <-rc.datagramChan:
|
||||
return msg.sessionID, msg.payload, nil
|
||||
fmt.Println("Resp", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *datagramChannel) Close() {
|
||||
// No need to close msgChan, it will be garbage collect once there is no reference to it
|
||||
close(rc.closedChan)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -18,21 +19,20 @@ func SessionIdleErr(timeout time.Duration) error {
|
|||
return fmt.Errorf("session idle for %v", timeout)
|
||||
}
|
||||
|
||||
type transportSender func(sessionID uuid.UUID, payload []byte) error
|
||||
|
||||
// Session is a bidirectional pipe of datagrams between transport and dstConn
|
||||
// Currently the only implementation of transport is quic DatagramMuxer
|
||||
// Destination can be a connection with origin or with eyeball
|
||||
// When the destination is origin:
|
||||
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
||||
// write method of the Session to send to origin
|
||||
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
|
||||
// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to origin
|
||||
// - Datagrams from origin are read from conn and Send to transport using the transportSender callback. Transport will return them to eyeball
|
||||
// When the destination is eyeball:
|
||||
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
|
||||
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
||||
// write method of the Session to send to eyeball
|
||||
// - Datagrams from eyeball are read from conn and Send to transport. Transport will send them to cloudflared using the transportSender callback.
|
||||
// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to the eyeball
|
||||
type Session struct {
|
||||
ID uuid.UUID
|
||||
transport transport
|
||||
dstConn io.ReadWriteCloser
|
||||
ID uuid.UUID
|
||||
sendFunc transportSender
|
||||
dstConn io.ReadWriteCloser
|
||||
// activeAtChan is used to communicate the last read/write time
|
||||
activeAtChan chan time.Time
|
||||
closeChan chan error
|
||||
|
@ -46,9 +46,16 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos
|
|||
const maxPacketSize = 1500
|
||||
readBuffer := make([]byte, maxPacketSize)
|
||||
for {
|
||||
if err := s.dstToTransport(readBuffer); err != nil {
|
||||
s.closeChan <- err
|
||||
return
|
||||
if closeSession, err := s.dstToTransport(readBuffer); err != nil {
|
||||
if err != net.ErrClosed {
|
||||
s.log.Error().Err(err).Msg("Failed to send session payload from destination to transport")
|
||||
} else {
|
||||
s.log.Debug().Msg("Session cannot read from destination because the connection is closed")
|
||||
}
|
||||
if closeSession {
|
||||
s.closeChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -89,32 +96,25 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) dstToTransport(buffer []byte) error {
|
||||
func (s *Session) dstToTransport(buffer []byte) (closeSession bool, err error) {
|
||||
n, err := s.dstConn.Read(buffer)
|
||||
s.markActive()
|
||||
// https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes
|
||||
if n > 0 {
|
||||
if n <= int(s.transport.MTU()) {
|
||||
err = s.transport.SendTo(s.ID, buffer[:n])
|
||||
} else {
|
||||
// drop packet for now, eventually reply with ICMP for PMTUD
|
||||
s.log.Debug().
|
||||
Str("session", s.ID.String()).
|
||||
Int("len", n).
|
||||
Int("mtu", s.transport.MTU()).
|
||||
Msg("dropped packet exceeding MTU")
|
||||
if n > 0 || err == nil {
|
||||
if sendErr := s.sendFunc(s.ID, buffer[:n]); sendErr != nil {
|
||||
return false, sendErr
|
||||
}
|
||||
}
|
||||
// Some UDP application might send 0-size payload.
|
||||
if err == nil && n == 0 {
|
||||
err = s.transport.SendTo(s.ID, []byte{})
|
||||
}
|
||||
return err
|
||||
return err != nil, err
|
||||
}
|
||||
|
||||
func (s *Session) transportToDst(payload []byte) (int, error) {
|
||||
s.markActive()
|
||||
return s.dstConn.Write(payload)
|
||||
n, err := s.dstConn.Write(payload)
|
||||
if err != nil {
|
||||
s.log.Err(err).Msg("Failed to write payload to session")
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
|
||||
|
|
|
@ -11,8 +11,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||
)
|
||||
|
||||
// TestCloseSession makes sure a session will stop after context is done
|
||||
|
@ -41,7 +44,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
|
|||
cfdConn, originConn := net.Pipe()
|
||||
payload := testPayload(sessionID)
|
||||
|
||||
mg, _ := newTestManager(1)
|
||||
log := zerolog.Nop()
|
||||
mg := NewManager(&log, nil, nil)
|
||||
session := mg.newSession(sessionID, cfdConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -114,7 +118,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
|
|||
cfdConn, originConn := net.Pipe()
|
||||
payload := testPayload(sessionID)
|
||||
|
||||
mg, _ := newTestManager(100)
|
||||
respChan := make(chan *quicpogs.SessionDatagram)
|
||||
sender := newMockTransportSender(sessionID, payload)
|
||||
mg := NewManager(&nopLogger, sender.muxSession, respChan)
|
||||
session := mg.newSession(sessionID, cfdConn)
|
||||
|
||||
startTime := time.Now()
|
||||
|
@ -177,7 +183,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
|
|||
|
||||
func TestMarkActiveNotBlocking(t *testing.T) {
|
||||
const concurrentCalls = 50
|
||||
mg, _ := newTestManager(1)
|
||||
mg := NewManager(&nopLogger, nil, nil)
|
||||
session := mg.newSession(uuid.New(), nil)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentCalls)
|
||||
|
@ -190,11 +196,16 @@ func TestMarkActiveNotBlocking(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
// Some UDP application might send 0-size payload.
|
||||
func TestZeroBytePayload(t *testing.T) {
|
||||
sessionID := uuid.New()
|
||||
cfdConn, originConn := net.Pipe()
|
||||
|
||||
mg, transport := newTestManager(1)
|
||||
sender := sendOnceTransportSender{
|
||||
baseSender: newMockTransportSender(sessionID, make([]byte, 0)),
|
||||
sentChan: make(chan struct{}),
|
||||
}
|
||||
mg := NewManager(&nopLogger, sender.muxSession, nil)
|
||||
session := mg.newSession(sessionID, cfdConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -215,11 +226,39 @@ func TestZeroBytePayload(t *testing.T) {
|
|||
return nil
|
||||
})
|
||||
|
||||
receivedSessionID, payload, err := transport.respChan.Receive(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, payload, 0)
|
||||
require.Equal(t, sessionID, receivedSessionID)
|
||||
|
||||
<-sender.sentChan
|
||||
cancel()
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
type mockTransportSender struct {
|
||||
expectedSessionID uuid.UUID
|
||||
expectedPayload []byte
|
||||
}
|
||||
|
||||
func newMockTransportSender(expectedSessionID uuid.UUID, expectedPayload []byte) *mockTransportSender {
|
||||
return &mockTransportSender{
|
||||
expectedSessionID: expectedSessionID,
|
||||
expectedPayload: expectedPayload,
|
||||
}
|
||||
}
|
||||
|
||||
func (mts *mockTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
|
||||
if sessionID != mts.expectedSessionID {
|
||||
return fmt.Errorf("Expect session %s, got %s", mts.expectedSessionID, sessionID)
|
||||
}
|
||||
if !bytes.Equal(payload, mts.expectedPayload) {
|
||||
return fmt.Errorf("Expect %v, read %v", mts.expectedPayload, payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type sendOnceTransportSender struct {
|
||||
baseSender *mockTransportSender
|
||||
sentChan chan struct{}
|
||||
}
|
||||
|
||||
func (sots *sendOnceTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
|
||||
defer close(sots.sentChan)
|
||||
return sots.baseSender.muxSession(sessionID, payload)
|
||||
}
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
package datagramsession
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
|
||||
type transport interface {
|
||||
// SendTo writes payload for a session to the transport
|
||||
SendTo(sessionID uuid.UUID, payload []byte) error
|
||||
// ReceiveFrom reads the next datagram from the transport
|
||||
ReceiveFrom() (uuid.UUID, []byte, error)
|
||||
// Max transmission unit to receive from the transport
|
||||
MTU() int
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package datagramsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type mockQUICTransport struct {
|
||||
reqChan *datagramChannel
|
||||
respChan *datagramChannel
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
|
||||
buf := make([]byte, len(payload))
|
||||
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||
copy(buf, payload)
|
||||
return mt.respChan.Send(context.Background(), sessionID, buf)
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
|
||||
return mt.reqChan.Receive(context.Background())
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) MTU() int {
|
||||
return 1280
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
||||
return mt.reqChan.Send(ctx, sessionID, payload)
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) close() {
|
||||
mt.reqChan.Close()
|
||||
mt.respChan.Close()
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -13,53 +14,88 @@ const (
|
|||
sessionIDLen = len(uuid.UUID{})
|
||||
)
|
||||
|
||||
type SessionDatagram struct {
|
||||
ID uuid.UUID
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
type BaseDatagramMuxer interface {
|
||||
// MuxSession suffix the session ID to the payload so the other end of the QUIC connection can demultiplex the
|
||||
// payload from multiple datagram sessions
|
||||
MuxSession(sessionID uuid.UUID, payload []byte) error
|
||||
// ServeReceive starts a loop to receive datagrams from the QUIC connection
|
||||
ServeReceive(ctx context.Context) error
|
||||
}
|
||||
|
||||
type DatagramMuxer struct {
|
||||
session quic.Connection
|
||||
logger *zerolog.Logger
|
||||
session quic.Connection
|
||||
logger *zerolog.Logger
|
||||
demuxChan chan<- *SessionDatagram
|
||||
}
|
||||
|
||||
func NewDatagramMuxer(quicSession quic.Connection, logger *zerolog.Logger) (*DatagramMuxer, error) {
|
||||
func NewDatagramMuxer(quicSession quic.Connection, log *zerolog.Logger, demuxChan chan<- *SessionDatagram) *DatagramMuxer {
|
||||
logger := log.With().Uint8("datagramVersion", 1).Logger()
|
||||
return &DatagramMuxer{
|
||||
session: quicSession,
|
||||
logger: logger,
|
||||
}, nil
|
||||
session: quicSession,
|
||||
logger: &logger,
|
||||
demuxChan: demuxChan,
|
||||
}
|
||||
}
|
||||
|
||||
// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex
|
||||
// the payload from multiple datagram sessions
|
||||
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error {
|
||||
if len(payload) > maxDatagramPayloadSize {
|
||||
// Maximum application payload to send to / receive from QUIC datagram frame
|
||||
func (dm *DatagramMuxer) mtu() int {
|
||||
return maxDatagramPayloadSize
|
||||
}
|
||||
|
||||
func (dm *DatagramMuxer) MuxSession(sessionID uuid.UUID, payload []byte) error {
|
||||
if len(payload) > dm.mtu() {
|
||||
// TODO: TUN-5302 return ICMP packet too big message
|
||||
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.MTU())
|
||||
// drop packet for now, eventually reply with ICMP for PMTUD
|
||||
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
|
||||
}
|
||||
msgWithID, err := suffixSessionID(sessionID, payload)
|
||||
payloadWithMetadata, err := suffixSessionID(sessionID, payload)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
||||
}
|
||||
if err := dm.session.SendMessage(msgWithID); err != nil {
|
||||
if err := dm.session.SendMessage(payloadWithMetadata); err != nil {
|
||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager
|
||||
// which determines how to proxy to the origin. It assumes the datagram session has already been
|
||||
// registered with session manager through other side channel
|
||||
func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
|
||||
msg, err := dm.session.ReceiveMessage()
|
||||
if err != nil {
|
||||
return uuid.Nil, nil, err
|
||||
func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error {
|
||||
for {
|
||||
// Extracts datagram session ID, then sends the session ID and payload to receiver
|
||||
// which determines how to proxy to the origin. It assumes the datagram session has already been
|
||||
// registered with receiver through other side channel
|
||||
msg, err := dm.session.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := dm.demux(ctx, msg); err != nil {
|
||||
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
|
||||
if err == context.Canceled {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
sessionID, payload, err := extractSessionID(msg)
|
||||
if err != nil {
|
||||
return uuid.Nil, nil, err
|
||||
}
|
||||
return sessionID, payload, nil
|
||||
}
|
||||
|
||||
// Maximum application payload to send to / receive from QUIC datagram frame
|
||||
func (dm *DatagramMuxer) MTU() int {
|
||||
return maxDatagramPayloadSize
|
||||
func (dm *DatagramMuxer) demux(ctx context.Context, msg []byte) error {
|
||||
sessionID, payload, err := extractSessionID(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionDatagram := SessionDatagram{
|
||||
ID: sessionID,
|
||||
Payload: payload,
|
||||
}
|
||||
select {
|
||||
case dm.demuxChan <- &sessionDatagram:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Each QUIC datagram should be suffixed with session ID.
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -52,9 +53,29 @@ func TestSuffixSessionIDError(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMaxDatagramPayload(t *testing.T) {
|
||||
payload := make([]byte, maxDatagramPayloadSize)
|
||||
func TestDatagram(t *testing.T) {
|
||||
maxPayload := make([]byte, maxDatagramPayloadSize)
|
||||
noPayloadSession := uuid.New()
|
||||
maxPayloadSession := uuid.New()
|
||||
sessionToPayload := []*SessionDatagram{
|
||||
{
|
||||
ID: noPayloadSession,
|
||||
Payload: make([]byte, 0),
|
||||
},
|
||||
{
|
||||
ID: maxPayloadSession,
|
||||
Payload: maxPayload,
|
||||
},
|
||||
}
|
||||
flowPayloads := [][]byte{
|
||||
maxPayload,
|
||||
}
|
||||
|
||||
testDatagram(t, 1, sessionToPayload, nil)
|
||||
testDatagram(t, 2, sessionToPayload, flowPayloads)
|
||||
}
|
||||
|
||||
func testDatagram(t *testing.T, version uint8, sessionToPayloads []*SessionDatagram, packetPayloads [][]byte) {
|
||||
quicConfig := &quic.Config{
|
||||
KeepAlivePeriod: 5 * time.Millisecond,
|
||||
EnableDatagrams: true,
|
||||
|
@ -63,6 +84,8 @@ func TestMaxDatagramPayload(t *testing.T) {
|
|||
quicListener := newQUICListener(t, quicConfig)
|
||||
defer quicListener.Close()
|
||||
|
||||
logger := zerolog.Nop()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(context.Background())
|
||||
// Run edge side of datagram muxer
|
||||
errGroup.Go(func() error {
|
||||
|
@ -72,22 +95,32 @@ func TestMaxDatagramPayload(t *testing.T) {
|
|||
return err
|
||||
}
|
||||
|
||||
logger := zerolog.Nop()
|
||||
muxer, err := NewDatagramMuxer(quicSession, &logger)
|
||||
if err != nil {
|
||||
return err
|
||||
sessionDemuxChan := make(chan *SessionDatagram, 16)
|
||||
|
||||
switch version {
|
||||
case 1:
|
||||
muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan)
|
||||
muxer.ServeReceive(ctx)
|
||||
case 2:
|
||||
packetDemuxChan := make(chan []byte, len(packetPayloads))
|
||||
muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan)
|
||||
muxer.ServeReceive(ctx)
|
||||
|
||||
for _, expectedPayload := range packetPayloads {
|
||||
require.Equal(t, expectedPayload, <-packetDemuxChan)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown datagram version %d", version)
|
||||
}
|
||||
|
||||
sessionID, receivedPayload, err := muxer.ReceiveFrom()
|
||||
if err != nil {
|
||||
return err
|
||||
for _, expectedPayload := range sessionToPayloads {
|
||||
actualPayload := <-sessionDemuxChan
|
||||
require.Equal(t, expectedPayload, actualPayload)
|
||||
}
|
||||
require.Equal(t, testSessionID, sessionID)
|
||||
require.True(t, bytes.Equal(payload, receivedPayload))
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
largePayload := make([]byte, MaxDatagramFrameSize)
|
||||
// Run cloudflared side of datagram muxer
|
||||
errGroup.Go(func() error {
|
||||
tlsClientConfig := &tls.Config{
|
||||
|
@ -97,24 +130,35 @@ func TestMaxDatagramPayload(t *testing.T) {
|
|||
// Establish quic connection
|
||||
quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := zerolog.Nop()
|
||||
muxer, err := NewDatagramMuxer(quicSession, &logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer quicSession.CloseWithError(0, "")
|
||||
|
||||
// Wait a few milliseconds for MTU discovery to take place
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
err = muxer.SendTo(testSessionID, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
var muxer BaseDatagramMuxer
|
||||
switch version {
|
||||
case 1:
|
||||
muxer = NewDatagramMuxer(quicSession, &logger, nil)
|
||||
case 2:
|
||||
muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil)
|
||||
for _, payload := range packetPayloads {
|
||||
require.NoError(t, muxerV2.MuxPacket(payload))
|
||||
}
|
||||
// Payload larger than transport MTU, should not be sent
|
||||
require.Error(t, muxerV2.MuxPacket(largePayload))
|
||||
muxer = muxerV2
|
||||
default:
|
||||
return fmt.Errorf("unknown datagram version %d", version)
|
||||
}
|
||||
|
||||
// Payload larger than transport MTU, should return an error
|
||||
largePayload := make([]byte, MaxDatagramFrameSize)
|
||||
err = muxer.SendTo(testSessionID, largePayload)
|
||||
require.Error(t, err)
|
||||
for _, sessionDatagram := range sessionToPayloads {
|
||||
require.NoError(t, muxer.MuxSession(sessionDatagram.ID, sessionDatagram.Payload))
|
||||
}
|
||||
// Payload larger than transport MTU, should not be sent
|
||||
require.Error(t, muxer.MuxSession(testSessionID, largePayload))
|
||||
|
||||
// Wait for edge to finish receiving the messages
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
@ -154,3 +198,35 @@ func generateTLSConfig() *tls.Config {
|
|||
NextProtos: []string{"argotunnel"},
|
||||
}
|
||||
}
|
||||
|
||||
type sessionMuxer interface {
|
||||
SendToSession(sessionID uuid.UUID, payload []byte) error
|
||||
}
|
||||
|
||||
type mockSessionReceiver struct {
|
||||
expectedSessionToPayload map[uuid.UUID][]byte
|
||||
receivedCount int
|
||||
}
|
||||
|
||||
func (msr *mockSessionReceiver) ReceiveDatagram(sessionID uuid.UUID, payload []byte) error {
|
||||
expectedPayload := msr.expectedSessionToPayload[sessionID]
|
||||
if !bytes.Equal(expectedPayload, payload) {
|
||||
return fmt.Errorf("expect %v to have payload %s, got %s", sessionID, string(expectedPayload), string(payload))
|
||||
}
|
||||
msr.receivedCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockFlowReceiver struct {
|
||||
expectedPayloads [][]byte
|
||||
receivedCount int
|
||||
}
|
||||
|
||||
func (mfr *mockFlowReceiver) ReceiveFlow(payload []byte) error {
|
||||
expectedPayload := mfr.expectedPayloads[mfr.receivedCount]
|
||||
if !bytes.Equal(expectedPayload, payload) {
|
||||
return fmt.Errorf("expect flow %d to have payload %s, got %s", mfr.receivedCount, string(expectedPayload), string(payload))
|
||||
}
|
||||
mfr.receivedCount++
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type datagramV2Type byte
|
||||
|
||||
const (
|
||||
udp datagramV2Type = iota
|
||||
ip
|
||||
)
|
||||
|
||||
func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) {
|
||||
if len(b)+1 > MaxDatagramFrameSize {
|
||||
return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize)
|
||||
}
|
||||
b = append(b, byte(datagramType))
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Maximum application payload to send to / receive from QUIC datagram frame
|
||||
func (dm *DatagramMuxerV2) mtu() int {
|
||||
return maxDatagramPayloadSize
|
||||
}
|
||||
|
||||
type DatagramMuxerV2 struct {
|
||||
session quic.Connection
|
||||
logger *zerolog.Logger
|
||||
sessionDemuxChan chan<- *SessionDatagram
|
||||
packetDemuxChan chan<- []byte
|
||||
}
|
||||
|
||||
func NewDatagramMuxerV2(
|
||||
quicSession quic.Connection,
|
||||
log *zerolog.Logger,
|
||||
sessionDemuxChan chan<- *SessionDatagram,
|
||||
packetDemuxChan chan<- []byte) *DatagramMuxerV2 {
|
||||
logger := log.With().Uint8("datagramVersion", 2).Logger()
|
||||
return &DatagramMuxerV2{
|
||||
session: quicSession,
|
||||
logger: &logger,
|
||||
sessionDemuxChan: sessionDemuxChan,
|
||||
packetDemuxChan: packetDemuxChan,
|
||||
}
|
||||
}
|
||||
|
||||
// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
|
||||
// demultiplex the payload from multiple datagram sessions
|
||||
func (dm *DatagramMuxerV2) MuxSession(sessionID uuid.UUID, payload []byte) error {
|
||||
if len(payload) > dm.mtu() {
|
||||
// TODO: TUN-5302 return ICMP packet too big message
|
||||
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
|
||||
}
|
||||
msgWithID, err := suffixSessionID(sessionID, payload)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
||||
}
|
||||
msgWithIDAndType, err := suffixType(msgWithID, udp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
||||
}
|
||||
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
|
||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing
|
||||
// the payload as IP and look at the source and destination.
|
||||
func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error {
|
||||
payloadWithVersion, err := suffixType(packet, ip)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
||||
}
|
||||
if err := dm.session.SendMessage(payloadWithVersion); err != nil {
|
||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
|
||||
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
|
||||
for {
|
||||
msg, err := dm.session.ReceiveMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := dm.demux(ctx, msg); err != nil {
|
||||
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
|
||||
if err == context.Canceled {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
|
||||
if len(msgWithType) < 1 {
|
||||
return fmt.Errorf("QUIC datagram should have at least 1 byte")
|
||||
}
|
||||
msgType := datagramV2Type(msgWithType[len(msgWithType)-1])
|
||||
msg := msgWithType[0 : len(msgWithType)-1]
|
||||
switch msgType {
|
||||
case udp:
|
||||
sessionID, payload, err := extractSessionID(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionDatagram := SessionDatagram{
|
||||
ID: sessionID,
|
||||
Payload: payload,
|
||||
}
|
||||
select {
|
||||
case dm.sessionDemuxChan <- &sessionDatagram:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
case ip:
|
||||
select {
|
||||
case dm.packetDemuxChan <- msg:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Unexpected datagram type %d", msgType)
|
||||
}
|
||||
}
|
|
@ -33,6 +33,7 @@ const (
|
|||
FeatureSerializedHeaders = "serialized_headers"
|
||||
FeatureQuickReconnects = "quick_reconnects"
|
||||
FeatureAllowRemoteConfig = "allow_remote_config"
|
||||
FeatureDatagramV2 = "support_datagram_v2"
|
||||
)
|
||||
|
||||
type TunnelConfig struct {
|
||||
|
|
Loading…
Reference in New Issue