From d2bc15e2240cb5ec20579cbdfade283d71cda98e Mon Sep 17 00:00:00 2001 From: cthuang Date: Wed, 17 Aug 2022 18:23:04 +0100 Subject: [PATCH] TUN-6667: DatagramMuxerV2 provides a method to receive RawPacket --- connection/quic.go | 8 ++--- quic/datagram_test.go | 71 +++++++++++++++++++++++++++++++++------- quic/datagramv2.go | 75 +++++++++++++++++++++++++++---------------- 3 files changed, 111 insertions(+), 43 deletions(-) diff --git a/connection/quic.go b/connection/quic.go index 33a2fc7a..edc17af2 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -47,7 +47,7 @@ type QUICConnection struct { // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer sessionManager datagramsession.Manager // datagramMuxer mux/demux datagrams from quic connection - datagramMuxer *quicpogs.DatagramMuxer + datagramMuxer quicpogs.BaseDatagramMuxer controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions } @@ -67,9 +67,9 @@ func NewQUICConnection( return nil, &EdgeQuicDialError{Cause: err} } - demuxChan := make(chan *packet.Session, demuxChanCapacity) - datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan) - sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, demuxChan) + sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) + datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan) + sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) return &QUICConnection{ session: session, diff --git a/quic/datagram_test.go b/quic/datagram_test.go index 32ce4399..c26362a9 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -9,13 +9,17 @@ import ( "encoding/pem" "fmt" "math/big" + "net/netip" "testing" "time" + "github.com/google/gopacket/layers" "github.com/google/uuid" "github.com/lucas-clemente/quic-go" "github.com/rs/zerolog" "github.com/stretchr/testify/require" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "golang.org/x/sync/errgroup" "github.com/cloudflare/cloudflared/packet" @@ -68,15 +72,45 @@ func TestDatagram(t *testing.T) { Payload: maxPayload, }, } - flowPayloads := [][]byte{ - maxPayload, + + packets := []packet.ICMP{ + { + IP: &packet.IP{ + Src: netip.MustParseAddr("172.16.0.1"), + Dst: netip.MustParseAddr("192.168.0.1"), + Protocol: layers.IPProtocolICMPv4, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeTimeExceeded, + Code: 0, + Body: &icmp.TimeExceeded{ + Data: []byte("original packet"), + }, + }, + }, + { + IP: &packet.IP{ + Src: netip.MustParseAddr("172.16.0.2"), + Dst: netip.MustParseAddr("192.168.0.2"), + Protocol: layers.IPProtocolICMPv4, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 6182, + Seq: 9151, + Data: []byte("Test ICMP echo"), + }, + }, + }, } testDatagram(t, 1, sessionToPayload, nil) - testDatagram(t, 2, sessionToPayload, flowPayloads) + testDatagram(t, 2, sessionToPayload, packets) } -func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Session, packetPayloads [][]byte) { +func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Session, packets []packet.ICMP) { quicConfig := &quic.Config{ KeepAlivePeriod: 5 * time.Millisecond, EnableDatagrams: true, @@ -103,12 +137,20 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan) muxer.ServeReceive(ctx) case 2: - packetDemuxChan := make(chan []byte, len(packetPayloads)) - muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan) + muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan) muxer.ServeReceive(ctx) - for _, expectedPayload := range packetPayloads { - require.Equal(t, expectedPayload, <-packetDemuxChan) + icmpDecoder := packet.NewICMPDecoder() + for _, pk := range packets { + received, err := muxer.ReceivePacket(ctx) + require.NoError(t, err) + + receivedICMP, err := icmpDecoder.Decode(received.Data) + require.NoError(t, err) + require.Equal(t, pk.IP, receivedICMP.IP) + require.Equal(t, pk.Type, receivedICMP.Type) + require.Equal(t, pk.Code, receivedICMP.Code) + require.Equal(t, pk.Body, receivedICMP.Body) } default: return fmt.Errorf("unknown datagram version %d", version) @@ -141,12 +183,17 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi case 1: muxer = NewDatagramMuxer(quicSession, &logger, nil) case 2: - muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil) - for _, payload := range packetPayloads { - require.NoError(t, muxerV2.MuxPacket(payload)) + muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil) + encoder := packet.NewEncoder() + for _, pk := range packets { + encodedPacket, err := encoder.Encode(&pk) + require.NoError(t, err) + require.NoError(t, muxerV2.SendPacket(encodedPacket)) } // Payload larger than transport MTU, should not be sent - require.Error(t, muxerV2.MuxPacket(largePayload)) + require.Error(t, muxerV2.SendPacket(packet.RawPacket{ + Data: largePayload, + })) muxer = muxerV2 default: return fmt.Errorf("unknown datagram version %d", version) diff --git a/quic/datagramv2.go b/quic/datagramv2.go index d11bcfaf..03e60b1e 100644 --- a/quic/datagramv2.go +++ b/quic/datagramv2.go @@ -16,6 +16,8 @@ type datagramV2Type byte const ( udp datagramV2Type = iota ip + // Same as sessionDemuxChan capacity + packetChanCapacity = 16 ) func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) { @@ -35,24 +37,24 @@ type DatagramMuxerV2 struct { session quic.Connection logger *zerolog.Logger sessionDemuxChan chan<- *packet.Session - packetDemuxChan chan<- []byte + packetDemuxChan chan packet.RawPacket } func NewDatagramMuxerV2( quicSession quic.Connection, log *zerolog.Logger, sessionDemuxChan chan<- *packet.Session, - packetDemuxChan chan<- []byte) *DatagramMuxerV2 { +) *DatagramMuxerV2 { logger := log.With().Uint8("datagramVersion", 2).Logger() return &DatagramMuxerV2{ session: quicSession, logger: &logger, sessionDemuxChan: sessionDemuxChan, - packetDemuxChan: packetDemuxChan, + packetDemuxChan: make(chan packet.RawPacket, packetChanCapacity), } } -// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can +// SendToSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can // demultiplex the payload from multiple datagram sessions func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error { if len(session.Payload) > dm.mtu() { @@ -73,10 +75,10 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error { return nil } -// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing +// SendPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing // the payload as IP and look at the source and destination. -func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error { - payloadWithVersion, err := suffixType(packet, ip) +func (dm *DatagramMuxerV2) SendPacket(pk packet.RawPacket) error { + payloadWithVersion, err := suffixType(pk.Data, ip) if err != nil { return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") } @@ -102,6 +104,15 @@ func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error { } } +func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { + select { + case <-ctx.Done(): + return packet.RawPacket{}, ctx.Err() + case pk := <-dm.packetDemuxChan: + return pk, nil + } +} + func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error { if len(msgWithType) < 1 { return fmt.Errorf("QUIC datagram should have at least 1 byte") @@ -110,28 +121,38 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error msg := msgWithType[0 : len(msgWithType)-1] switch msgType { case udp: - sessionID, payload, err := extractSessionID(msg) - if err != nil { - return err - } - sessionDatagram := packet.Session{ - ID: sessionID, - Payload: payload, - } - select { - case dm.sessionDemuxChan <- &sessionDatagram: - return nil - case <-ctx.Done(): - return ctx.Err() - } + return dm.handleSession(ctx, msg) case ip: - select { - case dm.packetDemuxChan <- msg: - return nil - case <-ctx.Done(): - return ctx.Err() - } + return dm.handlePacket(ctx, msg) default: return fmt.Errorf("Unexpected datagram type %d", msgType) } } + +func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) error { + sessionID, payload, err := extractSessionID(session) + if err != nil { + return err + } + sessionDatagram := packet.Session{ + ID: sessionID, + Payload: payload, + } + select { + case dm.sessionDemuxChan <- &sessionDatagram: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte) error { + select { + case <-ctx.Done(): + return ctx.Err() + case dm.packetDemuxChan <- packet.RawPacket{ + Data: pk, + }: + return nil + } +}