TUN-6855: Add DatagramV2Type for IP packet with trace and tracing spans

This commit is contained in:
cthuang 2022-10-13 21:30:43 +01:00
parent 61007dd2dd
commit 225c344ceb
6 changed files with 347 additions and 34 deletions

View File

@ -93,7 +93,8 @@ func NewQUICConnection(
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan)
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
packetRouter := packet.NewRouter(packetRouterConfig, datagramMuxer, &returnPipe{muxer: datagramMuxer}, logger, orchestrator.WarpRoutingEnabled) muxer := muxerWrapper{muxer: datagramMuxer}
packetRouter := packet.NewRouter(packetRouterConfig, &muxer, &muxer, logger, orchestrator.WarpRoutingEnabled)
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
@ -498,16 +499,28 @@ func (np *nopCloserReadWriter) Close() error {
return nil return nil
} }
// returnPipe wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface // muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface
type returnPipe struct { type muxerWrapper struct {
muxer *quicpogs.DatagramMuxerV2 muxer *quicpogs.DatagramMuxerV2
} }
func (rp *returnPipe) SendPacket(dst netip.Addr, pk packet.RawPacket) error { func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
return rp.muxer.SendPacket(pk) return rp.muxer.SendPacket(quicpogs.RawPacket(pk))
} }
func (rp *returnPipe) Close() error { func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) {
pk, err := rp.muxer.ReceivePacket(ctx)
if err != nil {
return packet.RawPacket{}, err
}
rawPacket, ok := pk.(quicpogs.RawPacket)
if ok {
return packet.RawPacket(rawPacket), nil
}
return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk)
}
func (rp *muxerWrapper) Close() error {
return nil return nil
} }

View File

@ -113,9 +113,12 @@ func extractSessionID(b []byte) (uuid.UUID, []byte, error) {
// SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because // SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because
// the payload slice might already have enough capacity to append the session ID at the end // the payload slice might already have enough capacity to append the session ID at the end
func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) { func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) {
if len(b)+len(sessionID) > MaxDatagramFrameSize { return suffixMetadata(b, sessionID[:])
}
func suffixMetadata(payload, metadata []byte) ([]byte, error) {
if len(payload)+len(metadata) > MaxDatagramFrameSize {
return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize) return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize)
} }
b = append(b, sessionID[:]...) return append(payload, metadata...), nil
return b, nil
} }

View File

@ -1,6 +1,7 @@
package quic package quic
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@ -23,6 +24,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/packet"
"github.com/cloudflare/cloudflared/tracing"
) )
var ( var (
@ -121,6 +123,15 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
logger := zerolog.Nop() logger := zerolog.Nop()
tracingIdentity, err := tracing.NewIdentity("ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1")
require.NoError(t, err)
serializedTracingID, err := tracingIdentity.MarshalBinary()
require.NoError(t, err)
tracingSpan := &TracingSpanPacket{
Spans: []byte("tracing"),
TracingIdentity: serializedTracingID,
}
errGroup, ctx := errgroup.WithContext(context.Background()) errGroup, ctx := errgroup.WithContext(context.Background())
// Run edge side of datagram muxer // Run edge side of datagram muxer
errGroup.Go(func() error { errGroup.Go(func() error {
@ -140,18 +151,17 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan) muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan)
muxer.ServeReceive(ctx) muxer.ServeReceive(ctx)
icmpDecoder := packet.NewICMPDecoder()
for _, pk := range packets { for _, pk := range packets {
received, err := muxer.ReceivePacket(ctx) received, err := muxer.ReceivePacket(ctx)
require.NoError(t, err) require.NoError(t, err)
validateIPPacket(t, received, &pk)
receivedICMP, err := icmpDecoder.Decode(received) received, err = muxer.ReceivePacket(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, pk.IP, receivedICMP.IP) validateIPPacketWithTracing(t, received, &pk, serializedTracingID)
require.Equal(t, pk.Type, receivedICMP.Type)
require.Equal(t, pk.Code, receivedICMP.Code)
require.Equal(t, pk.Body, receivedICMP.Body)
} }
received, err := muxer.ReceivePacket(ctx)
require.NoError(t, err)
validateTracingSpans(t, received, tracingSpan)
default: default:
return fmt.Errorf("unknown datagram version %d", version) return fmt.Errorf("unknown datagram version %d", version)
} }
@ -188,10 +198,15 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
for _, pk := range packets { for _, pk := range packets {
encodedPacket, err := encoder.Encode(&pk) encodedPacket, err := encoder.Encode(&pk)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, muxerV2.SendPacket(encodedPacket)) require.NoError(t, muxerV2.SendPacket(RawPacket(encodedPacket)))
require.NoError(t, muxerV2.SendPacket(&TracedPacket{
Packet: encodedPacket,
TracingIdentity: serializedTracingID,
}))
} }
require.NoError(t, muxerV2.SendPacket(tracingSpan))
// Payload larger than transport MTU, should not be sent // Payload larger than transport MTU, should not be sent
require.Error(t, muxerV2.SendPacket(packet.RawPacket{ require.Error(t, muxerV2.SendPacket(RawPacket{
Data: largePayload, Data: largePayload,
})) }))
muxer = muxerV2 muxer = muxerV2
@ -217,6 +232,38 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
} }
func validateIPPacket(t *testing.T, receivedPacket Packet, expectedICMP *packet.ICMP) {
require.Equal(t, DatagramTypeIP, receivedPacket.Type())
rawPacket := receivedPacket.(RawPacket)
decoder := packet.NewICMPDecoder()
receivedICMP, err := decoder.Decode(packet.RawPacket(rawPacket))
require.NoError(t, err)
validateICMP(t, expectedICMP, receivedICMP)
}
func validateIPPacketWithTracing(t *testing.T, receivedPacket Packet, expectedICMP *packet.ICMP, serializedTracingID []byte) {
require.Equal(t, DatagramTypeIPWithTrace, receivedPacket.Type())
tracedPacket := receivedPacket.(*TracedPacket)
decoder := packet.NewICMPDecoder()
receivedICMP, err := decoder.Decode(tracedPacket.Packet)
require.NoError(t, err)
validateICMP(t, expectedICMP, receivedICMP)
require.True(t, bytes.Equal(tracedPacket.TracingIdentity, serializedTracingID))
}
func validateICMP(t *testing.T, expected, actual *packet.ICMP) {
require.Equal(t, expected.IP, actual.IP)
require.Equal(t, expected.Type, actual.Type)
require.Equal(t, expected.Code, actual.Code)
require.Equal(t, expected.Body, actual.Body)
}
func validateTracingSpans(t *testing.T, receivedPacket Packet, expectedSpan *TracingSpanPacket) {
require.Equal(t, DatagramTypeTracingSpan, receivedPacket.Type())
tracingSpans := receivedPacket.(*TracingSpanPacket)
require.Equal(t, tracingSpans, expectedSpan)
}
func newQUICListener(t *testing.T, config *quic.Config) quic.Listener { func newQUICListener(t *testing.T, config *quic.Config) quic.Listener {
// Create a simple tls config. // Create a simple tls config.
tlsConfig := generateTLSConfig() tlsConfig := generateTLSConfig()

View File

@ -9,15 +9,28 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/packet"
"github.com/cloudflare/cloudflared/tracing"
) )
type DatagramV2Type byte type DatagramV2Type byte
const ( const (
// UDP payload
DatagramTypeUDP DatagramV2Type = iota DatagramTypeUDP DatagramV2Type = iota
// Full IP packet
DatagramTypeIP DatagramTypeIP
// DatagramTypeIP + tracing ID
DatagramTypeIPWithTrace
// Tracing spans in protobuf format
DatagramTypeTracingSpan
) )
type Packet interface {
Type() DatagramV2Type
Payload() []byte
Metadata() []byte
}
const ( const (
typeIDLen = 1 typeIDLen = 1
// Same as sessionDemuxChan capacity // Same as sessionDemuxChan capacity
@ -41,7 +54,7 @@ type DatagramMuxerV2 struct {
session quic.Connection session quic.Connection
logger *zerolog.Logger logger *zerolog.Logger
sessionDemuxChan chan<- *packet.Session sessionDemuxChan chan<- *packet.Session
packetDemuxChan chan packet.RawPacket packetDemuxChan chan Packet
} }
func NewDatagramMuxerV2( func NewDatagramMuxerV2(
@ -54,7 +67,7 @@ func NewDatagramMuxerV2(
session: quicSession, session: quicSession,
logger: &logger, logger: &logger,
sessionDemuxChan: sessionDemuxChan, sessionDemuxChan: sessionDemuxChan,
packetDemuxChan: make(chan packet.RawPacket, packetChanCapacity), packetDemuxChan: make(chan Packet, packetChanCapacity),
} }
} }
@ -79,14 +92,19 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
return nil return nil
} }
// SendPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing // SendPacket sends a packet with datagram version in the suffix. If ctx is a TracedContext, it adds the tracing
// the payload as IP and look at the source and destination. // context between payload and datagram version.
func (dm *DatagramMuxerV2) SendPacket(pk packet.RawPacket) error { // The other end of the QUIC connection can demultiplex by parsing the payload as IP and look at the source and destination.
payloadWithVersion, err := SuffixType(pk.Data, DatagramTypeIP) func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
payloadWithMetadata, err := suffixMetadata(pk.Payload(), pk.Metadata())
if err != nil {
return err
}
payloadWithMetadataAndType, err := SuffixType(payloadWithMetadata, pk.Type())
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
} }
if err := dm.session.SendMessage(payloadWithVersion); err != nil { if err := dm.session.SendMessage(payloadWithMetadataAndType); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge") return errors.Wrap(err, "Failed to send datagram back to edge")
} }
return nil return nil
@ -108,10 +126,10 @@ func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
} }
} }
func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (pk Packet, err error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return packet.RawPacket{}, ctx.Err() return nil, ctx.Err()
case pk := <-dm.packetDemuxChan: case pk := <-dm.packetDemuxChan:
return pk, nil return pk, nil
} }
@ -126,10 +144,8 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error
switch msgType { switch msgType {
case DatagramTypeUDP: case DatagramTypeUDP:
return dm.handleSession(ctx, msg) return dm.handleSession(ctx, msg)
case DatagramTypeIP:
return dm.handlePacket(ctx, msg)
default: default:
return fmt.Errorf("Unexpected datagram type %d", msgType) return dm.handlePacket(ctx, msg, msgType)
} }
} }
@ -150,13 +166,93 @@ func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) er
} }
} }
func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte) error { func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType DatagramV2Type) error {
var demuxedPacket Packet
switch msgType {
case DatagramTypeIP:
demuxedPacket = RawPacket(packet.RawPacket{Data: pk})
case DatagramTypeIPWithTrace:
tracingIdentity, payload, err := extractTracingIdentity(pk)
if err != nil {
return err
}
demuxedPacket = &TracedPacket{
Packet: packet.RawPacket{Data: payload},
TracingIdentity: tracingIdentity,
}
case DatagramTypeTracingSpan:
tracingIdentity, spans, err := extractTracingIdentity(pk)
if err != nil {
return err
}
demuxedPacket = &TracingSpanPacket{
Spans: spans,
TracingIdentity: tracingIdentity,
}
default:
return fmt.Errorf("Unexpected datagram type %d", msgType)
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case dm.packetDemuxChan <- packet.RawPacket{ case dm.packetDemuxChan <- demuxedPacket:
Data: pk,
}:
return nil return nil
} }
} }
func extractTracingIdentity(pk []byte) (tracingIdentity []byte, payload []byte, err error) {
if len(pk) < tracing.IdentityLength {
return nil, nil, fmt.Errorf("packet with tracing context should have at least %d bytes, got %v", tracing.IdentityLength, pk)
}
tracingIdentity = pk[len(pk)-tracing.IdentityLength:]
payload = pk[:len(pk)-tracing.IdentityLength]
return tracingIdentity, payload, nil
}
type RawPacket packet.RawPacket
func (rw RawPacket) Type() DatagramV2Type {
return DatagramTypeIP
}
func (rw RawPacket) Payload() []byte {
return rw.Data
}
func (rw RawPacket) Metadata() []byte {
return []byte{}
}
type TracedPacket struct {
Packet packet.RawPacket
TracingIdentity []byte
}
func (tp *TracedPacket) Type() DatagramV2Type {
return DatagramTypeIPWithTrace
}
func (tp *TracedPacket) Payload() []byte {
return tp.Packet.Data
}
func (tp *TracedPacket) Metadata() []byte {
return tp.TracingIdentity
}
type TracingSpanPacket struct {
Spans []byte
TracingIdentity []byte
}
func (tsp *TracingSpanPacket) Type() DatagramV2Type {
return DatagramTypeTracingSpan
}
func (tsp *TracingSpanPacket) Payload() []byte {
return tsp.Spans
}
func (tsp *TracingSpanPacket) Metadata() []byte {
return tsp.TracingIdentity
}

