From 278df5478a7cf7a411a44f8686cc3cd835938afd Mon Sep 17 00:00:00 2001 From: cthuang Date: Mon, 1 Aug 2022 13:48:33 +0100 Subject: [PATCH] TUN-6584: Define QUIC datagram v2 format to support proxying IP packets --- connection/quic.go | 28 ++-- datagramsession/event.go | 6 - datagramsession/manager.go | 101 +++++-------- datagramsession/manager_test.go | 244 ++++++++++++++---------------- datagramsession/session.go | 60 ++++---- datagramsession/session_test.go | 57 +++++-- datagramsession/transport.go | 13 -- datagramsession/transport_test.go | 36 ----- quic/datagram.go | 92 +++++++---- quic/datagram_test.go | 126 ++++++++++++--- quic/datagramv2.go | 136 +++++++++++++++++ supervisor/tunnel.go | 1 + 12 files changed, 548 insertions(+), 352 deletions(-) delete mode 100644 datagramsession/transport.go delete mode 100644 datagramsession/transport_test.go create mode 100644 quic/datagramv2.go diff --git a/connection/quic.go b/connection/quic.go index 84bbb1e8..a7f15e69 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -33,14 +33,19 @@ const ( HTTPHostKey = "HttpHost" QUICMetadataFlowID = "FlowID" + // emperically this capacity has been working well + demuxChanCapacity = 16 ) // QUICConnection represents the type that facilitates Proxying via QUIC streams. type QUICConnection struct { - session quic.Connection - logger *zerolog.Logger - orchestrator Orchestrator - sessionManager datagramsession.Manager + session quic.Connection + logger *zerolog.Logger + orchestrator Orchestrator + // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer + sessionManager datagramsession.Manager + // datagramMuxer mux/demux datagrams from quic connection + datagramMuxer *quicpogs.DatagramMuxer controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions } @@ -60,18 +65,16 @@ func NewQUICConnection( return nil, &EdgeQuicDialError{Cause: err} } - datagramMuxer, err := quicpogs.NewDatagramMuxer(session, logger) - if err != nil { - return nil, err - } - - sessionManager := datagramsession.NewManager(datagramMuxer, logger) + demuxChan := make(chan *quicpogs.SessionDatagram, demuxChanCapacity) + datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan) + sessionManager := datagramsession.NewManager(logger, datagramMuxer.MuxSession, demuxChan) return &QUICConnection{ session: session, orchestrator: orchestrator, logger: logger, sessionManager: sessionManager, + datagramMuxer: datagramMuxer, controlStreamHandler: controlStreamHandler, connOptions: connOptions, }, nil @@ -108,6 +111,11 @@ func (q *QUICConnection) Serve(ctx context.Context) error { return q.sessionManager.Serve(ctx) }) + errGroup.Go(func() error { + defer cancel() + return q.datagramMuxer.ServeReceive(ctx) + }) + return errGroup.Wait() } diff --git a/datagramsession/event.go b/datagramsession/event.go index d79c6b31..2e91964e 100644 --- a/datagramsession/event.go +++ b/datagramsession/event.go @@ -42,9 +42,3 @@ func (sc *errClosedSession) Error() string { return fmt.Sprintf("session closed by local due to %s", sc.message) } } - -// newDatagram is an event when transport receives new datagram -type newDatagram struct { - sessionID uuid.UUID - payload []byte -} diff --git a/datagramsession/manager.go b/datagramsession/manager.go index 54c5c24d..3cf198bd 100644 --- a/datagramsession/manager.go +++ b/datagramsession/manager.go @@ -7,9 +7,9 @@ import ( "time" "github.com/google/uuid" - "github.com/lucas-clemente/quic-go" "github.com/rs/zerolog" - "golang.org/x/sync/errgroup" + + quicpogs "github.com/cloudflare/cloudflared/quic" ) const ( @@ -36,26 +36,25 @@ type Manager interface { type manager struct { registrationChan chan *registerSessionEvent unregistrationChan chan *unregisterSessionEvent - datagramChan chan *newDatagram - closedChan chan struct{} - transport transport + sendFunc transportSender + receiveChan <-chan *quicpogs.SessionDatagram + closedChan <-chan struct{} sessions map[uuid.UUID]*Session log *zerolog.Logger // timeout waiting for an API to finish. This can be overriden in test timeout time.Duration } -func NewManager(transport transport, log *zerolog.Logger) *manager { +func NewManager(log *zerolog.Logger, sendF transportSender, receiveChan <-chan *quicpogs.SessionDatagram) *manager { return &manager{ registrationChan: make(chan *registerSessionEvent), unregistrationChan: make(chan *unregisterSessionEvent), - // datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events - datagramChan: make(chan *newDatagram, requestChanCapacity), - closedChan: make(chan struct{}), - transport: transport, - sessions: make(map[uuid.UUID]*Session), - log: log, - timeout: defaultReqTimeout, + sendFunc: sendF, + receiveChan: receiveChan, + closedChan: make(chan struct{}), + sessions: make(map[uuid.UUID]*Session), + log: log, + timeout: defaultReqTimeout, } } @@ -65,49 +64,21 @@ func (m *manager) UpdateLogger(log *zerolog.Logger) { } func (m *manager) Serve(ctx context.Context) error { - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - for { - sessionID, payload, err := m.transport.ReceiveFrom() - if err != nil { - if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) { - return nil - } else { - return err - } - } - datagram := &newDatagram{ - sessionID: sessionID, - payload: payload, - } - select { - case <-ctx.Done(): - return ctx.Err() - // Only the event loop routine can update/lookup the sessions map to avoid concurrent access - // Send the datagram to the event loop. It will find the session to send to - case m.datagramChan <- datagram: - } + for { + select { + case <-ctx.Done(): + m.shutdownSessions(ctx.Err()) + return ctx.Err() + // receiveChan is buffered, so the transport can read more datagrams from transport while the event loop is + // processing other events + case datagram := <-m.receiveChan: + m.sendToSession(datagram) + case registration := <-m.registrationChan: + m.registerSession(ctx, registration) + case unregistration := <-m.unregistrationChan: + m.unregisterSession(unregistration) } - }) - errGroup.Go(func() error { - for { - select { - case <-ctx.Done(): - return nil - case datagram := <-m.datagramChan: - m.sendToSession(datagram) - case registration := <-m.registrationChan: - m.registerSession(ctx, registration) - // TODO: TUN-5422: Unregister inactive session upon timeout - case unregistration := <-m.unregistrationChan: - m.unregisterSession(unregistration) - } - } - }) - err := errGroup.Wait() - close(m.closedChan) - m.shutdownSessions(err) - return err + } } func (m *manager) shutdownSessions(err error) { @@ -149,16 +120,17 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes } func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session { + logger := m.log.With().Str("sessionID", id.String()).Logger() return &Session{ - ID: id, - transport: m.transport, - dstConn: dstConn, + ID: id, + sendFunc: m.sendFunc, + dstConn: dstConn, // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // drop instead of blocking because last active time only needs to be an approximation activeAtChan: make(chan time.Time, 2), // capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel closeChan: make(chan error, 2), - log: m.log, + log: &logger, } } @@ -191,16 +163,13 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) { } } -func (m *manager) sendToSession(datagram *newDatagram) { - session, ok := m.sessions[datagram.sessionID] +func (m *manager) sendToSession(datagram *quicpogs.SessionDatagram) { + session, ok := m.sessions[datagram.ID] if !ok { - m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found") + m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found") return } // session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't // need to run in another go routine - _, err := session.transportToDst(datagram.payload) - if err != nil { - m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session") - } + session.transportToDst(datagram.Payload) } diff --git a/datagramsession/manager_test.go b/datagramsession/manager_test.go index 7b17bf7d..1d73b33f 100644 --- a/datagramsession/manager_test.go +++ b/datagramsession/manager_test.go @@ -14,22 +14,30 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + + quicpogs "github.com/cloudflare/cloudflared/quic" +) + +var ( + nopLogger = zerolog.Nop() ) func TestManagerServe(t *testing.T) { const ( - sessions = 20 - msgs = 50 + sessions = 2 + msgs = 5 remoteUnregisterMsg = "eyeball closed connection" ) - mg, transport := newTestManager(1) - - eyeballTracker := make(map[uuid.UUID]*datagramChannel) - for i := 0; i < sessions; i++ { - sessionID := uuid.New() - eyeballTracker[sessionID] = newDatagramChannel(1) + requestChan := make(chan *quicpogs.SessionDatagram) + transport := mockQUICTransport{ + sessions: make(map[uuid.UUID]chan []byte), } + for i := 0; i < sessions; i++ { + transport.sessions[uuid.New()] = make(chan []byte) + } + + mg := NewManager(&nopLogger, transport.MuxSession, requestChan) ctx, cancel := context.WithCancel(context.Background()) serveDone := make(chan struct{}) @@ -38,53 +46,42 @@ func TestManagerServe(t *testing.T) { close(serveDone) }(ctx) - go func(ctx context.Context) { - for { - sessionID, payload, err := transport.respChan.Receive(ctx) - if err != nil { - require.Equal(t, context.Canceled, err) - return - } - respChan := eyeballTracker[sessionID] - require.NoError(t, respChan.Send(ctx, sessionID, payload)) - } - }(ctx) - errGroup, ctx := errgroup.WithContext(ctx) - for sID, receiver := range eyeballTracker { + for sessionID, eyeballRespChan := range transport.sessions { + // Assign loop variables to local variables + sID := sessionID + payload := testPayload(sID) + expectResp := testResponse(payload) + + cfdConn, originConn := net.Pipe() + + origin := mockOrigin{ + expectMsgCount: msgs, + expectedMsg: payload, + expectedResp: expectResp, + conn: originConn, + } + + eyeball := mockEyeballSession{ + id: sID, + expectedMsgCount: msgs, + expectedMsg: payload, + expectedResponse: expectResp, + respReceiver: eyeballRespChan, + } + // Assign loop variables to local variables - sessionID := sID - eyeballRespReceiver := receiver errGroup.Go(func() error { - payload := testPayload(sessionID) - expectResp := testResponse(payload) - - cfdConn, originConn := net.Pipe() - - origin := mockOrigin{ - expectMsgCount: msgs, - expectedMsg: payload, - expectedResp: expectResp, - conn: originConn, - } - eyeball := mockEyeball{ - expectMsgCount: msgs, - expectedMsg: expectResp, - expectSessionID: sessionID, - respReceiver: eyeballRespReceiver, - } - + session, err := mg.RegisterSession(ctx, sID, cfdConn) + require.NoError(t, err) reqErrGroup, reqCtx := errgroup.WithContext(ctx) reqErrGroup.Go(func() error { return origin.serve() }) reqErrGroup.Go(func() error { - return eyeball.serve(reqCtx) + return eyeball.serve(reqCtx, requestChan) }) - session, err := mg.RegisterSession(ctx, sessionID, cfdConn) - require.NoError(t, err) - sessionDone := make(chan struct{}) go func() { closedByRemote, err := session.Serve(ctx, time.Minute*2) @@ -97,23 +94,17 @@ func TestManagerServe(t *testing.T) { close(sessionDone) }() - for i := 0; i < msgs; i++ { - require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID))) - } - // Make sure eyeball and origin have received all messages before unregistering the session require.NoError(t, reqErrGroup.Wait()) - require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true)) + require.NoError(t, mg.UnregisterSession(ctx, sID, remoteUnregisterMsg, true)) <-sessionDone - return nil }) } require.NoError(t, errGroup.Wait()) cancel() - transport.close() <-serveDone } @@ -122,7 +113,7 @@ func TestTimeout(t *testing.T) { testTimeout = time.Millisecond * 50 ) - mg, _ := newTestManager(1) + mg := NewManager(&nopLogger, nil, nil) mg.timeout = testTimeout ctx := context.Background() sessionID := uuid.New() @@ -135,9 +126,51 @@ func TestTimeout(t *testing.T) { require.ErrorIs(t, err, context.DeadlineExceeded) } -func TestCloseTransportCloseSessions(t *testing.T) { - mg, transport := newTestManager(1) - ctx := context.Background() +func TestUnregisterSessionCloseSession(t *testing.T) { + sessionID := uuid.New() + payload := []byte(t.Name()) + sender := newMockTransportSender(sessionID, payload) + mg := NewManager(&nopLogger, sender.muxSession, nil) + ctx, cancel := context.WithCancel(context.Background()) + + managerDone := make(chan struct{}) + go func() { + err := mg.Serve(ctx) + require.Error(t, err) + close(managerDone) + }() + + cfdConn, originConn := net.Pipe() + session, err := mg.RegisterSession(ctx, sessionID, cfdConn) + require.NoError(t, err) + require.NotNil(t, session) + + unregisteredChan := make(chan struct{}) + go func() { + _, err := originConn.Write(payload) + require.NoError(t, err) + + err = mg.UnregisterSession(ctx, sessionID, "eyeball closed session", true) + require.NoError(t, err) + + close(unregisteredChan) + }() + + closedByRemote, err := session.Serve(ctx, time.Minute) + require.True(t, closedByRemote) + require.Error(t, err) + + <-unregisteredChan + cancel() + <-managerDone +} + +func TestManagerCtxDoneCloseSessions(t *testing.T) { + sessionID := uuid.New() + payload := []byte(t.Name()) + sender := newMockTransportSender(sessionID, payload) + mg := NewManager(&nopLogger, sender.muxSession, nil) + ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) @@ -147,35 +180,26 @@ func TestCloseTransportCloseSessions(t *testing.T) { require.Error(t, err) }() - cfdConn, eyeballConn := net.Pipe() - session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn) + cfdConn, originConn := net.Pipe() + session, err := mg.RegisterSession(ctx, sessionID, cfdConn) require.NoError(t, err) require.NotNil(t, session) wg.Add(1) go func() { defer wg.Done() - _, err := eyeballConn.Write([]byte(t.Name())) + _, err := originConn.Write(payload) require.NoError(t, err) - transport.close() + cancel() }() closedByRemote, err := session.Serve(ctx, time.Minute) - require.True(t, closedByRemote) + require.False(t, closedByRemote) require.Error(t, err) wg.Wait() } -func newTestManager(capacity uint) (*manager, *mockQUICTransport) { - log := zerolog.Nop() - transport := &mockQUICTransport{ - reqChan: newDatagramChannel(capacity), - respChan: newDatagramChannel(capacity), - } - return NewManager(transport, &log), transport -} - type mockOrigin struct { expectMsgCount int expectedMsg []byte @@ -197,7 +221,6 @@ func (mo *mockOrigin) serve() error { if !bytes.Equal(readBuffer[:n], mo.expectedMsg) { return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n]) } - _, err = mo.conn.Write(mo.expectedResp) if err != nil { return err @@ -214,72 +237,35 @@ func testResponse(msg []byte) []byte { return []byte(fmt.Sprintf("Response to %v", msg)) } -type mockEyeball struct { - expectMsgCount int - expectedMsg []byte - expectSessionID uuid.UUID - respReceiver *datagramChannel +type mockQUICTransport struct { + sessions map[uuid.UUID]chan []byte } -func (me *mockEyeball) serve(ctx context.Context) error { - for i := 0; i < me.expectMsgCount; i++ { - sessionID, msg, err := me.respReceiver.Receive(ctx) - if err != nil { - return err - } - if sessionID != me.expectSessionID { - return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID) - } - if !bytes.Equal(msg, me.expectedMsg) { - return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg) - } - } +func (me *mockQUICTransport) MuxSession(id uuid.UUID, payload []byte) error { + session := me.sessions[id] + session <- payload return nil } -// datagramChannel is a channel for Datagram with wrapper to send/receive with context -type datagramChannel struct { - datagramChan chan *newDatagram - closedChan chan struct{} +type mockEyeballSession struct { + id uuid.UUID + expectedMsgCount int + expectedMsg []byte + expectedResponse []byte + respReceiver <-chan []byte } -func newDatagramChannel(capacity uint) *datagramChannel { - return &datagramChannel{ - datagramChan: make(chan *newDatagram, capacity), - closedChan: make(chan struct{}), - } -} - -func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error { - select { - case <-ctx.Done(): - return ctx.Err() - case <-rc.closedChan: - return &errClosedSession{ - message: fmt.Errorf("datagram channel closed").Error(), - byRemote: true, +func (me *mockEyeballSession) serve(ctx context.Context, requestChan chan *quicpogs.SessionDatagram) error { + for i := 0; i < me.expectedMsgCount; i++ { + requestChan <- &quicpogs.SessionDatagram{ + ID: me.id, + Payload: me.expectedMsg, } - case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}: - return nil - } -} - -func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) { - select { - case <-ctx.Done(): - return uuid.Nil, nil, ctx.Err() - case <-rc.closedChan: - err := &errClosedSession{ - message: fmt.Errorf("datagram channel closed").Error(), - byRemote: true, + resp := <-me.respReceiver + if !bytes.Equal(resp, me.expectedResponse) { + return fmt.Errorf("Expect %v, read %v", me.expectedResponse, resp) } - return uuid.Nil, nil, err - case msg := <-rc.datagramChan: - return msg.sessionID, msg.payload, nil + fmt.Println("Resp", resp) } -} - -func (rc *datagramChannel) Close() { - // No need to close msgChan, it will be garbage collect once there is no reference to it - close(rc.closedChan) + return nil } diff --git a/datagramsession/session.go b/datagramsession/session.go index 351ae298..8851b01e 100644 --- a/datagramsession/session.go +++ b/datagramsession/session.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "time" "github.com/google/uuid" @@ -18,21 +19,20 @@ func SessionIdleErr(timeout time.Duration) error { return fmt.Errorf("session idle for %v", timeout) } +type transportSender func(sessionID uuid.UUID, payload []byte) error + // Session is a bidirectional pipe of datagrams between transport and dstConn -// Currently the only implementation of transport is quic DatagramMuxer // Destination can be a connection with origin or with eyeball // When the destination is origin: -// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the -// write method of the Session to send to origin -// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball +// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to origin +// - Datagrams from origin are read from conn and Send to transport using the transportSender callback. Transport will return them to eyeball // When the destination is eyeball: -// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared -// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the -// write method of the Session to send to eyeball +// - Datagrams from eyeball are read from conn and Send to transport. Transport will send them to cloudflared using the transportSender callback. +// - Manager receives datagrams from receiveChan and calls the transportToDst method of the Session to send to the eyeball type Session struct { - ID uuid.UUID - transport transport - dstConn io.ReadWriteCloser + ID uuid.UUID + sendFunc transportSender + dstConn io.ReadWriteCloser // activeAtChan is used to communicate the last read/write time activeAtChan chan time.Time closeChan chan error @@ -46,9 +46,16 @@ func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (clos const maxPacketSize = 1500 readBuffer := make([]byte, maxPacketSize) for { - if err := s.dstToTransport(readBuffer); err != nil { - s.closeChan <- err - return + if closeSession, err := s.dstToTransport(readBuffer); err != nil { + if err != net.ErrClosed { + s.log.Error().Err(err).Msg("Failed to send session payload from destination to transport") + } else { + s.log.Debug().Msg("Session cannot read from destination because the connection is closed") + } + if closeSession { + s.closeChan <- err + return + } } } }() @@ -89,32 +96,25 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time } } -func (s *Session) dstToTransport(buffer []byte) error { +func (s *Session) dstToTransport(buffer []byte) (closeSession bool, err error) { n, err := s.dstConn.Read(buffer) s.markActive() // https://pkg.go.dev/io#Reader suggests caller should always process n > 0 bytes - if n > 0 { - if n <= int(s.transport.MTU()) { - err = s.transport.SendTo(s.ID, buffer[:n]) - } else { - // drop packet for now, eventually reply with ICMP for PMTUD - s.log.Debug(). - Str("session", s.ID.String()). - Int("len", n). - Int("mtu", s.transport.MTU()). - Msg("dropped packet exceeding MTU") + if n > 0 || err == nil { + if sendErr := s.sendFunc(s.ID, buffer[:n]); sendErr != nil { + return false, sendErr } } - // Some UDP application might send 0-size payload. - if err == nil && n == 0 { - err = s.transport.SendTo(s.ID, []byte{}) - } - return err + return err != nil, err } func (s *Session) transportToDst(payload []byte) (int, error) { s.markActive() - return s.dstConn.Write(payload) + n, err := s.dstConn.Write(payload) + if err != nil { + s.log.Err(err).Msg("Failed to write payload to session") + } + return n, err } // Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there diff --git a/datagramsession/session_test.go b/datagramsession/session_test.go index db4201f6..b0b7a66e 100644 --- a/datagramsession/session_test.go +++ b/datagramsession/session_test.go @@ -11,8 +11,11 @@ import ( "time" "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + + quicpogs "github.com/cloudflare/cloudflared/quic" ) // TestCloseSession makes sure a session will stop after context is done @@ -41,7 +44,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D cfdConn, originConn := net.Pipe() payload := testPayload(sessionID) - mg, _ := newTestManager(1) + log := zerolog.Nop() + mg := NewManager(&log, nil, nil) session := mg.newSession(sessionID, cfdConn) ctx, cancel := context.WithCancel(context.Background()) @@ -114,7 +118,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) cfdConn, originConn := net.Pipe() payload := testPayload(sessionID) - mg, _ := newTestManager(100) + respChan := make(chan *quicpogs.SessionDatagram) + sender := newMockTransportSender(sessionID, payload) + mg := NewManager(&nopLogger, sender.muxSession, respChan) session := mg.newSession(sessionID, cfdConn) startTime := time.Now() @@ -177,7 +183,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) func TestMarkActiveNotBlocking(t *testing.T) { const concurrentCalls = 50 - mg, _ := newTestManager(1) + mg := NewManager(&nopLogger, nil, nil) session := mg.newSession(uuid.New(), nil) var wg sync.WaitGroup wg.Add(concurrentCalls) @@ -190,11 +196,16 @@ func TestMarkActiveNotBlocking(t *testing.T) { wg.Wait() } +// Some UDP application might send 0-size payload. func TestZeroBytePayload(t *testing.T) { sessionID := uuid.New() cfdConn, originConn := net.Pipe() - mg, transport := newTestManager(1) + sender := sendOnceTransportSender{ + baseSender: newMockTransportSender(sessionID, make([]byte, 0)), + sentChan: make(chan struct{}), + } + mg := NewManager(&nopLogger, sender.muxSession, nil) session := mg.newSession(sessionID, cfdConn) ctx, cancel := context.WithCancel(context.Background()) @@ -215,11 +226,39 @@ func TestZeroBytePayload(t *testing.T) { return nil }) - receivedSessionID, payload, err := transport.respChan.Receive(ctx) - require.NoError(t, err) - require.Len(t, payload, 0) - require.Equal(t, sessionID, receivedSessionID) - + <-sender.sentChan cancel() require.NoError(t, errGroup.Wait()) } + +type mockTransportSender struct { + expectedSessionID uuid.UUID + expectedPayload []byte +} + +func newMockTransportSender(expectedSessionID uuid.UUID, expectedPayload []byte) *mockTransportSender { + return &mockTransportSender{ + expectedSessionID: expectedSessionID, + expectedPayload: expectedPayload, + } +} + +func (mts *mockTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error { + if sessionID != mts.expectedSessionID { + return fmt.Errorf("Expect session %s, got %s", mts.expectedSessionID, sessionID) + } + if !bytes.Equal(payload, mts.expectedPayload) { + return fmt.Errorf("Expect %v, read %v", mts.expectedPayload, payload) + } + return nil +} + +type sendOnceTransportSender struct { + baseSender *mockTransportSender + sentChan chan struct{} +} + +func (sots *sendOnceTransportSender) muxSession(sessionID uuid.UUID, payload []byte) error { + defer close(sots.sentChan) + return sots.baseSender.muxSession(sessionID, payload) +} diff --git a/datagramsession/transport.go b/datagramsession/transport.go deleted file mode 100644 index f41e1e83..00000000 --- a/datagramsession/transport.go +++ /dev/null @@ -1,13 +0,0 @@ -package datagramsession - -import "github.com/google/uuid" - -// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions -type transport interface { - // SendTo writes payload for a session to the transport - SendTo(sessionID uuid.UUID, payload []byte) error - // ReceiveFrom reads the next datagram from the transport - ReceiveFrom() (uuid.UUID, []byte, error) - // Max transmission unit to receive from the transport - MTU() int -} diff --git a/datagramsession/transport_test.go b/datagramsession/transport_test.go deleted file mode 100644 index 6b187722..00000000 --- a/datagramsession/transport_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package datagramsession - -import ( - "context" - - "github.com/google/uuid" -) - -type mockQUICTransport struct { - reqChan *datagramChannel - respChan *datagramChannel -} - -func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error { - buf := make([]byte, len(payload)) - // The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 - copy(buf, payload) - return mt.respChan.Send(context.Background(), sessionID, buf) -} - -func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) { - return mt.reqChan.Receive(context.Background()) -} - -func (mt *mockQUICTransport) MTU() int { - return 1280 -} - -func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error { - return mt.reqChan.Send(ctx, sessionID, payload) -} - -func (mt *mockQUICTransport) close() { - mt.reqChan.Close() - mt.respChan.Close() -} diff --git a/quic/datagram.go b/quic/datagram.go index 350b601f..e8e4fad7 100644 --- a/quic/datagram.go +++ b/quic/datagram.go @@ -1,6 +1,7 @@ package quic import ( + "context" "fmt" "github.com/google/uuid" @@ -13,53 +14,88 @@ const ( sessionIDLen = len(uuid.UUID{}) ) +type SessionDatagram struct { + ID uuid.UUID + Payload []byte +} + +type BaseDatagramMuxer interface { + // MuxSession suffix the session ID to the payload so the other end of the QUIC connection can demultiplex the + // payload from multiple datagram sessions + MuxSession(sessionID uuid.UUID, payload []byte) error + // ServeReceive starts a loop to receive datagrams from the QUIC connection + ServeReceive(ctx context.Context) error +} + type DatagramMuxer struct { - session quic.Connection - logger *zerolog.Logger + session quic.Connection + logger *zerolog.Logger + demuxChan chan<- *SessionDatagram } -func NewDatagramMuxer(quicSession quic.Connection, logger *zerolog.Logger) (*DatagramMuxer, error) { +func NewDatagramMuxer(quicSession quic.Connection, log *zerolog.Logger, demuxChan chan<- *SessionDatagram) *DatagramMuxer { + logger := log.With().Uint8("datagramVersion", 1).Logger() return &DatagramMuxer{ - session: quicSession, - logger: logger, - }, nil + session: quicSession, + logger: &logger, + demuxChan: demuxChan, + } } -// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex -// the payload from multiple datagram sessions -func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error { - if len(payload) > maxDatagramPayloadSize { +// Maximum application payload to send to / receive from QUIC datagram frame +func (dm *DatagramMuxer) mtu() int { + return maxDatagramPayloadSize +} + +func (dm *DatagramMuxer) MuxSession(sessionID uuid.UUID, payload []byte) error { + if len(payload) > dm.mtu() { // TODO: TUN-5302 return ICMP packet too big message - return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.MTU()) + // drop packet for now, eventually reply with ICMP for PMTUD + return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu()) } - msgWithID, err := suffixSessionID(sessionID, payload) + payloadWithMetadata, err := suffixSessionID(sessionID, payload) if err != nil { return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped") } - if err := dm.session.SendMessage(msgWithID); err != nil { + if err := dm.session.SendMessage(payloadWithMetadata); err != nil { return errors.Wrap(err, "Failed to send datagram back to edge") } return nil } -// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager -// which determines how to proxy to the origin. It assumes the datagram session has already been -// registered with session manager through other side channel -func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) { - msg, err := dm.session.ReceiveMessage() - if err != nil { - return uuid.Nil, nil, err +func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error { + for { + // Extracts datagram session ID, then sends the session ID and payload to receiver + // which determines how to proxy to the origin. It assumes the datagram session has already been + // registered with receiver through other side channel + msg, err := dm.session.ReceiveMessage() + if err != nil { + return err + } + if err := dm.demux(ctx, msg); err != nil { + dm.logger.Error().Err(err).Msg("Failed to demux datagram") + if err == context.Canceled { + return err + } + } } - sessionID, payload, err := extractSessionID(msg) - if err != nil { - return uuid.Nil, nil, err - } - return sessionID, payload, nil } -// Maximum application payload to send to / receive from QUIC datagram frame -func (dm *DatagramMuxer) MTU() int { - return maxDatagramPayloadSize +func (dm *DatagramMuxer) demux(ctx context.Context, msg []byte) error { + sessionID, payload, err := extractSessionID(msg) + if err != nil { + return err + } + sessionDatagram := SessionDatagram{ + ID: sessionID, + Payload: payload, + } + select { + case dm.demuxChan <- &sessionDatagram: + return nil + case <-ctx.Done(): + return ctx.Err() + } } // Each QUIC datagram should be suffixed with session ID. diff --git a/quic/datagram_test.go b/quic/datagram_test.go index ac32f410..00673835 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/pem" + "fmt" "math/big" "testing" "time" @@ -52,9 +53,29 @@ func TestSuffixSessionIDError(t *testing.T) { require.Error(t, err) } -func TestMaxDatagramPayload(t *testing.T) { - payload := make([]byte, maxDatagramPayloadSize) +func TestDatagram(t *testing.T) { + maxPayload := make([]byte, maxDatagramPayloadSize) + noPayloadSession := uuid.New() + maxPayloadSession := uuid.New() + sessionToPayload := []*SessionDatagram{ + { + ID: noPayloadSession, + Payload: make([]byte, 0), + }, + { + ID: maxPayloadSession, + Payload: maxPayload, + }, + } + flowPayloads := [][]byte{ + maxPayload, + } + testDatagram(t, 1, sessionToPayload, nil) + testDatagram(t, 2, sessionToPayload, flowPayloads) +} + +func testDatagram(t *testing.T, version uint8, sessionToPayloads []*SessionDatagram, packetPayloads [][]byte) { quicConfig := &quic.Config{ KeepAlivePeriod: 5 * time.Millisecond, EnableDatagrams: true, @@ -63,6 +84,8 @@ func TestMaxDatagramPayload(t *testing.T) { quicListener := newQUICListener(t, quicConfig) defer quicListener.Close() + logger := zerolog.Nop() + errGroup, ctx := errgroup.WithContext(context.Background()) // Run edge side of datagram muxer errGroup.Go(func() error { @@ -72,22 +95,32 @@ func TestMaxDatagramPayload(t *testing.T) { return err } - logger := zerolog.Nop() - muxer, err := NewDatagramMuxer(quicSession, &logger) - if err != nil { - return err + sessionDemuxChan := make(chan *SessionDatagram, 16) + + switch version { + case 1: + muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan) + muxer.ServeReceive(ctx) + case 2: + packetDemuxChan := make(chan []byte, len(packetPayloads)) + muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan) + muxer.ServeReceive(ctx) + + for _, expectedPayload := range packetPayloads { + require.Equal(t, expectedPayload, <-packetDemuxChan) + } + default: + return fmt.Errorf("unknown datagram version %d", version) } - sessionID, receivedPayload, err := muxer.ReceiveFrom() - if err != nil { - return err + for _, expectedPayload := range sessionToPayloads { + actualPayload := <-sessionDemuxChan + require.Equal(t, expectedPayload, actualPayload) } - require.Equal(t, testSessionID, sessionID) - require.True(t, bytes.Equal(payload, receivedPayload)) - return nil }) + largePayload := make([]byte, MaxDatagramFrameSize) // Run cloudflared side of datagram muxer errGroup.Go(func() error { tlsClientConfig := &tls.Config{ @@ -97,24 +130,35 @@ func TestMaxDatagramPayload(t *testing.T) { // Establish quic connection quicSession, err := quic.DialAddrEarly(quicListener.Addr().String(), tlsClientConfig, quicConfig) require.NoError(t, err) - - logger := zerolog.Nop() - muxer, err := NewDatagramMuxer(quicSession, &logger) - if err != nil { - return err - } + defer quicSession.CloseWithError(0, "") // Wait a few milliseconds for MTU discovery to take place time.Sleep(time.Millisecond * 100) - err = muxer.SendTo(testSessionID, payload) - if err != nil { - return err + + var muxer BaseDatagramMuxer + switch version { + case 1: + muxer = NewDatagramMuxer(quicSession, &logger, nil) + case 2: + muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil) + for _, payload := range packetPayloads { + require.NoError(t, muxerV2.MuxPacket(payload)) + } + // Payload larger than transport MTU, should not be sent + require.Error(t, muxerV2.MuxPacket(largePayload)) + muxer = muxerV2 + default: + return fmt.Errorf("unknown datagram version %d", version) } - // Payload larger than transport MTU, should return an error - largePayload := make([]byte, MaxDatagramFrameSize) - err = muxer.SendTo(testSessionID, largePayload) - require.Error(t, err) + for _, sessionDatagram := range sessionToPayloads { + require.NoError(t, muxer.MuxSession(sessionDatagram.ID, sessionDatagram.Payload)) + } + // Payload larger than transport MTU, should not be sent + require.Error(t, muxer.MuxSession(testSessionID, largePayload)) + + // Wait for edge to finish receiving the messages + time.Sleep(time.Millisecond * 100) return nil }) @@ -154,3 +198,35 @@ func generateTLSConfig() *tls.Config { NextProtos: []string{"argotunnel"}, } } + +type sessionMuxer interface { + SendToSession(sessionID uuid.UUID, payload []byte) error +} + +type mockSessionReceiver struct { + expectedSessionToPayload map[uuid.UUID][]byte + receivedCount int +} + +func (msr *mockSessionReceiver) ReceiveDatagram(sessionID uuid.UUID, payload []byte) error { + expectedPayload := msr.expectedSessionToPayload[sessionID] + if !bytes.Equal(expectedPayload, payload) { + return fmt.Errorf("expect %v to have payload %s, got %s", sessionID, string(expectedPayload), string(payload)) + } + msr.receivedCount++ + return nil +} + +type mockFlowReceiver struct { + expectedPayloads [][]byte + receivedCount int +} + +func (mfr *mockFlowReceiver) ReceiveFlow(payload []byte) error { + expectedPayload := mfr.expectedPayloads[mfr.receivedCount] + if !bytes.Equal(expectedPayload, payload) { + return fmt.Errorf("expect flow %d to have payload %s, got %s", mfr.receivedCount, string(expectedPayload), string(payload)) + } + mfr.receivedCount++ + return nil +} diff --git a/quic/datagramv2.go b/quic/datagramv2.go new file mode 100644 index 00000000..73e55f7d --- /dev/null +++ b/quic/datagramv2.go @@ -0,0 +1,136 @@ +package quic + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/lucas-clemente/quic-go" + "github.com/pkg/errors" + "github.com/rs/zerolog" +) + +type datagramV2Type byte + +const ( + udp datagramV2Type = iota + ip +) + +func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) { + if len(b)+1 > MaxDatagramFrameSize { + return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize) + } + b = append(b, byte(datagramType)) + return b, nil +} + +// Maximum application payload to send to / receive from QUIC datagram frame +func (dm *DatagramMuxerV2) mtu() int { + return maxDatagramPayloadSize +} + +type DatagramMuxerV2 struct { + session quic.Connection + logger *zerolog.Logger + sessionDemuxChan chan<- *SessionDatagram + packetDemuxChan chan<- []byte +} + +func NewDatagramMuxerV2( + quicSession quic.Connection, + log *zerolog.Logger, + sessionDemuxChan chan<- *SessionDatagram, + packetDemuxChan chan<- []byte) *DatagramMuxerV2 { + logger := log.With().Uint8("datagramVersion", 2).Logger() + return &DatagramMuxerV2{ + session: quicSession, + logger: &logger, + sessionDemuxChan: sessionDemuxChan, + packetDemuxChan: packetDemuxChan, + } +} + +// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can +// demultiplex the payload from multiple datagram sessions +func (dm *DatagramMuxerV2) MuxSession(sessionID uuid.UUID, payload []byte) error { + if len(payload) > dm.mtu() { + // TODO: TUN-5302 return ICMP packet too big message + return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), dm.mtu()) + } + msgWithID, err := suffixSessionID(sessionID, payload) + if err != nil { + return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped") + } + msgWithIDAndType, err := suffixType(msgWithID, udp) + if err != nil { + return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") + } + if err := dm.session.SendMessage(msgWithIDAndType); err != nil { + return errors.Wrap(err, "Failed to send datagram back to edge") + } + return nil +} + +// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing +// the payload as IP and look at the source and destination. +func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error { + payloadWithVersion, err := suffixType(packet, ip) + if err != nil { + return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") + } + if err := dm.session.SendMessage(payloadWithVersion); err != nil { + return errors.Wrap(err, "Failed to send datagram back to edge") + } + return nil +} + +// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet +func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error { + for { + msg, err := dm.session.ReceiveMessage() + if err != nil { + return err + } + if err := dm.demux(ctx, msg); err != nil { + dm.logger.Error().Err(err).Msg("Failed to demux datagram") + if err == context.Canceled { + return err + } + } + } +} + +func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error { + if len(msgWithType) < 1 { + return fmt.Errorf("QUIC datagram should have at least 1 byte") + } + msgType := datagramV2Type(msgWithType[len(msgWithType)-1]) + msg := msgWithType[0 : len(msgWithType)-1] + switch msgType { + case udp: + sessionID, payload, err := extractSessionID(msg) + if err != nil { + return err + } + sessionDatagram := SessionDatagram{ + ID: sessionID, + Payload: payload, + } + select { + case dm.sessionDemuxChan <- &sessionDatagram: + return nil + case <-ctx.Done(): + return ctx.Err() + } + case ip: + select { + case dm.packetDemuxChan <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + } + default: + return fmt.Errorf("Unexpected datagram type %d", msgType) + } +} diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 2abf09f5..7517a02b 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -33,6 +33,7 @@ const ( FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" FeatureAllowRemoteConfig = "allow_remote_config" + FeatureDatagramV2 = "support_datagram_v2" ) type TunnelConfig struct {