TUN-6584: Define QUIC datagram v2 format to support proxying IP packets
This commit is contained in:
parent
d3fd581b7b
commit
278df5478a
|
@ -33,6 +33,8 @@ const (
|
||||||
HTTPHostKey = "HttpHost"
|
HTTPHostKey = "HttpHost"
|
||||||
|
|
||||||
QUICMetadataFlowID = "FlowID"
|
QUICMetadataFlowID = "FlowID"
|
||||||
|
// emperically this capacity has been working well
|
||||||
|
demuxChanCapacity = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
||||||
|
@ -40,7 +42,10 @@ type QUICConnection struct {
|
||||||
session quic.Connection
|
session quic.Connection
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
orchestrator Orchestrator
|
orchestrator Orchestrator
|
||||||
|
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||||
sessionManager datagramsession.Manager
|
sessionManager datagramsession.Manager
|
||||||
|
// datagramMuxer mux/demux datagrams from quic connection
|
||||||
|
datagramMuxer *quicpogs.DatagramMuxer
|
||||||
controlStreamHandler ControlStreamHandler
|
controlStreamHandler ControlStreamHandler
|
||||||
connOptions *tunnelpogs.ConnectionOptions
|
connOptions *tunnelpogs.ConnectionOptions
|
||||||
}
|
}
|
||||||
|
@ -60,18 +65,16 @@ func NewQUICConnection(
|
||||||
return nil, &EdgeQuicDialError{Cause: err}
|
return nil, &EdgeQuicDialError{Cause: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
datagramMuxer, err := quicpogs.NewDatagramMuxer(session, logger)
|
demuxChan := make(chan *quicpogs.SessionDatagram, demuxChanCapacity)
|
||||||
if err != nil {
|
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan)
|
||||||
return nil, err
|
sessionManager := datagramsession.NewManager(logger, datagramMuxer.MuxSession, demuxChan)
|
||||||
}
|
|
||||||
|
|
||||||
sessionManager := datagramsession.NewManager(datagramMuxer, logger)
|
|
||||||
|
|
||||||
return &QUICConnection{
|
return &QUICConnection{
|
||||||
session: session,
|
session: session,
|
||||||
orchestrator: orchestrator,
|
orchestrator: orchestrator,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
sessionManager: sessionManager,
|
sessionManager: sessionManager,
|
||||||
|
datagramMuxer: datagramMuxer,
|
||||||
controlStreamHandler: controlStreamHandler,
|
controlStreamHandler: controlStreamHandler,
|
||||||
connOptions: connOptions,
|
connOptions: connOptions,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -108,6 +111,11 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||||
return q.sessionManager.Serve(ctx)
|
return q.sessionManager.Serve(ctx)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
defer cancel()
|
||||||
|
return q.datagramMuxer.ServeReceive(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,9 +42,3 @@ func (sc *errClosedSession) Error() string {
|
||||||
return fmt.Sprintf("session closed by local due to %s", sc.message)
|
return fmt.Sprintf("session closed by local due to %s", sc.message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDatagram is an event when transport receives new datagram
|
|
||||||
type newDatagram struct {
|
|
||||||
sessionID uuid.UUID
|
|
||||||
payload []byte
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,9 +7,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -36,23 +36,22 @@ type Manager interface {
|
||||||
type manager struct {
|
type manager struct {
|
||||||
registrationChan chan *registerSessionEvent
|
registrationChan chan *registerSessionEvent
|
||||||
unregistrationChan chan *unregisterSessionEvent
|
unregistrationChan chan *unregisterSessionEvent
|
||||||
datagramChan chan *newDatagram
|
sendFunc transportSender
|
||||||
closedChan chan struct{}
|
receiveChan <-chan *quicpogs.SessionDatagram
|
||||||
transport transport
|
closedChan <-chan struct{}
|
||||||
sessions map[uuid.UUID]*Session
|
sessions map[uuid.UUID]*Session
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
// timeout waiting for an API to finish. This can be overriden in test
|
// timeout waiting for an API to finish. This can be overriden in test
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(transport transport, log *zerolog.Logger) *manager {
|
func NewManager(log *zerolog.Logger, sendF transportSender, receiveChan <-chan *quicpogs.SessionDatagram) *manager {
|
||||||
return &manager{
|
return &manager{
|
||||||
registrationChan: make(chan *registerSessionEvent),
|
registrationChan: make(chan *registerSessionEvent),
|
||||||
unregistrationChan: make(chan *unregisterSessionEvent),
|
unregistrationChan: make(chan *unregisterSessionEvent),
|
||||||
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
|
sendFunc: sendF,
|
||||||
datagramChan: make(chan *newDatagram, requestChanCapacity),
|
receiveChan: receiveChan,
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
transport: transport,
|
|
||||||
sessions: make(map[uuid.UUID]*Session),
|
sessions: make(map[uuid.UUID]*Session),
|
||||||
log: log,
|
log: log,
|
||||||
timeout: defaultReqTimeout,
|
timeout: defaultReqTimeout,
|
||||||
|
@ -65,49 +64,21 @@ func (m *manager) UpdateLogger(log *zerolog.Logger) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) Serve(ctx context.Context) error {
|
func (m *manager) Serve(ctx context.Context) error {
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
|
||||||
errGroup.Go(func() error {
|
|
||||||
for {
|
for {
|
||||||
sessionID, payload, err := m.transport.ReceiveFrom()
|
|
||||||
if err != nil {
|
|
||||||
if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) {
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
datagram := &newDatagram{
|
|
||||||
sessionID: sessionID,
|
|
||||||
payload: payload,
|
|
||||||
}
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
m.shutdownSessions(ctx.Err())
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
|
// receiveChan is buffered, so the transport can read more datagrams from transport while the event loop is
|
||||||
// Send the datagram to the event loop. It will find the session to send to
|
// processing other events
|
||||||
case m.datagramChan <- datagram:
|
case datagram := <-m.receiveChan:
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
errGroup.Go(func() error {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
case datagram := <-m.datagramChan:
|
|
||||||
m.sendToSession(datagram)
|
m.sendToSession(datagram)
|
||||||
case registration := <-m.registrationChan:
|
case registration := <-m.registrationChan:
|
||||||
m.registerSession(ctx, registration)
|
m.registerSession(ctx, registration)
|
||||||
// TODO: TUN-5422: Unregister inactive session upon timeout
|
|
||||||
case unregistration := <-m.unregistrationChan:
|
case unregistration := <-m.unregistrationChan:
|
||||||
m.unregisterSession(unregistration)
|
m.unregisterSession(unregistration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
err := errGroup.Wait()
|
|
||||||
close(m.closedChan)
|
|
||||||
m.shutdownSessions(err)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) shutdownSessions(err error) {
|
func (m *manager) shutdownSessions(err error) {
|
||||||
|
@ -149,16 +120,17 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
|
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
|
||||||
|
logger := m.log.With().Str("sessionID", id.String()).Logger()
|
||||||
return &Session{
|
return &Session{
|
||||||
ID: id,
|
ID: id,
|
||||||
transport: m.transport,
|
sendFunc: m.sendFunc,
|
||||||
dstConn: dstConn,
|
dstConn: dstConn,
|
||||||
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
||||||
// drop instead of blocking because last active time only needs to be an approximation
|
// drop instead of blocking because last active time only needs to be an approximation
|
||||||
activeAtChan: make(chan time.Time, 2),
|
activeAtChan: make(chan time.Time, 2),
|
||||||
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
|
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
|
||||||
closeChan: make(chan error, 2),
|
closeChan: make(chan error, 2),
|
||||||
log: m.log,
|
log: &logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,16 +163,13 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *manager) sendToSession(datagram *newDatagram) {
|
func (m *manager) sendToSession(datagram *quicpogs.SessionDatagram) {
|
||||||
session, ok := m.sessions[datagram.sessionID]
|
session, ok := m.sessions[datagram.ID]
|
||||||
if !ok {
|
if !ok {
|
||||||
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found")
|
m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
||||||
// need to run in another go routine
|
// need to run in another go routine
|
||||||
_, err := session.transportToDst(datagram.payload)
|
session.transportToDst(datagram.Payload)
|
||||||
if err != nil {
|
|
||||||
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,22 +14,30 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
nopLogger = zerolog.Nop()
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestManagerServe(t *testing.T) {
|
func TestManagerServe(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
sessions = 20
|
sessions = 2
|
||||||
msgs = 50
|
msgs = 5
|
||||||
remoteUnregisterMsg = "eyeball closed connection"
|
remoteUnregisterMsg = "eyeball closed connection"
|
||||||
)
|
)
|
||||||
|
|
||||||
mg, transport := newTestManager(1)
|
requestChan := make(chan *quicpogs.SessionDatagram)
|
||||||
|
transport := mockQUICTransport{
|
||||||
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
|
sessions: make(map[uuid.UUID]chan []byte),
|
||||||
for i := 0; i < sessions; i++ {
|
|
||||||
sessionID := uuid.New()
|
|
||||||
eyeballTracker[sessionID] = newDatagramChannel(1)
|
|
||||||
}
|
}
|
||||||
|
for i := 0; i < sessions; i++ {
|
||||||
|
transport.sessions[uuid.New()] = make(chan []byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
mg := NewManager(&nopLogger, transport.MuxSession, requestChan)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
serveDone := make(chan struct{})
|
serveDone := make(chan struct{})
|
||||||
|
@ -38,25 +46,11 @@ func TestManagerServe(t *testing.T) {
|
||||||
close(serveDone)
|
close(serveDone)
|
||||||
}(ctx)
|
}(ctx)
|
||||||
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
for {
|
|
||||||
sessionID, payload, err := transport.respChan.Receive(ctx)
|
|
||||||
if err != nil {
|
|
||||||
require.Equal(t, context.Canceled, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
respChan := eyeballTracker[sessionID]
|
|
||||||
require.NoError(t, respChan.Send(ctx, sessionID, payload))
|
|
||||||
}
|
|
||||||
}(ctx)
|
|
||||||
|
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
for sID, receiver := range eyeballTracker {
|
for sessionID, eyeballRespChan := range transport.sessions {
|
||||||
// Assign loop variables to local variables
|
// Assign loop variables to local variables
|
||||||
sessionID := sID
|
sID := sessionID
|
||||||
eyeballRespReceiver := receiver
|
payload := testPayload(sID)
|
||||||
errGroup.Go(func() error {
|
|
||||||
payload := testPayload(sessionID)
|
|
||||||
expectResp := testResponse(payload)
|
expectResp := testResponse(payload)
|
||||||
|
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
|
@ -67,24 +61,27 @@ func TestManagerServe(t *testing.T) {
|
||||||
expectedResp: expectResp,
|
expectedResp: expectResp,
|
||||||
conn: originConn,
|
conn: originConn,
|
||||||
}
|
}
|
||||||
eyeball := mockEyeball{
|
|
||||||
expectMsgCount: msgs,
|
eyeball := mockEyeballSession{
|
||||||
expectedMsg: expectResp,
|
id: sID,
|
||||||
expectSessionID: sessionID,
|
expectedMsgCount: msgs,
|
||||||
respReceiver: eyeballRespReceiver,
|
expectedMsg: payload,
|
||||||
|
expectedResponse: expectResp,
|
||||||
|
respReceiver: eyeballRespChan,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assign loop variables to local variables
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
session, err := mg.RegisterSession(ctx, sID, cfdConn)
|
||||||
|
require.NoError(t, err)
|
||||||
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
|
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
|
||||||
reqErrGroup.Go(func() error {
|
reqErrGroup.Go(func() error {
|
||||||
return origin.serve()
|
return origin.serve()
|
||||||
})
|
})
|
||||||
reqErrGroup.Go(func() error {
|
reqErrGroup.Go(func() error {
|
||||||
return eyeball.serve(reqCtx)
|
return eyeball.serve(reqCtx, requestChan)
|
||||||
})
|
})
|
||||||
|
|
||||||
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
sessionDone := make(chan struct{})
|
sessionDone := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
closedByRemote, err := session.Serve(ctx, time.Minute*2)
|
closedByRemote, err := session.Serve(ctx, time.Minute*2)
|
||||||
|
@ -97,23 +94,17 @@ func TestManagerServe(t *testing.T) {
|
||||||
close(sessionDone)
|
close(sessionDone)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i < msgs; i++ {
|
|
||||||
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure eyeball and origin have received all messages before unregistering the session
|
// Make sure eyeball and origin have received all messages before unregistering the session
|
||||||
require.NoError(t, reqErrGroup.Wait())
|
require.NoError(t, reqErrGroup.Wait())
|
||||||
|
|
||||||
require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true))
|
require.NoError(t, mg.UnregisterSession(ctx, sID, remoteUnregisterMsg, true))
|
||||||
<-sessionDone
|
<-sessionDone
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
cancel()
|
cancel()
|
||||||
transport.close()
|
|
||||||
<-serveDone
|
<-serveDone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,7 +113,7 @@ func TestTimeout(t *testing.T) {
|
||||||
testTimeout = time.Millisecond * 50
|
testTimeout = time.Millisecond * 50
|
||||||
)
|
)
|
||||||
|
|
||||||
mg, _ := newTestManager(1)
|
mg := NewManager(&nopLogger, nil, nil)
|
||||||
mg.timeout = testTimeout
|
mg.timeout = testTimeout
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
sessionID := uuid.New()
|
sessionID := uuid.New()
|
||||||
|
@ -135,9 +126,51 @@ func TestTimeout(t *testing.T) {
|
||||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseTransportCloseSessions(t *testing.T) {
|
func TestUnregisterSessionCloseSession(t *testing.T) {
|
||||||
mg, transport := newTestManager(1)
|
sessionID := uuid.New()
|
||||||
ctx := context.Background()
|
payload := []byte(t.Name())
|
||||||
|
sender := newMockTransportSender(sessionID, payload)
|
||||||
|
mg := NewManager(&nopLogger, sender.muxSession, nil)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
managerDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
err := mg.Serve(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
close(managerDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfdConn, originConn := net.Pipe()
|
||||||
|
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, session)
|
||||||
|
|
||||||
|
unregisteredChan := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_, err := originConn.Write(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = mg.UnregisterSession(ctx, sessionID, "eyeball closed session", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
close(unregisteredChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
closedByRemote, err := session.Serve(ctx, time.Minute)
|
||||||
|
require.True(t, closedByRemote)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
<-unregisteredChan
|
||||||
|
cancel()
|
||||||
|
<-managerDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerCtxDoneCloseSessions(t *testing.T) {
|
||||||
|
sessionID := uuid.New()
|
||||||
|
payload := []byte(t.Name())
|
||||||
|
sender := newMockTransportSender(sessionID, payload)
|
||||||
|
mg := NewManager(&nopLogger, sender.muxSession, nil)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -147,35 +180,26 @@ func TestCloseTransportCloseSessions(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cfdConn, eyeballConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn)
|
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, session)
|
require.NotNil(t, session)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
_, err := eyeballConn.Write([]byte(t.Name()))
|
_, err := originConn.Write(payload)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
transport.close()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
closedByRemote, err := session.Serve(ctx, time.Minute)
|
closedByRemote, err := session.Serve(ctx, time.Minute)
|
||||||
require.True(t, closedByRemote)
|
require.False(t, closedByRemote)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
|
|
||||||
log := zerolog.Nop()
|
|
||||||
transport := &mockQUICTransport{
|
|
||||||
reqChan: newDatagramChannel(capacity),
|
|
||||||
respChan: newDatagramChannel(capacity),
|
|
||||||
}
|
|
||||||
return NewManager(transport, &log), transport
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockOrigin struct {
|
type mockOrigin struct {
|
||||||
expectMsgCount int
|
expectMsgCount int
|
||||||
expectedMsg []byte
|
expectedMsg []byte
|
||||||
|
@ -197,7 +221,6 @@ func (mo *mockOrigin) serve() error {
|
||||||
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
|
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
|
||||||
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
|
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = mo.conn.Write(mo.expectedResp)
|
_, err = mo.conn.Write(mo.expectedResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -214,72 +237,35 @@ func testResponse(msg []byte) []byte {
|
||||||
return []byte(fmt.Sprintf("Response to %v", msg))
|
return []byte(fmt.Sprintf("Response to %v", msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockEyeball struct {
|
type mockQUICTransport struct {
|
||||||
expectMsgCount int
|
sessions map[uuid.UUID]chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (me *mockQUICTransport) MuxSession(id uuid.UUID, payload []byte) error {
|
||||||
|
session := me.sessions[id]
|
||||||
|
session <- payload
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockEyeballSession struct {
|
||||||
|
id uuid.UUID
|
||||||
|
expectedMsgCount int
|
||||||
expectedMsg []byte
|
expectedMsg []byte
|
||||||
expectSessionID uuid.UUID
|
expectedResponse []byte
|
||||||
respReceiver *datagramChannel
|
respReceiver <-chan []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (me *mockEyeball) serve(ctx context.Context) error {
|
func (me *mockEyeballSession) serve(ctx context.Context, requestChan chan *quicpogs.SessionDatagram) error {
|
||||||
for i := 0; i < me.expectMsgCount; i++ {
|
for i := 0; i < me.expectedMsgCount; i++ {
|
||||||
sessionID, msg, err := me.respReceiver.Receive(ctx)
|
requestChan <- &quicpogs.SessionDatagram{
|
||||||
if err != nil {
|
ID: me.id,
|
||||||
return err
|
Payload: me.expectedMsg,
|
||||||
}
|
}
|
||||||
if sessionID != me.expectSessionID {
|
resp := <-me.respReceiver
|
||||||
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID)
|
if !bytes.Equal(resp, me.expectedResponse) {
|
||||||
}
|
return fmt.Errorf("Expect %v, read %v", me.expectedResponse, resp)
|
||||||
if !bytes.Equal(msg, me.expectedMsg) {
|
|
||||||
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
|
|
||||||
}
|
}
|
||||||
|
fmt.Println("Resp", resp)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
|
|
||||||
type datagramChannel struct {
|
|
||||||
datagramChan chan *newDatagram
|
|
||||||
closedChan chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newDatagramChannel(capacity uint) *datagramChannel {
|
|
||||||
return &datagramChannel{
|
|
||||||
datagramChan: make(chan *newDatagram, capacity),
|
|
||||||
closedChan: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-rc.closedChan:
|
|
||||||
return &errClosedSession{
|
|
||||||
message: fmt.Errorf("datagram channel closed").Error(),
|
|
||||||
byRemote: true,
|
|
||||||
}
|
|
||||||
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return uuid.Nil, nil, ctx.Err()
|
|
||||||
case <-rc.closedChan:
|
|
||||||
err := &errClosedSession{
|
|
||||||
message: fmt.Errorf("datagram channel closed").Error(),
|
|
||||||
byRemote: true,
|
|
||||||
}
|
|
||||||
return uuid.Nil, nil, err
|
|
||||||
case msg := <-rc.datagramChan:
|
|
||||||
return msg.sessionID, msg.payload, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rc *datagramChannel) Close() {
|
|
||||||
// No need to close msgChan, it will be garbage collect once there is no reference to it
|
|
||||||
close(rc.closedChan)
|
|
||||||
}
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -18,20 +19,19 @@ func SessionIdleErr(timeout time.Duration) error {
|
||||||
return fmt.Errorf("session idle for %v", timeout)
|
return fmt.Errorf("session idle for %v", timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type transportSender func(sessionID uuid.UUID, payload []byte) error
|
||||||
|
|
||||||
// Session is a bidirectional pipe of datagrams between transport and dstConn
|
// Session is a bidirectional pipe of datagrams between transport and dstConn
|
||||||
// Currently the only implementation of transport is quic DatagramMuxer
|
|
||||||
// Destination can be a connection with origin or with eyeball
|
// Destination can be a connection with origin or with eyeball
|
||||||
// When the destination is origin:
|
// When the destination is origin:
|
||||||
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to origin
|
||||||
// write method of the Session to send to origin
|
// - Datagrams from origin are read from conn and Send to transport using the transportSender callback. Transport will return them to eyeball
|
||||||
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
|
|
||||||
// When the destination is eyeball:
|
// When the destination is eyeball:
|
||||||
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
|
// - Datagrams from eyeball are read from conn and Send to transport. Transport will send them to cloudflared using the transportSender callback.
|
||||||
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
|
// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to the eyeball
|
||||||
// write method of the Session to send to eyeball
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID uuid.UUID
|
ID uuid.UUID
|
||||||
transport transport
|
sendFunc transportSender
|
||||||
dstConn io.ReadWriteCloser
|
dstConn io.ReadWriteCloser
|
||||||
// activeAtChan is used to communicate the last read/write time
|
// activeAtChan is used to communicate the last read/write time
|
||||||
activeAtChan chan time.Time
|
activeAtChan chan time.Time
|
||||||
|
@ -46,11 +46,18 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos
|
||||||
const maxPacketSize = 1500
|
const maxPacketSize = 1500
|
||||||
readBuffer := make([]byte, maxPacketSize)
|
readBuffer := make([]byte, maxPacketSize)
|
||||||
for {
|
for {
|
||||||
if err := s.dstToTransport(readBuffer); err != nil {
|
if closeSession, err := s.dstToTransport(readBuffer); err != nil {
|
||||||
|
if err != net.ErrClosed {
|
||||||
|
s.log.Error().Err(err).Msg("Failed to send session payload from destination to transport")
|
||||||
|
} else {
|
||||||
|
s.log.Debug().Msg("Session cannot read from destination because the connection is closed")
|
||||||
|
}
|
||||||
|
if closeSession {
|
||||||
s.closeChan <- err
|
s.closeChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
err = s.waitForCloseCondition(ctx, closeAfterIdle)
|
err = s.waitForCloseCondition(ctx, closeAfterIdle)
|
||||||
if closeSession, ok := err.(*errClosedSession); ok {
|
if closeSession, ok := err.(*errClosedSession); ok {
|
||||||
|
@ -89,32 +96,25 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) dstToTransport(buffer []byte) error {
|
func (s *Session) dstToTransport(buffer []byte) (closeSession bool, err error) {
|
||||||
n, err := s.dstConn.Read(buffer)
|
n, err := s.dstConn.Read(buffer)
|
||||||
s.markActive()
|
s.markActive()
|
||||||
// https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes
|
// https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes
|
||||||
if n > 0 {
|
if n > 0 || err == nil {
|
||||||
if n <= int(s.transport.MTU()) {
|
if sendErr := s.sendFunc(s.ID, buffer[:n]); sendErr != nil {
|
||||||
err = s.transport.SendTo(s.ID, buffer[:n])
|
return false, sendErr
|
||||||
} else {
|
|
||||||
// drop packet for now, eventually reply with ICMP for PMTUD
|
|
||||||
s.log.Debug().
|
|
||||||
Str("session", s.ID.String()).
|
|
||||||
Int("len", n).
|
|
||||||
Int("mtu", s.transport.MTU()).
|
|
||||||
Msg("dropped packet exceeding MTU")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Some UDP application might send 0-size payload.
|
return err != nil, err
|
||||||
if err == nil && n == 0 {
|
|
||||||
err = s.transport.SendTo(s.ID, []byte{})
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) transportToDst(payload []byte) (int, error) {
|
func (s *Session) transportToDst(payload []byte) (int, error) {
|
||||||
s.markActive()
|
s.markActive()
|
||||||
return s.dstConn.Write(payload)
|
n, err := s.dstConn.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
s.log.Err(err).Msg("Failed to write payload to session")
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
|
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
|
||||||
|
|
|
@ -11,8 +11,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestCloseSession makes sure a session will stop after context is done
|
// TestCloseSession makes sure a session will stop after context is done
|
||||||
|
@ -41,7 +44,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
payload := testPayload(sessionID)
|
payload := testPayload(sessionID)
|
||||||
|
|
||||||
mg, _ := newTestManager(1)
|
log := zerolog.Nop()
|
||||||
|
mg := NewManager(&log, nil, nil)
|
||||||
session := mg.newSession(sessionID, cfdConn)
|
session := mg.newSession(sessionID, cfdConn)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -114,7 +118,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
payload := testPayload(sessionID)
|
payload := testPayload(sessionID)
|
||||||
|
|
||||||
mg, _ := newTestManager(100)
|
respChan := make(chan *quicpogs.SessionDatagram)
|
||||||
|
sender := newMockTransportSender(sessionID, payload)
|
||||||
|
mg := NewManager(&nopLogger, sender.muxSession, respChan)
|
||||||
session := mg.newSession(sessionID, cfdConn)
|
session := mg.newSession(sessionID, cfdConn)
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
@ -177,7 +183,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
|
||||||
|
|
||||||
func TestMarkActiveNotBlocking(t *testing.T) {
|
func TestMarkActiveNotBlocking(t *testing.T) {
|
||||||
const concurrentCalls = 50
|
const concurrentCalls = 50
|
||||||
mg, _ := newTestManager(1)
|
mg := NewManager(&nopLogger, nil, nil)
|
||||||
session := mg.newSession(uuid.New(), nil)
|
session := mg.newSession(uuid.New(), nil)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(concurrentCalls)
|
wg.Add(concurrentCalls)
|
||||||
|
@ -190,11 +196,16 @@ func TestMarkActiveNotBlocking(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Some UDP application might send 0-size payload.
|
||||||
func TestZeroBytePayload(t *testing.T) {
|
func TestZeroBytePayload(t *testing.T) {
|
||||||
sessionID := uuid.New()
|
sessionID := uuid.New()
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
|
|
||||||
mg, transport := newTestManager(1)
|
sender := sendOnceTransportSender{
|
||||||
|
baseSender: newMockTransportSender(sessionID, make([]byte, 0)),
|
||||||
|
sentChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
mg := NewManager(&nopLogger, sender.muxSession, nil)
|
||||||
session := mg.newSession(sessionID, cfdConn)
|
session := mg.newSession(sessionID, cfdConn)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -215,11 +226,39 @@ func TestZeroBytePayload(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
receivedSessionID, payload, err := transport.respChan.Receive(ctx)
|
<-sender.sentChan
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, payload, 0)
|
|
||||||
require.Equal(t, sessionID, receivedSessionID)
|
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockTransportSender struct {
|
||||||
|
expectedSessionID uuid.UUID
|
||||||
|
expectedPayload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockTransportSender(expectedSessionID uuid.UUID, expectedPayload []byte) *mockTransportSender {
|
||||||
|
return &mockTransportSender{
|
||||||
|
expectedSessionID: expectedSessionID,
|
||||||
|
expectedPayload: expectedPayload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mts *mockTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
if sessionID != mts.expectedSessionID {
|
||||||
|
return fmt.Errorf("Expect session %s, got %s", mts.expectedSessionID, sessionID)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(payload, mts.expectedPayload) {
|
||||||
|
return fmt.Errorf("Expect %v, read %v", mts.expectedPayload, payload)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sendOnceTransportSender struct {
|
||||||
|
baseSender *mockTransportSender
|
||||||
|
sentChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sots *sendOnceTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
defer close(sots.sentChan)
|
||||||
|
return sots.baseSender.muxSession(sessionID, payload)
|
||||||
|
}
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
package datagramsession
|
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
|
||||||
|
|
||||||
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
|
|
||||||
type transport interface {
|
|
||||||
// SendTo writes payload for a session to the transport
|
|
||||||
SendTo(sessionID uuid.UUID, payload []byte) error
|
|
||||||
// ReceiveFrom reads the next datagram from the transport
|
|
||||||
ReceiveFrom() (uuid.UUID, []byte, error)
|
|
||||||
// Max transmission unit to receive from the transport
|
|
||||||
MTU() int
|
|
||||||
}
|
|
|
@ -1,36 +0,0 @@
|
||||||
package datagramsession
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockQUICTransport struct {
|
|
||||||
reqChan *datagramChannel
|
|
||||||
respChan *datagramChannel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
|
|
||||||
buf := make([]byte, len(payload))
|
|
||||||
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
|
||||||
copy(buf, payload)
|
|
||||||
return mt.respChan.Send(context.Background(), sessionID, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
|
|
||||||
return mt.reqChan.Receive(context.Background())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mt *mockQUICTransport) MTU() int {
|
|
||||||
return 1280
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
|
||||||
return mt.reqChan.Send(ctx, sessionID, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mt *mockQUICTransport) close() {
|
|
||||||
mt.reqChan.Close()
|
|
||||||
mt.respChan.Close()
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -13,53 +14,88 @@ const (
|
||||||
sessionIDLen = len(uuid.UUID{})
|
sessionIDLen = len(uuid.UUID{})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type SessionDatagram struct {
|
||||||
|
ID uuid.UUID
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseDatagramMuxer interface {
|
||||||
|
// MuxSession suffix the session ID to the payload so the other end of the QUIC connection can demultiplex the
|
||||||
|
// payload from multiple datagram sessions
|
||||||
|
MuxSession(sessionID uuid.UUID, payload []byte) error
|
||||||
|
// ServeReceive starts a loop to receive datagrams from the QUIC connection
|
||||||
|
ServeReceive(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
type DatagramMuxer struct {
|
type DatagramMuxer struct {
|
||||||
session quic.Connection
|
session quic.Connection
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
|
demuxChan chan<- *SessionDatagram
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatagramMuxer(quicSession quic.Connection, logger *zerolog.Logger) (*DatagramMuxer, error) {
|
func NewDatagramMuxer(quicSession quic.Connection, log *zerolog.Logger, demuxChan chan<- *SessionDatagram) *DatagramMuxer {
|
||||||
|
logger := log.With().Uint8("datagramVersion", 1).Logger()
|
||||||
return &DatagramMuxer{
|
return &DatagramMuxer{
|
||||||
session: quicSession,
|
session: quicSession,
|
||||||
logger: logger,
|
logger: &logger,
|
||||||
}, nil
|
demuxChan: demuxChan,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex
|
// Maximum application payload to send to / receive from QUIC datagram frame
|
||||||
// the payload from multiple datagram sessions
|
func (dm *DatagramMuxer) mtu() int {
|
||||||
func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error {
|
return maxDatagramPayloadSize
|
||||||
if len(payload) > maxDatagramPayloadSize {
|
|
||||||
// TODO: TUN-5302 return ICMP packet too big message
|
|
||||||
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.MTU())
|
|
||||||
}
|
}
|
||||||
msgWithID, err := suffixSessionID(sessionID, payload)
|
|
||||||
|
func (dm *DatagramMuxer) MuxSession(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
if len(payload) > dm.mtu() {
|
||||||
|
// TODO: TUN-5302 return ICMP packet too big message
|
||||||
|
// drop packet for now, eventually reply with ICMP for PMTUD
|
||||||
|
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
|
||||||
|
}
|
||||||
|
payloadWithMetadata, err := suffixSessionID(sessionID, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
||||||
}
|
}
|
||||||
if err := dm.session.SendMessage(msgWithID); err != nil {
|
if err := dm.session.SendMessage(payloadWithMetadata); err != nil {
|
||||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager
|
func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error {
|
||||||
|
for {
|
||||||
|
// Extracts datagram session ID, then sends the session ID and payload to receiver
|
||||||
// which determines how to proxy to the origin. It assumes the datagram session has already been
|
// which determines how to proxy to the origin. It assumes the datagram session has already been
|
||||||
// registered with session manager through other side channel
|
// registered with receiver through other side channel
|
||||||
func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
|
|
||||||
msg, err := dm.session.ReceiveMessage()
|
msg, err := dm.session.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return uuid.Nil, nil, err
|
return err
|
||||||
|
}
|
||||||
|
if err := dm.demux(ctx, msg); err != nil {
|
||||||
|
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
|
||||||
|
if err == context.Canceled {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sessionID, payload, err := extractSessionID(msg)
|
|
||||||
if err != nil {
|
|
||||||
return uuid.Nil, nil, err
|
|
||||||
}
|
}
|
||||||
return sessionID, payload, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Maximum application payload to send to / receive from QUIC datagram frame
|
func (dm *DatagramMuxer) demux(ctx context.Context, msg []byte) error {
|
||||||
func (dm *DatagramMuxer) MTU() int {
|
sessionID, payload, err := extractSessionID(msg)
|
||||||
return maxDatagramPayloadSize
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sessionDatagram := SessionDatagram{
|
||||||
|
ID: sessionID,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case dm.demuxChan <- &sessionDatagram:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each QUIC datagram should be suffixed with session ID.
|
// Each QUIC datagram should be suffixed with session ID.
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -52,9 +53,29 @@ func TestSuffixSessionIDError(t *testing.T) {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMaxDatagramPayload(t *testing.T) {
|
func TestDatagram(t *testing.T) {
|
||||||
payload := make([]byte, maxDatagramPayloadSize)
|
maxPayload := make([]byte, maxDatagramPayloadSize)
|
||||||
|
noPayloadSession := uuid.New()
|
||||||
|
maxPayloadSession := uuid.New()
|
||||||
|
sessionToPayload := []*SessionDatagram{
|
||||||
|
{
|
||||||
|
ID: noPayloadSession,
|
||||||
|
Payload: make([]byte, 0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: maxPayloadSession,
|
||||||
|
Payload: maxPayload,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
flowPayloads := [][]byte{
|
||||||
|
maxPayload,
|
||||||
|
}
|
||||||
|
|
||||||
|
testDatagram(t, 1, sessionToPayload, nil)
|
||||||
|
testDatagram(t, 2, sessionToPayload, flowPayloads)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDatagram(t *testing.T, version uint8, sessionToPayloads []*SessionDatagram, packetPayloads [][]byte) {
|
||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
KeepAlivePeriod: 5 * time.Millisecond,
|
KeepAlivePeriod: 5 * time.Millisecond,
|
||||||
EnableDatagrams: true,
|
EnableDatagrams: true,
|
||||||
|
@ -63,6 +84,8 @@ func TestMaxDatagramPayload(t *testing.T) {
|
||||||
quicListener := newQUICListener(t, quicConfig)
|
quicListener := newQUICListener(t, quicConfig)
|
||||||
defer quicListener.Close()
|
defer quicListener.Close()
|
||||||
|
|
||||||
|
logger := zerolog.Nop()
|
||||||
|
|
||||||
errGroup, ctx := errgroup.WithContext(context.Background())
|
errGroup, ctx := errgroup.WithContext(context.Background())
|
||||||
// Run edge side of datagram muxer
|
// Run edge side of datagram muxer
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
|
@ -72,22 +95,32 @@ func TestMaxDatagramPayload(t *testing.T) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := zerolog.Nop()
|
sessionDemuxChan := make(chan *SessionDatagram, 16)
|
||||||
muxer, err := NewDatagramMuxer(quicSession, &logger)
|
|
||||||
if err != nil {
|
switch version {
|
||||||
return err
|
case 1:
|
||||||
|
muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan)
|
||||||
|
muxer.ServeReceive(ctx)
|
||||||
|
case 2:
|
||||||
|
packetDemuxChan := make(chan []byte, len(packetPayloads))
|
||||||
|
muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan)
|
||||||
|
muxer.ServeReceive(ctx)
|
||||||
|
|
||||||
|
for _, expectedPayload := range packetPayloads {
|
||||||
|
require.Equal(t, expectedPayload, <-packetDemuxChan)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown datagram version %d", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, receivedPayload, err := muxer.ReceiveFrom()
|
for _, expectedPayload := range sessionToPayloads {
|
||||||
if err != nil {
|
actualPayload := <-sessionDemuxChan
|
||||||
return err
|
require.Equal(t, expectedPayload, actualPayload)
|
||||||
}
|
}
|
||||||
require.Equal(t, testSessionID, sessionID)
|
|
||||||
require.True(t, bytes.Equal(payload, receivedPayload))
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
largePayload := make([]byte, MaxDatagramFrameSize)
|
||||||
// Run cloudflared side of datagram muxer
|
// Run cloudflared side of datagram muxer
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
tlsClientConfig := &tls.Config{
|
tlsClientConfig := &tls.Config{
|
||||||
|
@ -97,24 +130,35 @@ func TestMaxDatagramPayload(t *testing.T) {
|
||||||
// Establish quic connection
|
// Establish quic connection
|
||||||
quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig)
|
quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
defer quicSession.CloseWithError(0, "")
|
||||||
logger := zerolog.Nop()
|
|
||||||
muxer, err := NewDatagramMuxer(quicSession, &logger)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait a few milliseconds for MTU discovery to take place
|
// Wait a few milliseconds for MTU discovery to take place
|
||||||
time.Sleep(time.Millisecond * 100)
|
time.Sleep(time.Millisecond * 100)
|
||||||
err = muxer.SendTo(testSessionID, payload)
|
|
||||||
if err != nil {
|
var muxer BaseDatagramMuxer
|
||||||
return err
|
switch version {
|
||||||
|
case 1:
|
||||||
|
muxer = NewDatagramMuxer(quicSession, &logger, nil)
|
||||||
|
case 2:
|
||||||
|
muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil)
|
||||||
|
for _, payload := range packetPayloads {
|
||||||
|
require.NoError(t, muxerV2.MuxPacket(payload))
|
||||||
|
}
|
||||||
|
// Payload larger than transport MTU, should not be sent
|
||||||
|
require.Error(t, muxerV2.MuxPacket(largePayload))
|
||||||
|
muxer = muxerV2
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown datagram version %d", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Payload larger than transport MTU, should return an error
|
for _, sessionDatagram := range sessionToPayloads {
|
||||||
largePayload := make([]byte, MaxDatagramFrameSize)
|
require.NoError(t, muxer.MuxSession(sessionDatagram.ID, sessionDatagram.Payload))
|
||||||
err = muxer.SendTo(testSessionID, largePayload)
|
}
|
||||||
require.Error(t, err)
|
// Payload larger than transport MTU, should not be sent
|
||||||
|
require.Error(t, muxer.MuxSession(testSessionID, largePayload))
|
||||||
|
|
||||||
|
// Wait for edge to finish receiving the messages
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -154,3 +198,35 @@ func generateTLSConfig() *tls.Config {
|
||||||
NextProtos: []string{"argotunnel"},
|
NextProtos: []string{"argotunnel"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sessionMuxer interface {
|
||||||
|
SendToSession(sessionID uuid.UUID, payload []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockSessionReceiver struct {
|
||||||
|
expectedSessionToPayload map[uuid.UUID][]byte
|
||||||
|
receivedCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (msr *mockSessionReceiver) ReceiveDatagram(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
expectedPayload := msr.expectedSessionToPayload[sessionID]
|
||||||
|
if !bytes.Equal(expectedPayload, payload) {
|
||||||
|
return fmt.Errorf("expect %v to have payload %s, got %s", sessionID, string(expectedPayload), string(payload))
|
||||||
|
}
|
||||||
|
msr.receivedCount++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockFlowReceiver struct {
|
||||||
|
expectedPayloads [][]byte
|
||||||
|
receivedCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mfr *mockFlowReceiver) ReceiveFlow(payload []byte) error {
|
||||||
|
expectedPayload := mfr.expectedPayloads[mfr.receivedCount]
|
||||||
|
if !bytes.Equal(expectedPayload, payload) {
|
||||||
|
return fmt.Errorf("expect flow %d to have payload %s, got %s", mfr.receivedCount, string(expectedPayload), string(payload))
|
||||||
|
}
|
||||||
|
mfr.receivedCount++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
package quic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type datagramV2Type byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
udp datagramV2Type = iota
|
||||||
|
ip
|
||||||
|
)
|
||||||
|
|
||||||
|
func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) {
|
||||||
|
if len(b)+1 > MaxDatagramFrameSize {
|
||||||
|
return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize)
|
||||||
|
}
|
||||||
|
b = append(b, byte(datagramType))
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maximum application payload to send to / receive from QUIC datagram frame
|
||||||
|
func (dm *DatagramMuxerV2) mtu() int {
|
||||||
|
return maxDatagramPayloadSize
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatagramMuxerV2 struct {
|
||||||
|
session quic.Connection
|
||||||
|
logger *zerolog.Logger
|
||||||
|
sessionDemuxChan chan<- *SessionDatagram
|
||||||
|
packetDemuxChan chan<- []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatagramMuxerV2(
|
||||||
|
quicSession quic.Connection,
|
||||||
|
log *zerolog.Logger,
|
||||||
|
sessionDemuxChan chan<- *SessionDatagram,
|
||||||
|
packetDemuxChan chan<- []byte) *DatagramMuxerV2 {
|
||||||
|
logger := log.With().Uint8("datagramVersion", 2).Logger()
|
||||||
|
return &DatagramMuxerV2{
|
||||||
|
session: quicSession,
|
||||||
|
logger: &logger,
|
||||||
|
sessionDemuxChan: sessionDemuxChan,
|
||||||
|
packetDemuxChan: packetDemuxChan,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
|
||||||
|
// demultiplex the payload from multiple datagram sessions
|
||||||
|
func (dm *DatagramMuxerV2) MuxSession(sessionID uuid.UUID, payload []byte) error {
|
||||||
|
if len(payload) > dm.mtu() {
|
||||||
|
// TODO: TUN-5302 return ICMP packet too big message
|
||||||
|
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu())
|
||||||
|
}
|
||||||
|
msgWithID, err := suffixSessionID(sessionID, payload)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
||||||
|
}
|
||||||
|
msgWithIDAndType, err := suffixType(msgWithID, udp)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
||||||
|
}
|
||||||
|
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing
|
||||||
|
// the payload as IP and look at the source and destination.
|
||||||
|
func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error {
|
||||||
|
payloadWithVersion, err := suffixType(packet, ip)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
||||||
|
}
|
||||||
|
if err := dm.session.SendMessage(payloadWithVersion); err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
|
||||||
|
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
|
||||||
|
for {
|
||||||
|
msg, err := dm.session.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := dm.demux(ctx, msg); err != nil {
|
||||||
|
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
|
||||||
|
if err == context.Canceled {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
|
||||||
|
if len(msgWithType) < 1 {
|
||||||
|
return fmt.Errorf("QUIC datagram should have at least 1 byte")
|
||||||
|
}
|
||||||
|
msgType := datagramV2Type(msgWithType[len(msgWithType)-1])
|
||||||
|
msg := msgWithType[0 : len(msgWithType)-1]
|
||||||
|
switch msgType {
|
||||||
|
case udp:
|
||||||
|
sessionID, payload, err := extractSessionID(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sessionDatagram := SessionDatagram{
|
||||||
|
ID: sessionID,
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case dm.sessionDemuxChan <- &sessionDatagram:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
case ip:
|
||||||
|
select {
|
||||||
|
case dm.packetDemuxChan <- msg:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Unexpected datagram type %d", msgType)
|
||||||
|
}
|
||||||
|
}
|
|
@ -33,6 +33,7 @@ const (
|
||||||
FeatureSerializedHeaders = "serialized_headers"
|
FeatureSerializedHeaders = "serialized_headers"
|
||||||
FeatureQuickReconnects = "quick_reconnects"
|
FeatureQuickReconnects = "quick_reconnects"
|
||||||
FeatureAllowRemoteConfig = "allow_remote_config"
|
FeatureAllowRemoteConfig = "allow_remote_config"
|
||||||
|
FeatureDatagramV2 = "support_datagram_v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelConfig struct {
|
type TunnelConfig struct {
|
||||||
|
|
Loading…
Reference in New Issue