TUN-8638: Add datagram v3 serializers and deserializers
Closes TUN-8638
This commit is contained in:
parent
a3ee49d8a9
commit
abb3466c31
|
@ -0,0 +1,372 @@
|
|||
package v3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DatagramType byte
|
||||
|
||||
const (
|
||||
// UDP Registration
|
||||
UDPSessionRegistrationType DatagramType = 0x0
|
||||
// UDP Session Payload
|
||||
UDPSessionPayloadType DatagramType = 0x1
|
||||
// DatagramTypeICMP (supporting both ICMPv4 and ICMPv6)
|
||||
ICMPType DatagramType = 0x2
|
||||
// UDP Session Registration Response
|
||||
UDPSessionRegistrationResponseType DatagramType = 0x3
|
||||
)
|
||||
|
||||
const (
|
||||
// Total number of bytes representing the [DatagramType]
|
||||
datagramTypeLen = 1
|
||||
|
||||
// 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12
|
||||
maxDatagramLen = 1280
|
||||
)
|
||||
|
||||
func parseDatagramType(data []byte) (DatagramType, error) {
|
||||
if len(data) < datagramTypeLen {
|
||||
return 0, ErrDatagramHeaderTooSmall
|
||||
}
|
||||
return DatagramType(data[0]), nil
|
||||
}
|
||||
|
||||
// UDPSessionRegistrationDatagram handles a request to initialize a UDP session on the remote client.
|
||||
type UDPSessionRegistrationDatagram struct {
|
||||
RequestID RequestID
|
||||
Dest netip.AddrPort
|
||||
Traced bool
|
||||
IdleDurationHint time.Duration
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
const (
|
||||
sessionRegistrationFlagsIPMask byte = 0b0000_0001
|
||||
sessionRegistrationFlagsTracedMask byte = 0b0000_0010
|
||||
sessionRegistrationFlagsBundledMask byte = 0b0000_0100
|
||||
|
||||
sessionRegistrationIPv4DatagramHeaderLen = datagramTypeLen +
|
||||
1 + // Flag length
|
||||
2 + // Destination port length
|
||||
2 + // Idle duration seconds length
|
||||
datagramRequestIdLen + // Request ID length
|
||||
4 // IPv4 address length
|
||||
|
||||
// The IPv4 and IPv6 address share space, so adding 12 to the header length gets the space taken by the IPv6 field.
|
||||
sessionRegistrationIPv6DatagramHeaderLen = sessionRegistrationIPv4DatagramHeaderLen + 12
|
||||
)
|
||||
|
||||
// The datagram structure for UDPSessionRegistrationDatagram is:
|
||||
//
|
||||
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 0| Type | Flags | Destination Port |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 4| Idle Duration Seconds | |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
|
||||
// 8| |
|
||||
// + Session Identifier +
|
||||
// 12| (16 Bytes) |
|
||||
// + +
|
||||
// 16| |
|
||||
// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 20| | Destination IPv4 Address |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- - - - - - - - - - - - - - - -+
|
||||
// 24| Destination IPv4 Address cont | |
|
||||
// +- - - - - - - - - - - - - - - - +
|
||||
// 28| Destination IPv6 Address |
|
||||
// + (extension of IPv4 region) +
|
||||
// 32| |
|
||||
// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 36| | |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
|
||||
// . .
|
||||
// . Bundle Payload .
|
||||
// . .
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|
||||
func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error) {
|
||||
ipv6 := s.Dest.Addr().Is6()
|
||||
var flags byte
|
||||
if s.Traced {
|
||||
flags |= sessionRegistrationFlagsTracedMask
|
||||
}
|
||||
hasPayload := len(s.Payload) > 0
|
||||
if hasPayload {
|
||||
flags |= sessionRegistrationFlagsBundledMask
|
||||
}
|
||||
var maxPayloadLen int
|
||||
if ipv6 {
|
||||
maxPayloadLen = maxDatagramLen - sessionRegistrationIPv6DatagramHeaderLen
|
||||
flags |= sessionRegistrationFlagsIPMask
|
||||
} else {
|
||||
maxPayloadLen = maxDatagramLen - sessionRegistrationIPv4DatagramHeaderLen
|
||||
}
|
||||
// Make sure that the payload being bundled can actually fit in the payload destination
|
||||
if len(s.Payload) > maxPayloadLen {
|
||||
return nil, wrapMarshalErr(ErrDatagramPayloadTooLarge)
|
||||
}
|
||||
// Allocate the buffer with the right size for the destination IP family
|
||||
if ipv6 {
|
||||
data = make([]byte, sessionRegistrationIPv6DatagramHeaderLen+len(s.Payload))
|
||||
} else {
|
||||
data = make([]byte, sessionRegistrationIPv4DatagramHeaderLen+len(s.Payload))
|
||||
}
|
||||
data[0] = byte(UDPSessionRegistrationType)
|
||||
data[1] = byte(flags)
|
||||
binary.BigEndian.PutUint16(data[2:4], s.Dest.Port())
|
||||
binary.BigEndian.PutUint16(data[4:6], uint16(s.IdleDurationHint.Seconds()))
|
||||
err = s.RequestID.MarshalBinaryTo(data[6:22])
|
||||
if err != nil {
|
||||
return nil, wrapMarshalErr(err)
|
||||
}
|
||||
var end int
|
||||
if ipv6 {
|
||||
copy(data[22:38], s.Dest.Addr().AsSlice())
|
||||
end = 38
|
||||
} else {
|
||||
copy(data[22:26], s.Dest.Addr().AsSlice())
|
||||
end = 26
|
||||
}
|
||||
|
||||
if hasPayload {
|
||||
copy(data[end:], s.Payload)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *UDPSessionRegistrationDatagram) UnmarshalBinary(data []byte) error {
|
||||
datagramType, err := parseDatagramType(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if datagramType != UDPSessionRegistrationType {
|
||||
return wrapUnmarshalErr(ErrInvalidDatagramType)
|
||||
}
|
||||
|
||||
requestID, err := RequestIDFromSlice(data[6:22])
|
||||
if err != nil {
|
||||
return wrapUnmarshalErr(err)
|
||||
}
|
||||
|
||||
traced := (data[1] & sessionRegistrationFlagsTracedMask) == sessionRegistrationFlagsTracedMask
|
||||
bundled := (data[1] & sessionRegistrationFlagsBundledMask) == sessionRegistrationFlagsBundledMask
|
||||
ipv6 := (data[1] & sessionRegistrationFlagsIPMask) == sessionRegistrationFlagsIPMask
|
||||
|
||||
port := binary.BigEndian.Uint16(data[2:4])
|
||||
var datagramHeaderSize int
|
||||
var dest netip.AddrPort
|
||||
if ipv6 {
|
||||
datagramHeaderSize = sessionRegistrationIPv6DatagramHeaderLen
|
||||
dest = netip.AddrPortFrom(netip.AddrFrom16([16]byte(data[22:38])), port)
|
||||
} else {
|
||||
datagramHeaderSize = sessionRegistrationIPv4DatagramHeaderLen
|
||||
dest = netip.AddrPortFrom(netip.AddrFrom4([4]byte(data[22:26])), port)
|
||||
}
|
||||
|
||||
idle := time.Duration(binary.BigEndian.Uint16(data[4:6])) * time.Second
|
||||
|
||||
var payload []byte
|
||||
if bundled && len(data) >= datagramHeaderSize && len(data[datagramHeaderSize:]) > 0 {
|
||||
payload = data[datagramHeaderSize:]
|
||||
}
|
||||
|
||||
*s = UDPSessionRegistrationDatagram{
|
||||
RequestID: requestID,
|
||||
Dest: dest,
|
||||
Traced: traced,
|
||||
IdleDurationHint: idle,
|
||||
Payload: payload,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UDPSessionPayloadDatagram provides the payload for a session to be send to either the origin or the client.
|
||||
type UDPSessionPayloadDatagram struct {
|
||||
RequestID RequestID
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
const (
|
||||
datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen
|
||||
|
||||
// The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram]
|
||||
maxPayloadPlusHeaderLen = maxDatagramLen - datagramPayloadHeaderLen
|
||||
)
|
||||
|
||||
// The datagram structure for UDPSessionPayloadDatagram is:
|
||||
//
|
||||
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 0| Type | |
|
||||
// +-+-+-+-+-+-+-+-+ +
|
||||
// 4| |
|
||||
// + +
|
||||
// 8| Session Identifier |
|
||||
// + (16 Bytes) +
|
||||
// 12| |
|
||||
// + +-+-+-+-+-+-+-+-+
|
||||
// 16| | |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
|
||||
// . .
|
||||
// . Payload .
|
||||
// . .
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|
||||
// MarshalPayloadHeaderTo provides a way to insert the Session Payload header into an already existing byte slice
|
||||
// without having to allocate and copy the payload into the destination.
|
||||
//
|
||||
// This method should be used in-place of MarshalBinary which will allocate in-place the required byte array to return.
|
||||
func MarshalPayloadHeaderTo(requestID RequestID, payload []byte) error {
|
||||
if len(payload) < 17 {
|
||||
return wrapMarshalErr(ErrDatagramPayloadHeaderTooSmall)
|
||||
}
|
||||
payload[0] = byte(UDPSessionPayloadType)
|
||||
return requestID.MarshalBinaryTo(payload[1:17])
|
||||
}
|
||||
|
||||
func (s *UDPSessionPayloadDatagram) UnmarshalBinary(data []byte) error {
|
||||
datagramType, err := parseDatagramType(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if datagramType != UDPSessionPayloadType {
|
||||
return wrapUnmarshalErr(ErrInvalidDatagramType)
|
||||
}
|
||||
|
||||
// Make sure that the slice provided is the right size to be parsed.
|
||||
if len(data) < 17 || len(data) > maxPayloadPlusHeaderLen {
|
||||
return wrapUnmarshalErr(ErrDatagramPayloadInvalidSize)
|
||||
}
|
||||
|
||||
requestID, err := RequestIDFromSlice(data[1:17])
|
||||
if err != nil {
|
||||
return wrapUnmarshalErr(err)
|
||||
}
|
||||
|
||||
*s = UDPSessionPayloadDatagram{
|
||||
RequestID: requestID,
|
||||
Payload: data[17:],
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UDPSessionRegistrationResponseDatagram is used to either return a successful registration or error to the client
|
||||
// that requested the registration of a UDP session.
|
||||
type UDPSessionRegistrationResponseDatagram struct {
|
||||
RequestID RequestID
|
||||
ResponseType SessionRegistrationResp
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
const (
|
||||
datagramRespTypeLen = 1
|
||||
datagramRespErrMsgLen = 2
|
||||
|
||||
datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen
|
||||
|
||||
// The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram].
|
||||
maxResponseErrorMessageLen = maxDatagramLen - datagramSessionRegistrationResponseLen
|
||||
)
|
||||
|
||||
// SessionRegistrationResp represents all of the responses that a UDP session registration response
|
||||
// can return back to the client.
|
||||
type SessionRegistrationResp byte
|
||||
|
||||
const (
|
||||
// Session was received and is ready to proxy.
|
||||
ResponseOk SessionRegistrationResp = 0x00
|
||||
// Session registration was unable to reach the requested origin destination.
|
||||
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
|
||||
// Session registration was unable to bind to a local UDP socket.
|
||||
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
|
||||
// Session registration failed with an unexpected error but provided a message.
|
||||
ResponseErrorWithMsg SessionRegistrationResp = 0xff
|
||||
)
|
||||
|
||||
// The datagram structure for UDPSessionRegistrationResponseDatagram is:
|
||||
//
|
||||
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 0| Type | Resp Type | |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +
|
||||
// 4| |
|
||||
// + Session Identifier +
|
||||
// 8| (16 Bytes) |
|
||||
// + +
|
||||
// 12| |
|
||||
// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// 16| | Error Length |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// . .
|
||||
// . .
|
||||
// . .
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|
||||
func (s *UDPSessionRegistrationResponseDatagram) MarshalBinary() (data []byte, err error) {
|
||||
if len(s.ErrorMsg) > maxResponseErrorMessageLen {
|
||||
return nil, wrapMarshalErr(ErrDatagramResponseMsgInvalidSize)
|
||||
}
|
||||
errMsgLen := uint16(len(s.ErrorMsg))
|
||||
|
||||
data = make([]byte, datagramSessionRegistrationResponseLen+errMsgLen)
|
||||
data[0] = byte(UDPSessionRegistrationResponseType)
|
||||
data[1] = byte(s.ResponseType)
|
||||
err = s.RequestID.MarshalBinaryTo(data[2:18])
|
||||
if err != nil {
|
||||
return nil, wrapMarshalErr(err)
|
||||
}
|
||||
|
||||
if errMsgLen > 0 {
|
||||
binary.BigEndian.PutUint16(data[18:20], errMsgLen)
|
||||
copy(data[20:], []byte(s.ErrorMsg))
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *UDPSessionRegistrationResponseDatagram) UnmarshalBinary(data []byte) error {
|
||||
datagramType, err := parseDatagramType(data)
|
||||
if err != nil {
|
||||
return wrapUnmarshalErr(err)
|
||||
}
|
||||
if datagramType != UDPSessionRegistrationResponseType {
|
||||
return wrapUnmarshalErr(ErrInvalidDatagramType)
|
||||
}
|
||||
|
||||
if len(data) < datagramSessionRegistrationResponseLen {
|
||||
return wrapUnmarshalErr(ErrDatagramResponseInvalidSize)
|
||||
}
|
||||
|
||||
respType := SessionRegistrationResp(data[1])
|
||||
|
||||
requestID, err := RequestIDFromSlice(data[2:18])
|
||||
if err != nil {
|
||||
return wrapUnmarshalErr(err)
|
||||
}
|
||||
|
||||
errMsgLen := binary.BigEndian.Uint16(data[18:20])
|
||||
if errMsgLen > maxResponseErrorMessageLen {
|
||||
return wrapUnmarshalErr(ErrDatagramResponseMsgTooLargeMaximum)
|
||||
}
|
||||
|
||||
if len(data[20:]) < int(errMsgLen) {
|
||||
return wrapUnmarshalErr(ErrDatagramResponseMsgTooLargeDatagram)
|
||||
}
|
||||
|
||||
var errMsg string
|
||||
if errMsgLen > 0 {
|
||||
errMsg = string(data[20:])
|
||||
}
|
||||
|
||||
*s = UDPSessionRegistrationResponseDatagram{
|
||||
RequestID: requestID,
|
||||
ResponseType: respType,
|
||||
ErrorMsg: errMsg,
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package v3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidDatagramType error = errors.New("invalid datagram type expected")
|
||||
ErrDatagramHeaderTooSmall error = fmt.Errorf("datagram should have at least %d bytes", datagramTypeLen)
|
||||
ErrDatagramPayloadTooLarge error = errors.New("payload length is too large to be bundled in datagram")
|
||||
ErrDatagramPayloadHeaderTooSmall error = errors.New("payload length is too small to fit the datagram header")
|
||||
ErrDatagramPayloadInvalidSize error = errors.New("datagram provided is an invalid size")
|
||||
ErrDatagramResponseMsgInvalidSize error = errors.New("datagram response message is an invalid size")
|
||||
ErrDatagramResponseInvalidSize error = errors.New("datagram response is an invalid size")
|
||||
ErrDatagramResponseMsgTooLargeMaximum error = fmt.Errorf("datagram response error message length exceeds the length of the datagram maximum: %d", maxResponseErrorMessageLen)
|
||||
ErrDatagramResponseMsgTooLargeDatagram error = fmt.Errorf("datagram response error message length exceeds the length of the provided datagram")
|
||||
)
|
||||
|
||||
func wrapMarshalErr(err error) error {
|
||||
return fmt.Errorf("datagram marshal error: %w", err)
|
||||
}
|
||||
|
||||
func wrapUnmarshalErr(err error) error {
|
||||
return fmt.Errorf("datagram unmarshal error: %w", err)
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
package v3_test
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
)
|
||||
|
||||
func makePayload(size int) []byte {
|
||||
payload := make([]byte, size)
|
||||
for i := range len(payload) {
|
||||
payload[i] = 0xfc
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
|
||||
payload := makePayload(1254)
|
||||
tests := []*v3.UDPSessionRegistrationDatagram{
|
||||
// Default (IPv4)
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: nil,
|
||||
},
|
||||
// Request ID (max)
|
||||
{
|
||||
RequestID: mustRequestID([16]byte{
|
||||
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
|
||||
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
|
||||
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
|
||||
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
|
||||
}),
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: nil,
|
||||
},
|
||||
// IPv6
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("[fc00::0]:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: nil,
|
||||
},
|
||||
// Traced
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: true,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: nil,
|
||||
},
|
||||
// IdleDurationHint (max)
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 65535 * time.Second,
|
||||
Payload: nil,
|
||||
},
|
||||
// Payload
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: []byte{0xff, 0xaa, 0xcc, 0x44},
|
||||
},
|
||||
// Payload (max: 1254) for IPv4
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: payload,
|
||||
},
|
||||
// Payload (max: 1242) for IPv4
|
||||
{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: payload[:1242],
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
marshaled, err := tt.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := v3.UDPSessionRegistrationDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(marshaled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !compareRegistrationDatagrams(t, tt, &unmarshaled) {
|
||||
t.Errorf("not equal:\n%+v\n%+v", tt, &unmarshaled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionRegistration_MarshalBinary(t *testing.T) {
|
||||
t.Run("idle hint too large", func(t *testing.T) {
|
||||
// idle hint duration overflows back to 1
|
||||
datagram := &v3.UDPSessionRegistrationDatagram{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 65537 * time.Second,
|
||||
Payload: nil,
|
||||
}
|
||||
expected := &v3.UDPSessionRegistrationDatagram{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 1 * time.Second,
|
||||
Payload: nil,
|
||||
}
|
||||
marshaled, err := datagram.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := v3.UDPSessionRegistrationDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(marshaled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !compareRegistrationDatagrams(t, expected, &unmarshaled) {
|
||||
t.Errorf("not equal:\n%+v\n%+v", expected, &unmarshaled)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTypeUnmarshalErrors(t *testing.T) {
|
||||
t.Run("invalid length", func(t *testing.T) {
|
||||
d1 := v3.UDPSessionRegistrationDatagram{}
|
||||
err := d1.UnmarshalBinary([]byte{})
|
||||
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
|
||||
t.Errorf("expected invalid length to throw error")
|
||||
}
|
||||
|
||||
d2 := v3.UDPSessionPayloadDatagram{}
|
||||
err = d2.UnmarshalBinary([]byte{})
|
||||
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
|
||||
t.Errorf("expected invalid length to throw error")
|
||||
}
|
||||
|
||||
d3 := v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err = d3.UnmarshalBinary([]byte{})
|
||||
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
|
||||
t.Errorf("expected invalid length to throw error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid types", func(t *testing.T) {
|
||||
d1 := v3.UDPSessionRegistrationDatagram{}
|
||||
err := d1.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationResponseType)})
|
||||
if !errors.Is(err, v3.ErrInvalidDatagramType) {
|
||||
t.Errorf("expected invalid type to throw error")
|
||||
}
|
||||
|
||||
d2 := v3.UDPSessionPayloadDatagram{}
|
||||
err = d2.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationType)})
|
||||
if !errors.Is(err, v3.ErrInvalidDatagramType) {
|
||||
t.Errorf("expected invalid type to throw error")
|
||||
}
|
||||
|
||||
d3 := v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err = d3.UnmarshalBinary([]byte{byte(v3.UDPSessionPayloadType)})
|
||||
if !errors.Is(err, v3.ErrInvalidDatagramType) {
|
||||
t.Errorf("expected invalid type to throw error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionPayload(t *testing.T) {
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
payload := makePayload(128)
|
||||
err := v3.MarshalPayloadHeaderTo(testRequestID, payload[0:17])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := v3.UDPSessionPayloadDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(payload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, testRequestID, unmarshaled.RequestID)
|
||||
require.Equal(t, payload[17:], unmarshaled.Payload)
|
||||
})
|
||||
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
payload := makePayload(17)
|
||||
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := v3.UDPSessionPayloadDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(payload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, testRequestID, unmarshaled.RequestID)
|
||||
require.Equal(t, payload[17:], unmarshaled.Payload)
|
||||
})
|
||||
|
||||
t.Run("header size too small", func(t *testing.T) {
|
||||
payload := makePayload(16)
|
||||
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
|
||||
if !errors.Is(err, v3.ErrDatagramPayloadHeaderTooSmall) {
|
||||
t.Errorf("expected an error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("payload size too small", func(t *testing.T) {
|
||||
payload := makePayload(17)
|
||||
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := v3.UDPSessionPayloadDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(payload[:16])
|
||||
if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) {
|
||||
t.Errorf("expected an error: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("payload size too large", func(t *testing.T) {
|
||||
datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed
|
||||
err := v3.MarshalPayloadHeaderTo(testRequestID, datagram)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := v3.UDPSessionPayloadDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(datagram[:])
|
||||
if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) {
|
||||
t.Errorf("expected an error: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionRegistrationResponse(t *testing.T) {
|
||||
validRespTypes := []v3.SessionRegistrationResp{
|
||||
v3.ResponseOk,
|
||||
v3.ResponseDestinationUnreachable,
|
||||
v3.ResponseUnableToBindSocket,
|
||||
v3.ResponseErrorWithMsg,
|
||||
}
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
for _, responseType := range validRespTypes {
|
||||
datagram := &v3.UDPSessionRegistrationResponseDatagram{
|
||||
RequestID: testRequestID,
|
||||
ResponseType: responseType,
|
||||
ErrorMsg: "test",
|
||||
}
|
||||
marshaled, err := datagram.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(marshaled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, datagram, unmarshaled)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unsupported resp type is valid", func(t *testing.T) {
|
||||
datagram := &v3.UDPSessionRegistrationResponseDatagram{
|
||||
RequestID: testRequestID,
|
||||
ResponseType: v3.SessionRegistrationResp(0xfc),
|
||||
ErrorMsg: "",
|
||||
}
|
||||
marshaled, err := datagram.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err = unmarshaled.UnmarshalBinary(marshaled)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
require.Equal(t, datagram, unmarshaled)
|
||||
})
|
||||
|
||||
t.Run("too small to unmarshal", func(t *testing.T) {
|
||||
payload := makePayload(17)
|
||||
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
|
||||
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err := unmarshaled.UnmarshalBinary(payload)
|
||||
if !errors.Is(err, v3.ErrDatagramResponseInvalidSize) {
|
||||
t.Errorf("expected an error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error message too long", func(t *testing.T) {
|
||||
message := ""
|
||||
for i := 0; i < 1280; i++ {
|
||||
message += "a"
|
||||
}
|
||||
datagram := &v3.UDPSessionRegistrationResponseDatagram{
|
||||
RequestID: testRequestID,
|
||||
ResponseType: v3.SessionRegistrationResp(0xfc),
|
||||
ErrorMsg: message,
|
||||
}
|
||||
_, err := datagram.MarshalBinary()
|
||||
if !errors.Is(err, v3.ErrDatagramResponseMsgInvalidSize) {
|
||||
t.Errorf("expected an error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error message too large to unmarshal", func(t *testing.T) {
|
||||
payload := makePayload(1280)
|
||||
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
|
||||
binary.BigEndian.PutUint16(payload[18:20], 1280) // larger than the datagram size could be
|
||||
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err := unmarshaled.UnmarshalBinary(payload)
|
||||
if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeMaximum) {
|
||||
t.Errorf("expected an error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error message larger than provided buffer", func(t *testing.T) {
|
||||
payload := makePayload(1000)
|
||||
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
|
||||
binary.BigEndian.PutUint16(payload[18:20], 1001) // larger than the datagram size provided
|
||||
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
|
||||
err := unmarshaled.UnmarshalBinary(payload)
|
||||
if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeDatagram) {
|
||||
t.Errorf("expected an error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func compareRegistrationDatagrams(t *testing.T, l *v3.UDPSessionRegistrationDatagram, r *v3.UDPSessionRegistrationDatagram) bool {
|
||||
require.Equal(t, l.Payload, r.Payload)
|
||||
return l.RequestID == r.RequestID &&
|
||||
l.Dest == r.Dest &&
|
||||
l.IdleDurationHint == r.IdleDurationHint &&
|
||||
l.Traced == r.Traced
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package v3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
datagramRequestIdLen = 16
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidRequestIDLen is returned when the provided request id can not be parsed from the provided byte slice.
|
||||
ErrInvalidRequestIDLen error = errors.New("invalid request id length provided")
|
||||
// ErrInvalidPayloadDestLen is returned when the provided destination byte slice cannot fit the whole request id.
|
||||
ErrInvalidPayloadDestLen error = errors.New("invalid payload size provided")
|
||||
)
|
||||
|
||||
// RequestID is the request-id-v2 identifier, it is used to distinguish between specific flows or sessions proxied
|
||||
// from the edge to cloudflared.
|
||||
type RequestID uint128
|
||||
|
||||
type uint128 struct {
|
||||
hi uint64
|
||||
lo uint64
|
||||
}
|
||||
|
||||
// RequestIDFromSlice reads a request ID from a byte slice.
|
||||
func RequestIDFromSlice(data []byte) (RequestID, error) {
|
||||
if len(data) != datagramRequestIdLen {
|
||||
return RequestID{}, ErrInvalidRequestIDLen
|
||||
}
|
||||
|
||||
return RequestID{
|
||||
hi: binary.BigEndian.Uint64(data[:8]),
|
||||
lo: binary.BigEndian.Uint64(data[8:]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compare returns an integer comparing two IPs.
|
||||
// The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2.
|
||||
// The definition of "less than" is the same as the [RequestID.Less] method.
|
||||
func (id RequestID) Compare(id2 RequestID) int {
|
||||
hi1, hi2 := id.hi, id2.hi
|
||||
if hi1 < hi2 {
|
||||
return -1
|
||||
}
|
||||
if hi1 > hi2 {
|
||||
return 1
|
||||
}
|
||||
lo1, lo2 := id.lo, id2.lo
|
||||
if lo1 < lo2 {
|
||||
return -1
|
||||
}
|
||||
if lo1 > lo2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Less reports whether id sorts before id2.
|
||||
func (id RequestID) Less(id2 RequestID) bool { return id.Compare(id2) == -1 }
|
||||
|
||||
// MarshalBinaryTo writes the id to the provided destination byte slice; the byte slice must be of at least size 16.
|
||||
func (id RequestID) MarshalBinaryTo(data []byte) error {
|
||||
if len(data) < datagramRequestIdLen {
|
||||
return ErrInvalidPayloadDestLen
|
||||
}
|
||||
binary.BigEndian.PutUint64(data[:8], id.hi)
|
||||
binary.BigEndian.PutUint64(data[8:], id.lo)
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package v3_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testRequestIDBytes = [16]byte{
|
||||
0x00, 0x11, 0x22, 0x33,
|
||||
0x44, 0x55, 0x66, 0x77,
|
||||
0x88, 0x99, 0xaa, 0xbb,
|
||||
0xcc, 0xdd, 0xee, 0xff,
|
||||
}
|
||||
testRequestID = mustRequestID(testRequestIDBytes)
|
||||
)
|
||||
|
||||
func mustRequestID(data [16]byte) v3.RequestID {
|
||||
id, err := v3.RequestIDFromSlice(data[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestRequestIDParsing(t *testing.T) {
|
||||
buf1 := make([]byte, 16)
|
||||
n, err := rand.Read(buf1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 16 {
|
||||
t.Fatalf("did not read 16 bytes: %d", n)
|
||||
}
|
||||
id, err := v3.RequestIDFromSlice(buf1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
buf2 := make([]byte, 16)
|
||||
err = id.MarshalBinaryTo(buf2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !slices.Equal(buf1, buf2) {
|
||||
t.Fatalf("buf1 != buf2: %+v %+v", buf1, buf2)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue