TUN-8640: Refactor ICMPRouter to support new ICMPResponders

A new ICMPResponder interface is introduced to provide different
implementations of how the ICMP flows should return to the QUIC
connection muxer.

Improves usages of netip.AddrPort to leverage the embedded zone
field for IPv6 addresses.

Closes TUN-8640
This commit is contained in:
Devin Carr 2024-11-27 12:46:08 -08:00
parent 46dc6316f9
commit 9da15b5d96
19 changed files with 199 additions and 236 deletions

View File

@ -510,7 +510,7 @@ func StartServer(
// Disable ICMP packet routing for quick tunnels // Disable ICMP packet routing for quick tunnels
if quickTunnelURL != "" { if quickTunnelURL != "" {
tunnelConfig.PacketConfig = nil tunnelConfig.ICMPRouterServer = nil
} }
internalRules := []ingress.Rule{} internalRules := []ingress.Rule{}

View File

@ -252,11 +252,11 @@ func prepareTunnelConfig(
QUICConnectionLevelFlowControlLimit: c.Uint64(quicConnLevelFlowControlLimit), QUICConnectionLevelFlowControlLimit: c.Uint64(quicConnLevelFlowControlLimit),
QUICStreamLevelFlowControlLimit: c.Uint64(quicStreamLevelFlowControlLimit), QUICStreamLevelFlowControlLimit: c.Uint64(quicStreamLevelFlowControlLimit),
} }
packetConfig, err := newPacketConfig(c, log) icmpRouter, err := newICMPRouter(c, log)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("ICMP proxy feature is disabled") log.Warn().Err(err).Msg("ICMP proxy feature is disabled")
} else { } else {
tunnelConfig.PacketConfig = packetConfig tunnelConfig.ICMPRouterServer = icmpRouter
} }
orchestratorConfig := &orchestration.Config{ orchestratorConfig := &orchestration.Config{
Ingress: &ingressRules, Ingress: &ingressRules,
@ -351,7 +351,7 @@ func adjustIPVersionByBindAddress(ipVersion allregions.ConfigIPVersion, ip net.I
} }
} }
func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) { func newICMPRouter(c *cli.Context, logger *zerolog.Logger) (ingress.ICMPRouterServer, error) {
ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger) ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy") return nil, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
@ -368,16 +368,11 @@ func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRou
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src) logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
} }
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger, icmpFunnelTimeout) icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, logger, icmpFunnelTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &ingress.GlobalRouterConfig{ return icmpRouter, nil
ICMPRouter: icmpRouter,
IPv4Src: ipv4Src,
IPv6Src: ipv6Src,
Zone: zone,
}, nil
} }
func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) { func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) {
@ -407,13 +402,12 @@ type interfaceIP struct {
func determineICMPv6Src(userDefinedSrc string, logger *zerolog.Logger, ipv4Src netip.Addr) (addr netip.Addr, zone string, err error) { func determineICMPv6Src(userDefinedSrc string, logger *zerolog.Logger, ipv4Src netip.Addr) (addr netip.Addr, zone string, err error) {
if userDefinedSrc != "" { if userDefinedSrc != "" {
userDefinedIP, zone, _ := strings.Cut(userDefinedSrc, "%") addr, err := netip.ParseAddr(userDefinedSrc)
addr, err := netip.ParseAddr(userDefinedIP)
if err != nil { if err != nil {
return netip.Addr{}, "", err return netip.Addr{}, "", err
} }
if addr.Is6() { if addr.Is6() {
return addr, zone, nil return addr, addr.Zone(), nil
} }
return netip.Addr{}, "", fmt.Errorf("expect IPv6, but %s is IPv4", userDefinedSrc) return netip.Addr{}, "", fmt.Errorf("expect IPv6, but %s is IPv4", userDefinedSrc)
} }

View File

@ -7,7 +7,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -18,7 +17,6 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/packet"
cfdquic "github.com/cloudflare/cloudflared/quic" cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -417,28 +415,3 @@ func (np *nopCloserReadWriter) Close() error {
return nil return nil
} }
// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface
type muxerWrapper struct {
muxer *cfdquic.DatagramMuxerV2
}
func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
return rp.muxer.SendPacket(cfdquic.RawPacket(pk))
}
func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) {
pk, err := rp.muxer.ReceivePacket(ctx)
if err != nil {
return packet.RawPacket{}, err
}
rawPacket, ok := pk.(cfdquic.RawPacket)
if ok {
return packet.RawPacket(rawPacket), nil
}
return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk)
}
func (rp *muxerWrapper) Close() error {
return nil
}

View File

