cloudflared-mirror/packet/decoder_test.go

256 lines
5.8 KiB
Go

package packet
import (
"net"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func TestDecodeIP(t *testing.T) {
ipDecoder := NewIPDecoder()
icmpDecoder := NewICMPDecoder()
udps := []UDP{
{
IP: IP{
Src: netip.MustParseAddr("172.16.0.1"),
Dst: netip.MustParseAddr("10.0.0.1"),
Protocol: layers.IPProtocolUDP,
},
SrcPort: 31678,
DstPort: 53,
},
{
IP: IP{
Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
Protocol: layers.IPProtocolUDP,
},
SrcPort: 52139,
DstPort: 1053,
},
}
encoder := NewEncoder()
for _, udp := range udps {
p, err := encoder.Encode(&udp)
require.NoError(t, err)
ipPacket, err := ipDecoder.Decode(p)
require.NoError(t, err)
assertIPLayer(t, &udp.IP, ipPacket)
icmpPacket, err := icmpDecoder.Decode(p)
require.Error(t, err)
require.Nil(t, icmpPacket)
}
}
func TestDecodeICMP(t *testing.T) {
ipDecoder := NewIPDecoder()
icmpDecoder := NewICMPDecoder()
var (
ipv4Packet = IP{
Src: netip.MustParseAddr("172.16.0.1"),
Dst: netip.MustParseAddr("10.0.0.1"),
Protocol: layers.IPProtocolICMPv4,
TTL: DefaultTTL,
}
ipv6Packet = IP{
Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
Protocol: layers.IPProtocolICMPv6,
TTL: DefaultTTL,
}
icmpID = 100
icmpSeq = 52819
)
tests := []struct {
testCase string
packet *ICMP
}{
{
testCase: "icmpv4 time exceed",
packet: &ICMP{
IP: &ipv4Packet,
Message: &icmp.Message{
Type: ipv4.ICMPTypeTimeExceeded,
Code: 0,
Body: &icmp.TimeExceeded{
Data: []byte("original packet"),
},
},
},
},
{
testCase: "icmpv4 echo",
packet: &ICMP{
IP: &ipv4Packet,
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: icmpID,
Seq: icmpSeq,
Data: []byte("icmpv4 echo"),
},
},
},
},
{
testCase: "icmpv6 destination unreachable",
packet: &ICMP{
IP: &ipv6Packet,
Message: &icmp.Message{
Type: ipv6.ICMPTypeDestinationUnreachable,
Code: 4,
Body: &icmp.DstUnreach{
Data: []byte("original packet"),
},
},
},
},
{
testCase: "icmpv6 echo",
packet: &ICMP{
IP: &ipv6Packet,
Message: &icmp.Message{
Type: ipv6.ICMPTypeEchoRequest,
Code: 0,
Body: &icmp.Echo{
ID: icmpID,
Seq: icmpSeq,
Data: []byte("icmpv6 echo"),
},
},
},
},
}
encoder := NewEncoder()
for _, test := range tests {
p, err := encoder.Encode(test.packet)
require.NoError(t, err)
ipPacket, err := ipDecoder.Decode(p)
require.NoError(t, err)
if ipPacket.Src.Is4() {
assertIPLayer(t, &ipv4Packet, ipPacket)
} else {
assertIPLayer(t, &ipv6Packet, ipPacket)
}
icmpPacket, err := icmpDecoder.Decode(p)
require.NoError(t, err)
require.Equal(t, ipPacket, icmpPacket.IP)
require.Equal(t, test.packet.Type, icmpPacket.Type)
require.Equal(t, test.packet.Code, icmpPacket.Code)
require.Equal(t, test.packet.Body, icmpPacket.Body)
expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
require.NoError(t, err)
decodedBody, err := icmpPacket.Body.Marshal(test.packet.Type.Protocol())
require.NoError(t, err)
require.Equal(t, expectedBody, decodedBody)
}
}
// TestDecodeBadPackets makes sure decoders don't decode invalid packets
func TestDecodeBadPackets(t *testing.T) {
var (
srcIPv4 = net.ParseIP("172.16.0.1")
dstIPv4 = net.ParseIP("10.0.0.1")
)
ipLayer := layers.IPv4{
Version: 10,
SrcIP: srcIPv4,
DstIP: dstIPv4,
Protocol: layers.IPProtocolICMPv4,
TTL: DefaultTTL,
}
icmpLayer := layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(uint8(ipv4.ICMPTypeEcho), 0),
Id: 100,
Seq: 52819,
}
wrongIPVersion, err := createPacket(&ipLayer, &icmpLayer, nil, nil)
require.NoError(t, err)
tests := []struct {
testCase string
packet []byte
}{
{
testCase: "unknown IP version",
packet: wrongIPVersion,
},
{
testCase: "invalid packet",
packet: []byte("not a packet"),
},
{
testCase: "zero length packet",
packet: []byte{},
},
}
ipDecoder := NewIPDecoder()
icmpDecoder := NewICMPDecoder()
for _, test := range tests {
ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
require.Error(t, err)
require.Nil(t, ipPacket)
icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
require.Error(t, err)
require.Nil(t, icmpPacket)
}
}
func createPacket(ipLayer, secondLayer, thirdLayer gopacket.SerializableLayer, body []byte) ([]byte, error) {
payload := gopacket.Payload(body)
packet := gopacket.NewSerializeBuffer()
var err error
if thirdLayer != nil {
err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, thirdLayer, payload)
} else {
err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, payload)
}
if err != nil {
return nil, err
}
return packet.Bytes(), nil
}
func assertIPLayer(t *testing.T, expected, actual *IP) {
require.Equal(t, expected.Src, actual.Src)
require.Equal(t, expected.Dst, actual.Dst)
require.Equal(t, expected.Protocol, actual.Protocol)
require.Equal(t, expected.TTL, actual.TTL)
}
type UDP struct {
IP
SrcPort, DstPort layers.UDPPort
}
func (u *UDP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
ipLayers, err := u.IP.EncodeLayers()
if err != nil {
return nil, err
}
udpLayer := layers.UDP{
SrcPort: u.SrcPort,
DstPort: u.DstPort,
}
udpLayer.SetNetworkLayerForChecksum(ipLayers[0].(gopacket.NetworkLayer))
return append(ipLayers, &udpLayer), nil
}