diff --git a/ingress/origin_icmp_proxy_test.go b/ingress/origin_icmp_proxy_test.go index 6046c575..e6f37a29 100644 --- a/ingress/origin_icmp_proxy_test.go +++ b/ingress/origin_icmp_proxy_test.go @@ -299,12 +299,7 @@ func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) { require.Equal(t, ipv6.ICMPTypeEchoReply, decoded.Type) } require.Equal(t, 0, decoded.Code) - if echoReq.Type == ipv4.ICMPTypeEcho { - require.NotZero(t, decoded.Checksum) - } else { - // For ICMPv6, the kernel will compute the checksum during transmission unless pseudo header is not nil - require.Zero(t, decoded.Checksum) - } + require.NotZero(t, decoded.Checksum) require.Equal(t, echoReq.Body, decoded.Body) } diff --git a/packet/decoder_test.go b/packet/decoder_test.go index c07896e4..b8770d74 100644 --- a/packet/decoder_test.go +++ b/packet/decoder_test.go @@ -152,6 +152,7 @@ func TestDecodeICMP(t *testing.T) { require.Equal(t, test.packet.Type, icmpPacket.Type) require.Equal(t, test.packet.Code, icmpPacket.Code) + assertICMPChecksum(t, icmpPacket) require.Equal(t, test.packet.Body, icmpPacket.Body) expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol()) require.NoError(t, err) diff --git a/packet/packet.go b/packet/packet.go index b691790f..de1e1d50 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -1,6 +1,7 @@ package packet import ( + "encoding/binary" "fmt" "net/netip" @@ -21,6 +22,7 @@ const ( // 0 = ttl exceed in transit, 1 = fragment reassembly time exceeded icmpTTLExceedInTransitCode = 0 DefaultTTL uint8 = 255 + pseudoHeaderLen = 40 ) // Packet represents an IP packet or a packet that is encapsulated by IP @@ -117,7 +119,18 @@ func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) { return nil, err } - msg, err := i.Marshal(nil) + var serializedPsh []byte = nil + if i.Protocol == layers.IPProtocolICMPv6 { + psh := &PseudoHeader{ + SrcIP: i.Src.As16(), + DstIP: i.Dst.As16(), + // i.Marshal re-calculates the UpperLayerPacketLength + UpperLayerPacketLength: 0, + NextHeader: uint8(i.Protocol), + } + serializedPsh = psh.Marshal() + } + msg, err := i.Marshal(serializedPsh) if err != nil { return nil, err } @@ -125,6 +138,29 @@ func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) { return append(ipLayers, icmpLayer), nil } +// https://www.rfc-editor.org/rfc/rfc2460#section-8.1 +type PseudoHeader struct { + SrcIP [16]byte + DstIP [16]byte + UpperLayerPacketLength uint32 + zero [3]byte + NextHeader uint8 +} + +func (ph *PseudoHeader) Marshal() []byte { + buf := make([]byte, pseudoHeaderLen) + index := 0 + copy(buf, ph.SrcIP[:]) + index += 16 + copy(buf[index:], ph.DstIP[:]) + index += 16 + binary.BigEndian.PutUint32(buf[index:], ph.UpperLayerPacketLength) + index += 4 + copy(buf[index:], ph.zero[:]) + buf[pseudoHeaderLen-1] = ph.NextHeader + return buf +} + func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP netip.Addr) *ICMP { var ( protocol layers.IPProtocol @@ -137,6 +173,7 @@ func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP n protocol = layers.IPProtocolICMPv6 icmpType = ipv6.ICMPTypeTimeExceeded } + return &ICMP{ IP: &IP{ Src: routerIP, diff --git a/packet/packet_test.go b/packet/packet_test.go index 8b0b314f..e486c744 100644 --- a/packet/packet_test.go +++ b/packet/packet_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" "golang.org/x/net/icmp" @@ -101,4 +102,96 @@ func assertTTLExceedPacket(t *testing.T, pk *ICMP) { require.Len(t, rawTTLExceedPacket.Data, headerLen+icmpHeaderLen+len(rawPacket.Data)) require.True(t, bytes.Equal(rawPacket.Data, rawTTLExceedPacket.Data[headerLen+icmpHeaderLen:])) } + + decoder := NewICMPDecoder() + decodedPacket, err := decoder.Decode(rawTTLExceedPacket) + require.NoError(t, err) + assertICMPChecksum(t, decodedPacket) +} + +func assertICMPChecksum(t *testing.T, icmpPacket *ICMP) { + buf := gopacket.NewSerializeBuffer() + if icmpPacket.Protocol == layers.IPProtocolICMPv4 { + icmpv4 := layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(uint8(icmpPacket.Type.(ipv4.ICMPType)), uint8(icmpPacket.Code)), + } + switch body := icmpPacket.Body.(type) { + case *icmp.Echo: + icmpv4.Id = uint16(body.ID) + icmpv4.Seq = uint16(body.Seq) + payload := gopacket.Payload(body.Data) + require.NoError(t, payload.SerializeTo(buf, serializeOpts)) + default: + require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf)) + } + // SerializeTo sets the checksum in icmpv4 + require.NoError(t, icmpv4.SerializeTo(buf, serializeOpts)) + require.Equal(t, icmpv4.Checksum, uint16(icmpPacket.Checksum)) + } else { + switch body := icmpPacket.Body.(type) { + case *icmp.Echo: + payload := gopacket.Payload(body.Data) + require.NoError(t, payload.SerializeTo(buf, serializeOpts)) + echo := layers.ICMPv6Echo{ + Identifier: uint16(body.ID), + SeqNumber: uint16(body.Seq), + } + require.NoError(t, echo.SerializeTo(buf, serializeOpts)) + default: + require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf)) + } + + icmpv6 := layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(uint8(icmpPacket.Type.(ipv6.ICMPType)), uint8(icmpPacket.Code)), + } + ipLayer := layers.IPv6{ + Version: 6, + SrcIP: icmpPacket.Src.AsSlice(), + DstIP: icmpPacket.Dst.AsSlice(), + NextHeader: icmpPacket.Protocol, + HopLimit: icmpPacket.TTL, + } + require.NoError(t, icmpv6.SetNetworkLayerForChecksum(&ipLayer)) + + // SerializeTo sets the checksum in icmpv4 + require.NoError(t, icmpv6.SerializeTo(buf, serializeOpts)) + require.Equal(t, icmpv6.Checksum, uint16(icmpPacket.Checksum)) + } +} + +func serializeICMPAsPayload(message *icmp.Message, buf gopacket.SerializeBuffer) error { + serializedBody, err := message.Body.Marshal(message.Type.Protocol()) + if err != nil { + return err + } + payload := gopacket.Payload(serializedBody) + return payload.SerializeTo(buf, serializeOpts) +} + +func TestChecksum(t *testing.T) { + data := []byte{0x63, 0x2c, 0x49, 0xd6, 0x00, 0x0d, 0xc1, 0xda} + pk := ICMP{ + IP: &IP{ + Src: netip.MustParseAddr("2606:4700:110:89c1:c63a:861:e08c:b049"), + Dst: netip.MustParseAddr("fde8:b693:d420:109b::2"), + Protocol: layers.IPProtocolICMPv6, + TTL: 3, + }, + Message: &icmp.Message{ + Type: ipv6.ICMPTypeEchoRequest, + Code: 0, + Body: &icmp.Echo{ + ID: 0x20a7, + Seq: 8, + Data: data, + }, + }, + } + encoder := NewEncoder() + encoded, err := encoder.Encode(&pk) + require.NoError(t, err) + + decoder := NewICMPDecoder() + decoded, err := decoder.Decode(encoded) + require.Equal(t, 0xff96, decoded.Checksum) } diff --git a/packet/router_test.go b/packet/router_test.go index 14e49986..48afdc77 100644 --- a/packet/router_test.go +++ b/packet/router_test.go @@ -96,6 +96,7 @@ func assertTTLExceed(t *testing.T, originalPacket *ICMP, expectedSrc netip.Addr, require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type) } require.Equal(t, 0, decoded.Code) + assertICMPChecksum(t, decoded) timeExceed, ok := decoded.Body.(*icmp.TimeExceeded) require.True(t, ok) require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data))