diff --git a/quic/v3/datagram.go b/quic/v3/datagram.go new file mode 100644 index 00000000..d5c2ac1b --- /dev/null +++ b/quic/v3/datagram.go @@ -0,0 +1,372 @@ +package v3 + +import ( + "encoding/binary" + "net/netip" + "time" +) + +type DatagramType byte + +const ( + // UDP Registration + UDPSessionRegistrationType DatagramType = 0x0 + // UDP Session Payload + UDPSessionPayloadType DatagramType = 0x1 + // DatagramTypeICMP (supporting both ICMPv4 and ICMPv6) + ICMPType DatagramType = 0x2 + // UDP Session Registration Response + UDPSessionRegistrationResponseType DatagramType = 0x3 +) + +const ( + // Total number of bytes representing the [DatagramType] + datagramTypeLen = 1 + + // 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12 + maxDatagramLen = 1280 +) + +func parseDatagramType(data []byte) (DatagramType, error) { + if len(data) < datagramTypeLen { + return 0, ErrDatagramHeaderTooSmall + } + return DatagramType(data[0]), nil +} + +// UDPSessionRegistrationDatagram handles a request to initialize a UDP session on the remote client. +type UDPSessionRegistrationDatagram struct { + RequestID RequestID + Dest netip.AddrPort + Traced bool + IdleDurationHint time.Duration + Payload []byte +} + +const ( + sessionRegistrationFlagsIPMask byte = 0b0000_0001 + sessionRegistrationFlagsTracedMask byte = 0b0000_0010 + sessionRegistrationFlagsBundledMask byte = 0b0000_0100 + + sessionRegistrationIPv4DatagramHeaderLen = datagramTypeLen + + 1 + // Flag length + 2 + // Destination port length + 2 + // Idle duration seconds length + datagramRequestIdLen + // Request ID length + 4 // IPv4 address length + + // The IPv4 and IPv6 address share space, so adding 12 to the header length gets the space taken by the IPv6 field. + sessionRegistrationIPv6DatagramHeaderLen = sessionRegistrationIPv4DatagramHeaderLen + 12 +) + +// The datagram structure for UDPSessionRegistrationDatagram is: +// +// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 0| Type | Flags | Destination Port | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 4| Idle Duration Seconds | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// 8| | +// + Session Identifier + +// 12| (16 Bytes) | +// + + +// 16| | +// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 20| | Destination IPv4 Address | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- - - - - - - - - - - - - - - -+ +// 24| Destination IPv4 Address cont | | +// +- - - - - - - - - - - - - - - - + +// 28| Destination IPv6 Address | +// + (extension of IPv4 region) + +// 32| | +// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 36| | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// . . +// . Bundle Payload . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error) { + ipv6 := s.Dest.Addr().Is6() + var flags byte + if s.Traced { + flags |= sessionRegistrationFlagsTracedMask + } + hasPayload := len(s.Payload) > 0 + if hasPayload { + flags |= sessionRegistrationFlagsBundledMask + } + var maxPayloadLen int + if ipv6 { + maxPayloadLen = maxDatagramLen - sessionRegistrationIPv6DatagramHeaderLen + flags |= sessionRegistrationFlagsIPMask + } else { + maxPayloadLen = maxDatagramLen - sessionRegistrationIPv4DatagramHeaderLen + } + // Make sure that the payload being bundled can actually fit in the payload destination + if len(s.Payload) > maxPayloadLen { + return nil, wrapMarshalErr(ErrDatagramPayloadTooLarge) + } + // Allocate the buffer with the right size for the destination IP family + if ipv6 { + data = make([]byte, sessionRegistrationIPv6DatagramHeaderLen+len(s.Payload)) + } else { + data = make([]byte, sessionRegistrationIPv4DatagramHeaderLen+len(s.Payload)) + } + data[0] = byte(UDPSessionRegistrationType) + data[1] = byte(flags) + binary.BigEndian.PutUint16(data[2:4], s.Dest.Port()) + binary.BigEndian.PutUint16(data[4:6], uint16(s.IdleDurationHint.Seconds())) + err = s.RequestID.MarshalBinaryTo(data[6:22]) + if err != nil { + return nil, wrapMarshalErr(err) + } + var end int + if ipv6 { + copy(data[22:38], s.Dest.Addr().AsSlice()) + end = 38 + } else { + copy(data[22:26], s.Dest.Addr().AsSlice()) + end = 26 + } + + if hasPayload { + copy(data[end:], s.Payload) + } + + return data, nil +} + +func (s *UDPSessionRegistrationDatagram) UnmarshalBinary(data []byte) error { + datagramType, err := parseDatagramType(data) + if err != nil { + return err + } + if datagramType != UDPSessionRegistrationType { + return wrapUnmarshalErr(ErrInvalidDatagramType) + } + + requestID, err := RequestIDFromSlice(data[6:22]) + if err != nil { + return wrapUnmarshalErr(err) + } + + traced := (data[1] & sessionRegistrationFlagsTracedMask) == sessionRegistrationFlagsTracedMask + bundled := (data[1] & sessionRegistrationFlagsBundledMask) == sessionRegistrationFlagsBundledMask + ipv6 := (data[1] & sessionRegistrationFlagsIPMask) == sessionRegistrationFlagsIPMask + + port := binary.BigEndian.Uint16(data[2:4]) + var datagramHeaderSize int + var dest netip.AddrPort + if ipv6 { + datagramHeaderSize = sessionRegistrationIPv6DatagramHeaderLen + dest = netip.AddrPortFrom(netip.AddrFrom16([16]byte(data[22:38])), port) + } else { + datagramHeaderSize = sessionRegistrationIPv4DatagramHeaderLen + dest = netip.AddrPortFrom(netip.AddrFrom4([4]byte(data[22:26])), port) + } + + idle := time.Duration(binary.BigEndian.Uint16(data[4:6])) * time.Second + + var payload []byte + if bundled && len(data) >= datagramHeaderSize && len(data[datagramHeaderSize:]) > 0 { + payload = data[datagramHeaderSize:] + } + + *s = UDPSessionRegistrationDatagram{ + RequestID: requestID, + Dest: dest, + Traced: traced, + IdleDurationHint: idle, + Payload: payload, + } + return nil +} + +// UDPSessionPayloadDatagram provides the payload for a session to be send to either the origin or the client. +type UDPSessionPayloadDatagram struct { + RequestID RequestID + Payload []byte +} + +const ( + datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen + + // The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram] + maxPayloadPlusHeaderLen = maxDatagramLen - datagramPayloadHeaderLen +) + +// The datagram structure for UDPSessionPayloadDatagram is: +// +// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 0| Type | | +// +-+-+-+-+-+-+-+-+ + +// 4| | +// + + +// 8| Session Identifier | +// + (16 Bytes) + +// 12| | +// + +-+-+-+-+-+-+-+-+ +// 16| | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// . . +// . Payload . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// MarshalPayloadHeaderTo provides a way to insert the Session Payload header into an already existing byte slice +// without having to allocate and copy the payload into the destination. +// +// This method should be used in-place of MarshalBinary which will allocate in-place the required byte array to return. +func MarshalPayloadHeaderTo(requestID RequestID, payload []byte) error { + if len(payload) < 17 { + return wrapMarshalErr(ErrDatagramPayloadHeaderTooSmall) + } + payload[0] = byte(UDPSessionPayloadType) + return requestID.MarshalBinaryTo(payload[1:17]) +} + +func (s *UDPSessionPayloadDatagram) UnmarshalBinary(data []byte) error { + datagramType, err := parseDatagramType(data) + if err != nil { + return err + } + if datagramType != UDPSessionPayloadType { + return wrapUnmarshalErr(ErrInvalidDatagramType) + } + + // Make sure that the slice provided is the right size to be parsed. + if len(data) < 17 || len(data) > maxPayloadPlusHeaderLen { + return wrapUnmarshalErr(ErrDatagramPayloadInvalidSize) + } + + requestID, err := RequestIDFromSlice(data[1:17]) + if err != nil { + return wrapUnmarshalErr(err) + } + + *s = UDPSessionPayloadDatagram{ + RequestID: requestID, + Payload: data[17:], + } + return nil +} + +// UDPSessionRegistrationResponseDatagram is used to either return a successful registration or error to the client +// that requested the registration of a UDP session. +type UDPSessionRegistrationResponseDatagram struct { + RequestID RequestID + ResponseType SessionRegistrationResp + ErrorMsg string +} + +const ( + datagramRespTypeLen = 1 + datagramRespErrMsgLen = 2 + + datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen + + // The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram]. + maxResponseErrorMessageLen = maxDatagramLen - datagramSessionRegistrationResponseLen +) + +// SessionRegistrationResp represents all of the responses that a UDP session registration response +// can return back to the client. +type SessionRegistrationResp byte + +const ( + // Session was received and is ready to proxy. + ResponseOk SessionRegistrationResp = 0x00 + // Session registration was unable to reach the requested origin destination. + ResponseDestinationUnreachable SessionRegistrationResp = 0x01 + // Session registration was unable to bind to a local UDP socket. + ResponseUnableToBindSocket SessionRegistrationResp = 0x02 + // Session registration failed with an unexpected error but provided a message. + ResponseErrorWithMsg SessionRegistrationResp = 0xff +) + +// The datagram structure for UDPSessionRegistrationResponseDatagram is: +// +// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 0| Type | Resp Type | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// 4| | +// + Session Identifier + +// 8| (16 Bytes) | +// + + +// 12| | +// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 16| | Error Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// . . +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +func (s *UDPSessionRegistrationResponseDatagram) MarshalBinary() (data []byte, err error) { + if len(s.ErrorMsg) > maxResponseErrorMessageLen { + return nil, wrapMarshalErr(ErrDatagramResponseMsgInvalidSize) + } + errMsgLen := uint16(len(s.ErrorMsg)) + + data = make([]byte, datagramSessionRegistrationResponseLen+errMsgLen) + data[0] = byte(UDPSessionRegistrationResponseType) + data[1] = byte(s.ResponseType) + err = s.RequestID.MarshalBinaryTo(data[2:18]) + if err != nil { + return nil, wrapMarshalErr(err) + } + + if errMsgLen > 0 { + binary.BigEndian.PutUint16(data[18:20], errMsgLen) + copy(data[20:], []byte(s.ErrorMsg)) + } + + return data, nil +} + +func (s *UDPSessionRegistrationResponseDatagram) UnmarshalBinary(data []byte) error { + datagramType, err := parseDatagramType(data) + if err != nil { + return wrapUnmarshalErr(err) + } + if datagramType != UDPSessionRegistrationResponseType { + return wrapUnmarshalErr(ErrInvalidDatagramType) + } + + if len(data) < datagramSessionRegistrationResponseLen { + return wrapUnmarshalErr(ErrDatagramResponseInvalidSize) + } + + respType := SessionRegistrationResp(data[1]) + + requestID, err := RequestIDFromSlice(data[2:18]) + if err != nil { + return wrapUnmarshalErr(err) + } + + errMsgLen := binary.BigEndian.Uint16(data[18:20]) + if errMsgLen > maxResponseErrorMessageLen { + return wrapUnmarshalErr(ErrDatagramResponseMsgTooLargeMaximum) + } + + if len(data[20:]) < int(errMsgLen) { + return wrapUnmarshalErr(ErrDatagramResponseMsgTooLargeDatagram) + } + + var errMsg string + if errMsgLen > 0 { + errMsg = string(data[20:]) + } + + *s = UDPSessionRegistrationResponseDatagram{ + RequestID: requestID, + ResponseType: respType, + ErrorMsg: errMsg, + } + return nil +} diff --git a/quic/v3/datagram_errors.go b/quic/v3/datagram_errors.go new file mode 100644 index 00000000..244915db --- /dev/null +++ b/quic/v3/datagram_errors.go @@ -0,0 +1,26 @@ +package v3 + +import ( + "errors" + "fmt" +) + +var ( + ErrInvalidDatagramType error = errors.New("invalid datagram type expected") + ErrDatagramHeaderTooSmall error = fmt.Errorf("datagram should have at least %d bytes", datagramTypeLen) + ErrDatagramPayloadTooLarge error = errors.New("payload length is too large to be bundled in datagram") + ErrDatagramPayloadHeaderTooSmall error = errors.New("payload length is too small to fit the datagram header") + ErrDatagramPayloadInvalidSize error = errors.New("datagram provided is an invalid size") + ErrDatagramResponseMsgInvalidSize error = errors.New("datagram response message is an invalid size") + ErrDatagramResponseInvalidSize error = errors.New("datagram response is an invalid size") + ErrDatagramResponseMsgTooLargeMaximum error = fmt.Errorf("datagram response error message length exceeds the length of the datagram maximum: %d", maxResponseErrorMessageLen) + ErrDatagramResponseMsgTooLargeDatagram error = fmt.Errorf("datagram response error message length exceeds the length of the provided datagram") +) + +func wrapMarshalErr(err error) error { + return fmt.Errorf("datagram marshal error: %w", err) +} + +func wrapUnmarshalErr(err error) error { + return fmt.Errorf("datagram unmarshal error: %w", err) +} diff --git a/quic/v3/datagram_test.go b/quic/v3/datagram_test.go new file mode 100644 index 00000000..b2e77f89 --- /dev/null +++ b/quic/v3/datagram_test.go @@ -0,0 +1,352 @@ +package v3_test + +import ( + "encoding/binary" + "errors" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" + + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +func makePayload(size int) []byte { + payload := make([]byte, size) + for i := range len(payload) { + payload[i] = 0xfc + } + return payload +} + +func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { + payload := makePayload(1254) + tests := []*v3.UDPSessionRegistrationDatagram{ + // Default (IPv4) + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: nil, + }, + // Request ID (max) + { + RequestID: mustRequestID([16]byte{ + ^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0), + ^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0), + ^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0), + ^uint8(0), ^uint8(0), ^uint8(0), ^uint8(0), + }), + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: nil, + }, + // IPv6 + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("[fc00::0]:8080"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: nil, + }, + // Traced + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: true, + IdleDurationHint: 5 * time.Second, + Payload: nil, + }, + // IdleDurationHint (max) + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 65535 * time.Second, + Payload: nil, + }, + // Payload + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: []byte{0xff, 0xaa, 0xcc, 0x44}, + }, + // Payload (max: 1254) for IPv4 + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: payload, + }, + // Payload (max: 1242) for IPv4 + { + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: payload[:1242], + }, + } + for _, tt := range tests { + marshaled, err := tt.MarshalBinary() + if err != nil { + t.Error(err) + } + unmarshaled := v3.UDPSessionRegistrationDatagram{} + err = unmarshaled.UnmarshalBinary(marshaled) + if err != nil { + t.Error(err) + } + if !compareRegistrationDatagrams(t, tt, &unmarshaled) { + t.Errorf("not equal:\n%+v\n%+v", tt, &unmarshaled) + } + } +} + +func TestSessionRegistration_MarshalBinary(t *testing.T) { + t.Run("idle hint too large", func(t *testing.T) { + // idle hint duration overflows back to 1 + datagram := &v3.UDPSessionRegistrationDatagram{ + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 65537 * time.Second, + Payload: nil, + } + expected := &v3.UDPSessionRegistrationDatagram{ + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("1.1.1.1:8080"), + Traced: false, + IdleDurationHint: 1 * time.Second, + Payload: nil, + } + marshaled, err := datagram.MarshalBinary() + if err != nil { + t.Error(err) + } + unmarshaled := v3.UDPSessionRegistrationDatagram{} + err = unmarshaled.UnmarshalBinary(marshaled) + if err != nil { + t.Error(err) + } + if !compareRegistrationDatagrams(t, expected, &unmarshaled) { + t.Errorf("not equal:\n%+v\n%+v", expected, &unmarshaled) + } + }) +} + +func TestTypeUnmarshalErrors(t *testing.T) { + t.Run("invalid length", func(t *testing.T) { + d1 := v3.UDPSessionRegistrationDatagram{} + err := d1.UnmarshalBinary([]byte{}) + if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) { + t.Errorf("expected invalid length to throw error") + } + + d2 := v3.UDPSessionPayloadDatagram{} + err = d2.UnmarshalBinary([]byte{}) + if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) { + t.Errorf("expected invalid length to throw error") + } + + d3 := v3.UDPSessionRegistrationResponseDatagram{} + err = d3.UnmarshalBinary([]byte{}) + if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) { + t.Errorf("expected invalid length to throw error") + } + }) + + t.Run("invalid types", func(t *testing.T) { + d1 := v3.UDPSessionRegistrationDatagram{} + err := d1.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationResponseType)}) + if !errors.Is(err, v3.ErrInvalidDatagramType) { + t.Errorf("expected invalid type to throw error") + } + + d2 := v3.UDPSessionPayloadDatagram{} + err = d2.UnmarshalBinary([]byte{byte(v3.UDPSessionRegistrationType)}) + if !errors.Is(err, v3.ErrInvalidDatagramType) { + t.Errorf("expected invalid type to throw error") + } + + d3 := v3.UDPSessionRegistrationResponseDatagram{} + err = d3.UnmarshalBinary([]byte{byte(v3.UDPSessionPayloadType)}) + if !errors.Is(err, v3.ErrInvalidDatagramType) { + t.Errorf("expected invalid type to throw error") + } + }) +} + +func TestSessionPayload(t *testing.T) { + t.Run("basic", func(t *testing.T) { + payload := makePayload(128) + err := v3.MarshalPayloadHeaderTo(testRequestID, payload[0:17]) + if err != nil { + t.Error(err) + } + unmarshaled := v3.UDPSessionPayloadDatagram{} + err = unmarshaled.UnmarshalBinary(payload) + if err != nil { + t.Error(err) + } + require.Equal(t, testRequestID, unmarshaled.RequestID) + require.Equal(t, payload[17:], unmarshaled.Payload) + }) + + t.Run("empty", func(t *testing.T) { + payload := makePayload(17) + err := v3.MarshalPayloadHeaderTo(testRequestID, payload) + if err != nil { + t.Error(err) + } + unmarshaled := v3.UDPSessionPayloadDatagram{} + err = unmarshaled.UnmarshalBinary(payload) + if err != nil { + t.Error(err) + } + require.Equal(t, testRequestID, unmarshaled.RequestID) + require.Equal(t, payload[17:], unmarshaled.Payload) + }) + + t.Run("header size too small", func(t *testing.T) { + payload := makePayload(16) + err := v3.MarshalPayloadHeaderTo(testRequestID, payload) + if !errors.Is(err, v3.ErrDatagramPayloadHeaderTooSmall) { + t.Errorf("expected an error") + } + }) + + t.Run("payload size too small", func(t *testing.T) { + payload := makePayload(17) + err := v3.MarshalPayloadHeaderTo(testRequestID, payload) + if err != nil { + t.Error(err) + } + unmarshaled := v3.UDPSessionPayloadDatagram{} + err = unmarshaled.UnmarshalBinary(payload[:16]) + if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) { + t.Errorf("expected an error: %s", err) + } + }) + + t.Run("payload size too large", func(t *testing.T) { + datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed + err := v3.MarshalPayloadHeaderTo(testRequestID, datagram) + if err != nil { + t.Error(err) + } + unmarshaled := v3.UDPSessionPayloadDatagram{} + err = unmarshaled.UnmarshalBinary(datagram[:]) + if !errors.Is(err, v3.ErrDatagramPayloadInvalidSize) { + t.Errorf("expected an error: %s", err) + } + }) +} + +func TestSessionRegistrationResponse(t *testing.T) { + validRespTypes := []v3.SessionRegistrationResp{ + v3.ResponseOk, + v3.ResponseDestinationUnreachable, + v3.ResponseUnableToBindSocket, + v3.ResponseErrorWithMsg, + } + t.Run("basic", func(t *testing.T) { + for _, responseType := range validRespTypes { + datagram := &v3.UDPSessionRegistrationResponseDatagram{ + RequestID: testRequestID, + ResponseType: responseType, + ErrorMsg: "test", + } + marshaled, err := datagram.MarshalBinary() + if err != nil { + t.Error(err) + } + unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{} + err = unmarshaled.UnmarshalBinary(marshaled) + if err != nil { + t.Error(err) + } + require.Equal(t, datagram, unmarshaled) + } + }) + + t.Run("unsupported resp type is valid", func(t *testing.T) { + datagram := &v3.UDPSessionRegistrationResponseDatagram{ + RequestID: testRequestID, + ResponseType: v3.SessionRegistrationResp(0xfc), + ErrorMsg: "", + } + marshaled, err := datagram.MarshalBinary() + if err != nil { + t.Error(err) + } + unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{} + err = unmarshaled.UnmarshalBinary(marshaled) + if err != nil { + t.Error(err) + } + require.Equal(t, datagram, unmarshaled) + }) + + t.Run("too small to unmarshal", func(t *testing.T) { + payload := makePayload(17) + payload[0] = byte(v3.UDPSessionRegistrationResponseType) + unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{} + err := unmarshaled.UnmarshalBinary(payload) + if !errors.Is(err, v3.ErrDatagramResponseInvalidSize) { + t.Errorf("expected an error") + } + }) + + t.Run("error message too long", func(t *testing.T) { + message := "" + for i := 0; i < 1280; i++ { + message += "a" + } + datagram := &v3.UDPSessionRegistrationResponseDatagram{ + RequestID: testRequestID, + ResponseType: v3.SessionRegistrationResp(0xfc), + ErrorMsg: message, + } + _, err := datagram.MarshalBinary() + if !errors.Is(err, v3.ErrDatagramResponseMsgInvalidSize) { + t.Errorf("expected an error") + } + }) + + t.Run("error message too large to unmarshal", func(t *testing.T) { + payload := makePayload(1280) + payload[0] = byte(v3.UDPSessionRegistrationResponseType) + binary.BigEndian.PutUint16(payload[18:20], 1280) // larger than the datagram size could be + unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{} + err := unmarshaled.UnmarshalBinary(payload) + if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeMaximum) { + t.Errorf("expected an error: %v", err) + } + }) + + t.Run("error message larger than provided buffer", func(t *testing.T) { + payload := makePayload(1000) + payload[0] = byte(v3.UDPSessionRegistrationResponseType) + binary.BigEndian.PutUint16(payload[18:20], 1001) // larger than the datagram size provided + unmarshaled := &v3.UDPSessionRegistrationResponseDatagram{} + err := unmarshaled.UnmarshalBinary(payload) + if !errors.Is(err, v3.ErrDatagramResponseMsgTooLargeDatagram) { + t.Errorf("expected an error: %v", err) + } + }) +} + +func compareRegistrationDatagrams(t *testing.T, l *v3.UDPSessionRegistrationDatagram, r *v3.UDPSessionRegistrationDatagram) bool { + require.Equal(t, l.Payload, r.Payload) + return l.RequestID == r.RequestID && + l.Dest == r.Dest && + l.IdleDurationHint == r.IdleDurationHint && + l.Traced == r.Traced +} diff --git a/quic/v3/request.go b/quic/v3/request.go new file mode 100644 index 00000000..29509e83 --- /dev/null +++ b/quic/v3/request.go @@ -0,0 +1,72 @@ +package v3 + +import ( + "encoding/binary" + "errors" +) + +const ( + datagramRequestIdLen = 16 +) + +var ( + // ErrInvalidRequestIDLen is returned when the provided request id can not be parsed from the provided byte slice. + ErrInvalidRequestIDLen error = errors.New("invalid request id length provided") + // ErrInvalidPayloadDestLen is returned when the provided destination byte slice cannot fit the whole request id. + ErrInvalidPayloadDestLen error = errors.New("invalid payload size provided") +) + +// RequestID is the request-id-v2 identifier, it is used to distinguish between specific flows or sessions proxied +// from the edge to cloudflared. +type RequestID uint128 + +type uint128 struct { + hi uint64 + lo uint64 +} + +// RequestIDFromSlice reads a request ID from a byte slice. +func RequestIDFromSlice(data []byte) (RequestID, error) { + if len(data) != datagramRequestIdLen { + return RequestID{}, ErrInvalidRequestIDLen + } + + return RequestID{ + hi: binary.BigEndian.Uint64(data[:8]), + lo: binary.BigEndian.Uint64(data[8:]), + }, nil +} + +// Compare returns an integer comparing two IPs. +// The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2. +// The definition of "less than" is the same as the [RequestID.Less] method. +func (id RequestID) Compare(id2 RequestID) int { + hi1, hi2 := id.hi, id2.hi + if hi1 < hi2 { + return -1 + } + if hi1 > hi2 { + return 1 + } + lo1, lo2 := id.lo, id2.lo + if lo1 < lo2 { + return -1 + } + if lo1 > lo2 { + return 1 + } + return 0 +} + +// Less reports whether id sorts before id2. +func (id RequestID) Less(id2 RequestID) bool { return id.Compare(id2) == -1 } + +// MarshalBinaryTo writes the id to the provided destination byte slice; the byte slice must be of at least size 16. +func (id RequestID) MarshalBinaryTo(data []byte) error { + if len(data) < datagramRequestIdLen { + return ErrInvalidPayloadDestLen + } + binary.BigEndian.PutUint64(data[:8], id.hi) + binary.BigEndian.PutUint64(data[8:], id.lo) + return nil +} diff --git a/quic/v3/request_test.go b/quic/v3/request_test.go new file mode 100644 index 00000000..519c2dd2 --- /dev/null +++ b/quic/v3/request_test.go @@ -0,0 +1,50 @@ +package v3_test + +import ( + "crypto/rand" + "slices" + "testing" + + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +var ( + testRequestIDBytes = [16]byte{ + 0x00, 0x11, 0x22, 0x33, + 0x44, 0x55, 0x66, 0x77, + 0x88, 0x99, 0xaa, 0xbb, + 0xcc, 0xdd, 0xee, 0xff, + } + testRequestID = mustRequestID(testRequestIDBytes) +) + +func mustRequestID(data [16]byte) v3.RequestID { + id, err := v3.RequestIDFromSlice(data[:]) + if err != nil { + panic(err) + } + return id +} + +func TestRequestIDParsing(t *testing.T) { + buf1 := make([]byte, 16) + n, err := rand.Read(buf1) + if err != nil { + t.Fatal(err) + } + if n != 16 { + t.Fatalf("did not read 16 bytes: %d", n) + } + id, err := v3.RequestIDFromSlice(buf1) + if err != nil { + t.Fatal(err) + } + buf2 := make([]byte, 16) + err = id.MarshalBinaryTo(buf2) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(buf1, buf2) { + t.Fatalf("buf1 != buf2: %+v %+v", buf1, buf2) + } +}