package packet

import (
	"bytes"
	"net/netip"
	"testing"

	"github.com/google/gopacket"
	"github.com/google/gopacket/layers"
	"github.com/stretchr/testify/require"
	"golang.org/x/net/icmp"
	"golang.org/x/net/ipv4"
	"golang.org/x/net/ipv6"
)

func TestNewICMPTTLExceedPacket(t *testing.T) {
	ipv4Packet := IP{
		Src:      netip.MustParseAddr("192.168.1.1"),
		Dst:      netip.MustParseAddr("10.0.0.1"),
		Protocol: layers.IPProtocolICMPv4,
		TTL:      0,
	}
	icmpV4Packet := ICMP{
		IP: &ipv4Packet,
		Message: &icmp.Message{
			Type: ipv4.ICMPTypeEcho,
			Code: 0,
			Body: &icmp.Echo{
				ID:   25821,
				Seq:  58129,
				Data: []byte("test ttl=0"),
			},
		},
	}
	assertTTLExceedPacket(t, &icmpV4Packet)
	icmpV4Packet.Body = &icmp.Echo{
		ID:   3487,
		Seq:  19183,
		Data: make([]byte, ipv4MinMTU),
	}
	assertTTLExceedPacket(t, &icmpV4Packet)
	ipv6Packet := IP{
		Src:      netip.MustParseAddr("fd51:2391:523:f4ee::1"),
		Dst:      netip.MustParseAddr("fd51:2391:697:f4ee::2"),
		Protocol: layers.IPProtocolICMPv6,
		TTL:      0,
	}
	icmpV6Packet := ICMP{
		IP: &ipv6Packet,
		Message: &icmp.Message{
			Type: ipv6.ICMPTypeEchoRequest,
			Code: 0,
			Body: &icmp.Echo{
				ID:   25821,
				Seq:  58129,
				Data: []byte("test ttl=0"),
			},
		},
	}
	assertTTLExceedPacket(t, &icmpV6Packet)
	icmpV6Packet.Body = &icmp.Echo{
		ID:   1497,
		Seq:  39284,
		Data: make([]byte, ipv6MinMTU),
	}
	assertTTLExceedPacket(t, &icmpV6Packet)
}

func assertTTLExceedPacket(t *testing.T, pk *ICMP) {
	encoder := NewEncoder()
	rawPacket, err := encoder.Encode(pk)
	require.NoError(t, err)

	minMTU := ipv4MinMTU
	headerLen := ipv4MinHeaderLen
	routerIP := netip.MustParseAddr("172.16.0.3")
	if pk.Dst.Is6() {
		minMTU = ipv6MinMTU
		headerLen = ipv6HeaderLen
		routerIP = netip.MustParseAddr("fd51:2391:697:f4ee::3")
	}

	ttlExceedPacket := NewICMPTTLExceedPacket(pk.IP, rawPacket, routerIP)
	require.Equal(t, routerIP, ttlExceedPacket.Src)
	require.Equal(t, pk.Src, ttlExceedPacket.Dst)
	require.Equal(t, pk.Protocol, ttlExceedPacket.Protocol)
	require.Equal(t, DefaultTTL, ttlExceedPacket.TTL)

	timeExceed, ok := ttlExceedPacket.Body.(*icmp.TimeExceeded)
	require.True(t, ok)
	if len(rawPacket.Data) > minMTU {
		require.True(t, bytes.Equal(timeExceed.Data, rawPacket.Data[:minMTU-headerLen-icmpHeaderLen]))
	} else {
		require.True(t, bytes.Equal(timeExceed.Data, rawPacket.Data))
	}

	rawTTLExceedPacket, err := encoder.Encode(ttlExceedPacket)
	require.NoError(t, err)
	if len(rawPacket.Data) > minMTU {
		require.Len(t, rawTTLExceedPacket.Data, minMTU)
	} else {
		require.Len(t, rawTTLExceedPacket.Data, headerLen+icmpHeaderLen+len(rawPacket.Data))
		require.True(t, bytes.Equal(rawPacket.Data, rawTTLExceedPacket.Data[headerLen+icmpHeaderLen:]))
	}

	decoder := NewICMPDecoder()
	decodedPacket, err := decoder.Decode(rawTTLExceedPacket)
	require.NoError(t, err)
	assertICMPChecksum(t, decodedPacket)
}