@ -752,7 +752,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
sessionDemuxChan := make(chan *packet.Session, 4) sessionDemuxChan := make(chan *packet.Session, 4)
datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan) datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan)
sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan)
packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, &log) packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, 0, &log)
datagramConn := &datagramV2Connection{ datagramConn := &datagramV2Connection{
conn, conn,

View File

@ -54,7 +54,8 @@ type datagramV2Connection struct {
func NewDatagramV2Connection(ctx context.Context, func NewDatagramV2Connection(ctx context.Context,
conn quic.Connection, conn quic.Connection,
packetConfig *ingress.GlobalRouterConfig, icmpRouter ingress.ICMPRouter,
index uint8,
rpcTimeout time.Duration, rpcTimeout time.Duration,
streamWriteTimeout time.Duration, streamWriteTimeout time.Duration,
logger *zerolog.Logger, logger *zerolog.Logger,
@ -62,7 +63,7 @@ func NewDatagramV2Connection(ctx context.Context,
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan) datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan)
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
packetRouter := ingress.NewPacketRouter(packetConfig, datagramMuxer, logger) packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger)
return &datagramV2Connection{ return &datagramV2Connection{
conn, conn,

View File

@ -28,8 +28,6 @@ type icmpProxy struct {
srcFunnelTracker *packet.FunnelTracker srcFunnelTracker *packet.FunnelTracker
echoIDTracker *echoIDTracker echoIDTracker *echoIDTracker
conn *icmp.PacketConn conn *icmp.PacketConn
// Response is handled in one-by-one, so encoder can be shared between funnels
encoder *packet.Encoder
logger *zerolog.Logger logger *zerolog.Logger
idleTimeout time.Duration idleTimeout time.Duration
} }
@ -114,8 +112,8 @@ func (snf echoFunnelID) String() string {
return strconv.FormatUint(uint64(snf), 10) return strconv.FormatUint(uint64(snf), 10)
} }
func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
conn, err := newICMPConn(listenIP, zone) conn, err := newICMPConn(listenIP)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,16 +121,15 @@ func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idle
return &icmpProxy{ return &icmpProxy{
srcFunnelTracker: packet.NewFunnelTracker(), srcFunnelTracker: packet.NewFunnelTracker(),
echoIDTracker: newEchoIDTracker(), echoIDTracker: newEchoIDTracker(),
encoder: packet.NewEncoder(),
conn: conn, conn: conn,
logger: logger, logger: logger,
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
}, nil }, nil
} }
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error {
_, span := responder.requestSpan(ctx, pk) _, span := responder.RequestSpan(ctx, pk)
defer responder.exportSpan() defer responder.ExportSpan()
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -154,7 +151,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
} }
span.SetAttributes(attribute.Int("assignedEchoID", int(assignedEchoID))) span.SetAttributes(attribute.Int("assignedEchoID", int(assignedEchoID)))
shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID) shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder, pk, originalEcho.ID)
newFunnelFunc := func() (packet.Funnel, error) { newFunnelFunc := func() (packet.Funnel, error) {
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -164,7 +161,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
ip.echoIDTracker.release(echoIDTrackerKey, assignedEchoID) ip.echoIDTracker.release(echoIDTrackerKey, assignedEchoID)
return nil return nil
} }
icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, ip.conn, responder, int(assignedEchoID), originalEcho.ID, ip.encoder) icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, ip.conn, responder, int(assignedEchoID), originalEcho.ID)
return icmpFlow, nil return icmpFlow, nil
} }
funnelID := echoFunnelID(assignedEchoID) funnelID := echoFunnelID(assignedEchoID)
@ -265,8 +262,8 @@ func (ip *icmpProxy) sendReply(ctx context.Context, reply *echoReply) error {
return err return err
} }
_, span := icmpFlow.responder.replySpan(ctx, ip.logger) _, span := icmpFlow.responder.ReplySpan(ctx, ip.logger)
defer icmpFlow.responder.exportSpan() defer icmpFlow.responder.ExportSpan()
if err := icmpFlow.returnToSrc(reply); err != nil { if err := icmpFlow.returnToSrc(reply); err != nil {
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)

View File

@ -18,7 +18,7 @@ var errICMPProxyNotImplemented = fmt.Errorf("ICMP proxy is not implemented on %s
type icmpProxy struct{} type icmpProxy struct{}
func (ip icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error { func (ip icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error {
return errICMPProxyNotImplemented return errICMPProxyNotImplemented
} }
@ -26,6 +26,6 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
return errICMPProxyNotImplemented return errICMPProxyNotImplemented
} }
func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
return nil, errICMPProxyNotImplemented return nil, errICMPProxyNotImplemented
} }

View File

