TUN-6791: Calculate ICMPv6 checksum
This commit is contained in:
parent
7f487c2651
commit
3449ea35f2
|
@ -299,12 +299,7 @@ func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
|
||||||
require.Equal(t, ipv6.ICMPTypeEchoReply, decoded.Type)
|
require.Equal(t, ipv6.ICMPTypeEchoReply, decoded.Type)
|
||||||
}
|
}
|
||||||
require.Equal(t, 0, decoded.Code)
|
require.Equal(t, 0, decoded.Code)
|
||||||
if echoReq.Type == ipv4.ICMPTypeEcho {
|
|
||||||
require.NotZero(t, decoded.Checksum)
|
require.NotZero(t, decoded.Checksum)
|
||||||
} else {
|
|
||||||
// For ICMPv6, the kernel will compute the checksum during transmission unless pseudo header is not nil
|
|
||||||
require.Zero(t, decoded.Checksum)
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, echoReq.Body, decoded.Body)
|
require.Equal(t, echoReq.Body, decoded.Body)
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,6 +152,7 @@ func TestDecodeICMP(t *testing.T) {
|
||||||
|
|
||||||
require.Equal(t, test.packet.Type, icmpPacket.Type)
|
require.Equal(t, test.packet.Type, icmpPacket.Type)
|
||||||
require.Equal(t, test.packet.Code, icmpPacket.Code)
|
require.Equal(t, test.packet.Code, icmpPacket.Code)
|
||||||
|
assertICMPChecksum(t, icmpPacket)
|
||||||
require.Equal(t, test.packet.Body, icmpPacket.Body)
|
require.Equal(t, test.packet.Body, icmpPacket.Body)
|
||||||
expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
|
expectedBody, err := test.packet.Body.Marshal(test.packet.Type.Protocol())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package packet
|
package packet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
@ -21,6 +22,7 @@ const (
|
||||||
// 0 = ttl exceed in transit, 1 = fragment reassembly time exceeded
|
// 0 = ttl exceed in transit, 1 = fragment reassembly time exceeded
|
||||||
icmpTTLExceedInTransitCode = 0
|
icmpTTLExceedInTransitCode = 0
|
||||||
DefaultTTL uint8 = 255
|
DefaultTTL uint8 = 255
|
||||||
|
pseudoHeaderLen = 40
|
||||||
)
|
)
|
||||||
|
|
||||||
// Packet represents an IP packet or a packet that is encapsulated by IP
|
// Packet represents an IP packet or a packet that is encapsulated by IP
|
||||||
|
@ -117,7 +119,18 @@ func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := i.Marshal(nil)
|
var serializedPsh []byte = nil
|
||||||
|
if i.Protocol == layers.IPProtocolICMPv6 {
|
||||||
|
psh := &PseudoHeader{
|
||||||
|
SrcIP: i.Src.As16(),
|
||||||
|
DstIP: i.Dst.As16(),
|
||||||
|
// i.Marshal re-calculates the UpperLayerPacketLength
|
||||||
|
UpperLayerPacketLength: 0,
|
||||||
|
NextHeader: uint8(i.Protocol),
|
||||||
|
}
|
||||||
|
serializedPsh = psh.Marshal()
|
||||||
|
}
|
||||||
|
msg, err := i.Marshal(serializedPsh)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -125,6 +138,29 @@ func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
|
||||||
return append(ipLayers, icmpLayer), nil
|
return append(ipLayers, icmpLayer), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://www.rfc-editor.org/rfc/rfc2460#section-8.1
|
||||||
|
type PseudoHeader struct {
|
||||||
|
SrcIP [16]byte
|
||||||
|
DstIP [16]byte
|
||||||
|
UpperLayerPacketLength uint32
|
||||||
|
zero [3]byte
|
||||||
|
NextHeader uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ph *PseudoHeader) Marshal() []byte {
|
||||||
|
buf := make([]byte, pseudoHeaderLen)
|
||||||
|
index := 0
|
||||||
|
copy(buf, ph.SrcIP[:])
|
||||||
|
index += 16
|
||||||
|
copy(buf[index:], ph.DstIP[:])
|
||||||
|
index += 16
|
||||||
|
binary.BigEndian.PutUint32(buf[index:], ph.UpperLayerPacketLength)
|
||||||
|
index += 4
|
||||||
|
copy(buf[index:], ph.zero[:])
|
||||||
|
buf[pseudoHeaderLen-1] = ph.NextHeader
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP netip.Addr) *ICMP {
|
func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP netip.Addr) *ICMP {
|
||||||
var (
|
var (
|
||||||
protocol layers.IPProtocol
|
protocol layers.IPProtocol
|
||||||
|
@ -137,6 +173,7 @@ func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP n
|
||||||
protocol = layers.IPProtocolICMPv6
|
protocol = layers.IPProtocolICMPv6
|
||||||
icmpType = ipv6.ICMPTypeTimeExceeded
|
icmpType = ipv6.ICMPTypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ICMP{
|
return &ICMP{
|
||||||
IP: &IP{
|
IP: &IP{
|
||||||
Src: routerIP,
|
Src: routerIP,
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
|
@ -101,4 +102,96 @@ func assertTTLExceedPacket(t *testing.T, pk *ICMP) {
|
||||||
require.Len(t, rawTTLExceedPacket.Data, headerLen+icmpHeaderLen+len(rawPacket.Data))
|
require.Len(t, rawTTLExceedPacket.Data, headerLen+icmpHeaderLen+len(rawPacket.Data))
|
||||||
require.True(t, bytes.Equal(rawPacket.Data, rawTTLExceedPacket.Data[headerLen+icmpHeaderLen:]))
|
require.True(t, bytes.Equal(rawPacket.Data, rawTTLExceedPacket.Data[headerLen+icmpHeaderLen:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
decoder := NewICMPDecoder()
|
||||||
|
decodedPacket, err := decoder.Decode(rawTTLExceedPacket)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assertICMPChecksum(t, decodedPacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertICMPChecksum(t *testing.T, icmpPacket *ICMP) {
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
if icmpPacket.Protocol == layers.IPProtocolICMPv4 {
|
||||||
|
icmpv4 := layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(uint8(icmpPacket.Type.(ipv4.ICMPType)), uint8(icmpPacket.Code)),
|
||||||
|
}
|
||||||
|
switch body := icmpPacket.Body.(type) {
|
||||||
|
case *icmp.Echo:
|
||||||
|
icmpv4.Id = uint16(body.ID)
|
||||||
|
icmpv4.Seq = uint16(body.Seq)
|
||||||
|
payload := gopacket.Payload(body.Data)
|
||||||
|
require.NoError(t, payload.SerializeTo(buf, serializeOpts))
|
||||||
|
default:
|
||||||
|
require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf))
|
||||||
|
}
|
||||||
|
// SerializeTo sets the checksum in icmpv4
|
||||||
|
require.NoError(t, icmpv4.SerializeTo(buf, serializeOpts))
|
||||||
|
require.Equal(t, icmpv4.Checksum, uint16(icmpPacket.Checksum))
|
||||||
|
} else {
|
||||||
|
switch body := icmpPacket.Body.(type) {
|
||||||
|
case *icmp.Echo:
|
||||||
|
payload := gopacket.Payload(body.Data)
|
||||||
|
require.NoError(t, payload.SerializeTo(buf, serializeOpts))
|
||||||
|
echo := layers.ICMPv6Echo{
|
||||||
|
Identifier: uint16(body.ID),
|
||||||
|
SeqNumber: uint16(body.Seq),
|
||||||
|
}
|
||||||
|
require.NoError(t, echo.SerializeTo(buf, serializeOpts))
|
||||||
|
default:
|
||||||
|
require.NoError(t, serializeICMPAsPayload(icmpPacket.Message, buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpv6 := layers.ICMPv6{
|
||||||
|
TypeCode: layers.CreateICMPv6TypeCode(uint8(icmpPacket.Type.(ipv6.ICMPType)), uint8(icmpPacket.Code)),
|
||||||
|
}
|
||||||
|
ipLayer := layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
SrcIP: icmpPacket.Src.AsSlice(),
|
||||||
|
DstIP: icmpPacket.Dst.AsSlice(),
|
||||||
|
NextHeader: icmpPacket.Protocol,
|
||||||
|
HopLimit: icmpPacket.TTL,
|
||||||
|
}
|
||||||
|
require.NoError(t, icmpv6.SetNetworkLayerForChecksum(&ipLayer))
|
||||||
|
|
||||||
|
// SerializeTo sets the checksum in icmpv4
|
||||||
|
require.NoError(t, icmpv6.SerializeTo(buf, serializeOpts))
|
||||||
|
require.Equal(t, icmpv6.Checksum, uint16(icmpPacket.Checksum))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeICMPAsPayload(message *icmp.Message, buf gopacket.SerializeBuffer) error {
|
||||||
|
serializedBody, err := message.Body.Marshal(message.Type.Protocol())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := gopacket.Payload(serializedBody)
|
||||||
|
return payload.SerializeTo(buf, serializeOpts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChecksum(t *testing.T) {
|
||||||
|
data := []byte{0x63, 0x2c, 0x49, 0xd6, 0x00, 0x0d, 0xc1, 0xda}
|
||||||
|
pk := ICMP{
|
||||||
|
IP: &IP{
|
||||||
|
Src: netip.MustParseAddr("2606:4700:110:89c1:c63a:861:e08c:b049"),
|
||||||
|
Dst: netip.MustParseAddr("fde8:b693:d420:109b::2"),
|
||||||
|
Protocol: layers.IPProtocolICMPv6,
|
||||||
|
TTL: 3,
|
||||||
|
},
|
||||||
|
Message: &icmp.Message{
|
||||||
|
Type: ipv6.ICMPTypeEchoRequest,
|
||||||
|
Code: 0,
|
||||||
|
Body: &icmp.Echo{
|
||||||
|
ID: 0x20a7,
|
||||||
|
Seq: 8,
|
||||||
|
Data: data,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
encoder := NewEncoder()
|
||||||
|
encoded, err := encoder.Encode(&pk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
decoder := NewICMPDecoder()
|
||||||
|
decoded, err := decoder.Decode(encoded)
|
||||||
|
require.Equal(t, 0xff96, decoded.Checksum)
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,6 +96,7 @@ func assertTTLExceed(t *testing.T, originalPacket *ICMP, expectedSrc netip.Addr,
|
||||||
require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type)
|
require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type)
|
||||||
}
|
}
|
||||||
require.Equal(t, 0, decoded.Code)
|
require.Equal(t, 0, decoded.Code)
|
||||||
|
assertICMPChecksum(t, decoded)
|
||||||
timeExceed, ok := decoded.Body.(*icmp.TimeExceeded)
|
timeExceed, ok := decoded.Body.(*icmp.TimeExceeded)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data))
|
require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data))
|
||||||
|
|
Loading…
Reference in New Issue