From eea3d11e402bb09b4e4ae65e73101064d4683c1f Mon Sep 17 00:00:00 2001 From: cthuang Date: Tue, 23 Nov 2021 12:45:59 +0000 Subject: [PATCH] TUN-5301: Separate datagram multiplex and session management logic from quic connection logic --- connection/quic.go | 147 ++++++++++---------- connection/udp_session.go | 90 ------------- datagramsession/event.go | 33 +++++ datagramsession/manager.go | 134 +++++++++++++++++++ datagramsession/manager_test.go | 214 ++++++++++++++++++++++++++++++ datagramsession/session.go | 70 ++++++++++ datagramsession/session_test.go | 59 ++++++++ datagramsession/transport.go | 11 ++ datagramsession/transport_test.go | 32 +++++ quic/datagram.go | 48 ++++++- 10 files changed, 675 insertions(+), 163 deletions(-) delete mode 100644 connection/udp_session.go create mode 100644 datagramsession/event.go create mode 100644 datagramsession/manager.go create mode 100644 datagramsession/manager_test.go create mode 100644 datagramsession/session.go create mode 100644 datagramsession/session_test.go create mode 100644 datagramsession/transport.go create mode 100644 datagramsession/transport_test.go diff --git a/connection/quic.go b/connection/quic.go index b0590b4f..933cb197 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -16,6 +16,7 @@ import ( "github.com/rs/zerolog" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/datagramsession" quicpogs "github.com/cloudflare/cloudflared/quic" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -32,10 +33,11 @@ const ( // QUICConnection represents the type that facilitates Proxying via QUIC streams. type QUICConnection struct { - session quic.Session - logger *zerolog.Logger - httpProxy OriginProxy - udpSessions *udpSessions + session quic.Session + logger *zerolog.Logger + httpProxy OriginProxy + sessionManager datagramsession.Manager + localIP net.IP } // NewQUICConnection returns a new instance of QUICConnection. @@ -49,12 +51,6 @@ func NewQUICConnection( controlStreamHandler ControlStreamHandler, observer *Observer, ) (*QUICConnection, error) { - localIP, err := GetLocalIP() - if err != nil { - return nil, err - } - observer.log.Info().Msgf("UDP proxy will use %s as packet source IP", localIP) - udpSessions := newUDPSessions(localIP) session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { return nil, fmt.Errorf("failed to dial to edge: %w", err) @@ -71,24 +67,36 @@ func NewQUICConnection( return nil, err } + datagramMuxer, err := quicpogs.NewDatagramMuxer(session) + if err != nil { + return nil, err + } + + sessionManager := datagramsession.NewManager(datagramMuxer, observer.log) + + localIP, err := getLocalIP() + if err != nil { + return nil, err + } + return &QUICConnection{ - session: session, - httpProxy: httpProxy, - logger: observer.log, - udpSessions: udpSessions, + session: session, + httpProxy: httpProxy, + logger: observer.log, + sessionManager: sessionManager, + localIP: localIP, }, nil } // Serve starts a QUIC session that begins accepting streams. func (q *QUICConnection) Serve(ctx context.Context) error { errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - return q.listenEdgeDatagram() - }) - errGroup.Go(func() error { return q.acceptStream(ctx) }) + errGroup.Go(func() error { + return q.sessionManager.Serve(ctx) + }) return errGroup.Wait() } @@ -111,26 +119,6 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error { } } -// listenEdgeDatagram listens for datagram from edge, parse the session ID and find the UDPConn to send the payload -func (q *QUICConnection) listenEdgeDatagram() error { - for { - msg, err := q.session.ReceiveMessage() - if err != nil { - return err - } - go func(msg []byte) { - sessionID, msgWithoutID, err := quicpogs.ExtractSessionID(msg) - if err != nil { - q.logger.Err(err).Msg("Failed to parse session ID from datagram") - return - } - if err := q.udpSessions.send(sessionID, msgWithoutID); err != nil { - q.logger.Err(err).Msg("Failed to send UDP to origin") - } - }(msg) - } -} - // Close closes the session with no errors specified. func (q *QUICConnection) Close() { q.session.CloseWithError(0, "") @@ -186,46 +174,29 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er } func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error { - udpConn, err := q.udpSessions.register(sessionID, dstIP, dstPort) + // Each session is a series of datagram from an eyeball to a dstIP:dstPort. + // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. + originProxy, err := q.newUDPProxy(dstIP, dstPort) if err != nil { + q.logger.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) return err } - q.logger.Debug().Msgf("Register session %v, %v, %v", sessionID, dstIP, dstPort) - go q.listenOriginUDP(sessionID, udpConn) + session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) + if err != nil { + q.logger.Err(err).Msgf("Failed to register udp session %s", sessionID) + return err + } + go func() { + defer q.sessionManager.UnregisterSession(q.session.Context(), sessionID) + if err := session.Serve(q.session.Context()); err != nil { + q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).Msg("session terminated") + } + }() + q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort) return nil } -// listenOriginUDP reads UDP from origin in a loop, and returns when it cannot write to edge or cannot read from origin -func (q *QUICConnection) listenOriginUDP(sessionID uuid.UUID, conn *net.UDPConn) { - defer func() { - q.udpSessions.unregister(sessionID) - conn.Close() - }() - readBuffer := make([]byte, MaxDatagramFrameSize) - for { - n, err := conn.Read(readBuffer) - if n > 0 { - if n > MaxDatagramFrameSize-sessionIDLen { - // TODO: TUN-5302 return ICMP packet too big message - q.logger.Error().Msgf("Origin UDP payload has %d bytes, which exceeds transport MTU %d", n, MaxDatagramFrameSize-sessionIDLen) - continue - } - msgWithID, err := quicpogs.SuffixSessionID(sessionID, readBuffer[:n]) - if err != nil { - q.logger.Err(err).Msg("Failed to suffix session ID to datagram, it will be dropped") - continue - } - if err := q.session.SendMessage(msgWithID); err != nil { - q.logger.Err(err).Msg("Failed to send datagram back to edge") - return - } - } - if err != nil { - q.logger.Err(err).Msg("Failed to read UDP from origin") - return - } - } -} +// TODO: TUN-5422 Implement UnregisterUdpSession RPC // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // the client. @@ -320,3 +291,35 @@ func isTransferEncodingChunked(req *http.Request) bool { // separated value as well. return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") } + +// TODO: TUN-5303: Define an UDPProxy in ingress package +func (q *QUICConnection) newUDPProxy(dstIP net.IP, dstPort uint16) (*net.UDPConn, error) { + dstAddr := &net.UDPAddr{ + IP: dstIP, + Port: int(dstPort), + } + return net.DialUDP("udp", nil, dstAddr) +} + +// TODO: TUN-5303: Find the local IP once in ingress package +// TODO: TUN-5421 allow user to specify which IP to bind to +func getLocalIP() (net.IP, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + for _, addr := range addrs { + // Find the IP that is not loop back + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if !ip.IsLoopback() { + return ip, nil + } + } + return nil, fmt.Errorf("cannot determine IP to bind to") +} diff --git a/connection/udp_session.go b/connection/udp_session.go deleted file mode 100644 index 1d0a6057..00000000 --- a/connection/udp_session.go +++ /dev/null @@ -1,90 +0,0 @@ -package connection - -import ( - "fmt" - "net" - "sync" - - "github.com/google/uuid" -) - -// TODO: TUN-5422 Unregister session -const ( - sessionIDLen = len(uuid.UUID{}) -) - -type udpSessions struct { - lock sync.RWMutex - sessions map[uuid.UUID]*net.UDPConn - localIP net.IP -} - -func newUDPSessions(localIP net.IP) *udpSessions { - return &udpSessions{ - sessions: make(map[uuid.UUID]*net.UDPConn), - localIP: localIP, - } -} - -func (us *udpSessions) register(id uuid.UUID, dstIP net.IP, dstPort uint16) (*net.UDPConn, error) { - us.lock.Lock() - defer us.lock.Unlock() - dstAddr := &net.UDPAddr{ - IP: dstIP, - Port: int(dstPort), - } - conn, err := net.DialUDP("udp", us.localAddr(), dstAddr) - if err != nil { - return nil, err - } - us.sessions[id] = conn - return conn, nil -} - -func (us *udpSessions) unregister(id uuid.UUID) { - us.lock.Lock() - defer us.lock.Unlock() - delete(us.sessions, id) -} - -func (us *udpSessions) send(id uuid.UUID, payload []byte) error { - us.lock.RLock() - defer us.lock.RUnlock() - conn, ok := us.sessions[id] - if !ok { - return fmt.Errorf("session %s not found", id) - } - _, err := conn.Write(payload) - return err -} - -func (ud *udpSessions) localAddr() *net.UDPAddr { - // TODO: Determine the IP to bind to - - return &net.UDPAddr{ - IP: ud.localIP, - Port: 0, - } -} - -// TODO: TUN-5421 allow user to specify which IP to bind to -func GetLocalIP() (net.IP, error) { - addrs, err := net.InterfaceAddrs() - if err != nil { - return nil, err - } - for _, addr := range addrs { - // Find the IP that is not loop back - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - if !ip.IsLoopback() { - return ip, nil - } - } - return nil, fmt.Errorf("cannot determine IP to bind to") -} diff --git a/datagramsession/event.go b/datagramsession/event.go new file mode 100644 index 00000000..67edcd03 --- /dev/null +++ b/datagramsession/event.go @@ -0,0 +1,33 @@ +package datagramsession + +import ( + "io" + + "github.com/google/uuid" +) + +// registerSessionEvent is an event to start tracking a new session +type registerSessionEvent struct { + sessionID uuid.UUID + originProxy io.ReadWriteCloser + resultChan chan *Session +} + +func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser) *registerSessionEvent { + return ®isterSessionEvent{ + sessionID: sessionID, + originProxy: originProxy, + resultChan: make(chan *Session, 1), + } +} + +// unregisterSessionEvent is an event to stop tracking and terminate the session. +type unregisterSessionEvent struct { + sessionID uuid.UUID +} + +// newDatagram is an event when transport receives new datagram +type newDatagram struct { + sessionID uuid.UUID + payload []byte +} diff --git a/datagramsession/manager.go b/datagramsession/manager.go new file mode 100644 index 00000000..32156079 --- /dev/null +++ b/datagramsession/manager.go @@ -0,0 +1,134 @@ +package datagramsession + +import ( + "context" + "io" + + "github.com/google/uuid" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" +) + +const ( + requestChanCapacity = 16 +) + +// Manager defines the APIs to manage sessions from the same transport. +type Manager interface { + // Serve starts the event loop + Serve(ctx context.Context) error + // RegisterSession starts tracking a session. Caller is responsible for starting the session + RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error) + // UnregisterSession stops tracking the session and terminates it + UnregisterSession(ctx context.Context, sessionID uuid.UUID) error +} + +type manager struct { + registrationChan chan *registerSessionEvent + unregistrationChan chan *unregisterSessionEvent + datagramChan chan *newDatagram + transport transport + sessions map[uuid.UUID]*Session + log *zerolog.Logger +} + +func NewManager(transport transport, log *zerolog.Logger) Manager { + return &manager{ + registrationChan: make(chan *registerSessionEvent), + unregistrationChan: make(chan *unregisterSessionEvent), + // datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events + datagramChan: make(chan *newDatagram, requestChanCapacity), + transport: transport, + sessions: make(map[uuid.UUID]*Session), + log: log, + } +} + +func (m *manager) Serve(ctx context.Context) error { + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + for { + sessionID, payload, err := m.transport.ReceiveFrom() + if err != nil { + m.log.Err(err).Msg("Failed to receive datagram from transport, closing session manager") + return err + } + datagram := &newDatagram{ + sessionID: sessionID, + payload: payload, + } + select { + case <-ctx.Done(): + return ctx.Err() + // Only the event loop routine can update/lookup the sessions map to avoid concurrent access + // Send the datagram to the event loop. It will find the session to send to + case m.datagramChan <- datagram: + } + } + }) + errGroup.Go(func() error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case datagram := <-m.datagramChan: + m.sendToSession(datagram) + case registration := <-m.registrationChan: + m.registerSession(ctx, registration) + // TODO: TUN-5422: Unregister inactive session upon timeout + case unregistration := <-m.unregistrationChan: + m.unregisterSession(unregistration) + } + } + }) + return errGroup.Wait() +} + +func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { + event := newRegisterSessionEvent(sessionID, originProxy) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case m.registrationChan <- event: + session := <-event.resultChan + return session, nil + } +} + +func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { + session := newSession(registration.sessionID, m.transport, registration.originProxy) + m.sessions[registration.sessionID] = session + registration.resultChan <- session +} + +func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID) error { + event := &unregisterSessionEvent{sessionID: sessionID} + select { + case <-ctx.Done(): + return ctx.Err() + case m.unregistrationChan <- event: + return nil + } +} + +func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) { + session, ok := m.sessions[unregistration.sessionID] + if ok { + delete(m.sessions, unregistration.sessionID) + session.close() + } +} + +func (m *manager) sendToSession(datagram *newDatagram) { + session, ok := m.sessions[datagram.sessionID] + if !ok { + m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found") + return + } + // session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't + // need to run in another go routine + _, err := session.writeToDst(datagram.payload) + if err != nil { + m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session") + } +} diff --git a/datagramsession/manager_test.go b/datagramsession/manager_test.go new file mode 100644 index 00000000..1fe11cec --- /dev/null +++ b/datagramsession/manager_test.go @@ -0,0 +1,214 @@ +package datagramsession + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "testing" + + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestManagerServe(t *testing.T) { + const ( + sessions = 20 + msgs = 50 + ) + log := zerolog.Nop() + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(), + respChan: newDatagramChannel(), + } + mg := NewManager(transport, &log) + + eyeballTracker := make(map[uuid.UUID]*datagramChannel) + for i := 0; i < sessions; i++ { + sessionID := uuid.New() + eyeballTracker[sessionID] = newDatagramChannel() + } + + ctx, cancel := context.WithCancel(context.Background()) + serveDone := make(chan struct{}) + go func(ctx context.Context) { + mg.Serve(ctx) + close(serveDone) + }(ctx) + + go func(ctx context.Context) { + for { + sessionID, payload, err := transport.respChan.Receive(ctx) + if err != nil { + require.Equal(t, context.Canceled, err) + return + } + respChan := eyeballTracker[sessionID] + require.NoError(t, respChan.Send(ctx, sessionID, payload)) + } + }(ctx) + + errGroup, ctx := errgroup.WithContext(ctx) + for sID, receiver := range eyeballTracker { + // Assign loop variables to local variables + sessionID := sID + eyeballRespReceiver := receiver + errGroup.Go(func() error { + payload := testPayload(sessionID) + expectResp := testResponse(payload) + + cfdConn, originConn := net.Pipe() + + origin := mockOrigin{ + expectMsgCount: msgs, + expectedMsg: payload, + expectedResp: expectResp, + conn: originConn, + } + eyeball := mockEyeball{ + expectMsgCount: msgs, + expectedMsg: expectResp, + expectSessionID: sessionID, + respReceiver: eyeballRespReceiver, + } + + reqErrGroup, reqCtx := errgroup.WithContext(ctx) + reqErrGroup.Go(func() error { + return origin.serve() + }) + reqErrGroup.Go(func() error { + return eyeball.serve(reqCtx) + }) + + session, err := mg.RegisterSession(ctx, sessionID, cfdConn) + require.NoError(t, err) + + sessionDone := make(chan struct{}) + go func() { + session.Serve(ctx) + close(sessionDone) + }() + + for i := 0; i < msgs; i++ { + require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID))) + } + + // Make sure eyeball and origin have received all messages before unregistering the session + require.NoError(t, reqErrGroup.Wait()) + + require.NoError(t, mg.UnregisterSession(ctx, sessionID)) + <-sessionDone + + return nil + }) + } + + require.NoError(t, errGroup.Wait()) + cancel() + transport.close() + <-serveDone +} + +type mockOrigin struct { + expectMsgCount int + expectedMsg []byte + expectedResp []byte + conn io.ReadWriteCloser +} + +func (mo *mockOrigin) serve() error { + expectedMsgLen := len(mo.expectedMsg) + readBuffer := make([]byte, expectedMsgLen+1) + for i := 0; i < mo.expectMsgCount; i++ { + n, err := mo.conn.Read(readBuffer) + if err != nil { + return err + } + if n != expectedMsgLen { + return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n) + } + if !bytes.Equal(readBuffer[:n], mo.expectedMsg) { + return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n]) + } + + _, err = mo.conn.Write(mo.expectedResp) + if err != nil { + return err + } + } + return nil +} + +func testPayload(sessionID uuid.UUID) []byte { + return []byte(fmt.Sprintf("Message from %s", sessionID)) +} + +func testResponse(msg []byte) []byte { + return []byte(fmt.Sprintf("Response to %v", msg)) +} + +type mockEyeball struct { + expectMsgCount int + expectedMsg []byte + expectSessionID uuid.UUID + respReceiver *datagramChannel +} + +func (me *mockEyeball) serve(ctx context.Context) error { + for i := 0; i < me.expectMsgCount; i++ { + sessionID, msg, err := me.respReceiver.Receive(ctx) + if err != nil { + return err + } + if sessionID != me.expectSessionID { + return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID) + } + if !bytes.Equal(msg, me.expectedMsg) { + return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg) + } + } + return nil +} + +// datagramChannel is a channel for Datagram with wrapper to send/receive with context +type datagramChannel struct { + datagramChan chan *newDatagram + closedChan chan struct{} +} + +func newDatagramChannel() *datagramChannel { + return &datagramChannel{ + datagramChan: make(chan *newDatagram, 1), + closedChan: make(chan struct{}), + } +} + +func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-rc.closedChan: + return fmt.Errorf("datagram channel closed") + case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}: + return nil + } +} + +func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) { + select { + case <-ctx.Done(): + return uuid.Nil, nil, ctx.Err() + case <-rc.closedChan: + return uuid.Nil, nil, fmt.Errorf("datagram channel closed") + case msg := <-rc.datagramChan: + return msg.sessionID, msg.payload, nil + } +} + +func (rc *datagramChannel) Close() { + // No need to close msgChan, it will be garbage collect once there is no reference to it + close(rc.closedChan) +} diff --git a/datagramsession/session.go b/datagramsession/session.go new file mode 100644 index 00000000..acd4056d --- /dev/null +++ b/datagramsession/session.go @@ -0,0 +1,70 @@ +package datagramsession + +import ( + "context" + "io" + + "github.com/google/uuid" +) + +// Each Session is a bidirectional pipe of datagrams between transport and dstConn +// Currently the only implementation of transport is quic DatagramMuxer +// Destination can be a connection with origin or with eyeball +// When the destination is origin: +// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the +// write method of the Session to send to origin +// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball +// When the destination is eyeball: +// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared +// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the +// write method of the Session to send to eyeball +type Session struct { + id uuid.UUID + transport transport + dstConn io.ReadWriteCloser + doneChan chan struct{} +} + +func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session { + return &Session{ + id: id, + transport: transport, + dstConn: dstConn, + doneChan: make(chan struct{}), + } +} + +func (s *Session) Serve(ctx context.Context) error { + serveCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-serveCtx.Done(): + case <-s.doneChan: + } + s.dstConn.Close() + }() + // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 + // This makes it safe to share readBuffer between iterations + readBuffer := make([]byte, 1280) + for { + // TODO: TUN-5303: origin proxy should determine the buffer size + n, err := s.dstConn.Read(readBuffer) + if n > 0 { + if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil { + return err + } + } + if err != nil { + return err + } + } +} + +func (s *Session) writeToDst(payload []byte) (int, error) { + return s.dstConn.Write(payload) +} + +func (s *Session) close() { + close(s.doneChan) +} diff --git a/datagramsession/session_test.go b/datagramsession/session_test.go new file mode 100644 index 00000000..6fc25b6e --- /dev/null +++ b/datagramsession/session_test.go @@ -0,0 +1,59 @@ +package datagramsession + +import ( + "context" + "net" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// TestCloseSession makes sure a session will stop after context is done +func TestSessionCtxDone(t *testing.T) { + testSessionReturns(t, true) +} + +// TestCloseSession makes sure a session will stop after close method is called +func TestCloseSession(t *testing.T) { + testSessionReturns(t, false) +} + +func testSessionReturns(t *testing.T, closeByContext bool) { + sessionID := uuid.New() + cfdConn, originConn := net.Pipe() + payload := testPayload(sessionID) + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(), + respChan: newDatagramChannel(), + } + session := newSession(sessionID, transport, cfdConn) + + ctx, cancel := context.WithCancel(context.Background()) + sessionDone := make(chan struct{}) + go func() { + session.Serve(ctx) + close(sessionDone) + }() + + go func() { + n, err := session.writeToDst(payload) + require.NoError(t, err) + require.Equal(t, len(payload), n) + }() + + readBuffer := make([]byte, len(payload)+1) + n, err := originConn.Read(readBuffer) + require.NoError(t, err) + require.Equal(t, len(payload), n) + + if closeByContext { + cancel() + } else { + session.close() + } + + <-sessionDone + // call cancelled again otherwise the linter will warn about possible context leak + cancel() +} diff --git a/datagramsession/transport.go b/datagramsession/transport.go new file mode 100644 index 00000000..4d078ac3 --- /dev/null +++ b/datagramsession/transport.go @@ -0,0 +1,11 @@ +package datagramsession + +import "github.com/google/uuid" + +// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions +type transport interface { + // SendTo writes payload for a session to the transport + SendTo(sessionID uuid.UUID, payload []byte) error + // ReceiveFrom reads the next datagram from the transport + ReceiveFrom() (uuid.UUID, []byte, error) +} diff --git a/datagramsession/transport_test.go b/datagramsession/transport_test.go new file mode 100644 index 00000000..6d2fa91e --- /dev/null +++ b/datagramsession/transport_test.go @@ -0,0 +1,32 @@ +package datagramsession + +import ( + "context" + + "github.com/google/uuid" +) + +type mockQUICTransport struct { + reqChan *datagramChannel + respChan *datagramChannel +} + +func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error { + buf := make([]byte, len(payload)) + // The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 + copy(buf, payload) + return mt.respChan.Send(context.Background(), sessionID, buf) +} + +func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) { + return mt.reqChan.Receive(context.Background()) +} + +func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error { + return mt.reqChan.Send(ctx, sessionID, payload) +} + +func (mt *mockQUICTransport) close() { + mt.reqChan.Close() + mt.respChan.Close() +} diff --git a/quic/datagram.go b/quic/datagram.go index cd538d87..d32056dd 100644 --- a/quic/datagram.go +++ b/quic/datagram.go @@ -4,13 +4,59 @@ import ( "fmt" "github.com/google/uuid" + "github.com/lucas-clemente/quic-go" + "github.com/pkg/errors" ) const ( - sessionIDLen = len(uuid.UUID{}) MaxDatagramFrameSize = 1220 + sessionIDLen = len(uuid.UUID{}) ) +type DatagramMuxer struct { + ID uuid.UUID + session quic.Session +} + +func NewDatagramMuxer(quicSession quic.Session) (*DatagramMuxer, error) { + muxerID, err := uuid.NewRandom() + if err != nil { + return nil, err + } + return &DatagramMuxer{ + ID: muxerID, + session: quicSession, + }, nil +} + +// SendTo suffix the session ID to the payload so the other end of the QUIC session can demultiplex +// the payload from multiple datagram sessions +func (dm *DatagramMuxer) SendTo(sessionID uuid.UUID, payload []byte) error { + if len(payload) > MaxDatagramFrameSize-sessionIDLen { + // TODO: TUN-5302 return ICMP packet too big message + return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(payload), MaxDatagramFrameSize-sessionIDLen) + } + msgWithID, err := SuffixSessionID(sessionID, payload) + if err != nil { + return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped") + } + if err := dm.session.SendMessage(msgWithID); err != nil { + return errors.Wrap(err, "Failed to send datagram back to edge") + } + return nil +} + +// ReceiveFrom extracts datagram session ID, then sends the session ID and payload to session manager +// which determines how to proxy to the origin. It assumes the datagram session has already been +// registered with session manager through other side channel +func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) { + msg, err := dm.session.ReceiveMessage() + if err != nil { + return uuid.Nil, nil, err + } + return ExtractSessionID(msg) +} + // Each QUIC datagram should be suffixed with session ID. // ExtractSessionID extracts the session ID and a slice with only the payload func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {