cloudflared-mirror/packet/router.go

127 lines
3.5 KiB
Go

package packet
import (
"context"
"net"
"net/netip"
"github.com/rs/zerolog"
)
var (
// Source IP in documentation range to return ICMP error messages if we can't determine the IP of this machine
icmpv4ErrFallbackSrc = netip.MustParseAddr("192.0.2.30")
icmpv6ErrFallbackSrc = netip.MustParseAddr("2001:db8::")
)
// ICMPRouter sends ICMP messages and listens for their responses
type ICMPRouter interface {
// Serve starts listening for responses to the requests until context is done
Serve(ctx context.Context) error
// Request sends an ICMP message
Request(pk *ICMP, responder FunnelUniPipe) error
}
// Upstream of raw packets
type Upstream interface {
// ReceivePacket waits for the next raw packet from upstream
ReceivePacket(ctx context.Context) (RawPacket, error)
}
type Router struct {
upstream Upstream
returnPipe FunnelUniPipe
icmpProxy ICMPRouter
ipv4Src netip.Addr
ipv6Src netip.Addr
logger *zerolog.Logger
}
func NewRouter(upstream Upstream, returnPipe FunnelUniPipe, icmpProxy ICMPRouter, logger *zerolog.Logger) *Router {
ipv4Src, err := findLocalAddr(net.ParseIP("1.1.1.1"), 53)
if err != nil {
logger.Warn().Err(err).Msgf("Failed to determine the IPv4 for this machine. It will use %s as source IP for error messages such as ICMP TTL exceed", icmpv4ErrFallbackSrc)
ipv4Src = icmpv4ErrFallbackSrc
}
ipv6Src, err := findLocalAddr(net.ParseIP("2606:4700:4700::1111"), 53)
if err != nil {
logger.Warn().Err(err).Msgf("Failed to determine the IPv6 for this machine. It will use %s as source IP for error messages such as ICMP TTL exceed", icmpv6ErrFallbackSrc)
ipv6Src = icmpv6ErrFallbackSrc
}
return &Router{
upstream: upstream,
returnPipe: returnPipe,
icmpProxy: icmpProxy,
ipv4Src: ipv4Src,
ipv6Src: ipv6Src,
logger: logger,
}
}
func (r *Router) Serve(ctx context.Context) error {
icmpDecoder := NewICMPDecoder()
encoder := NewEncoder()
for {
rawPacket, err := r.upstream.ReceivePacket(ctx)
if err != nil {
return err
}
icmpPacket, err := icmpDecoder.Decode(rawPacket)
if err != nil {
r.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram")
continue
}
if icmpPacket.TTL <= 1 {
if err := r.sendTTLExceedMsg(icmpPacket, rawPacket, encoder); err != nil {
r.logger.Err(err).Msg("Failed to return ICMP TTL exceed error")
}
continue
}
icmpPacket.TTL--
if err := r.icmpProxy.Request(icmpPacket, r.returnPipe); err != nil {
r.logger.Err(err).
Str("src", icmpPacket.Src.String()).
Str("dst", icmpPacket.Dst.String()).
Interface("type", icmpPacket.Type).
Msg("Failed to send ICMP packet")
continue
}
}
}
func (r *Router) sendTTLExceedMsg(pk *ICMP, rawPacket RawPacket, encoder *Encoder) error {
var srcIP netip.Addr
if pk.Dst.Is4() {
srcIP = r.ipv4Src
} else {
srcIP = r.ipv6Src
}
ttlExceedPacket := NewICMPTTLExceedPacket(pk.IP, rawPacket, srcIP)
encodedTTLExceed, err := encoder.Encode(ttlExceedPacket)
if err != nil {
return err
}
return r.returnPipe.SendPacket(pk.Src, encodedTTLExceed)
}
// findLocalAddr tries to dial UDP and returns the local address picked by the OS
func findLocalAddr(dst net.IP, port int) (netip.Addr, error) {
udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: dst,
Port: port,
})
if err != nil {
return netip.Addr{}, err
}
defer udpConn.Close()
localAddrPort, err := netip.ParseAddrPort(udpConn.LocalAddr().String())
if err != nil {
return netip.Addr{}, err
}
localAddr := localAddrPort.Addr()
return localAddr, nil
}