func assertICMPChecksum(t *testing.T, icmpPacket *ICMP) {
	buf := gopacket.NewSerializeBuffer()
	if icmpPacket.Protocol == layers.IPProtocolICMPv4 {
		icmpv4 := layers.ICMPv4{
			TypeCode: layers.CreateICMPv4TypeCode(uint8(icmpPacket.Type.(ipv4.ICMPType)), uint8(icmpPacket.Code)),
		}
		switch body := icmpPacket.Body.(type) {
		case *icmp.Echo:
			icmpv4.Id = uint16(body.ID)
			icmpv4.Seq = uint16(body.Seq)
			payload := gopacket.Payload(body.Data)
			require.NoError(t, payload.SerializeTo(buf, serializeOpts))
		default:
			require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf))
		}
		// SerializeTo sets the checksum in icmpv4
		require.NoError(t, icmpv4.SerializeTo(buf, serializeOpts))
		require.Equal(t, icmpv4.Checksum, uint16(icmpPacket.Checksum))
	} else {
		switch body := icmpPacket.Body.(type) {
		case *icmp.Echo:
			payload := gopacket.Payload(body.Data)
			require.NoError(t, payload.SerializeTo(buf, serializeOpts))
			echo := layers.ICMPv6Echo{
				Identifier: uint16(body.ID),
				SeqNumber:  uint16(body.Seq),
			}
			require.NoError(t, echo.SerializeTo(buf, serializeOpts))
		default:
			require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf))
		}

		icmpv6 := layers.ICMPv6{
			TypeCode: layers.CreateICMPv6TypeCode(uint8(icmpPacket.Type.(ipv6.ICMPType)), uint8(icmpPacket.Code)),
		}
		ipLayer := layers.IPv6{
			Version:    6,
			SrcIP:      icmpPacket.Src.AsSlice(),
			DstIP:      icmpPacket.Dst.AsSlice(),
			NextHeader: icmpPacket.Protocol,
			HopLimit:   icmpPacket.TTL,
		}
		require.NoError(t, icmpv6.SetNetworkLayerForChecksum(&ipLayer))

		// SerializeTo sets the checksum in icmpv4
		require.NoError(t, icmpv6.SerializeTo(buf, serializeOpts))
		require.Equal(t, icmpv6.Checksum, uint16(icmpPacket.Checksum))
	}
}

func serializeICMPAsPayload(message *icmp.Message, buf gopacket.SerializeBuffer) error {
	serializedBody, err := message.Body.Marshal(message.Type.Protocol())
	if err != nil {
		return err
	}
	payload := gopacket.Payload(serializedBody)
	return payload.SerializeTo(buf, serializeOpts)
}

func TestChecksum(t *testing.T) {
	data := []byte{0x63, 0x2c, 0x49, 0xd6, 0x00, 0x0d, 0xc1, 0xda}
	pk := ICMP{
		IP: &IP{
			Src:      netip.MustParseAddr("2606:4700:110:89c1:c63a:861:e08c:b049"),
			Dst:      netip.MustParseAddr("fde8:b693:d420:109b::2"),
			Protocol: layers.IPProtocolICMPv6,
			TTL:      3,
		},
		Message: &icmp.Message{
			Type: ipv6.ICMPTypeEchoRequest,
			Code: 0,
			Body: &icmp.Echo{
				ID:   0x20a7,
				Seq:  8,
				Data: data,
			},
		},
	}
	encoder := NewEncoder()
	encoded, err := encoder.Encode(&pk)
	require.NoError(t, err)

	decoder := NewICMPDecoder()
	decoded, err := decoder.Decode(encoded)
	require.Equal(t, 0xff96, decoded.Checksum)
}