@ -37,25 +37,23 @@ var (
type icmpProxy struct { type icmpProxy struct {
srcFunnelTracker *packet.FunnelTracker srcFunnelTracker *packet.FunnelTracker
listenIP netip.Addr listenIP netip.Addr
ipv6Zone string
logger *zerolog.Logger logger *zerolog.Logger
idleTimeout time.Duration idleTimeout time.Duration
} }
func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
if err := testPermission(listenIP, zone, logger); err != nil { if err := testPermission(listenIP, logger); err != nil {
return nil, err return nil, err
} }
return &icmpProxy{ return &icmpProxy{
srcFunnelTracker: packet.NewFunnelTracker(), srcFunnelTracker: packet.NewFunnelTracker(),
listenIP: listenIP, listenIP: listenIP,
ipv6Zone: zone,
logger: logger, logger: logger,
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
}, nil }, nil
} }
func testPermission(listenIP netip.Addr, zone string, logger *zerolog.Logger) error { func testPermission(listenIP netip.Addr, logger *zerolog.Logger) error {
// Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range // Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range
// Only check ping_group_range once for IPv4 // Only check ping_group_range once for IPv4
if listenIP.Is4() { if listenIP.Is4() {
@ -64,7 +62,7 @@ func testPermission(listenIP netip.Addr, zone string, logger *zerolog.Logger) er
return err return err
} }
} }
conn, err := newICMPConn(listenIP, zone) conn, err := newICMPConn(listenIP)
if err != nil { if err != nil {
return err return err
} }
@ -98,9 +96,9 @@ func checkInPingGroup() error {
return fmt.Errorf("did not find group range in %s", pingGroupPath) return fmt.Errorf("did not find group range in %s", pingGroupPath)
} }
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error {
ctx, span := responder.requestSpan(ctx, pk) ctx, span := responder.RequestSpan(ctx, pk)
defer responder.exportSpan() defer responder.ExportSpan()
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -109,9 +107,9 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
} }
observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq) observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq)
shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID) shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder, pk, originalEcho.ID)
newFunnelFunc := func() (packet.Funnel, error) { newFunnelFunc := func() (packet.Funnel, error) {
conn, err := newICMPConn(ip.listenIP, ip.ipv6Zone) conn, err := newICMPConn(ip.listenIP)
if err != nil { if err != nil {
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return nil, errors.Wrap(err, "failed to open ICMP socket") return nil, errors.Wrap(err, "failed to open ICMP socket")
@ -127,7 +125,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
span.SetAttributes(attribute.Int("port", localUDPAddr.Port)) span.SetAttributes(attribute.Int("port", localUDPAddr.Port))
echoID := localUDPAddr.Port echoID := localUDPAddr.Port
icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, conn, responder, echoID, originalEcho.ID, packet.NewEncoder()) icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, conn, responder, echoID, originalEcho.ID)
return icmpFlow, nil return icmpFlow, nil
} }
funnelID := flow3Tuple{ funnelID := flow3Tuple{
@ -181,8 +179,8 @@ func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) {
// Listens for ICMP response and handles error logging // Listens for ICMP response and handles error logging
func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) { func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) {
_, span := flow.responder.replySpan(ctx, ip.logger) _, span := flow.responder.ReplySpan(ctx, ip.logger)
defer flow.responder.exportSpan() defer flow.responder.ExportSpan()
span.SetAttributes( span.SetAttributes(
attribute.Int("originalEchoID", flow.originalEchoID), attribute.Int("originalEchoID", flow.originalEchoID),

View File

@ -18,15 +18,11 @@ import (
) )
// Opens a non-privileged ICMP socket on Linux and Darwin // Opens a non-privileged ICMP socket on Linux and Darwin
func newICMPConn(listenIP netip.Addr, zone string) (*icmp.PacketConn, error) { func newICMPConn(listenIP netip.Addr) (*icmp.PacketConn, error) {
if listenIP.Is4() { if listenIP.Is4() {
return icmp.ListenPacket("udp4", listenIP.String()) return icmp.ListenPacket("udp4", listenIP.String())
} }
listenAddr := listenIP.String() return icmp.ListenPacket("udp6", listenIP.String())
if zone != "" {
listenAddr = listenAddr + "%" + zone
}
return icmp.ListenPacket("udp6", listenAddr)
} }
func netipAddr(addr net.Addr) (netip.Addr, bool) { func netipAddr(addr net.Addr) (netip.Addr, bool) {
@ -34,7 +30,8 @@ func netipAddr(addr net.Addr) (netip.Addr, bool) {
if !ok { if !ok {
return netip.Addr{}, false return netip.Addr{}, false
} }
return netip.AddrFromSlice(udpAddr.IP)
return udpAddr.AddrPort().Addr(), true
} }
type flow3Tuple struct { type flow3Tuple struct {
@ -50,14 +47,12 @@ type icmpEchoFlow struct {
closed *atomic.Bool closed *atomic.Bool
src netip.Addr src netip.Addr
originConn *icmp.PacketConn originConn *icmp.PacketConn
responder *packetResponder responder ICMPResponder
assignedEchoID int assignedEchoID int
originalEchoID int originalEchoID int
// it's up to the user to ensure respEncoder is not used concurrently
respEncoder *packet.Encoder
} }
func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icmp.PacketConn, responder *packetResponder, assignedEchoID, originalEchoID int, respEncoder *packet.Encoder) *icmpEchoFlow { func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icmp.PacketConn, responder ICMPResponder, assignedEchoID, originalEchoID int) *icmpEchoFlow {
return &icmpEchoFlow{ return &icmpEchoFlow{
ActivityTracker: packet.NewActivityTracker(), ActivityTracker: packet.NewActivityTracker(),
closeCallback: closeCallback, closeCallback: closeCallback,
@ -67,7 +62,6 @@ func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icm
responder: responder, responder: responder,
assignedEchoID: assignedEchoID, assignedEchoID: assignedEchoID,
originalEchoID: originalEchoID, originalEchoID: originalEchoID,
respEncoder: respEncoder,
} }
} }
@ -139,11 +133,7 @@ func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) error {
}, },
Message: reply.msg, Message: reply.msg,
} }
serializedPacket, err := ief.respEncoder.Encode(&pk) return ief.responder.ReturnPacket(&pk)
if err != nil {
return err
}
return ief.responder.returnPacket(serializedPacket)
} }
type echoReply struct { type echoReply struct {
@ -184,7 +174,7 @@ func toICMPEchoFlow(funnel packet.Funnel) (*icmpEchoFlow, error) {
return icmpFlow, nil return icmpFlow, nil
} }
func createShouldReplaceFunnelFunc(logger *zerolog.Logger, muxer muxer, pk *packet.ICMP, originalEchoID int) func(packet.Funnel) bool { func createShouldReplaceFunnelFunc(logger *zerolog.Logger, responder ICMPResponder, pk *packet.ICMP, originalEchoID int) func(packet.Funnel) bool {
return func(existing packet.Funnel) bool { return func(existing packet.Funnel) bool {
existingFlow, err := toICMPEchoFlow(existing) existingFlow, err := toICMPEchoFlow(existing)
if err != nil { if err != nil {
@ -199,7 +189,7 @@ func createShouldReplaceFunnelFunc(logger *zerolog.Logger, muxer muxer, pk *pack
// If the existing flow has a different muxer, there's a new quic connection where return packets should be // If the existing flow has a different muxer, there's a new quic connection where return packets should be
// routed. Otherwise, return packets will be send to the first observed incoming connection, rather than the // routed. Otherwise, return packets will be send to the first observed incoming connection, rather than the
// most recently observed connection. // most recently observed connection.
if existingFlow.responder.datagramMuxer != muxer { if existingFlow.responder.ConnectionIndex() != responder.ConnectionIndex() {
logger.Debug(). logger.Debug().
Str("src", pk.Src.String()). Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()). Str("dst", pk.Dst.String()).

View File

@ -27,7 +27,7 @@ func TestFunnelIdleTimeout(t *testing.T) {
startSeq = 8129 startSeq = 8129
) )
logger := zerolog.New(os.Stderr) logger := zerolog.New(os.Stderr)
proxy, err := newICMPProxy(localhostIP, "", &logger, idleTimeout) proxy, err := newICMPProxy(localhostIP, &logger, idleTimeout)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -56,24 +56,19 @@ func TestFunnelIdleTimeout(t *testing.T) {
}, },
} }
muxer := newMockMuxer(0) muxer := newMockMuxer(0)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer, require.NoError(t, proxy.Request(ctx, &pk, responder))
}
require.NoError(t, proxy.Request(ctx, &pk, &responder))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
// Send second request, should reuse the funnel // Send second request, should reuse the funnel
require.NoError(t, proxy.Request(ctx, &pk, &packetResponder{ require.NoError(t, proxy.Request(ctx, &pk, responder))
datagramMuxer: muxer,
}))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
// New muxer on a different connection should use a new flow
time.Sleep(idleTimeout * 2) time.Sleep(idleTimeout * 2)
newMuxer := newMockMuxer(0) newMuxer := newMockMuxer(0)
newResponder := packetResponder{ newResponder := newPacketResponder(newMuxer, 1, packet.NewEncoder())
datagramMuxer: newMuxer, require.NoError(t, proxy.Request(ctx, &pk, newResponder))
}
require.NoError(t, proxy.Request(ctx, &pk, &newResponder))
validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk) validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk)
time.Sleep(idleTimeout * 2) time.Sleep(idleTimeout * 2)
@ -90,7 +85,7 @@ func TestReuseFunnel(t *testing.T) {
startSeq = 8129 startSeq = 8129
) )
logger := zerolog.New(os.Stderr) logger := zerolog.New(os.Stderr)
proxy, err := newICMPProxy(localhostIP, "", &logger, idleTimeout) proxy, err := newICMPProxy(localhostIP, &logger, idleTimeout)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -124,18 +119,14 @@ func TestReuseFunnel(t *testing.T) {
originalEchoID: echoID, originalEchoID: echoID,
} }
muxer := newMockMuxer(0) muxer := newMockMuxer(0)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer, require.NoError(t, proxy.Request(ctx, &pk, responder))
}
require.NoError(t, proxy.Request(ctx, &pk, &responder))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
funnel1, found := getFunnel(t, proxy, tuple) funnel1, found := getFunnel(t, proxy, tuple)
require.True(t, found) require.True(t, found)
// Send second request, should reuse the funnel // Send second request, should reuse the funnel
require.NoError(t, proxy.Request(ctx, &pk, &packetResponder{ require.NoError(t, proxy.Request(ctx, &pk, responder))
datagramMuxer: muxer,
}))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
funnel2, found := getFunnel(t, proxy, tuple) funnel2, found := getFunnel(t, proxy, tuple)
require.True(t, found) require.True(t, found)

