cloudflared-mirror/packet/packet.go

175 lines
4.4 KiB
Go
Raw Normal View History

package packet
import (
"fmt"
"net/netip"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
ipv4MinHeaderLen = 20
ipv6HeaderLen = 40
ipv4MinMTU = 576
ipv6MinMTU = 1280
icmpHeaderLen = 8
// https://www.rfc-editor.org/rfc/rfc792 and https://datatracker.ietf.org/doc/html/rfc4443#section-3.3 define 2 codes.
// 0 = ttl exceed in transit, 1 = fragment reassembly time exceeded
icmpTTLExceedInTransitCode = 0
DefaultTTL uint8 = 255
)
// Packet represents an IP packet or a packet that is encapsulated by IP
type Packet interface {
// IPLayer returns the IP of the packet
IPLayer() *IP
// EncodeLayers returns the layers that make up this packet. They can be passed to an Encoder to serialize into RawPacket
EncodeLayers() ([]gopacket.SerializableLayer, error)
}
// IP represents a generic IP packet. It can be embedded in more specific IP protocols
type IP struct {
Src netip.Addr
Dst netip.Addr
Protocol layers.IPProtocol
TTL uint8
}
func newIPv4(ipLayer *layers.IPv4) (*IP, error) {
src, ok := netip.AddrFromSlice(ipLayer.SrcIP)
if !ok {
return nil, fmt.Errorf("cannot convert source IP %s to netip.Addr", ipLayer.SrcIP)
}
dst, ok := netip.AddrFromSlice(ipLayer.DstIP)
if !ok {
return nil, fmt.Errorf("cannot convert source IP %s to netip.Addr", ipLayer.DstIP)
}
return &IP{
Src: src,
Dst: dst,
Protocol: ipLayer.Protocol,
TTL: ipLayer.TTL,
}, nil
}
func newIPv6(ipLayer *layers.IPv6) (*IP, error) {
src, ok := netip.AddrFromSlice(ipLayer.SrcIP)
if !ok {
return nil, fmt.Errorf("cannot convert source IP %s to netip.Addr", ipLayer.SrcIP)
}
dst, ok := netip.AddrFromSlice(ipLayer.DstIP)
if !ok {
return nil, fmt.Errorf("cannot convert source IP %s to netip.Addr", ipLayer.DstIP)
}
return &IP{
Src: src,
Dst: dst,
Protocol: ipLayer.NextHeader,
TTL: ipLayer.HopLimit,
}, nil
}
func (ip *IP) IPLayer() *IP {
return ip
}
func (ip *IP) isIPv4() bool {
return ip.Src.Is4()
}
func (ip *IP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
if ip.isIPv4() {
return []gopacket.SerializableLayer{
&layers.IPv4{
Version: 4,
SrcIP: ip.Src.AsSlice(),
DstIP: ip.Dst.AsSlice(),
Protocol: layers.IPProtocol(ip.Protocol),
TTL: ip.TTL,
},
}, nil
} else {
return []gopacket.SerializableLayer{
&layers.IPv6{
Version: 6,
SrcIP: ip.Src.AsSlice(),
DstIP: ip.Dst.AsSlice(),
NextHeader: layers.IPProtocol(ip.Protocol),
HopLimit: ip.TTL,
},
}, nil
}
}
// ICMP represents is an IP packet + ICMP message
type ICMP struct {
*IP
*icmp.Message
}
func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) {
ipLayers, err := i.IP.EncodeLayers()
if err != nil {
return nil, err
}
msg, err := i.Marshal(nil)
if err != nil {
return nil, err
}
icmpLayer := gopacket.Payload(msg)
return append(ipLayers, icmpLayer), nil
}
func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP netip.Addr) *ICMP {
var (
protocol layers.IPProtocol
icmpType icmp.Type
)
if originalIP.Dst.Is4() {
protocol = layers.IPProtocolICMPv4
icmpType = ipv4.ICMPTypeTimeExceeded
} else {
protocol = layers.IPProtocolICMPv6
icmpType = ipv6.ICMPTypeTimeExceeded
}
return &ICMP{
IP: &IP{
Src: routerIP,
Dst: originalIP.Src,
Protocol: protocol,
TTL: DefaultTTL,
},
Message: &icmp.Message{
Type: icmpType,
Code: icmpTTLExceedInTransitCode,
Body: &icmp.TimeExceeded{
Data: originalDatagram(originalPacket, originalIP.Dst.Is4()),
},
},
}
}
// originalDatagram returns a slice of the original datagram for ICMP error messages
// https://www.rfc-editor.org/rfc/rfc1812#section-4.3.2.3 suggests to copy as much without exceeding 576 bytes.
// https://datatracker.ietf.org/doc/html/rfc4443#section-3.3 suggests to copy as much without exceeding 1280 bytes
func originalDatagram(originalPacket RawPacket, isIPv4 bool) []byte {
var upperBound int
if isIPv4 {
upperBound = ipv4MinMTU - ipv4MinHeaderLen - icmpHeaderLen
if upperBound > len(originalPacket.Data) {
upperBound = len(originalPacket.Data)
}
} else {
upperBound = ipv6MinMTU - ipv6HeaderLen - icmpHeaderLen
if upperBound > len(originalPacket.Data) {
upperBound = len(originalPacket.Data)
}
}
return originalPacket.Data[:upperBound]
}