cloudflared-mirror/quic/v3/datagram_test.go

353 lines
10 KiB
Go

package v3_test
import (
"encoding/binary"
"errors"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
func makePayload(size int) []byte {
payload := make([]byte, size)
for i := range len(payload) {
payload[i] = 0xfc
}
return payload
}
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
payload := makePayload(1254)
tests := []*v3.UDPSessionRegistrationDatagram{
// Default (IPv4)
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// Request ID (max)
{
RequestID: mustRequestID([16]byte{
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0),
}),
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// IPv6
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("[fc00::0]:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// Traced
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: true,
IdleDurationHint: 5 * time.Second,
Payload: nil,
},
// IdleDurationHint (max)
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 65535 * time.Second,
Payload: nil,
},
// Payload
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: []byte{0xff, 0xaa, 0xcc, 0x44},
},
// Payload (max: 1254) for IPv4
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: payload,
},
// Payload (max: 1242) for IPv4
{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: payload[:1242],
},
}
for _, tt := range tests {
marshaled, err := tt.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionRegistrationDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
if !compareRegistrationDatagrams(t, tt, &unmarshaled) {
t.Errorf("not equal:\n%+v\n%+v", tt, &unmarshaled)
}
}
}
func TestSessionRegistration_MarshalBinary(t *testing.T) {
t.Run("idle hint too large", func(t *testing.T) {
// idle hint duration overflows back to 1
datagram := &v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 65537 * time.Second,
Payload: nil,
}
expected := &v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("1.1.1.1:8080"),
Traced: false,
IdleDurationHint: 1 * time.Second,
Payload: nil,
}
marshaled, err := datagram.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionRegistrationDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
if !compareRegistrationDatagrams(t, expected, &unmarshaled) {
t.Errorf("not equal:\n%+v\n%+v", expected, &unmarshaled)
}
})
}
func TestTypeUnmarshalErrors(t *testing.T) {
t.Run("invalid length", func(t *testing.T) {
d1 := v3.UDPSessionRegistrationDatagram{}
err := d1.UnmarshalBinary([]byte{})
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
t.Errorf("expected invalid length to throw error")
}
d2 := v3.UDPSessionPayloadDatagram{}
err = d2.UnmarshalBinary([]byte{})
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
t.Errorf("expected invalid length to throw error")
}
d3 := v3.UDPSessionRegistrationResponseDatagram{}
err = d3.UnmarshalBinary([]byte{})
if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) {
t.Errorf("expected invalid length to throw error")
}
})
t.Run("invalid types", func(t *testing.T) {
d1 := v3.UDPSessionRegistrationDatagram{}
err := d1.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationResponseType)})
if !errors.Is(err, v3.ErrInvalidDatagramType) {
t.Errorf("expected invalid type to throw error")
}
d2 := v3.UDPSessionPayloadDatagram{}
err = d2.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationType)})
if !errors.Is(err, v3.ErrInvalidDatagramType) {
t.Errorf("expected invalid type to throw error")
}
d3 := v3.UDPSessionRegistrationResponseDatagram{}
err = d3.UnmarshalBinary([]byte{byte(v3.UDPSessionPayloadType)})
if !errors.Is(err, v3.ErrInvalidDatagramType) {
t.Errorf("expected invalid type to throw error")
}
})
}
func TestSessionPayload(t *testing.T) {
t.Run("basic", func(t *testing.T) {
payload := makePayload(128)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload[0:17])
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(payload)
if err != nil {
t.Error(err)
}
require.Equal(t, testRequestID, unmarshaled.RequestID)
require.Equal(t, payload[17:], unmarshaled.Payload)
})
t.Run("empty", func(t *testing.T) {
payload := makePayload(17)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(payload)
if err != nil {
t.Error(err)
}
require.Equal(t, testRequestID, unmarshaled.RequestID)
require.Equal(t, payload[17:], unmarshaled.Payload)
})
t.Run("header size too small", func(t *testing.T) {
payload := makePayload(16)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
if !errors.Is(err, v3.ErrDatagramPayloadHeaderTooSmall) {
t.Errorf("expected an error")
}
})
t.Run("payload size too small", func(t *testing.T) {
payload := makePayload(17)
err := v3.MarshalPayloadHeaderTo(testRequestID, payload)
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(payload[:16])
if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) {
t.Errorf("expected an error: %s", err)
}
})
t.Run("payload size too large", func(t *testing.T) {
datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed
err := v3.MarshalPayloadHeaderTo(testRequestID, datagram)
if err != nil {
t.Error(err)
}
unmarshaled := v3.UDPSessionPayloadDatagram{}
err = unmarshaled.UnmarshalBinary(datagram[:])
if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) {
t.Errorf("expected an error: %s", err)
}
})
}
func TestSessionRegistrationResponse(t *testing.T) {
validRespTypes := []v3.SessionRegistrationResp{
v3.ResponseOk,
v3.ResponseDestinationUnreachable,
v3.ResponseUnableToBindSocket,
v3.ResponseErrorWithMsg,
}
t.Run("basic", func(t *testing.T) {
for _, responseType := range validRespTypes {
datagram := &v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: responseType,
ErrorMsg: "test",
}
marshaled, err := datagram.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
require.Equal(t, datagram, unmarshaled)
}
})
t.Run("unsupported resp type is valid", func(t *testing.T) {
datagram := &v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: v3.SessionRegistrationResp(0xfc),
ErrorMsg: "",
}
marshaled, err := datagram.MarshalBinary()
if err != nil {
t.Error(err)
}
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err = unmarshaled.UnmarshalBinary(marshaled)
if err != nil {
t.Error(err)
}
require.Equal(t, datagram, unmarshaled)
})
t.Run("too small to unmarshal", func(t *testing.T) {
payload := makePayload(17)
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err := unmarshaled.UnmarshalBinary(payload)
if !errors.Is(err, v3.ErrDatagramResponseInvalidSize) {
t.Errorf("expected an error")
}
})
t.Run("error message too long", func(t *testing.T) {
message := ""
for i := 0; i < 1280; i++ {
message += "a"
}
datagram := &v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: v3.SessionRegistrationResp(0xfc),
ErrorMsg: message,
}
_, err := datagram.MarshalBinary()
if !errors.Is(err, v3.ErrDatagramResponseMsgInvalidSize) {
t.Errorf("expected an error")
}
})
t.Run("error message too large to unmarshal", func(t *testing.T) {
payload := makePayload(1280)
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
binary.BigEndian.PutUint16(payload[18:20], 1280) // larger than the datagram size could be
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err := unmarshaled.UnmarshalBinary(payload)
if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeMaximum) {
t.Errorf("expected an error: %v", err)
}
})
t.Run("error message larger than provided buffer", func(t *testing.T) {
payload := makePayload(1000)
payload[0] = byte(v3.UDPSessionRegistrationResponseType)
binary.BigEndian.PutUint16(payload[18:20], 1001) // larger than the datagram size provided
unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{}
err := unmarshaled.UnmarshalBinary(payload)
if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeDatagram) {
t.Errorf("expected an error: %v", err)
}
})
}
func compareRegistrationDatagrams(t *testing.T, l *v3.UDPSessionRegistrationDatagram, r *v3.UDPSessionRegistrationDatagram) bool {
require.Equal(t, l.Payload, r.Payload)
return l.RequestID == r.RequestID &&
l.Dest == r.Dest &&
l.IdleDurationHint == r.IdleDurationHint &&
l.Traced == r.Traced
}