diff --git a/connection/quic.go b/connection/quic.go index 9905e90b..4eb405c7 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -93,7 +93,8 @@ func NewQUICConnection( sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) - packetRouter := packet.NewRouter(packetRouterConfig, datagramMuxer, &returnPipe{muxer: datagramMuxer}, logger, orchestrator.WarpRoutingEnabled) + muxer := muxerWrapper{muxer: datagramMuxer} + packetRouter := packet.NewRouter(packetRouterConfig, &muxer, &muxer, logger, orchestrator.WarpRoutingEnabled) return &QUICConnection{ session: session, @@ -498,16 +499,28 @@ func (np *nopCloserReadWriter) Close() error { return nil } -// returnPipe wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface -type returnPipe struct { +// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface +type muxerWrapper struct { muxer *quicpogs.DatagramMuxerV2 } -func (rp *returnPipe) SendPacket(dst netip.Addr, pk packet.RawPacket) error { - return rp.muxer.SendPacket(pk) +func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error { + return rp.muxer.SendPacket(quicpogs.RawPacket(pk)) } -func (rp *returnPipe) Close() error { +func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { + pk, err := rp.muxer.ReceivePacket(ctx) + if err != nil { + return packet.RawPacket{}, err + } + rawPacket, ok := pk.(quicpogs.RawPacket) + if ok { + return packet.RawPacket(rawPacket), nil + } + return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk) +} + +func (rp *muxerWrapper) Close() error { return nil } diff --git a/quic/datagram.go b/quic/datagram.go index 23408cfe..c4637dc0 100644 --- a/quic/datagram.go +++ b/quic/datagram.go @@ -113,9 +113,12 @@ func extractSessionID(b []byte) (uuid.UUID, []byte, error) { // SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because // the payload slice might already have enough capacity to append the session ID at the end func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) { - if len(b)+len(sessionID) > MaxDatagramFrameSize { + return suffixMetadata(b, sessionID[:]) +} + +func suffixMetadata(payload, metadata []byte) ([]byte, error) { + if len(payload)+len(metadata) > MaxDatagramFrameSize { return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize) } - b = append(b, sessionID[:]...) - return b, nil + return append(payload, metadata...), nil } diff --git a/quic/datagram_test.go b/quic/datagram_test.go index 69bb0b71..54307c79 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -23,6 +24,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/cloudflare/cloudflared/packet" + "github.com/cloudflare/cloudflared/tracing" ) var ( @@ -121,6 +123,15 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi logger := zerolog.Nop() + tracingIdentity, err := tracing.NewIdentity("ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1") + require.NoError(t, err) + serializedTracingID, err := tracingIdentity.MarshalBinary() + require.NoError(t, err) + tracingSpan := &TracingSpanPacket{ + Spans: []byte("tracing"), + TracingIdentity: serializedTracingID, + } + errGroup, ctx := errgroup.WithContext(context.Background()) // Run edge side of datagram muxer errGroup.Go(func() error { @@ -140,18 +151,17 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi muxer := NewDatagramMuxerV2(quicSession, &logger, sessionDemuxChan) muxer.ServeReceive(ctx) - icmpDecoder := packet.NewICMPDecoder() for _, pk := range packets { received, err := muxer.ReceivePacket(ctx) require.NoError(t, err) - - receivedICMP, err := icmpDecoder.Decode(received) + validateIPPacket(t, received, &pk) + received, err = muxer.ReceivePacket(ctx) require.NoError(t, err) - require.Equal(t, pk.IP, receivedICMP.IP) - require.Equal(t, pk.Type, receivedICMP.Type) - require.Equal(t, pk.Code, receivedICMP.Code) - require.Equal(t, pk.Body, receivedICMP.Body) + validateIPPacketWithTracing(t, received, &pk, serializedTracingID) } + received, err := muxer.ReceivePacket(ctx) + require.NoError(t, err) + validateTracingSpans(t, received, tracingSpan) default: return fmt.Errorf("unknown datagram version %d", version) } @@ -188,10 +198,15 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi for _, pk := range packets { encodedPacket, err := encoder.Encode(&pk) require.NoError(t, err) - require.NoError(t, muxerV2.SendPacket(encodedPacket)) + require.NoError(t, muxerV2.SendPacket(RawPacket(encodedPacket))) + require.NoError(t, muxerV2.SendPacket(&TracedPacket{ + Packet: encodedPacket, + TracingIdentity: serializedTracingID, + })) } + require.NoError(t, muxerV2.SendPacket(tracingSpan)) // Payload larger than transport MTU, should not be sent - require.Error(t, muxerV2.SendPacket(packet.RawPacket{ + require.Error(t, muxerV2.SendPacket(RawPacket{ Data: largePayload, })) muxer = muxerV2 @@ -217,6 +232,38 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi require.NoError(t, errGroup.Wait()) } +func validateIPPacket(t *testing.T, receivedPacket Packet, expectedICMP *packet.ICMP) { + require.Equal(t, DatagramTypeIP, receivedPacket.Type()) + rawPacket := receivedPacket.(RawPacket) + decoder := packet.NewICMPDecoder() + receivedICMP, err := decoder.Decode(packet.RawPacket(rawPacket)) + require.NoError(t, err) + validateICMP(t, expectedICMP, receivedICMP) +} + +func validateIPPacketWithTracing(t *testing.T, receivedPacket Packet, expectedICMP *packet.ICMP, serializedTracingID []byte) { + require.Equal(t, DatagramTypeIPWithTrace, receivedPacket.Type()) + tracedPacket := receivedPacket.(*TracedPacket) + decoder := packet.NewICMPDecoder() + receivedICMP, err := decoder.Decode(tracedPacket.Packet) + require.NoError(t, err) + validateICMP(t, expectedICMP, receivedICMP) + require.True(t, bytes.Equal(tracedPacket.TracingIdentity, serializedTracingID)) +} + +func validateICMP(t *testing.T, expected, actual *packet.ICMP) { + require.Equal(t, expected.IP, actual.IP) + require.Equal(t, expected.Type, actual.Type) + require.Equal(t, expected.Code, actual.Code) + require.Equal(t, expected.Body, actual.Body) +} + +func validateTracingSpans(t *testing.T, receivedPacket Packet, expectedSpan *TracingSpanPacket) { + require.Equal(t, DatagramTypeTracingSpan, receivedPacket.Type()) + tracingSpans := receivedPacket.(*TracingSpanPacket) + require.Equal(t, tracingSpans, expectedSpan) +} + func newQUICListener(t *testing.T, config *quic.Config) quic.Listener { // Create a simple tls config. tlsConfig := generateTLSConfig() diff --git a/quic/datagramv2.go b/quic/datagramv2.go index 373d8731..01012618 100644 --- a/quic/datagramv2.go +++ b/quic/datagramv2.go @@ -9,15 +9,28 @@ import ( "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/packet" + "github.com/cloudflare/cloudflared/tracing" ) type DatagramV2Type byte const ( + // UDP payload DatagramTypeUDP DatagramV2Type = iota + // Full IP packet DatagramTypeIP + // DatagramTypeIP + tracing ID + DatagramTypeIPWithTrace + // Tracing spans in protobuf format + DatagramTypeTracingSpan ) +type Packet interface { + Type() DatagramV2Type + Payload() []byte + Metadata() []byte +} + const ( typeIDLen = 1 // Same as sessionDemuxChan capacity @@ -41,7 +54,7 @@ type DatagramMuxerV2 struct { session quic.Connection logger *zerolog.Logger sessionDemuxChan chan<- *packet.Session - packetDemuxChan chan packet.RawPacket + packetDemuxChan chan Packet } func NewDatagramMuxerV2( @@ -54,7 +67,7 @@ func NewDatagramMuxerV2( session: quicSession, logger: &logger, sessionDemuxChan: sessionDemuxChan, - packetDemuxChan: make(chan packet.RawPacket, packetChanCapacity), + packetDemuxChan: make(chan Packet, packetChanCapacity), } } @@ -79,14 +92,19 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error { return nil } -// SendPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing -// the payload as IP and look at the source and destination. -func (dm *DatagramMuxerV2) SendPacket(pk packet.RawPacket) error { - payloadWithVersion, err := SuffixType(pk.Data, DatagramTypeIP) +// SendPacket sends a packet with datagram version in the suffix. If ctx is a TracedContext, it adds the tracing +// context between payload and datagram version. +// The other end of the QUIC connection can demultiplex by parsing the payload as IP and look at the source and destination. +func (dm *DatagramMuxerV2) SendPacket(pk Packet) error { + payloadWithMetadata, err := suffixMetadata(pk.Payload(), pk.Metadata()) + if err != nil { + return err + } + payloadWithMetadataAndType, err := SuffixType(payloadWithMetadata, pk.Type()) if err != nil { return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") } - if err := dm.session.SendMessage(payloadWithVersion); err != nil { + if err := dm.session.SendMessage(payloadWithMetadataAndType); err != nil { return errors.Wrap(err, "Failed to send datagram back to edge") } return nil @@ -108,10 +126,10 @@ func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error { } } -func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (packet.RawPacket, error) { +func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (pk Packet, err error) { select { case <-ctx.Done(): - return packet.RawPacket{}, ctx.Err() + return nil, ctx.Err() case pk := <-dm.packetDemuxChan: return pk, nil } @@ -126,10 +144,8 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error switch msgType { case DatagramTypeUDP: return dm.handleSession(ctx, msg) - case DatagramTypeIP: - return dm.handlePacket(ctx, msg) default: - return fmt.Errorf("Unexpected datagram type %d", msgType) + return dm.handlePacket(ctx, msg, msgType) } } @@ -150,13 +166,93 @@ func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) er } } -func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte) error { +func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType DatagramV2Type) error { + var demuxedPacket Packet + switch msgType { + case DatagramTypeIP: + demuxedPacket = RawPacket(packet.RawPacket{Data: pk}) + case DatagramTypeIPWithTrace: + tracingIdentity, payload, err := extractTracingIdentity(pk) + if err != nil { + return err + } + demuxedPacket = &TracedPacket{ + Packet: packet.RawPacket{Data: payload}, + TracingIdentity: tracingIdentity, + } + case DatagramTypeTracingSpan: + tracingIdentity, spans, err := extractTracingIdentity(pk) + if err != nil { + return err + } + demuxedPacket = &TracingSpanPacket{ + Spans: spans, + TracingIdentity: tracingIdentity, + } + default: + return fmt.Errorf("Unexpected datagram type %d", msgType) + } select { case <-ctx.Done(): return ctx.Err() - case dm.packetDemuxChan <- packet.RawPacket{ - Data: pk, - }: + case dm.packetDemuxChan <- demuxedPacket: return nil } } + +func extractTracingIdentity(pk []byte) (tracingIdentity []byte, payload []byte, err error) { + if len(pk) < tracing.IdentityLength { + return nil, nil, fmt.Errorf("packet with tracing context should have at least %d bytes, got %v", tracing.IdentityLength, pk) + } + tracingIdentity = pk[len(pk)-tracing.IdentityLength:] + payload = pk[:len(pk)-tracing.IdentityLength] + return tracingIdentity, payload, nil +} + +type RawPacket packet.RawPacket + +func (rw RawPacket) Type() DatagramV2Type { + return DatagramTypeIP +} + +func (rw RawPacket) Payload() []byte { + return rw.Data +} + +func (rw RawPacket) Metadata() []byte { + return []byte{} +} + +type TracedPacket struct { + Packet packet.RawPacket + TracingIdentity []byte +} + +func (tp *TracedPacket) Type() DatagramV2Type { + return DatagramTypeIPWithTrace +} + +func (tp *TracedPacket) Payload() []byte { + return tp.Packet.Data +} + +func (tp *TracedPacket) Metadata() []byte { + return tp.TracingIdentity +} + +type TracingSpanPacket struct { + Spans []byte + TracingIdentity []byte +} + +func (tsp *TracingSpanPacket) Type() DatagramV2Type { + return DatagramTypeTracingSpan +} + +func (tsp *TracingSpanPacket) Payload() []byte { + return tsp.Spans +} + +func (tsp *TracingSpanPacket) Metadata() []byte { + return tsp.TracingIdentity +} diff --git a/tracing/identity.go b/tracing/identity.go new file mode 100644 index 00000000..e6bbb4cb --- /dev/null +++ b/tracing/identity.go @@ -0,0 +1,102 @@ +package tracing + +import ( + "bytes" + "encoding/binary" + "fmt" + "strconv" + "strings" +) + +const ( + // 16 bytes for tracing ID, 8 bytes for span ID and 1 byte for flags + IdentityLength = 16 + 8 + 1 +) + +type Identity struct { + // Based on https://www.jaegertracing.io/docs/1.36/client-libraries/#value + // parent span ID is always 0 for our case + traceIDUpper uint64 + traceIDLower uint64 + spanID uint64 + flags uint8 +} + +// TODO: TUN-6604 Remove this. To reconstruct into Jaeger propagation format, convert tracingContext to tracing.Identity +func (tc *Identity) String() string { + return fmt.Sprintf("%x%x:%x:0:%x", tc.traceIDUpper, tc.traceIDLower, tc.spanID, tc.flags) +} + +func (tc *Identity) MarshalBinary() ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, IdentityLength)) + for _, field := range []interface{}{ + tc.traceIDUpper, + tc.traceIDLower, + tc.spanID, + tc.flags, + } { + if err := binary.Write(buf, binary.BigEndian, field); err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} + +func (tc *Identity) UnmarshalBinary(data []byte) error { + if len(data) < IdentityLength { + return fmt.Errorf("expect tracingContext to have at least %d bytes, got %d", IdentityLength, len(data)) + } + + buf := bytes.NewBuffer(data) + for _, field := range []interface{}{ + &tc.traceIDUpper, + &tc.traceIDLower, + &tc.spanID, + &tc.flags, + } { + if err := binary.Read(buf, binary.BigEndian, field); err != nil { + return err + } + } + + return nil +} + +func NewIdentity(trace string) (*Identity, error) { + parts := strings.Split(trace, separator) + if len(parts) != 4 { + return nil, fmt.Errorf("trace '%s' doesn't have exactly 4 parts separated by %s", trace, separator) + } + const base = 16 + tracingID := parts[0] + if len(tracingID) == 0 { + return nil, fmt.Errorf("missing tracing ID") + } + if len(tracingID) != 32 { + // Correctly left pad the trace to a length of 32 + left := traceID128bitsWidth - len(tracingID) + tracingID = strings.Repeat("0", left) + tracingID + } + traceIDUpper, err := strconv.ParseUint(tracingID[:16], base, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse first 16 bytes of tracing ID as uint64, err: %w", err) + } + traceIDLower, err := strconv.ParseUint(tracingID[16:], base, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse last 16 bytes of tracing ID as uint64, err: %w", err) + } + spanID, err := strconv.ParseUint(parts[1], base, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse span ID as uint64, err: %w", err) + } + flags, err := strconv.ParseUint(parts[3], base, 8) + if err != nil { + return nil, fmt.Errorf("failed to parse flag as uint8, err: %w", err) + } + return &Identity{ + traceIDUpper: traceIDUpper, + traceIDLower: traceIDLower, + spanID: spanID, + flags: uint8(flags), + }, nil +} diff --git a/tracing/identity_test.go b/tracing/identity_test.go new file mode 100644 index 00000000..3bdb7448 --- /dev/null +++ b/tracing/identity_test.go @@ -0,0 +1,52 @@ +package tracing + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewIdentity(t *testing.T) { + testCases := []struct { + testCase string + trace string + valid bool + }{ + { + testCase: "full length trace", + trace: "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1", + valid: true, + }, + { + testCase: "short trace ID", + trace: "ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1", + valid: true, + }, + { + testCase: "no trace", + trace: "", + valid: false, + }, + { + testCase: "missing flags", + trace: "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0", + valid: false, + }, + { + testCase: "missing separator", + trace: "ec31ad8a01fde11fdcabe2efdce3687352726f6cabc144f501", + valid: false, + }, + } + + for _, testCase := range testCases { + identity, err := NewIdentity(testCase.trace) + if testCase.valid { + require.NoError(t, err, testCase.testCase) + require.Equal(t, testCase.trace, identity.String()) + } else { + require.Error(t, err) + require.Nil(t, identity) + } + } +}