TUN-8638: Add datagram v3 serializers and deserializers

Closes TUN-8638
This commit is contained in:
Devin Carr 2024-10-16 12:05:55 -07:00
parent a3ee49d8a9
commit abb3466c31
5 changed files with 872 additions and 0 deletions

372
quic/v3/datagram.go Normal file
View File

@ -0,0 +1,372 @@
package v3
import (
"encoding/binary"
"net/netip"
"time"
)
type DatagramType byte
const (
// UDP Registration
UDPSessionRegistrationType DatagramType = 0x0
// UDP Session Payload
UDPSessionPayloadType DatagramType = 0x1
// DatagramTypeICMP (supporting both ICMPv4 and ICMPv6)
ICMPType DatagramType = 0x2
// UDP Session Registration Response
UDPSessionRegistrationResponseType DatagramType = 0x3
)
const (
// Total number of bytes representing the [DatagramType]
datagramTypeLen = 1
// 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12
maxDatagramLen = 1280
)
func parseDatagramType(data []byte) (DatagramType, error) {
if len(data) < datagramTypeLen {
return 0, ErrDatagramHeaderTooSmall
}
return DatagramType(data[0]), nil
}
// UDPSessionRegistrationDatagram handles a request to initialize a UDP session on the remote client.
type UDPSessionRegistrationDatagram struct {
RequestID RequestID
Dest netip.AddrPort
Traced bool
IdleDurationHint time.Duration
Payload []byte
}
const (
sessionRegistrationFlagsIPMask byte = 0b0000_0001
sessionRegistrationFlagsTracedMask byte = 0b0000_0010
sessionRegistrationFlagsBundledMask byte = 0b0000_0100
sessionRegistrationIPv4DatagramHeaderLen = datagramTypeLen +
1 + // Flag length
2 + // Destination port length
2 + // Idle duration seconds length
datagramRequestIdLen + // Request ID length
4 // IPv4 address length
// The IPv4 and IPv6 address share space, so adding 12 to the header length gets the space taken by the IPv6 field.
sessionRegistrationIPv6DatagramHeaderLen = sessionRegistrationIPv4DatagramHeaderLen + 12
)
// The datagram structure for UDPSessionRegistrationDatagram is:
//
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 0| Type | Flags | Destination Port |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 4| Idle Duration Seconds | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
// 8| |
// + Session Identifier +
// 12| (16 Bytes) |
// + +
// 16| |
// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 20| | Destination IPv4 Address |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- - - - - - - - - - - - - - - -+
// 24| Destination IPv4 Address cont | |
// +- - - - - - - - - - - - - - - - +
// 28| Destination IPv6 Address |
// + (extension of IPv4 region) +
// 32| |
// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 36| | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
// . .
// . Bundle Payload .
// . .
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error) {
ipv6 := s.Dest.Addr().Is6()
var flags byte
if s.Traced {
flags |= sessionRegistrationFlagsTracedMask
}
hasPayload := len(s.Payload) > 0
if hasPayload {
flags |= sessionRegistrationFlagsBundledMask
}
var maxPayloadLen int
if ipv6 {
maxPayloadLen = maxDatagramLen - sessionRegistrationIPv6DatagramHeaderLen
flags |= sessionRegistrationFlagsIPMask
} else {
maxPayloadLen = maxDatagramLen - sessionRegistrationIPv4DatagramHeaderLen
}
// Make sure that the payload being bundled can actually fit in the payload destination
if len(s.Payload) > maxPayloadLen {
return nil, wrapMarshalErr(ErrDatagramPayloadTooLarge)
}
// Allocate the buffer with the right size for the destination IP family
if ipv6 {
data = make([]byte, sessionRegistrationIPv6DatagramHeaderLen+len(s.Payload))
} else {
data = make([]byte, sessionRegistrationIPv4DatagramHeaderLen+len(s.Payload))
}
data[0] = byte(UDPSessionRegistrationType)
data[1] = byte(flags)
binary.BigEndian.PutUint16(data[2:4], s.Dest.Port())
binary.BigEndian.PutUint16(data[4:6], uint16(s.IdleDurationHint.Seconds()))
err = s.RequestID.MarshalBinaryTo(data[6:22])
if err != nil {
return nil, wrapMarshalErr(err)
}
var end int
if ipv6 {
copy(data[22:38], s.Dest.Addr().AsSlice())
end = 38
} else {
copy(data[22:26], s.Dest.Addr().AsSlice())
end = 26
}
if hasPayload {
copy(data[end:], s.Payload)
}
return data, nil
}
func (s *UDPSessionRegistrationDatagram) UnmarshalBinary(data []byte) error {
datagramType, err := parseDatagramType(data)
if err != nil {
return err
}
if datagramType != UDPSessionRegistrationType {
return wrapUnmarshalErr(ErrInvalidDatagramType)
}
requestID, err := RequestIDFromSlice(data[6:22])
if err != nil {
return wrapUnmarshalErr(err)
}
traced := (data[1] & sessionRegistrationFlagsTracedMask) == sessionRegistrationFlagsTracedMask
bundled := (data[1] & sessionRegistrationFlagsBundledMask) == sessionRegistrationFlagsBundledMask
ipv6 := (data[1] & sessionRegistrationFlagsIPMask) == sessionRegistrationFlagsIPMask
port := binary.BigEndian.Uint16(data[2:4])
var datagramHeaderSize int
var dest netip.AddrPort
if ipv6 {
datagramHeaderSize = sessionRegistrationIPv6DatagramHeaderLen
dest = netip.AddrPortFrom(netip.AddrFrom16([16]byte(data[22:38])), port)
} else {
datagramHeaderSize = sessionRegistrationIPv4DatagramHeaderLen
dest = netip.AddrPortFrom(netip.AddrFrom4([4]byte(data[22:26])), port)
}
idle := time.Duration(binary.BigEndian.Uint16(data[4:6])) * time.Second
var payload []byte
if bundled && len(data) >= datagramHeaderSize && len(data[datagramHeaderSize:]) > 0 {
payload = data[datagramHeaderSize:]
}
*s = UDPSessionRegistrationDatagram{
RequestID: requestID,
Dest: dest,
Traced: traced,
IdleDurationHint: idle,
Payload: payload,
}
return nil
}
// UDPSessionPayloadDatagram provides the payload for a session to be send to either the origin or the client.
type UDPSessionPayloadDatagram struct {
RequestID RequestID
Payload []byte
}
const (
datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen
// The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram]
maxPayloadPlusHeaderLen = maxDatagramLen - datagramPayloadHeaderLen
)
// The datagram structure for UDPSessionPayloadDatagram is:
//
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 0| Type | |
// +-+-+-+-+-+-+-+-+ +
// 4| |
// + +
// 8| Session Identifier |
// + (16 Bytes) +
// 12| |
// + +-+-+-+-+-+-+-+-+
// 16| | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
// . .
// . Payload .
// . .
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// MarshalPayloadHeaderTo provides a way to insert the Session Payload header into an already existing byte slice
// without having to allocate and copy the payload into the destination.
//
// This method should be used in-place of MarshalBinary which will allocate in-place the required byte array to return.
func MarshalPayloadHeaderTo(requestID RequestID, payload []byte) error {
if len(payload) < 17 {
return wrapMarshalErr(ErrDatagramPayloadHeaderTooSmall)
}
payload[0] = byte(UDPSessionPayloadType)
return requestID.MarshalBinaryTo(payload[1:17])
}
func (s *UDPSessionPayloadDatagram) UnmarshalBinary(data []byte) error {
datagramType, err := parseDatagramType(data)
if err != nil {
return err
}
if datagramType != UDPSessionPayloadType {
return wrapUnmarshalErr(ErrInvalidDatagramType)
}
// Make sure that the slice provided is the right size to be parsed.
if len(data) < 17 || len(data) > maxPayloadPlusHeaderLen {
return wrapUnmarshalErr(ErrDatagramPayloadInvalidSize)
}
requestID, err := RequestIDFromSlice(data[1:17])
if err != nil {
return wrapUnmarshalErr(err)
}
*s = UDPSessionPayloadDatagram{
RequestID: requestID,
Payload: data[17:],
}
return nil
}
// UDPSessionRegistrationResponseDatagram is used to either return a successful registration or error to the client
// that requested the registration of a UDP session.
type UDPSessionRegistrationResponseDatagram struct {
RequestID RequestID
ResponseType SessionRegistrationResp
ErrorMsg string
}
const (
datagramRespTypeLen = 1
datagramRespErrMsgLen = 2
datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen
// The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram].
maxResponseErrorMessageLen = maxDatagramLen - datagramSessionRegistrationResponseLen
)
// SessionRegistrationResp represents all of the responses that a UDP session registration response
// can return back to the client.
type SessionRegistrationResp byte
const (
// Session was received and is ready to proxy.
ResponseOk SessionRegistrationResp = 0x00
// Session registration was unable to reach the requested origin destination.
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
// Session registration was unable to bind to a local UDP socket.
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
// Session registration failed with an unexpected error but provided a message.
ResponseErrorWithMsg SessionRegistrationResp = 0xff
)
// The datagram structure for UDPSessionRegistrationResponseDatagram is:
//
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 0| Type | Resp Type | |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
// 4| |
// + Session Identifier +
// 8| (16 Bytes) |
// + +
// 12| |
// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// 16| | Error Length |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// . .
// . .
// . .
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
func (s *UDPSessionRegistrationResponseDatagram) MarshalBinary() (data []byte, err error) {
if len(s.ErrorMsg) > maxResponseErrorMessageLen {
return nil, wrapMarshalErr(ErrDatagramResponseMsgInvalidSize)
}
errMsgLen := uint16(len(s.ErrorMsg))
data = make([]byte, datagramSessionRegistrationResponseLen+errMsgLen)
data[0] = byte(UDPSessionRegistrationResponseType)
data[1] = byte(s.ResponseType)
err = s.RequestID.MarshalBinaryTo(data[2:18])
if err != nil {
return nil, wrapMarshalErr(err)
}
if errMsgLen > 0 {
binary.BigEndian.PutUint16(data[18:20], errMsgLen)
copy(data[20:], []byte(s.ErrorMsg))
}
return data, nil
}
func (s *UDPSessionRegistrationResponseDatagram) UnmarshalBinary(data []byte) error {
datagramType, err := parseDatagramType(data)
if err != nil {
return wrapUnmarshalErr(err)
}
if datagramType != UDPSessionRegistrationResponseType {
return wrapUnmarshalErr(ErrInvalidDatagramType)
}
if len(data) < datagramSessionRegistrationResponseLen {
return wrapUnmarshalErr(ErrDatagramResponseInvalidSize)
}
respType := SessionRegistrationResp(data[1])
requestID, err := RequestIDFromSlice(data[2:18])
if err != nil {
return wrapUnmarshalErr(err)
}
errMsgLen := binary.BigEndian.Uint16(data[18:20])
if errMsgLen > maxResponseErrorMessageLen {
return wrapUnmarshalErr(ErrDatagramResponseMsgTooLargeMaximum)
}
if len(data[20:]) < int(errMsgLen) {
return wrapUnmarshalErr(ErrDatagramResponseMsgTooLargeDatagram)
}
var errMsg string
if errMsgLen > 0 {
errMsg = string(data[20:])
}
*s = UDPSessionRegistrationResponseDatagram{
RequestID: requestID,
ResponseType: respType,
ErrorMsg: errMsg,
}
return nil
}

