package packet import ( "net" "net/netip" "testing" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) func TestDecodeIP(t *testing.T) { ipDecoder := NewIPDecoder() icmpDecoder := NewICMPDecoder() udps := []UDP{ { IP: IP{ Src: netip.MustParseAddr("172.16.0.1"), Dst: netip.MustParseAddr("10.0.0.1"), Protocol: layers.IPProtocolUDP, }, SrcPort: 31678, DstPort: 53, }, { IP: IP{ Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"), Protocol: layers.IPProtocolUDP, }, SrcPort: 52139, DstPort: 1053, }, } encoder := NewEncoder() for _, udp := range udps { p, err := encoder.Encode(&udp) require.NoError(t, err) ipPacket, err := ipDecoder.Decode(p.Data) require.NoError(t, err) assertIPLayer(t, &udp.IP, ipPacket) icmpPacket, err := icmpDecoder.Decode(p.Data) require.Error(t, err) require.Nil(t, icmpPacket) } } func TestDecodeICMP(t *testing.T) { ipDecoder := NewIPDecoder() icmpDecoder := NewICMPDecoder() var ( ipv4Packet = IP{ Src: netip.MustParseAddr("172.16.0.1"), Dst: netip.MustParseAddr("10.0.0.1"), Protocol: layers.IPProtocolICMPv4, } ipv6Packet = IP{ Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"), Protocol: layers.IPProtocolICMPv6, } icmpID = 100 icmpSeq = 52819 ) tests := []struct { testCase string packet *ICMP }{ { testCase: "icmpv4 time exceed", packet: &ICMP{ IP: &ipv4Packet, Message: &icmp.Message{ Type: ipv4.ICMPTypeTimeExceeded, Code: 0, Body: &icmp.TimeExceeded{ Data: []byte("original packet"), }, }, }, }, { testCase: "icmpv4 echo", packet: &ICMP{ IP: &ipv4Packet, Message: &icmp.Message{ Type: ipv4.ICMPTypeEcho, Code: 0, Body: &icmp.Echo{ ID: icmpID, Seq: icmpSeq, Data: []byte("icmpv4 echo"), }, }, }, }, { testCase: "icmpv6 destination unreachable", packet: &ICMP{ IP: &ipv6Packet, Message: &icmp.Message{ Type: ipv6.ICMPTypeDestinationUnreachable, Code: 4, Body: &icmp.DstUnreach{ Data: []byte("original packet"), }, }, }, }, { testCase: "icmpv6 echo", packet: &ICMP{ IP: &ipv6Packet, Message: &icmp.Message{ Type: ipv6.ICMPTypeEchoRequest, Code: 0, Body: &icmp.Echo{ ID: icmpID, Seq: icmpSeq, Data: []byte("icmpv6 echo"), }, }, }, }, } encoder := NewEncoder() for _, test := range tests { p, err := encoder.Encode(test.packet) require.NoError(t, err) ipPacket, err := ipDecoder.Decode(p.Data) require.NoError(t, err) if ipPacket.Src.Is4() { assertIPLayer(t, &ipv4Packet, ipPacket) } else { assertIPLayer(t, &ipv6Packet, ipPacket) } icmpPacket, err := icmpDecoder.Decode(p.Data) require.NoError(t, err) require.Equal(t, ipPacket, icmpPacket.IP) require.Equal(t, test.packet.Type, icmpPacket.Type) require.Equal(t, test.packet.Code, icmpPacket.Code) require.Equal(t, test.packet.Body, icmpPacket.Body) expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol()) require.NoError(t, err) decodedBody, err := icmpPacket.Body.Marshal(test.packet.Type.Protocol()) require.NoError(t, err) require.Equal(t, expectedBody, decodedBody) } } // TestDecodeBadPackets makes sure decoders don't decode invalid packets func TestDecodeBadPackets(t *testing.T) { var ( srcIPv4 = net.ParseIP("172.16.0.1") dstIPv4 = net.ParseIP("10.0.0.1") ) ipLayer := layers.IPv4{ Version: 10, SrcIP: srcIPv4, DstIP: dstIPv4, Protocol: layers.IPProtocolICMPv4, TTL: defaultTTL, } icmpLayer := layers.ICMPv4{ TypeCode: layers.CreateICMPv4TypeCode(uint8(ipv4.ICMPTypeEcho), 0), Id: 100, Seq: 52819, } wrongIPVersion, err := createPacket(&ipLayer, &icmpLayer, nil, nil) require.NoError(t, err) tests := []struct { testCase string packet []byte }{ { testCase: "unknown IP version", packet: wrongIPVersion, }, { testCase: "invalid packet", packet: []byte("not a packet"), }, { testCase: "zero length packet", packet: []byte{}, }, } ipDecoder := NewIPDecoder() icmpDecoder := NewICMPDecoder() for _, test := range tests { ipPacket, err := ipDecoder.Decode(test.packet) require.Error(t, err) require.Nil(t, ipPacket) icmpPacket, err := icmpDecoder.Decode(test.packet) require.Error(t, err) require.Nil(t, icmpPacket) } } func createPacket(ipLayer, secondLayer, thirdLayer gopacket.SerializableLayer, body []byte) ([]byte, error) { payload := gopacket.Payload(body) packet := gopacket.NewSerializeBuffer() var err error if thirdLayer != nil { err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, thirdLayer, payload) } else { err = gopacket.SerializeLayers(packet, serializeOpts, ipLayer, secondLayer, payload) } if err != nil { return nil, err } return packet.Bytes(), nil } func assertIPLayer(t *testing.T, expected, actual *IP) { require.Equal(t, expected.Src, actual.Src) require.Equal(t, expected.Dst, actual.Dst) require.Equal(t, expected.Protocol, actual.Protocol) } type UDP struct { IP SrcPort, DstPort layers.UDPPort } func (u *UDP) EncodeLayers() ([]gopacket.SerializableLayer, error) { ipLayers, err := u.IP.EncodeLayers() if err != nil { return nil, err } udpLayer := layers.UDP{ SrcPort: u.SrcPort, DstPort: u.DstPort, } udpLayer.SetNetworkLayerForChecksum(ipLayers[0].(gopacket.NetworkLayer)) return append(ipLayers, &udpLayer), nil }