View File

@ -13,7 +13,6 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"runtime/debug" "runtime/debug"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@ -222,11 +221,9 @@ type icmpProxy struct {
// This is a ICMPv6 if srcSocketAddr is not nil // This is a ICMPv6 if srcSocketAddr is not nil
srcSocketAddr *sockAddrIn6 srcSocketAddr *sockAddrIn6
logger *zerolog.Logger logger *zerolog.Logger
// A pool of reusable *packet.Encoder
encoderPool sync.Pool
} }
func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
var ( var (
srcSocketAddr *sockAddrIn6 srcSocketAddr *sockAddrIn6
handle uintptr handle uintptr
@ -250,11 +247,6 @@ func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idle
handle: handle, handle: handle,
srcSocketAddr: srcSocketAddr, srcSocketAddr: srcSocketAddr,
logger: logger, logger: logger,
encoderPool: sync.Pool{
New: func() any {
return packet.NewEncoder()
},
},
}, nil }, nil
} }
@ -267,15 +259,15 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
// Request sends an ICMP echo request and wait for a reply or timeout. // Request sends an ICMP echo request and wait for a reply or timeout.
// The async version of Win32 APIs take a callback whose memory is not garbage collected, so we use the synchronous version. // The async version of Win32 APIs take a callback whose memory is not garbage collected, so we use the synchronous version.
// It's possible that a slow request will block other requests, so we set the timeout to only 1s. // It's possible that a slow request will block other requests, so we set the timeout to only 1s.
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
ip.logger.Error().Interface("error", r).Msgf("Recover panic from sending icmp request/response, error %s", debug.Stack()) ip.logger.Error().Interface("error", r).Msgf("Recover panic from sending icmp request/response, error %s", debug.Stack())
} }
}() }()
_, requestSpan := responder.requestSpan(ctx, pk) _, requestSpan := responder.RequestSpan(ctx, pk)
defer responder.exportSpan() defer responder.ExportSpan()
echo, err := getICMPEcho(pk.Message) echo, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -290,9 +282,9 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
return err return err
} }
tracing.End(requestSpan) tracing.End(requestSpan)
responder.exportSpan() responder.ExportSpan()
_, replySpan := responder.replySpan(ctx, ip.logger) _, replySpan := responder.ReplySpan(ctx, ip.logger)
err = ip.handleEchoReply(pk, echo, resp, responder) err = ip.handleEchoReply(pk, echo, resp, responder)
if err != nil { if err != nil {
ip.logger.Err(err).Msg("Failed to send ICMP reply") ip.logger.Err(err).Msg("Failed to send ICMP reply")
@ -308,7 +300,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
return nil return nil
} }
func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, resp echoResp, responder *packetResponder) error { func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, resp echoResp, responder ICMPResponder) error {
var replyType icmp.Type var replyType icmp.Type
if request.Dst.Is4() { if request.Dst.Is4() {
replyType = ipv4.ICMPTypeEchoReply replyType = ipv4.ICMPTypeEchoReply
@ -333,21 +325,7 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, r
}, },
}, },
} }
return responder.ReturnPacket(&pk)
cachedEncoder := ip.encoderPool.Get()
// The encoded packet is a slice to of the encoder, so we shouldn't return the encoder back to the pool until
// the encoded packet is sent.
defer ip.encoderPool.Put(cachedEncoder)
encoder, ok := cachedEncoder.(*packet.Encoder)
if !ok {
return fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder)
}
serializedPacket, err := encoder.Encode(&pk)
if err != nil {
return err
}
return responder.returnPacket(serializedPacket)
} }
func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) (echoResp, error) { func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) (echoResp, error) {

View File

@ -132,7 +132,7 @@ func TestSendEchoErrors(t *testing.T) {
} }
func testSendEchoErrors(t *testing.T, listenIP netip.Addr) { func testSendEchoErrors(t *testing.T, listenIP netip.Addr) {
proxy, err := newICMPProxy(listenIP, "", &noopLogger, time.Second) proxy, err := newICMPProxy(listenIP, &noopLogger, time.Second)
require.NoError(t, err) require.NoError(t, err)
echo := icmp.Echo{ echo := icmp.Echo{

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/packet"
"github.com/cloudflare/cloudflared/tracing"
) )
const ( const (
@ -26,17 +27,46 @@ var (
errPacketNil = fmt.Errorf("packet is nil") errPacketNil = fmt.Errorf("packet is nil")
) )
// ICMPRouterServer is a parent interface over-top of ICMPRouter that allows for the operation of the proxy origin listeners.
type ICMPRouterServer interface {
ICMPRouter
// Serve runs the ICMPRouter proxy origin listeners for any of the IPv4 or IPv6 interfaces configured.
Serve(ctx context.Context) error
}
// ICMPRouter manages out-going ICMP requests towards the origin.
type ICMPRouter interface {
// Request will send an ICMP packet towards the origin with an ICMPResponder to attach to the ICMP flow for the
// response to utilize.
Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error
// ConvertToTTLExceeded will take an ICMP packet and create a ICMP TTL Exceeded response origininating from the
// ICMPRouter's IP interface.
ConvertToTTLExceeded(pk *packet.ICMP, rawPacket packet.RawPacket) *packet.ICMP
}
// ICMPResponder manages how to handle incoming ICMP messages coming from the origin to the edge.
type ICMPResponder interface {
ConnectionIndex() uint8
ReturnPacket(pk *packet.ICMP) error
AddTraceContext(tracedCtx *tracing.TracedContext, serializedIdentity []byte)
RequestSpan(ctx context.Context, pk *packet.ICMP) (context.Context, trace.Span)
ReplySpan(ctx context.Context, logger *zerolog.Logger) (context.Context, trace.Span)
ExportSpan()
}
type icmpRouter struct { type icmpRouter struct {
ipv4Proxy *icmpProxy ipv4Proxy *icmpProxy
ipv4Src netip.Addr
ipv6Proxy *icmpProxy ipv6Proxy *icmpProxy
ipv6Src netip.Addr
} }
// NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only // NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only
// support one of them. // support one of them.
// funnelIdleTimeout controls how long to wait to close a funnel without send/return // funnelIdleTimeout controls how long to wait to close a funnel without send/return
func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerolog.Logger, funnelIdleTimeout time.Duration) (*icmpRouter, error) { func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, logger *zerolog.Logger, funnelIdleTimeout time.Duration) (ICMPRouterServer, error) {
ipv4Proxy, ipv4Err := newICMPProxy(ipv4Addr, "", logger, funnelIdleTimeout) ipv4Proxy, ipv4Err := newICMPProxy(ipv4Addr, logger, funnelIdleTimeout)
ipv6Proxy, ipv6Err := newICMPProxy(ipv6Addr, ipv6Zone, logger, funnelIdleTimeout) ipv6Proxy, ipv6Err := newICMPProxy(ipv6Addr, logger, funnelIdleTimeout)
if ipv4Err != nil && ipv6Err != nil { if ipv4Err != nil && ipv6Err != nil {
err := fmt.Errorf("cannot create ICMPv4 proxy: %v nor ICMPv6 proxy: %v", ipv4Err, ipv6Err) err := fmt.Errorf("cannot create ICMPv4 proxy: %v nor ICMPv6 proxy: %v", ipv4Err, ipv6Err)
logger.Debug().Err(err).Msg("ICMP proxy feature is disabled") logger.Debug().Err(err).Msg("ICMP proxy feature is disabled")
@ -52,7 +82,9 @@ func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerol
} }
return &icmpRouter{ return &icmpRouter{
ipv4Proxy: ipv4Proxy, ipv4Proxy: ipv4Proxy,
ipv4Src: ipv4Addr,
ipv6Proxy: ipv6Proxy, ipv6Proxy: ipv6Proxy,
ipv6Src: ipv6Addr,
}, nil }, nil
} }
@ -76,7 +108,7 @@ func (ir *icmpRouter) Serve(ctx context.Context) error {
return fmt.Errorf("ICMPv4 proxy and ICMPv6 proxy are both nil") return fmt.Errorf("ICMPv4 proxy and ICMPv6 proxy are both nil")
} }
func (ir *icmpRouter) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error { func (ir *icmpRouter) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error {
if pk == nil { if pk == nil {
return errPacketNil return errPacketNil
} }
@ -92,6 +124,16 @@ func (ir *icmpRouter) Request(ctx context.Context, pk *packet.ICMP, responder *p
return fmt.Errorf("ICMPv6 proxy was not instantiated") return fmt.Errorf("ICMPv6 proxy was not instantiated")
} }
func (ir *icmpRouter) ConvertToTTLExceeded(pk *packet.ICMP, rawPacket packet.RawPacket) *packet.ICMP {
var srcIP netip.Addr
if pk.Dst.Is4() {
srcIP = ir.ipv4Src
} else {
srcIP = ir.ipv6Src
}
return packet.NewICMPTTLExceedPacket(pk.IP, rawPacket, srcIP)
}
func getICMPEcho(msg *icmp.Message) (*icmp.Echo, error) { func getICMPEcho(msg *icmp.Message) (*icmp.Echo, error) {
echo, ok := msg.Body.(*icmp.Echo) echo, ok := msg.Body.(*icmp.Echo)
if !ok { if !ok {

View File

@ -50,7 +50,7 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
endSeq = 20 endSeq = 20
) )
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout) router, err := NewICMPRouter(localhostIP, localhostIPv6, &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
proxyDone := make(chan struct{}) proxyDone := make(chan struct{})
@ -61,9 +61,7 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
}() }()
muxer := newMockMuxer(1) muxer := newMockMuxer(1)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer,
}
protocol := layers.IPProtocolICMPv6 protocol := layers.IPProtocolICMPv6
if sendIPv4 { if sendIPv4 {
@ -98,7 +96,7 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
}, },
}, },
} }
require.NoError(t, router.Request(ctx, &pk, &responder)) require.NoError(t, router.Request(ctx, &pk, responder))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
} }
} }
@ -114,7 +112,7 @@ func TestTraceICMPRouterEcho(t *testing.T) {
tracingCtx := "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1" tracingCtx := "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1"
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout) router, err := NewICMPRouter(localhostIP, localhostIPv6, &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
proxyDone := make(chan struct{}) proxyDone := make(chan struct{})
@ -131,11 +129,8 @@ func TestTraceICMPRouterEcho(t *testing.T) {
serializedIdentity, err := tracingIdentity.MarshalBinary() serializedIdentity, err := tracingIdentity.MarshalBinary()
require.NoError(t, err) require.NoError(t, err)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer, responder.AddTraceContext(tracing.NewTracedContext(ctx, tracingIdentity.String(), &noopLogger), serializedIdentity)
tracedCtx: tracing.NewTracedContext(ctx, tracingIdentity.String(), &noopLogger),
serializedIdentity: serializedIdentity,
}
echo := &icmp.Echo{ echo := &icmp.Echo{
ID: 12910, ID: 12910,
@ -156,7 +151,7 @@ func TestTraceICMPRouterEcho(t *testing.T) {
}, },
} }
require.NoError(t, router.Request(ctx, &pk, &responder)) require.NoError(t, router.Request(ctx, &pk, responder))
firstPK := <-muxer.cfdToEdge firstPK := <-muxer.cfdToEdge
var requestSpan *quicpogs.TracingSpanPacket var requestSpan *quicpogs.TracingSpanPacket
// The order of receiving reply or request span is not deterministic // The order of receiving reply or request span is not deterministic
@ -194,10 +189,8 @@ func TestTraceICMPRouterEcho(t *testing.T) {
echo.Seq++ echo.Seq++
pk.Body = echo pk.Body = echo
// Only first request for a flow is traced. The edge will not send tracing context for the second request // Only first request for a flow is traced. The edge will not send tracing context for the second request
newResponder := packetResponder{ newResponder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer, require.NoError(t, router.Request(ctx, &pk, newResponder))
}
require.NoError(t, router.Request(ctx, &pk, &newResponder))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
select { select {
@ -221,7 +214,7 @@ func TestConcurrentRequestsToSameDst(t *testing.T) {
endSeq = 5 endSeq = 5
) )
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout) router, err := NewICMPRouter(localhostIP, localhostIPv6, &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
proxyDone := make(chan struct{}) proxyDone := make(chan struct{})
@ -240,9 +233,7 @@ func TestConcurrentRequestsToSameDst(t *testing.T) {
defer wg.Done() defer wg.Done()
muxer := newMockMuxer(1) muxer := newMockMuxer(1)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer,
}
for seq := 0; seq < endSeq; seq++ { for seq := 0; seq < endSeq; seq++ {
pk := &packet.ICMP{ pk := &packet.ICMP{
IP: &packet.IP{ IP: &packet.IP{
@ -261,16 +252,14 @@ func TestConcurrentRequestsToSameDst(t *testing.T) {
}, },
}, },
} }
require.NoError(t, router.Request(ctx, pk, &responder)) require.NoError(t, router.Request(ctx, pk, responder))
validateEchoFlow(t, <-muxer.cfdToEdge, pk) validateEchoFlow(t, <-muxer.cfdToEdge, pk)
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
muxer := newMockMuxer(1) muxer := newMockMuxer(1)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer,
}
for seq := 0; seq < endSeq; seq++ { for seq := 0; seq < endSeq; seq++ {
pk := &packet.ICMP{ pk := &packet.ICMP{
IP: &packet.IP{ IP: &packet.IP{
@ -289,7 +278,7 @@ func TestConcurrentRequestsToSameDst(t *testing.T) {
}, },
}, },
} }
require.NoError(t, router.Request(ctx, pk, &responder)) require.NoError(t, router.Request(ctx, pk, responder))
validateEchoFlow(t, <-muxer.cfdToEdge, pk) validateEchoFlow(t, <-muxer.cfdToEdge, pk)
} }
}() }()
@ -358,13 +347,11 @@ func TestICMPRouterRejectNotEcho(t *testing.T) {
} }
func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) { func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) {
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout) router, err := NewICMPRouter(localhostIP, localhostIPv6, &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
muxer := newMockMuxer(1) muxer := newMockMuxer(1)
responder := packetResponder{ responder := newPacketResponder(muxer, 0, packet.NewEncoder())
datagramMuxer: muxer,
}
protocol := layers.IPProtocolICMPv4 protocol := layers.IPProtocolICMPv4
if srcDstIP.Is6() { if srcDstIP.Is6() {
protocol = layers.IPProtocolICMPv6 protocol = layers.IPProtocolICMPv6
@ -379,7 +366,7 @@ func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.
}, },
Message: &m, Message: &m,
} }
require.Error(t, router.Request(context.Background(), &pk, &responder)) require.Error(t, router.Request(context.Background(), &pk, responder))
} }
} }