View File

@ -0,0 +1,26 @@
package v3
import (
"errors"
"fmt"
)
var (
ErrInvalidDatagramType error = errors.New("invalid datagram type expected")
ErrDatagramHeaderTooSmall error = fmt.Errorf("datagram should have at least %d bytes", datagramTypeLen)
ErrDatagramPayloadTooLarge error = errors.New("payload length is too large to be bundled in datagram")
ErrDatagramPayloadHeaderTooSmall error = errors.New("payload length is too small to fit the datagram header")
ErrDatagramPayloadInvalidSize error = errors.New("datagram provided is an invalid size")
ErrDatagramResponseMsgInvalidSize error = errors.New("datagram response message is an invalid size")
ErrDatagramResponseInvalidSize error = errors.New("datagram response is an invalid size")
ErrDatagramResponseMsgTooLargeMaximum error = fmt.Errorf("datagram response error message length exceeds the length of the datagram maximum: %d", maxResponseErrorMessageLen)
ErrDatagramResponseMsgTooLargeDatagram error = fmt.Errorf("datagram response error message length exceeds the length of the provided datagram")
)
func wrapMarshalErr(err error) error {
return fmt.Errorf("datagram marshal error: %w", err)
}
func wrapUnmarshalErr(err error) error {
return fmt.Errorf("datagram unmarshal error: %w", err)
}

