TUN-6667: DatagramMuxerV2 provides a method to receive RawPacket

This commit is contained in:
cthuang 2022-08-17 18:23:04 +01:00
parent bad2e8e812
commit d2bc15e224
3 changed files with 111 additions and 43 deletions

View File

@ -47,7 +47,7 @@ type QUICConnection struct {
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
// datagramMuxer mux/demux datagrams from quic connection // datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *quicpogs.DatagramMuxer datagramMuxer quicpogs.BaseDatagramMuxer
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
} }
@ -67,9 +67,9 @@ func NewQUICConnection(
return nil, &EdgeQuicDialError{Cause: err} return nil, &EdgeQuicDialError{Cause: err}
} }
demuxChan := make(chan *packet.Session, demuxChanCapacity) sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, demuxChan) datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan)
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, demuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
return &QUICConnection{ return &QUICConnection{
session: session, session: session,

View File

@ -9,13 +9,17 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big" "math/big"
"net/netip"
"testing" "testing"
"time" "time"
"github.com/google/gopacket/layers"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/packet"
@ -68,15 +72,45 @@ func TestDatagram(t *testing.T) {
Payload: maxPayload, Payload: maxPayload,
}, },
} }
flowPayloads := [][]byte{
maxPayload, packets := []packet.ICMP{
{
IP: &packet.IP{
Src: netip.MustParseAddr("172.16.0.1"),
Dst: netip.MustParseAddr("192.168.0.1"),
Protocol: layers.IPProtocolICMPv4,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeTimeExceeded,
Code: 0,
Body: &icmp.TimeExceeded{
Data: []byte("original packet"),
},
},
},
{
IP: &packet.IP{
Src: netip.MustParseAddr("172.16.0.2"),
Dst: netip.MustParseAddr("192.168.0.2"),
Protocol: layers.IPProtocolICMPv4,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: 6182,
Seq: 9151,
Data: []byte("Test ICMP echo"),
},
},
},
} }
testDatagram(t, 1, sessionToPayload, nil) testDatagram(t, 1, sessionToPayload, nil)
testDatagram(t, 2, sessionToPayload, flowPayloads) testDatagram(t, 2, sessionToPayload, packets)
} }
func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Session, packetPayloads [][]byte) { func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Session, packets []packet.ICMP) {
quicConfig := &quic.Config{ quicConfig := &quic.Config{
KeepAlivePeriod: 5 * time.Millisecond, KeepAlivePeriod: 5 * time.Millisecond,
EnableDatagrams: true, EnableDatagrams: true,
@ -103,12 +137,20 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan) muxer := NewDatagramMuxer(quicSession, &logger, sessionDemuxChan)
muxer.ServeReceive(ctx) muxer.ServeReceive(ctx)
case 2: case 2:
packetDemuxChan := make(chan []byte, len(packetPayloads)) muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan)
muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan, packetDemuxChan)
muxer.ServeReceive(ctx) muxer.ServeReceive(ctx)
for _, expectedPayload := range packetPayloads { icmpDecoder := packet.NewICMPDecoder()
require.Equal(t, expectedPayload, <-packetDemuxChan) for _, pk := range packets {
received, err := muxer.ReceivePacket(ctx)
require.NoError(t, err)
receivedICMP, err := icmpDecoder.Decode(received.Data)
require.NoError(t, err)
require.Equal(t, pk.IP, receivedICMP.IP)
require.Equal(t, pk.Type, receivedICMP.Type)
require.Equal(t, pk.Code, receivedICMP.Code)
require.Equal(t, pk.Body, receivedICMP.Body)
} }
default: default:
return fmt.Errorf("unknown datagram version %d", version) return fmt.Errorf("unknown datagram version %d", version)
@ -141,12 +183,17 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
case 1: case 1:
muxer = NewDatagramMuxer(quicSession, &logger, nil) muxer = NewDatagramMuxer(quicSession, &logger, nil)
case 2: case 2:
muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil, nil) muxerV2 := NewDatagramMuxerV2(quicSession, &logger, nil)
for _, payload := range packetPayloads { encoder := packet.NewEncoder()
require.NoError(t, muxerV2.MuxPacket(payload)) for _, pk := range packets {
encodedPacket, err := encoder.Encode(&pk)
require.NoError(t, err)
require.NoError(t, muxerV2.SendPacket(encodedPacket))
} }
// Payload larger than transport MTU, should not be sent // Payload larger than transport MTU, should not be sent
require.Error(t, muxerV2.MuxPacket(largePayload)) require.Error(t, muxerV2.SendPacket(packet.RawPacket{
Data: largePayload,
}))
muxer = muxerV2 muxer = muxerV2
default: default:
return fmt.Errorf("unknown datagram version %d", version) return fmt.Errorf("unknown datagram version %d", version)