View File

@ -3,7 +3,6 @@ package ingress
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
@ -23,29 +22,23 @@ type muxer interface {
// PacketRouter routes packets between Upstream and ICMPRouter. Currently it rejects all other type of ICMP packets // PacketRouter routes packets between Upstream and ICMPRouter. Currently it rejects all other type of ICMP packets
type PacketRouter struct { type PacketRouter struct {
globalConfig *GlobalRouterConfig icmpRouter ICMPRouter
muxer muxer muxer muxer
connIndex uint8
logger *zerolog.Logger logger *zerolog.Logger
icmpDecoder *packet.ICMPDecoder
encoder *packet.Encoder encoder *packet.Encoder
} decoder *packet.ICMPDecoder
// GlobalRouterConfig is the configuration shared by all instance of Router.
type GlobalRouterConfig struct {
ICMPRouter *icmpRouter
IPv4Src netip.Addr
IPv6Src netip.Addr
Zone string
} }
// NewPacketRouter creates a PacketRouter that handles ICMP packets. Packets are read from muxer but dropped if globalConfig is nil. // NewPacketRouter creates a PacketRouter that handles ICMP packets. Packets are read from muxer but dropped if globalConfig is nil.
func NewPacketRouter(globalConfig *GlobalRouterConfig, muxer muxer, logger *zerolog.Logger) *PacketRouter { func NewPacketRouter(icmpRouter ICMPRouter, muxer muxer, connIndex uint8, logger *zerolog.Logger) *PacketRouter {
return &PacketRouter{ return &PacketRouter{
globalConfig: globalConfig, icmpRouter: icmpRouter,
muxer: muxer, muxer: muxer,
connIndex: connIndex,
logger: logger, logger: logger,
icmpDecoder: packet.NewICMPDecoder(),
encoder: packet.NewEncoder(), encoder: packet.NewEncoder(),
decoder: packet.NewICMPDecoder(),
} }
} }
@ -59,14 +52,13 @@ func (r *PacketRouter) Serve(ctx context.Context) error {
} }
} }
func (r *PacketRouter) nextPacket(ctx context.Context) (packet.RawPacket, *packetResponder, error) { func (r *PacketRouter) nextPacket(ctx context.Context) (packet.RawPacket, ICMPResponder, error) {
pk, err := r.muxer.ReceivePacket(ctx) pk, err := r.muxer.ReceivePacket(ctx)
if err != nil { if err != nil {
return packet.RawPacket{}, nil, err return packet.RawPacket{}, nil, err
} }
responder := &packetResponder{ responder := newPacketResponder(r.muxer, r.connIndex, packet.NewEncoder())
datagramMuxer: r.muxer,
}
switch pk.Type() { switch pk.Type() {
case quicpogs.DatagramTypeIP: case quicpogs.DatagramTypeIP:
return packet.RawPacket{Data: pk.Payload()}, responder, nil return packet.RawPacket{Data: pk.Payload()}, responder, nil
@ -75,8 +67,8 @@ func (r *PacketRouter) nextPacket(ctx context.Context) (packet.RawPacket, *packe
if err := identity.UnmarshalBinary(pk.Metadata()); err != nil { if err := identity.UnmarshalBinary(pk.Metadata()); err != nil {
r.logger.Err(err).Bytes("tracingIdentity", pk.Metadata()).Msg("Failed to unmarshal tracing identity") r.logger.Err(err).Bytes("tracingIdentity", pk.Metadata()).Msg("Failed to unmarshal tracing identity")
} else { } else {
responder.tracedCtx = tracing.NewTracedContext(ctx, identity.String(), r.logger) tracedCtx := tracing.NewTracedContext(ctx, identity.String(), r.logger)
responder.serializedIdentity = pk.Metadata() responder.AddTraceContext(tracedCtx, pk.Metadata())
} }
return packet.RawPacket{Data: pk.Payload()}, responder, nil return packet.RawPacket{Data: pk.Payload()}, responder, nil
default: default:
@ -84,27 +76,27 @@ func (r *PacketRouter) nextPacket(ctx context.Context) (packet.RawPacket, *packe
} }
} }
func (r *PacketRouter) handlePacket(ctx context.Context, rawPacket packet.RawPacket, responder *packetResponder) { func (r *PacketRouter) handlePacket(ctx context.Context, rawPacket packet.RawPacket, responder ICMPResponder) {
// ICMP Proxy feature is disabled, drop packets // ICMP Proxy feature is disabled, drop packets
if r.globalConfig == nil { if r.icmpRouter == nil {
return return
} }
icmpPacket, err := r.icmpDecoder.Decode(rawPacket) icmpPacket, err := r.decoder.Decode(rawPacket)
if err != nil { if err != nil {
r.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram") r.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram")
return return
} }
if icmpPacket.TTL <= 1 { if icmpPacket.TTL <= 1 {
if err := r.sendTTLExceedMsg(ctx, icmpPacket, rawPacket, r.encoder); err != nil { if err := r.sendTTLExceedMsg(icmpPacket, rawPacket); err != nil {
r.logger.Err(err).Msg("Failed to return ICMP TTL exceed error") r.logger.Err(err).Msg("Failed to return ICMP TTL exceed error")
} }
return return
} }
icmpPacket.TTL-- icmpPacket.TTL--
if err := r.globalConfig.ICMPRouter.Request(ctx, icmpPacket, responder); err != nil { if err := r.icmpRouter.Request(ctx, icmpPacket, responder); err != nil {
r.logger.Err(err). r.logger.Err(err).
Str("src", icmpPacket.Src.String()). Str("src", icmpPacket.Src.String()).
Str("dst", icmpPacket.Dst.String()). Str("dst", icmpPacket.Dst.String()).
@ -113,16 +105,9 @@ func (r *PacketRouter) handlePacket(ctx context.Context, rawPacket packet.RawPac
} }
} }
func (r *PacketRouter) sendTTLExceedMsg(ctx context.Context, pk *packet.ICMP, rawPacket packet.RawPacket, encoder *packet.Encoder) error { func (r *PacketRouter) sendTTLExceedMsg(pk *packet.ICMP, rawPacket packet.RawPacket) error {
var srcIP netip.Addr icmpTTLPacket := r.icmpRouter.ConvertToTTLExceeded(pk, rawPacket)
if pk.Dst.Is4() { encodedTTLExceed, err := r.encoder.Encode(icmpTTLPacket)
srcIP = r.globalConfig.IPv4Src
} else {
srcIP = r.globalConfig.IPv6Src
}
ttlExceedPacket := packet.NewICMPTTLExceedPacket(pk.IP, rawPacket, srcIP)
encodedTTLExceed, err := encoder.Encode(ttlExceedPacket)
if err != nil { if err != nil {
return err return err
} }
@ -132,22 +117,45 @@ func (r *PacketRouter) sendTTLExceedMsg(ctx context.Context, pk *packet.ICMP, ra
// packetResponder should not be used concurrently. This assumption is upheld because reply packets are ready one-by-one // packetResponder should not be used concurrently. This assumption is upheld because reply packets are ready one-by-one
type packetResponder struct { type packetResponder struct {
datagramMuxer muxer datagramMuxer muxer
connIndex uint8
encoder *packet.Encoder
tracedCtx *tracing.TracedContext tracedCtx *tracing.TracedContext
serializedIdentity []byte serializedIdentity []byte
// hadReply tracks if there has been any reply for this flow // hadReply tracks if there has been any reply for this flow
hadReply bool hadReply bool
} }
func newPacketResponder(datagramMuxer muxer, connIndex uint8, encoder *packet.Encoder) ICMPResponder {
return &packetResponder{
datagramMuxer: datagramMuxer,
connIndex: connIndex,
encoder: encoder,
}
}
func (pr *packetResponder) tracingEnabled() bool { func (pr *packetResponder) tracingEnabled() bool {
return pr.tracedCtx != nil return pr.tracedCtx != nil
} }
func (pr *packetResponder) returnPacket(rawPacket packet.RawPacket) error { func (pr *packetResponder) ConnectionIndex() uint8 {
return pr.connIndex
}
func (pr *packetResponder) ReturnPacket(pk *packet.ICMP) error {
rawPacket, err := pr.encoder.Encode(pk)
if err != nil {
return err
}
pr.hadReply = true pr.hadReply = true
return pr.datagramMuxer.SendPacket(quicpogs.RawPacket(rawPacket)) return pr.datagramMuxer.SendPacket(quicpogs.RawPacket(rawPacket))
} }
func (pr *packetResponder) requestSpan(ctx context.Context, pk *packet.ICMP) (context.Context, trace.Span) { func (pr *packetResponder) AddTraceContext(tracedCtx *tracing.TracedContext, serializedIdentity []byte) {
pr.tracedCtx = tracedCtx
pr.serializedIdentity = serializedIdentity
}
func (pr *packetResponder) RequestSpan(ctx context.Context, pk *packet.ICMP) (context.Context, trace.Span) {
if !pr.tracingEnabled() { if !pr.tracingEnabled() {
return ctx, tracing.NewNoopSpan() return ctx, tracing.NewNoopSpan()
} }
@ -157,14 +165,14 @@ func (pr *packetResponder) requestSpan(ctx context.Context, pk *packet.ICMP) (co
)) ))
} }
func (pr *packetResponder) replySpan(ctx context.Context, logger *zerolog.Logger) (context.Context, trace.Span) { func (pr *packetResponder) ReplySpan(ctx context.Context, logger *zerolog.Logger) (context.Context, trace.Span) {
if !pr.tracingEnabled() || pr.hadReply { if !pr.tracingEnabled() || pr.hadReply {
return ctx, tracing.NewNoopSpan() return ctx, tracing.NewNoopSpan()
} }
return pr.tracedCtx.Tracer().Start(pr.tracedCtx, "icmp-echo-reply") return pr.tracedCtx.Tracer().Start(pr.tracedCtx, "icmp-echo-reply")
} }
func (pr *packetResponder) exportSpan() { func (pr *packetResponder) ExportSpan() {
if !pr.tracingEnabled() { if !pr.tracingEnabled() {
return return
} }

View File

@ -19,16 +19,17 @@ import (
) )
var ( var (
packetConfig = &GlobalRouterConfig{ defaultRouter = &icmpRouter{
ICMPRouter: nil, ipv4Proxy: nil,
IPv4Src: netip.MustParseAddr("172.16.0.1"), ipv4Src: netip.MustParseAddr("172.16.0.1"),
IPv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), ipv6Proxy: nil,
ipv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
} }
) )
func TestRouterReturnTTLExceed(t *testing.T) { func TestRouterReturnTTLExceed(t *testing.T) {
muxer := newMockMuxer(0) muxer := newMockMuxer(0)
router := NewPacketRouter(packetConfig, muxer, &noopLogger) router := NewPacketRouter(defaultRouter, muxer, 0, &noopLogger)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
routerStopped := make(chan struct{}) routerStopped := make(chan struct{})
go func() { go func() {
@ -53,7 +54,7 @@ func TestRouterReturnTTLExceed(t *testing.T) {
}, },
}, },
} }
assertTTLExceed(t, &pk, router.globalConfig.IPv4Src, muxer) assertTTLExceed(t, &pk, defaultRouter.ipv4Src, muxer)
pk = packet.ICMP{ pk = packet.ICMP{
IP: &packet.IP{ IP: &packet.IP{
Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
@ -71,7 +72,7 @@ func TestRouterReturnTTLExceed(t *testing.T) {
}, },
}, },
} }
assertTTLExceed(t, &pk, router.globalConfig.IPv6Src, muxer) assertTTLExceed(t, &pk, defaultRouter.ipv6Src, muxer)
cancel() cancel()
<-routerStopped <-routerStopped