352
quic/v3/datagram_test.go Normal file
View File

@ -0,0 +1,352 @@
package v3_test
import (
"encoding/binary"
"errors"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
func makePayload(size int) []byte {
payload := make([]byte, size)
for i := range len(payload) {
payload[i] = 0xfc
}
return payload
}
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
payload := makePayload(1254)
tests := []*v3.UDPSessionRegistrationDatagram{
// Default (IPv4)
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// Request ID (max)
{
RequestID: mustRequestID([16]byte{
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
}),
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// IPv6
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("[fc00::0]:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// Traced
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: true,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// IdleDurationHint (max)
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 65535 * time.Second,
Payload: nil,
},
// Payload
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: []byte{0xff, 0xaa, 0xcc, 0x44},
},
// Payload (max: 1254) for IPv4
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: payload,
},
// Payload (max: 1242) for IPv4
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: payload[:1242],
},
}
for _, tt := range tests {
marshaled, err := tt.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionRegistrationDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
if !compareRegistrationDatagrams(t, tt, &unmarshaled) {
t.Errorf("not equal:\n%+v\n%+v", tt, &unmarshaled)
}
}
}
func TestSessionRegistration_MarshalBinary(t *testing.T) {
t.Run("idle hint too large", func(t *testing.T) {
// idle hint duration overflows back to 1
datagram := &v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 65537 * time.Second,
Payload: nil,
}
expected := &v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 1 * time.Second,
Payload: nil,
}
marshaled, err := datagram.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionRegistrationDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
if !compareRegistrationDatagrams(t, expected, &unmarshaled) {
t.Errorf("not equal:\n%+v\n%+v", expected, &unmarshaled)
}
})
}
func TestTypeUnmarshalErrors(t *testing.T) {
t.Run("invalid length", func(t *testing.T) {
d1 := v3.UDPSessionRegistrationDatagram{}
err := d1.UnmarshalBinary([]byte{})
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
t.Errorf("expected invalid length to throw error")
}
d2 := v3.UDPSessionPayloadDatagram{}
err = d2.UnmarshalBinary([]byte{})
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
t.Errorf("expected invalid length to throw error")
}
d3 := v3.UDPSessionRegistrationResponseDatagram{}
err = d3.UnmarshalBinary([]byte{})
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
t.Errorf("expected invalid length to throw error")
}
})
t.Run("invalid types", func(t *testing.T) {
d1 := v3.UDPSessionRegistrationDatagram{}
err := d1.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationResponseType)})
if !errors.Is(err, v3.ErrInvalidDatagramType) {
t.Errorf("expected invalid type to throw error")
}
d2 := v3.UDPSessionPayloadDatagram{}
err = d2.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationType)})
if !errors.Is(err, v3.ErrInvalidDatagramType) {
t.Errorf("expected invalid type to throw error")
}
d3 := v3.UDPSessionRegistrationResponseDatagram{}
err = d3.UnmarshalBinary([]byte{byte(v3.UDPSessionPayloadType)})
if !errors.Is(err, v3.ErrInvalidDatagramType) {
t.Errorf("expected invalid type to throw error")
}
})
}
func TestSessionPayload(t *testing.T) {
t.Run("basic", func(t *testing.T) {
payload := makePayload(128)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload[0:17])
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(payload)
if err != nil {
t.Error(err)
}
require.Equal(t, testRequestID, unmarshaled.RequestID)
require.Equal(t, payload[17:], unmarshaled.Payload)
})
t.Run("empty", func(t *testing.T) {
payload := makePayload(17)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(payload)
if err != nil {
t.Error(err)
}
require.Equal(t, testRequestID, unmarshaled.RequestID)
require.Equal(t, payload[17:], unmarshaled.Payload)
})
t.Run("header size too small", func(t *testing.T) {
payload := makePayload(16)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
if !errors.Is(err, v3.ErrDatagramPayloadHeaderTooSmall) {
t.Errorf("expected an error")
}
})
t.Run("payload size too small", func(t *testing.T) {
payload := makePayload(17)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(payload[:16])
if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) {
t.Errorf("expected an error: %s", err)
}
})
t.Run("payload size too large", func(t *testing.T) {
datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed
err := v3.MarshalPayloadHeaderTo(testRequestID, datagram)
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(datagram[:])
if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) {
t.Errorf("expected an error: %s", err)
}
})
}
func TestSessionRegistrationResponse(t *testing.T) {
validRespTypes := []v3.SessionRegistrationResp{
v3.ResponseOk,
v3.ResponseDestinationUnreachable,
v3.ResponseUnableToBindSocket,
v3.ResponseErrorWithMsg,
}
t.Run("basic", func(t *testing.T) {
for _, responseType := range validRespTypes {
datagram := &v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: responseType,
ErrorMsg: "test",
}
marshaled, err := datagram.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
require.Equal(t, datagram, unmarshaled)
}
})
t.Run("unsupported resp type is valid", func(t *testing.T) {
datagram := &v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: v3.SessionRegistrationResp(0xfc),
ErrorMsg: "",
}
marshaled, err := datagram.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
require.Equal(t, datagram, unmarshaled)
})
t.Run("too small to unmarshal", func(t *testing.T) {
payload := makePayload(17)
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err := unmarshaled.UnmarshalBinary(payload)
if !errors.Is(err, v3.ErrDatagramResponseInvalidSize) {
t.Errorf("expected an error")
}
})
t.Run("error message too long", func(t *testing.T) {
message := ""
for i := 0; i < 1280; i++ {
message += "a"
}
datagram := &v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: v3.SessionRegistrationResp(0xfc),
ErrorMsg: message,
}
_, err := datagram.MarshalBinary()
if !errors.Is(err, v3.ErrDatagramResponseMsgInvalidSize) {
t.Errorf("expected an error")
}
})
t.Run("error message too large to unmarshal", func(t *testing.T) {
payload := makePayload(1280)
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
binary.BigEndian.PutUint16(payload[18:20], 1280) // larger than the datagram size could be
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err := unmarshaled.UnmarshalBinary(payload)
if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeMaximum) {
t.Errorf("expected an error: %v", err)
}
})
t.Run("error message larger than provided buffer", func(t *testing.T) {
payload := makePayload(1000)
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
binary.BigEndian.PutUint16(payload[18:20], 1001) // larger than the datagram size provided
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err := unmarshaled.UnmarshalBinary(payload)
if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeDatagram) {
t.Errorf("expected an error: %v", err)
}
})
}
func compareRegistrationDatagrams(t *testing.T, l *v3.UDPSessionRegistrationDatagram, r *v3.UDPSessionRegistrationDatagram) bool {
require.Equal(t, l.Payload, r.Payload)
return l.RequestID == r.RequestID &&
l.Dest == r.Dest &&
l.IdleDurationHint == r.IdleDurationHint &&
l.Traced == r.Traced
}