View File

@ -16,6 +16,8 @@ type datagramV2Type byte
const ( const (
udp datagramV2Type = iota udp datagramV2Type = iota
ip ip
// Same as sessionDemuxChan capacity
packetChanCapacity = 16
) )
func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) { func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) {
@ -35,24 +37,24 @@ type DatagramMuxerV2 struct {
session quic.Connection session quic.Connection
logger *zerolog.Logger logger *zerolog.Logger
sessionDemuxChan chan<- *packet.Session sessionDemuxChan chan<- *packet.Session
packetDemuxChan chan<- []byte packetDemuxChan chan packet.RawPacket
} }
func NewDatagramMuxerV2( func NewDatagramMuxerV2(
quicSession quic.Connection, quicSession quic.Connection,
log *zerolog.Logger, log *zerolog.Logger,
sessionDemuxChan chan<- *packet.Session, sessionDemuxChan chan<- *packet.Session,
packetDemuxChan chan<- []byte) *DatagramMuxerV2 { ) *DatagramMuxerV2 {
logger := log.With().Uint8("datagramVersion", 2).Logger() logger := log.With().Uint8("datagramVersion", 2).Logger()
return &DatagramMuxerV2{ return &DatagramMuxerV2{
session: quicSession, session: quicSession,
logger: &logger, logger: &logger,
sessionDemuxChan: sessionDemuxChan, sessionDemuxChan: sessionDemuxChan,
packetDemuxChan: packetDemuxChan, packetDemuxChan: make(chan packet.RawPacket, packetChanCapacity),
} }
} }
// MuxSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can // SendToSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
// demultiplex the payload from multiple datagram sessions // demultiplex the payload from multiple datagram sessions
func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error { func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
if len(session.Payload) > dm.mtu() { if len(session.Payload) > dm.mtu() {
@ -73,10 +75,10 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
return nil return nil
} }
// MuxPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing // SendPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing
// the payload as IP and look at the source and destination. // the payload as IP and look at the source and destination.
func (dm *DatagramMuxerV2) MuxPacket(packet []byte) error { func (dm *DatagramMuxerV2) SendPacket(pk packet.RawPacket) error {
payloadWithVersion, err := suffixType(packet, ip) payloadWithVersion, err := suffixType(pk.Data, ip)
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
} }
@ -102,6 +104,15 @@ func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
} }
} }
func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (packet.RawPacket, error) {
select {
case <-ctx.Done():
return packet.RawPacket{}, ctx.Err()
case pk := <-dm.packetDemuxChan:
return pk, nil
}
}
func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error { func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
if len(msgWithType) < 1 { if len(msgWithType) < 1 {
return fmt.Errorf("QUIC datagram should have at least 1 byte") return fmt.Errorf("QUIC datagram should have at least 1 byte")
@ -110,28 +121,38 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error
msg := msgWithType[0 : len(msgWithType)-1] msg := msgWithType[0 : len(msgWithType)-1]
switch msgType { switch msgType {
case udp: case udp:
sessionID, payload, err := extractSessionID(msg) return dm.handleSession(ctx, msg)
if err != nil {
return err
}
sessionDatagram := packet.Session{
ID: sessionID,
Payload: payload,
}
select {
case dm.sessionDemuxChan <- &sessionDatagram:
return nil
case <-ctx.Done():
return ctx.Err()
}
case ip: case ip:
select { return dm.handlePacket(ctx, msg)
case dm.packetDemuxChan <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
}
default: default:
return fmt.Errorf("Unexpected datagram type %d", msgType) return fmt.Errorf("Unexpected datagram type %d", msgType)
} }
} }
func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) error {
sessionID, payload, err := extractSessionID(session)
if err != nil {
return err
}
sessionDatagram := packet.Session{
ID: sessionID,
Payload: payload,
}
select {
case dm.sessionDemuxChan <- &sessionDatagram:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case dm.packetDemuxChan <- packet.RawPacket{
Data: pk,
}:
return nil
}
}