102
tracing/identity.go Normal file
View File

@ -0,0 +1,102 @@
package tracing
import (
"bytes"
"encoding/binary"
"fmt"
"strconv"
"strings"
)
const (
// 16 bytes for tracing ID, 8 bytes for span ID and 1 byte for flags
IdentityLength = 16 + 8 + 1
)
type Identity struct {
// Based on https://www.jaegertracing.io/docs/1.36/client-libraries/#value
// parent span ID is always 0 for our case
traceIDUpper uint64
traceIDLower uint64
spanID uint64
flags uint8
}
// TODO: TUN-6604 Remove this. To reconstruct into Jaeger propagation format, convert tracingContext to tracing.Identity
func (tc *Identity) String() string {
return fmt.Sprintf("%x%x:%x:0:%x", tc.traceIDUpper, tc.traceIDLower, tc.spanID, tc.flags)
}
func (tc *Identity) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(make([]byte, 0, IdentityLength))
for _, field := range []interface{}{
tc.traceIDUpper,
tc.traceIDLower,
tc.spanID,
tc.flags,
} {
if err := binary.Write(buf, binary.BigEndian, field); err != nil {
return nil, err
}
}
return buf.Bytes(), nil
}
func (tc *Identity) UnmarshalBinary(data []byte) error {
if len(data) < IdentityLength {
return fmt.Errorf("expect tracingContext to have at least %d bytes, got %d", IdentityLength, len(data))
}
buf := bytes.NewBuffer(data)
for _, field := range []interface{}{
&tc.traceIDUpper,
&tc.traceIDLower,
&tc.spanID,
&tc.flags,
} {
if err := binary.Read(buf, binary.BigEndian, field); err != nil {
return err
}
}
return nil
}
func NewIdentity(trace string) (*Identity, error) {
parts := strings.Split(trace, separator)
if len(parts) != 4 {
return nil, fmt.Errorf("trace '%s' doesn't have exactly 4 parts separated by %s", trace, separator)
}
const base = 16
tracingID := parts[0]
if len(tracingID) == 0 {
return nil, fmt.Errorf("missing tracing ID")
}
if len(tracingID) != 32 {
// Correctly left pad the trace to a length of 32
left := traceID128bitsWidth - len(tracingID)
tracingID = strings.Repeat("0", left) + tracingID
}
traceIDUpper, err := strconv.ParseUint(tracingID[:16], base, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse first 16 bytes of tracing ID as uint64, err: %w", err)
}
traceIDLower, err := strconv.ParseUint(tracingID[16:], base, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse last 16 bytes of tracing ID as uint64, err: %w", err)
}
spanID, err := strconv.ParseUint(parts[1], base, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse span ID as uint64, err: %w", err)
}
flags, err := strconv.ParseUint(parts[3], base, 8)
if err != nil {
return nil, fmt.Errorf("failed to parse flag as uint8, err: %w", err)
}
return &Identity{
traceIDUpper: traceIDUpper,
traceIDLower: traceIDLower,
spanID: spanID,
flags: uint8(flags),
}, nil
}

52
tracing/identity_test.go Normal file
View File

@ -0,0 +1,52 @@
package tracing
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewIdentity(t *testing.T) {
testCases := []struct {
testCase string
trace string
valid bool
}{
{
testCase: "full length trace",
trace: "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1",
valid: true,
},
{
testCase: "short trace ID",
trace: "ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1",
valid: true,
},
{
testCase: "no trace",
trace: "",
valid: false,
},
{
testCase: "missing flags",
trace: "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0",
valid: false,
},
{
testCase: "missing separator",
trace: "ec31ad8a01fde11fdcabe2efdce3687352726f6cabc144f501",
valid: false,
},
}
for _, testCase := range testCases {
identity, err := NewIdentity(testCase.trace)
if testCase.valid {
require.NoError(t, err, testCase.testCase)
require.Equal(t, testCase.trace, identity.String())
} else {
require.Error(t, err)
require.Nil(t, identity)
}
}
}