View File

@ -1,6 +1,8 @@
package packet package packet
import "github.com/google/gopacket" import (
"github.com/google/gopacket"
)
var ( var (
serializeOpts = gopacket.SerializeOptions{ serializeOpts = gopacket.SerializeOptions{
@ -25,7 +27,7 @@ func NewEncoder() *Encoder {
} }
} }
func (e Encoder) Encode(packet Packet) (RawPacket, error) { func (e *Encoder) Encode(packet Packet) (RawPacket, error) {
encodedLayers, err := packet.EncodeLayers() encodedLayers, err := packet.EncodeLayers()
if err != nil { if err != nil {
return RawPacket{}, err return RawPacket{}, err

View File

@ -119,9 +119,9 @@ func (s *Supervisor) Run(
ctx context.Context, ctx context.Context,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
) error { ) error {
if s.config.PacketConfig != nil { if s.config.ICMPRouterServer != nil {
go func() { go func() {
if err := s.config.PacketConfig.ICMPRouter.Serve(ctx); err != nil { if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
s.log.Logger().Info().Err(err).Msg("icmp router terminated") s.log.Logger().Info().Err(err).Msg("icmp router terminated")
} else { } else {

View File

@ -63,7 +63,7 @@ type TunnelConfig struct {
NamedTunnel *connection.TunnelProperties NamedTunnel *connection.TunnelProperties
ProtocolSelector connection.ProtocolSelector ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config EdgeTLSConfigs map[connection.Protocol]*tls.Config
PacketConfig *ingress.GlobalRouterConfig ICMPRouterServer ingress.ICMPRouterServer
RPCTimeout time.Duration RPCTimeout time.Duration
WriteStreamTimeout time.Duration WriteStreamTimeout time.Duration
@ -615,7 +615,8 @@ func (e *EdgeTunnelServer) serveQUIC(
datagramSessionManager = connection.NewDatagramV2Connection( datagramSessionManager = connection.NewDatagramV2Connection(
ctx, ctx,
conn, conn,
e.config.PacketConfig, e.config.ICMPRouterServer,
connIndex,
e.config.RPCTimeout, e.config.RPCTimeout,
e.config.WriteStreamTimeout, e.config.WriteStreamTimeout,
connLogger.Logger(), connLogger.Logger(),