72
quic/v3/request.go Normal file
View File

@ -0,0 +1,72 @@
package v3
import (
"encoding/binary"
"errors"
)
const (
datagramRequestIdLen = 16
)
var (
// ErrInvalidRequestIDLen is returned when the provided request id can not be parsed from the provided byte slice.
ErrInvalidRequestIDLen error = errors.New("invalid request id length provided")
// ErrInvalidPayloadDestLen is returned when the provided destination byte slice cannot fit the whole request id.
ErrInvalidPayloadDestLen error = errors.New("invalid payload size provided")
)
// RequestID is the request-id-v2 identifier, it is used to distinguish between specific flows or sessions proxied
// from the edge to cloudflared.
type RequestID uint128
type uint128 struct {
hi uint64
lo uint64
}
// RequestIDFromSlice reads a request ID from a byte slice.
func RequestIDFromSlice(data []byte) (RequestID, error) {
if len(data) != datagramRequestIdLen {
return RequestID{}, ErrInvalidRequestIDLen
}
return RequestID{
hi: binary.BigEndian.Uint64(data[:8]),
lo: binary.BigEndian.Uint64(data[8:]),
}, nil
}
// Compare returns an integer comparing two IPs.
// The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2.
// The definition of "less than" is the same as the [RequestID.Less] method.
func (id RequestID) Compare(id2 RequestID) int {
hi1, hi2 := id.hi, id2.hi
if hi1 < hi2 {
return -1
}
if hi1 > hi2 {
return 1
}
lo1, lo2 := id.lo, id2.lo
if lo1 < lo2 {
return -1
}
if lo1 > lo2 {
return 1
}
return 0
}
// Less reports whether id sorts before id2.
func (id RequestID) Less(id2 RequestID) bool { return id.Compare(id2) == -1 }
// MarshalBinaryTo writes the id to the provided destination byte slice; the byte slice must be of at least size 16.
func (id RequestID) MarshalBinaryTo(data []byte) error {
if len(data) < datagramRequestIdLen {
return ErrInvalidPayloadDestLen
}
binary.BigEndian.PutUint64(data[:8], id.hi)
binary.BigEndian.PutUint64(data[8:], id.lo)
return nil
}

50
quic/v3/request_test.go Normal file
View File

@ -0,0 +1,50 @@
package v3_test
import (
"crypto/rand"
"slices"
"testing"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
var (
testRequestIDBytes = [16]byte{
0x00, 0x11, 0x22, 0x33,
0x44, 0x55, 0x66, 0x77,
0x88, 0x99, 0xaa, 0xbb,
0xcc, 0xdd, 0xee, 0xff,
}
testRequestID = mustRequestID(testRequestIDBytes)
)
func mustRequestID(data [16]byte) v3.RequestID {
id, err := v3.RequestIDFromSlice(data[:])
if err != nil {
panic(err)
}
return id
}
func TestRequestIDParsing(t *testing.T) {
buf1 := make([]byte, 16)
n, err := rand.Read(buf1)
if err != nil {
t.Fatal(err)
}
if n != 16 {
t.Fatalf("did not read 16 bytes: %d", n)
}
id, err := v3.RequestIDFromSlice(buf1)
if err != nil {
t.Fatal(err)
}
buf2 := make([]byte, 16)
err = id.MarshalBinaryTo(buf2)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(buf1, buf2) {
t.Fatalf("buf1 != buf2: %+v %+v", buf1, buf2)
}
}