259 lines
6.8 KiB
Go
259 lines
6.8 KiB
Go
package quic
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/lucas-clemente/quic-go"
|
|
"github.com/pkg/errors"
|
|
"github.com/rs/zerolog"
|
|
|
|
"github.com/cloudflare/cloudflared/packet"
|
|
"github.com/cloudflare/cloudflared/tracing"
|
|
)
|
|
|
|
type DatagramV2Type byte
|
|
|
|
const (
|
|
// UDP payload
|
|
DatagramTypeUDP DatagramV2Type = iota
|
|
// Full IP packet
|
|
DatagramTypeIP
|
|
// DatagramTypeIP + tracing ID
|
|
DatagramTypeIPWithTrace
|
|
// Tracing spans in protobuf format
|
|
DatagramTypeTracingSpan
|
|
)
|
|
|
|
type Packet interface {
|
|
Type() DatagramV2Type
|
|
Payload() []byte
|
|
Metadata() []byte
|
|
}
|
|
|
|
const (
|
|
typeIDLen = 1
|
|
// Same as sessionDemuxChan capacity
|
|
packetChanCapacity = 128
|
|
)
|
|
|
|
func SuffixType(b []byte, datagramType DatagramV2Type) ([]byte, error) {
|
|
if len(b)+typeIDLen > MaxDatagramFrameSize {
|
|
return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize)
|
|
}
|
|
b = append(b, byte(datagramType))
|
|
return b, nil
|
|
}
|
|
|
|
// Maximum application payload to send to / receive from QUIC datagram frame
|
|
func (dm *DatagramMuxerV2) mtu() int {
|
|
return maxDatagramPayloadSize
|
|
}
|
|
|
|
type DatagramMuxerV2 struct {
|
|
session quic.Connection
|
|
logger *zerolog.Logger
|
|
sessionDemuxChan chan<- *packet.Session
|
|
packetDemuxChan chan Packet
|
|
}
|
|
|
|
func NewDatagramMuxerV2(
|
|
quicSession quic.Connection,
|
|
log *zerolog.Logger,
|
|
sessionDemuxChan chan<- *packet.Session,
|
|
) *DatagramMuxerV2 {
|
|
logger := log.With().Uint8("datagramVersion", 2).Logger()
|
|
return &DatagramMuxerV2{
|
|
session: quicSession,
|
|
logger: &logger,
|
|
sessionDemuxChan: sessionDemuxChan,
|
|
packetDemuxChan: make(chan Packet, packetChanCapacity),
|
|
}
|
|
}
|
|
|
|
// SendToSession suffix the session ID and datagram version to the payload so the other end of the QUIC connection can
|
|
// demultiplex the payload from multiple datagram sessions
|
|
func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
|
|
if len(session.Payload) > dm.mtu() {
|
|
packetTooBigDropped.Inc()
|
|
return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(session.Payload), dm.mtu())
|
|
}
|
|
msgWithID, err := SuffixSessionID(session.ID, session.Payload)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
|
}
|
|
msgWithIDAndType, err := SuffixType(msgWithID, DatagramTypeUDP)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
|
}
|
|
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
|
|
return errors.Wrap(err, "Failed to send datagram back to edge")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SendPacket sends a packet with datagram version in the suffix. If ctx is a TracedContext, it adds the tracing
|
|
// context between payload and datagram version.
|
|
// The other end of the QUIC connection can demultiplex by parsing the payload as IP and look at the source and destination.
|
|
func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
|
|
payloadWithMetadata, err := suffixMetadata(pk.Payload(), pk.Metadata())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
payloadWithMetadataAndType, err := SuffixType(payloadWithMetadata, pk.Type())
|
|
if err != nil {
|
|
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
|
}
|
|
if err := dm.session.SendMessage(payloadWithMetadataAndType); err != nil {
|
|
return errors.Wrap(err, "Failed to send datagram back to edge")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
|
|
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
|
|
for {
|
|
msg, err := dm.session.ReceiveMessage()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := dm.demux(ctx, msg); err != nil {
|
|
dm.logger.Error().Err(err).Msg("Failed to demux datagram")
|
|
if err == context.Canceled {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (dm *DatagramMuxerV2) ReceivePacket(ctx context.Context) (pk Packet, err error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case pk := <-dm.packetDemuxChan:
|
|
return pk, nil
|
|
}
|
|
}
|
|
|
|
func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error {
|
|
if len(msgWithType) < typeIDLen {
|
|
return fmt.Errorf("QUIC datagram should have at least %d byte", typeIDLen)
|
|
}
|
|
msgType := DatagramV2Type(msgWithType[len(msgWithType)-typeIDLen])
|
|
msg := msgWithType[0 : len(msgWithType)-typeIDLen]
|
|
switch msgType {
|
|
case DatagramTypeUDP:
|
|
return dm.handleSession(ctx, msg)
|
|
default:
|
|
return dm.handlePacket(ctx, msg, msgType)
|
|
}
|
|
}
|
|
|
|
func (dm *DatagramMuxerV2) handleSession(ctx context.Context, session []byte) error {
|
|
sessionID, payload, err := extractSessionID(session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sessionDatagram := packet.Session{
|
|
ID: sessionID,
|
|
Payload: payload,
|
|
}
|
|
select {
|
|
case dm.sessionDemuxChan <- &sessionDatagram:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (dm *DatagramMuxerV2) handlePacket(ctx context.Context, pk []byte, msgType DatagramV2Type) error {
|
|
var demuxedPacket Packet
|
|
switch msgType {
|
|
case DatagramTypeIP:
|
|
demuxedPacket = RawPacket(packet.RawPacket{Data: pk})
|
|
case DatagramTypeIPWithTrace:
|
|
tracingIdentity, payload, err := extractTracingIdentity(pk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
demuxedPacket = &TracedPacket{
|
|
Packet: packet.RawPacket{Data: payload},
|
|
TracingIdentity: tracingIdentity,
|
|
}
|
|
case DatagramTypeTracingSpan:
|
|
tracingIdentity, spans, err := extractTracingIdentity(pk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
demuxedPacket = &TracingSpanPacket{
|
|
Spans: spans,
|
|
TracingIdentity: tracingIdentity,
|
|
}
|
|
default:
|
|
return fmt.Errorf("Unexpected datagram type %d", msgType)
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case dm.packetDemuxChan <- demuxedPacket:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func extractTracingIdentity(pk []byte) (tracingIdentity []byte, payload []byte, err error) {
|
|
if len(pk) < tracing.IdentityLength {
|
|
return nil, nil, fmt.Errorf("packet with tracing context should have at least %d bytes, got %v", tracing.IdentityLength, pk)
|
|
}
|
|
tracingIdentity = pk[len(pk)-tracing.IdentityLength:]
|
|
payload = pk[:len(pk)-tracing.IdentityLength]
|
|
return tracingIdentity, payload, nil
|
|
}
|
|
|
|
type RawPacket packet.RawPacket
|
|
|
|
func (rw RawPacket) Type() DatagramV2Type {
|
|
return DatagramTypeIP
|
|
}
|
|
|
|
func (rw RawPacket) Payload() []byte {
|
|
return rw.Data
|
|
}
|
|
|
|
func (rw RawPacket) Metadata() []byte {
|
|
return []byte{}
|
|
}
|
|
|
|
type TracedPacket struct {
|
|
Packet packet.RawPacket
|
|
TracingIdentity []byte
|
|
}
|
|
|
|
func (tp *TracedPacket) Type() DatagramV2Type {
|
|
return DatagramTypeIPWithTrace
|
|
}
|
|
|
|
func (tp *TracedPacket) Payload() []byte {
|
|
return tp.Packet.Data
|
|
}
|
|
|
|
func (tp *TracedPacket) Metadata() []byte {
|
|
return tp.TracingIdentity
|
|
}
|
|
|
|
type TracingSpanPacket struct {
|
|
Spans []byte
|
|
TracingIdentity []byte
|
|
}
|
|
|
|
func (tsp *TracingSpanPacket) Type() DatagramV2Type {
|
|
return DatagramTypeTracingSpan
|
|
}
|
|
|
|
func (tsp *TracingSpanPacket) Payload() []byte {
|
|
return tsp.Spans
|
|
}
|
|
|
|
func (tsp *TracingSpanPacket) Metadata() []byte {
|
|
return tsp.TracingIdentity
|
|
}
|