257 lines
5.8 KiB
Go
257 lines
5.8 KiB
Go
package packet
|
|
|
|
import (
|
|
"net"
|
|
"net/netip"
|
|
"testing"
|
|
|
|
"github.com/google/gopacket"
|
|
"github.com/google/gopacket/layers"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/icmp"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
)
|
|
|
|
func TestDecodeIP(t *testing.T) {
|
|
ipDecoder := NewIPDecoder()
|
|
icmpDecoder := NewICMPDecoder()
|
|
udps := []UDP{
|
|
{
|
|
IP: IP{
|
|
Src: netip.MustParseAddr("172.16.0.1"),
|
|
Dst: netip.MustParseAddr("10.0.0.1"),
|
|
Protocol: layers.IPProtocolUDP,
|
|
},
|
|
SrcPort: 31678,
|
|
DstPort: 53,
|
|
},
|
|
{
|
|
|
|
IP: IP{
|
|
Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
|
|
Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
|
|
Protocol: layers.IPProtocolUDP,
|
|
},
|
|
SrcPort: 52139,
|
|
DstPort: 1053,
|
|
},
|
|
}
|
|
|
|
encoder := NewEncoder()
|
|
for _, udp := range udps {
|
|
p, err := encoder.Encode(&udp)
|
|
require.NoError(t, err)
|
|
|
|
ipPacket, err := ipDecoder.Decode(p)
|
|
require.NoError(t, err)
|
|
assertIPLayer(t, &udp.IP, ipPacket)
|
|
|
|
icmpPacket, err := icmpDecoder.Decode(p)
|
|
require.Error(t, err)
|
|
require.Nil(t, icmpPacket)
|
|
}
|
|
}
|
|
|
|
func TestDecodeICMP(t *testing.T) {
|
|
ipDecoder := NewIPDecoder()
|
|
icmpDecoder := NewICMPDecoder()
|
|
var (
|
|
ipv4Packet = IP{
|
|
Src: netip.MustParseAddr("172.16.0.1"),
|
|
Dst: netip.MustParseAddr("10.0.0.1"),
|
|
Protocol: layers.IPProtocolICMPv4,
|
|
TTL: DefaultTTL,
|
|
}
|
|
ipv6Packet = IP{
|
|
Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
|
|
Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
|
|
Protocol: layers.IPProtocolICMPv6,
|
|
TTL: DefaultTTL,
|
|
}
|
|
icmpID = 100
|
|
icmpSeq = 52819
|
|
)
|
|
tests := []struct {
|
|
testCase string
|
|
packet *ICMP
|
|
}{
|
|
{
|
|
testCase: "icmpv4 time exceed",
|
|
packet: &ICMP{
|
|
IP: &ipv4Packet,
|
|
Message: &icmp.Message{
|
|
Type: ipv4.ICMPTypeTimeExceeded,
|
|
Code: 0,
|
|
Body: &icmp.TimeExceeded{
|
|
Data: []byte("original packet"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
testCase: "icmpv4 echo",
|
|
packet: &ICMP{
|
|
IP: &ipv4Packet,
|
|
Message: &icmp.Message{
|
|
Type: ipv4.ICMPTypeEcho,
|
|
Code: 0,
|
|
Body: &icmp.Echo{
|
|
ID: icmpID,
|
|
Seq: icmpSeq,
|
|
Data: []byte("icmpv4 echo"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
testCase: "icmpv6 destination unreachable",
|
|
packet: &ICMP{
|
|
IP: &ipv6Packet,
|
|
Message: &icmp.Message{
|
|
Type: ipv6.ICMPTypeDestinationUnreachable,
|
|
Code: 4,
|
|
Body: &icmp.DstUnreach{
|
|
Data: []byte("original packet"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
testCase: "icmpv6 echo",
|
|
packet: &ICMP{
|
|
IP: &ipv6Packet,
|
|
Message: &icmp.Message{
|
|
Type: ipv6.ICMPTypeEchoRequest,
|
|
Code: 0,
|
|
Body: &icmp.Echo{
|
|
ID: icmpID,
|
|
Seq: icmpSeq,
|
|
Data: []byte("icmpv6 echo"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
encoder := NewEncoder()
|
|
for _, test := range tests {
|
|
p, err := encoder.Encode(test.packet)
|
|
require.NoError(t, err)
|
|
|
|
ipPacket, err := ipDecoder.Decode(p)
|
|
require.NoError(t, err)
|
|
if ipPacket.Src.Is4() {
|
|
assertIPLayer(t, &ipv4Packet, ipPacket)
|
|
} else {
|
|
assertIPLayer(t, &ipv6Packet, ipPacket)
|
|
}
|
|
icmpPacket, err := icmpDecoder.Decode(p)
|
|
require.NoError(t, err)
|
|
require.Equal(t, ipPacket, icmpPacket.IP)
|
|
|
|
require.Equal(t, test.packet.Type, icmpPacket.Type)
|
|
require.Equal(t, test.packet.Code, icmpPacket.Code)
|
|
assertICMPChecksum(t, icmpPacket)
|
|
require.Equal(t, test.packet.Body, icmpPacket.Body)
|
|
expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
|
|
require.NoError(t, err)
|
|
decodedBody, err := icmpPacket.Body.Marshal(test.packet.Type.Protocol())
|
|
require.NoError(t, err)
|
|
require.Equal(t, expectedBody, decodedBody)
|
|
}
|
|
}
|
|
|
|
// TestDecodeBadPackets makes sure decoders don't decode invalid packets
|
|
func TestDecodeBadPackets(t *testing.T) {
|
|
var (
|
|
srcIPv4 = net.ParseIP("172.16.0.1")
|
|
dstIPv4 = net.ParseIP("10.0.0.1")
|
|
)
|
|
|
|
ipLayer := layers.IPv4{
|
|
Version: 10,
|
|
SrcIP: srcIPv4,
|
|
DstIP: dstIPv4,
|
|
Protocol: layers.IPProtocolICMPv4,
|
|
TTL: DefaultTTL,
|
|
}
|
|
icmpLayer := layers.ICMPv4{
|
|
TypeCode: layers.CreateICMPv4TypeCode(uint8(ipv4.ICMPTypeEcho), 0),
|
|
Id: 100,
|
|
Seq: 52819,
|
|
}
|
|
wrongIPVersion, err := createPacket(&ipLayer, &icmpLayer, nil, nil)
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
testCase string
|
|
packet []byte
|
|
}{
|
|
{
|
|
testCase: "unknown IP version",
|
|
packet: wrongIPVersion,
|
|
},
|
|
{
|
|
testCase: "invalid packet",
|
|
packet: []byte("not a packet"),
|
|
},
|
|
{
|
|
testCase: "zero length packet",
|
|
packet: []byte{},
|
|
},
|
|
}
|
|
|
|
ipDecoder := NewIPDecoder()
|
|
icmpDecoder := NewICMPDecoder()
|
|
for _, test := range tests {
|
|
ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
|
|
require.Error(t, err)
|
|
require.Nil(t, ipPacket)
|
|
|
|
icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
|
|
require.Error(t, err)
|
|
require.Nil(t, icmpPacket)
|
|
}
|
|
}
|
|
|
|
func createPacket(ipLayer, secondLayer, thirdLayer gopacket.SerializableLayer, body []byte) ([]byte, error) {
|
|
payload := gopacket.Payload(body)
|
|
packet := gopacket.NewSerializeBuffer()
|
|
var err error
|
|
if thirdLayer != nil {
|
|
err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, thirdLayer, payload)
|
|
} else {
|
|
err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, payload)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return packet.Bytes(), nil
|
|
}
|
|
|
|
func assertIPLayer(t *testing.T, expected, actual *IP) {
|
|
require.Equal(t, expected.Src, actual.Src)
|
|
require.Equal(t, expected.Dst, actual.Dst)
|
|
require.Equal(t, expected.Protocol, actual.Protocol)
|
|
require.Equal(t, expected.TTL, actual.TTL)
|
|
}
|
|
|
|
type UDP struct {
|
|
IP
|
|
SrcPort, DstPort layers.UDPPort
|
|
}
|
|
|
|
func (u *UDP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
|
|
ipLayers, err := u.IP.EncodeLayers()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
udpLayer := layers.UDP{
|
|
SrcPort: u.SrcPort,
|
|
DstPort: u.DstPort,
|
|
}
|
|
udpLayer.SetNetworkLayerForChecksum(ipLayers[0].(gopacket.NetworkLayer))
|
|
return append(ipLayers, &udpLayer), nil
|
|
}
|