TUN-6696: Refactor flow into funnel and close idle funnels
A funnel is an abstraction for 1 source to many destinations. As part of this refactoring, shared logic between Darwin and Linux are moved into icmp_posix
This commit is contained in:
parent
e380333520
commit
2ffff0687b
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -346,13 +347,31 @@ func (pr *packetRouter) serve(ctx context.Context) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pr.icmpProxy.Request(icmpPacket, pr.muxer); err != nil {
|
flowPipe := muxerResponder{muxer: pr.muxer}
|
||||||
pr.logger.Err(err).Str("src", icmpPacket.Src.String()).Str("dst", icmpPacket.Dst.String()).Msg("Failed to send ICMP packet")
|
if err := pr.icmpProxy.Request(icmpPacket, &flowPipe); err != nil {
|
||||||
|
pr.logger.Err(err).
|
||||||
|
Str("src", icmpPacket.Src.String()).
|
||||||
|
Str("dst", icmpPacket.Dst.String()).
|
||||||
|
Interface("type", icmpPacket.Type).
|
||||||
|
Msg("Failed to send ICMP packet")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// muxerResponder wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface
|
||||||
|
type muxerResponder struct {
|
||||||
|
muxer *quicpogs.DatagramMuxerV2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *muxerResponder) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
|
||||||
|
return mr.muxer.SendPacket(pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *muxerResponder) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||||
// the client.
|
// the client.
|
||||||
type streamReadWriteAcker struct {
|
type streamReadWriteAcker struct {
|
||||||
|
|
|
@ -2,6 +2,11 @@
|
||||||
|
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
|
// This file implements ICMPProxy for Darwin. It uses a non-privileged ICMP socket to send echo requests and listen for
|
||||||
|
// echo replies. The source IP of the requests are rewritten to the bind IP of the socket and the socket reads all
|
||||||
|
// messages, so we use echo ID to distinguish the replies. Each (source IP, destination IP, echo ID) is assigned a
|
||||||
|
// unique echo ID.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -10,9 +15,8 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
|
|
||||||
|
@ -20,13 +24,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: TUN-6654 Extend support to IPv6
|
// TODO: TUN-6654 Extend support to IPv6
|
||||||
// On Darwin, a non-privileged ICMP socket can read messages from all echo IDs, so we use it for all sources.
|
|
||||||
type icmpProxy struct {
|
type icmpProxy struct {
|
||||||
// TODO: TUN-6588 clean up flows
|
srcFunnelTracker *packet.FunnelTracker
|
||||||
srcFlowTracker *packet.FlowTracker
|
|
||||||
echoIDTracker *echoIDTracker
|
echoIDTracker *echoIDTracker
|
||||||
conn *icmp.PacketConn
|
conn *icmp.PacketConn
|
||||||
|
// Response is handled in one-by-one, so encoder can be shared between funnels
|
||||||
|
encoder *packet.Encoder
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
|
idleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// echoIDTracker tracks which ID has been assigned. It first loops through assignment from lastAssignment to then end,
|
// echoIDTracker tracks which ID has been assigned. It first loops through assignment from lastAssignment to then end,
|
||||||
|
@ -92,8 +97,8 @@ func (eit *echoIDTracker) release(srcIP netip.Addr, id uint16) bool {
|
||||||
eit.lock.Lock()
|
eit.lock.Lock()
|
||||||
defer eit.lock.Unlock()
|
defer eit.lock.Unlock()
|
||||||
|
|
||||||
currentID, ok := eit.srcIPMapping[srcIP]
|
currentID, exists := eit.srcIPMapping[srcIP]
|
||||||
if ok && id == currentID {
|
if exists && id == currentID {
|
||||||
delete(eit.srcIPMapping, srcIP)
|
delete(eit.srcIPMapping, srcIP)
|
||||||
eit.assignment[id] = false
|
eit.assignment[id] = false
|
||||||
return true
|
return true
|
||||||
|
@ -101,39 +106,68 @@ func (eit *echoIDTracker) release(srcIP netip.Addr, id uint16) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
type echoFlowID uint16
|
type echoFunnelID uint16
|
||||||
|
|
||||||
func (snf echoFlowID) Type() string {
|
func (snf echoFunnelID) Type() string {
|
||||||
return "echoID"
|
return "echoID"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (snf echoFlowID) String() string {
|
func (snf echoFunnelID) String() string {
|
||||||
return strconv.FormatUint(uint64(snf), 10)
|
return strconv.FormatUint(uint64(snf), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger) (ICMPProxy, error) {
|
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) {
|
||||||
conn, err := newICMPConn(listenIP)
|
conn, err := newICMPConn(listenIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &icmpProxy{
|
return &icmpProxy{
|
||||||
srcFlowTracker: packet.NewFlowTracker(),
|
srcFunnelTracker: packet.NewFunnelTracker(),
|
||||||
echoIDTracker: newEchoIDTracker(),
|
echoIDTracker: newEchoIDTracker(),
|
||||||
|
encoder: packet.NewEncoder(),
|
||||||
conn: conn,
|
conn: conn,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
idleTimeout: idleTimeout,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error {
|
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) error {
|
||||||
if pk == nil {
|
if pk == nil {
|
||||||
return errPacketNil
|
return errPacketNil
|
||||||
}
|
}
|
||||||
switch body := pk.Message.Body.(type) {
|
// TODO: TUN-6744 assign unique flow per (src, echo ID)
|
||||||
case *icmp.Echo:
|
echoID, exists := ip.echoIDTracker.get(pk.Src)
|
||||||
return ip.sendICMPEchoRequest(pk, body, responder)
|
if !exists {
|
||||||
default:
|
originalEcho, err := getICMPEcho(pk.Message)
|
||||||
return fmt.Errorf("sending ICMP %s is not implemented", pk.Type)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
echoID, exists = ip.echoIDTracker.assign(pk.Src)
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("failed to assign unique echo ID")
|
||||||
|
}
|
||||||
|
funnelID := echoFunnelID(echoID)
|
||||||
|
originSender := originSender{
|
||||||
|
conn: ip.conn,
|
||||||
|
echoIDTracker: ip.echoIDTracker,
|
||||||
|
srcIP: pk.Src,
|
||||||
|
echoID: echoID,
|
||||||
|
}
|
||||||
|
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(echoID), originalEcho.ID, ip.encoder)
|
||||||
|
if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced {
|
||||||
|
ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel")
|
||||||
|
}
|
||||||
|
return icmpFlow.sendToDst(pk.Dst, pk.Message)
|
||||||
|
}
|
||||||
|
funnel, exists := ip.srcFunnelTracker.Get(echoFunnelID(echoID))
|
||||||
|
if !exists {
|
||||||
|
return packet.ErrFunnelNotFound
|
||||||
|
}
|
||||||
|
icmpFlow, err := toICMPEchoFlow(funnel)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return icmpFlow.sendToDst(pk.Dst, pk.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve listens for responses to the requests until context is done
|
// Serve listens for responses to the requests until context is done
|
||||||
|
@ -142,88 +176,54 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
ip.conn.Close()
|
ip.conn.Close()
|
||||||
}()
|
}()
|
||||||
|
go func() {
|
||||||
|
ip.srcFunnelTracker.ScheduleCleanup(ctx, ip.idleTimeout)
|
||||||
|
}()
|
||||||
buf := make([]byte, mtu)
|
buf := make([]byte, mtu)
|
||||||
encoder := packet.NewEncoder()
|
|
||||||
for {
|
for {
|
||||||
n, src, err := ip.conn.ReadFrom(buf)
|
n, src, err := ip.conn.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: TUN-6654 Check for IPv6
|
if err := ip.handleResponse(src, buf[:n]); err != nil {
|
||||||
msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), buf[:n])
|
ip.logger.Err(err).Str("src", src.String()).Msg("Failed to handle ICMP response")
|
||||||
if err != nil {
|
|
||||||
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to parse ICMP message")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch body := msg.Body.(type) {
|
|
||||||
case *icmp.Echo:
|
|
||||||
if err := ip.handleEchoResponse(encoder, msg, body); err != nil {
|
|
||||||
ip.logger.Error().Err(err).
|
|
||||||
Str("src", src.String()).
|
|
||||||
Str("flowID", echoFlowID(body.ID).String()).
|
|
||||||
Msg("Failed to handle ICMP response")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
ip.logger.Warn().
|
|
||||||
Str("icmpType", fmt.Sprintf("%s", msg.Type)).
|
|
||||||
Msgf("Responding to this type of ICMP is not implemented")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) sendICMPEchoRequest(pk *packet.ICMP, echo *icmp.Echo, responder packet.FlowResponder) error {
|
func (ip *icmpProxy) handleResponse(from net.Addr, rawMsg []byte) error {
|
||||||
echoID, ok := ip.echoIDTracker.get(pk.Src)
|
reply, err := parseReply(from, rawMsg)
|
||||||
if !ok {
|
if err != nil {
|
||||||
echoID, ok = ip.echoIDTracker.assign(pk.Src)
|
return err
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("failed to assign unique echo ID")
|
|
||||||
}
|
}
|
||||||
flowID := echoFlowID(echoID)
|
funnel, exists := ip.srcFunnelTracker.Get(echoFunnelID(reply.echo.ID))
|
||||||
flow := packet.Flow{
|
if !exists {
|
||||||
Src: pk.Src,
|
return packet.ErrFunnelNotFound
|
||||||
Dst: pk.Dst,
|
|
||||||
Responder: responder,
|
|
||||||
}
|
}
|
||||||
if replaced := ip.srcFlowTracker.Register(flowID, &flow, true); replaced {
|
icmpFlow, err := toICMPEchoFlow(funnel)
|
||||||
ip.logger.Info().Str("src", flow.Src.String()).Str("dst", flow.Dst.String()).Msg("Replaced flow")
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
return icmpFlow.returnToSrc(reply)
|
||||||
}
|
}
|
||||||
|
|
||||||
echo.ID = int(echoID)
|
// originSender wraps icmp.PacketConn to implement packet.FunnelUniPipe interface
|
||||||
var pseudoHeader []byte = nil
|
type originSender struct {
|
||||||
serializedMsg, err := pk.Marshal(pseudoHeader)
|
conn *icmp.PacketConn
|
||||||
if err != nil {
|
echoIDTracker *echoIDTracker
|
||||||
return errors.Wrap(err, "Failed to encode ICMP message")
|
srcIP netip.Addr
|
||||||
|
echoID uint16
|
||||||
}
|
}
|
||||||
// The address needs to be of type UDPAddr when conn is created without priviledge
|
|
||||||
_, err = ip.conn.WriteTo(serializedMsg, &net.UDPAddr{
|
func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
|
||||||
IP: pk.Dst.AsSlice(),
|
_, err := os.conn.WriteTo(pk.Data, &net.UDPAddr{
|
||||||
|
IP: dst.AsSlice(),
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) handleEchoResponse(encoder *packet.Encoder, msg *icmp.Message, echo *icmp.Echo) error {
|
func (os *originSender) Close() error {
|
||||||
flowID := echoFlowID(echo.ID)
|
os.echoIDTracker.release(os.srcIP, os.echoID)
|
||||||
flow, ok := ip.srcFlowTracker.Get(flowID)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("flow not found")
|
|
||||||
}
|
|
||||||
icmpPacket := packet.ICMP{
|
|
||||||
IP: &packet.IP{
|
|
||||||
Src: flow.Dst,
|
|
||||||
Dst: flow.Src,
|
|
||||||
Protocol: layers.IPProtocol(msg.Type.Protocol()),
|
|
||||||
},
|
|
||||||
Message: msg,
|
|
||||||
}
|
|
||||||
serializedPacket, err := encoder.Encode(&icmpPacket)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Failed to encode ICMP message")
|
|
||||||
}
|
|
||||||
if err := flow.Responder.SendPacket(serializedPacket); err != nil {
|
|
||||||
return errors.Wrap(err, "Failed to send packet to the edge")
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,6 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger) (ICMPProxy, error) {
|
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) {
|
||||||
return nil, fmt.Errorf("ICMP proxy is not implemented on %s", runtime.GOOS)
|
return nil, fmt.Errorf("ICMP proxy is not implemented on %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,16 +2,18 @@
|
||||||
|
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
|
// This file implements ICMPProxy for Linux. Each (source IP, destination IP, echo ID) opens a non-privileged ICMP socket.
|
||||||
|
// The source IP of the requests are rewritten to the bind IP of the socket and echo ID rewritten to the port number of
|
||||||
|
// the socket. The kernel ensures the socket only reads replies whose echo ID matches the port number.
|
||||||
|
// For more information about the socket, see https://man7.org/linux/man-pages/man7/icmp.7.html and https://lwn.net/Articles/422330/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
|
@ -19,30 +21,27 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/packet"
|
"github.com/cloudflare/cloudflared/packet"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The request echo ID is rewritten to the port of the socket. The kernel uses the reply echo ID to demultiplex
|
|
||||||
// We can open a socket for each source so multiple sources requesting the same destination doesn't collide
|
|
||||||
type icmpProxy struct {
|
type icmpProxy struct {
|
||||||
srcToFlowTracker *srcToFlowTracker
|
srcFunnelTracker *packet.FunnelTracker
|
||||||
listenIP netip.Addr
|
listenIP netip.Addr
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
shutdownC chan struct{}
|
idleTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger) (ICMPProxy, error) {
|
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) {
|
||||||
if err := testPermission(listenIP); err != nil {
|
if err := testPermission(listenIP); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &icmpProxy{
|
return &icmpProxy{
|
||||||
srcToFlowTracker: newSrcToConnTracker(),
|
srcFunnelTracker: packet.NewFunnelTracker(),
|
||||||
listenIP: listenIP,
|
listenIP: listenIP,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
shutdownC: make(chan struct{}),
|
idleTimeout: idleTimeout,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPermission(listenIP netip.Addr) error {
|
func testPermission(listenIP netip.Addr) error {
|
||||||
// Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range
|
// Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range
|
||||||
// For more information, see https://man7.org/linux/man-pages/man7/icmp.7.html and https://lwn.net/Articles/422330/
|
|
||||||
conn, err := newICMPConn(listenIP)
|
conn, err := newICMPConn(listenIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO: TUN-6715 check if cloudflared is in ping_group_range if the check failed. If not log instruction to
|
// TODO: TUN-6715 check if cloudflared is in ping_group_range if the check failed. If not log instruction to
|
||||||
|
@ -54,213 +53,105 @@ func testPermission(listenIP netip.Addr) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error {
|
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) error {
|
||||||
if pk == nil {
|
if pk == nil {
|
||||||
return errPacketNil
|
return errPacketNil
|
||||||
}
|
}
|
||||||
echo, err := getICMPEcho(pk)
|
funnelID := srcIPFunnelID(pk.Src)
|
||||||
|
funnel, exists := ip.srcFunnelTracker.Get(funnelID)
|
||||||
|
if !exists {
|
||||||
|
originalEcho, err := getICMPEcho(pk.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return ip.sendICMPEchoRequest(pk, echo, responder)
|
conn, err := newICMPConn(ip.listenIP)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to open ICMP socket")
|
||||||
|
}
|
||||||
|
localUDPAddr, ok := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr())
|
||||||
|
}
|
||||||
|
originSender := originSender{conn: conn}
|
||||||
|
echoID := localUDPAddr.Port
|
||||||
|
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, echoID, originalEcho.ID, packet.NewEncoder())
|
||||||
|
if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced {
|
||||||
|
ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel")
|
||||||
|
}
|
||||||
|
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to send ICMP echo request")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
|
||||||
|
if err := ip.listenResponse(icmpFlow, conn); err != nil {
|
||||||
|
ip.logger.Err(err).
|
||||||
|
Str("funnelID", funnelID.String()).
|
||||||
|
Int("echoID", echoID).
|
||||||
|
Msg("Failed to listen for ICMP echo response")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
icmpFlow, err := toICMPEchoFlow(funnel)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to send ICMP echo request")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) Serve(ctx context.Context) error {
|
func (ip *icmpProxy) Serve(ctx context.Context) error {
|
||||||
<-ctx.Done()
|
ip.srcFunnelTracker.ScheduleCleanup(ctx, ip.idleTimeout)
|
||||||
close(ip.shutdownC)
|
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) sendICMPEchoRequest(pk *packet.ICMP, echo *icmp.Echo, responder packet.FlowResponder) error {
|
func (ip *icmpProxy) listenResponse(flow *icmpEchoFlow, conn *icmp.PacketConn) error {
|
||||||
icmpFlow, ok := ip.srcToFlowTracker.get(pk.Src)
|
|
||||||
if ok {
|
|
||||||
return icmpFlow.send(pk)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := newICMPConn(ip.listenIP)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
flow := packet.Flow{
|
|
||||||
Src: pk.Src,
|
|
||||||
Dst: pk.Dst,
|
|
||||||
Responder: responder,
|
|
||||||
}
|
|
||||||
icmpFlow = newICMPFlow(conn, &flow, uint16(echo.ID), ip.logger)
|
|
||||||
go func() {
|
|
||||||
defer ip.srcToFlowTracker.delete(pk.Src)
|
|
||||||
|
|
||||||
if err := icmpFlow.serve(ip.shutdownC, defaultCloseAfterIdle); err != nil {
|
|
||||||
ip.logger.Debug().Err(err).Uint16("flowID", icmpFlow.echoID).Msg("flow terminated")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ip.srcToFlowTracker.set(pk.Src, icmpFlow)
|
|
||||||
return icmpFlow.send(pk)
|
|
||||||
}
|
|
||||||
|
|
||||||
type srcIPFlowID netip.Addr
|
|
||||||
|
|
||||||
func (sifd srcIPFlowID) Type() string {
|
|
||||||
return "srcIP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sifd srcIPFlowID) String() string {
|
|
||||||
return netip.Addr(sifd).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type srcToFlowTracker struct {
|
|
||||||
lock sync.RWMutex
|
|
||||||
// srcIPToConn tracks source IP to ICMP connection
|
|
||||||
srcToFlow map[netip.Addr]*icmpFlow
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSrcToConnTracker() *srcToFlowTracker {
|
|
||||||
return &srcToFlowTracker{
|
|
||||||
srcToFlow: make(map[netip.Addr]*icmpFlow),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sft *srcToFlowTracker) get(srcIP netip.Addr) (*icmpFlow, bool) {
|
|
||||||
sft.lock.RLock()
|
|
||||||
defer sft.lock.RUnlock()
|
|
||||||
|
|
||||||
flow, ok := sft.srcToFlow[srcIP]
|
|
||||||
return flow, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sft *srcToFlowTracker) set(srcIP netip.Addr, flow *icmpFlow) {
|
|
||||||
sft.lock.Lock()
|
|
||||||
defer sft.lock.Unlock()
|
|
||||||
|
|
||||||
sft.srcToFlow[srcIP] = flow
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sft *srcToFlowTracker) delete(srcIP netip.Addr) {
|
|
||||||
sft.lock.Lock()
|
|
||||||
defer sft.lock.Unlock()
|
|
||||||
|
|
||||||
delete(sft.srcToFlow, srcIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
type icmpFlow struct {
|
|
||||||
conn *icmp.PacketConn
|
|
||||||
flow *packet.Flow
|
|
||||||
echoID uint16
|
|
||||||
// last active unix time. Unit is seconds
|
|
||||||
lastActive int64
|
|
||||||
logger *zerolog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func newICMPFlow(conn *icmp.PacketConn, flow *packet.Flow, echoID uint16, logger *zerolog.Logger) *icmpFlow {
|
|
||||||
return &icmpFlow{
|
|
||||||
conn: conn,
|
|
||||||
flow: flow,
|
|
||||||
echoID: echoID,
|
|
||||||
lastActive: time.Now().Unix(),
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *icmpFlow) serve(shutdownC chan struct{}, closeAfterIdle time.Duration) error {
|
|
||||||
errC := make(chan error)
|
|
||||||
go func() {
|
|
||||||
errC <- f.listenResponse()
|
|
||||||
}()
|
|
||||||
|
|
||||||
checkIdleTicker := time.NewTicker(closeAfterIdle)
|
|
||||||
defer f.conn.Close()
|
|
||||||
defer checkIdleTicker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case err := <-errC:
|
|
||||||
return err
|
|
||||||
case <-shutdownC:
|
|
||||||
return nil
|
|
||||||
case <-checkIdleTicker.C:
|
|
||||||
now := time.Now().Unix()
|
|
||||||
lastActive := atomic.LoadInt64(&f.lastActive)
|
|
||||||
if now > lastActive+int64(closeAfterIdle.Seconds()) {
|
|
||||||
return errFlowInactive
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *icmpFlow) send(pk *packet.ICMP) error {
|
|
||||||
f.updateLastActive()
|
|
||||||
|
|
||||||
// For IPv4, the pseudoHeader is not used because the checksum is always calculated
|
|
||||||
var pseudoHeader []byte = nil
|
|
||||||
serializedMsg, err := pk.Marshal(pseudoHeader)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Failed to encode ICMP message")
|
|
||||||
}
|
|
||||||
// The address needs to be of type UDPAddr when conn is created without priviledge
|
|
||||||
_, err = f.conn.WriteTo(serializedMsg, &net.UDPAddr{
|
|
||||||
IP: pk.Dst.AsSlice(),
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *icmpFlow) listenResponse() error {
|
|
||||||
buf := make([]byte, mtu)
|
buf := make([]byte, mtu)
|
||||||
encoder := packet.NewEncoder()
|
|
||||||
for {
|
for {
|
||||||
n, src, err := f.conn.ReadFrom(buf)
|
n, src, err := conn.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
f.updateLastActive()
|
|
||||||
|
|
||||||
if err := f.handleResponse(encoder, src, buf[:n]); err != nil {
|
if err := ip.handleResponse(flow, src, buf[:n]); err != nil {
|
||||||
f.logger.Err(err).Str("dst", src.String()).Msg("Failed to handle ICMP response")
|
ip.logger.Err(err).Str("dst", src.String()).Msg("Failed to handle ICMP response")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *icmpFlow) handleResponse(encoder *packet.Encoder, from net.Addr, rawPacket []byte) error {
|
func (ip *icmpProxy) handleResponse(flow *icmpEchoFlow, from net.Addr, rawMsg []byte) error {
|
||||||
// TODO: TUN-6654 Check for IPv6
|
reply, err := parseReply(from, rawMsg)
|
||||||
msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), rawPacket)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return flow.returnToSrc(reply)
|
||||||
echo, ok := msg.Body.(*icmp.Echo)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("received unexpected icmp type %s from non-privileged ICMP socket", msg.Type)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
addrPort, err := netip.ParseAddrPort(from.String())
|
// originSender wraps icmp.PacketConn to implement packet.FunnelUniPipe interface
|
||||||
if err != nil {
|
type originSender struct {
|
||||||
|
conn *icmp.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
|
||||||
|
_, err := os.conn.WriteTo(pk.Data, &net.UDPAddr{
|
||||||
|
IP: dst.AsSlice(),
|
||||||
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
icmpPacket := packet.ICMP{
|
|
||||||
IP: &packet.IP{
|
func (os *originSender) Close() error {
|
||||||
Src: addrPort.Addr(),
|
return os.conn.Close()
|
||||||
Dst: f.flow.Src,
|
|
||||||
Protocol: layers.IPProtocol(msg.Type.Protocol()),
|
|
||||||
},
|
|
||||||
Message: &icmp.Message{
|
|
||||||
Type: msg.Type,
|
|
||||||
Code: msg.Code,
|
|
||||||
Body: &icmp.Echo{
|
|
||||||
ID: int(f.echoID),
|
|
||||||
Seq: echo.Seq,
|
|
||||||
Data: echo.Data,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
serializedPacket, err := encoder.Encode(&icmpPacket)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Failed to encode ICMP message")
|
|
||||||
}
|
|
||||||
if err := f.flow.Responder.SendPacket(serializedPacket); err != nil {
|
|
||||||
return errors.Wrap(err, "Failed to send packet to the edge")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *icmpFlow) updateLastActive() {
|
type srcIPFunnelID netip.Addr
|
||||||
atomic.StoreInt64(&f.lastActive, time.Now().Unix())
|
|
||||||
|
func (sifd srcIPFunnelID) Type() string {
|
||||||
|
return "srcIP"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sifd srcIPFunnelID) String() string {
|
||||||
|
return netip.Addr(sifd).String()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
//go:build linux
|
|
||||||
|
|
||||||
package ingress
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/packet"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCloseIdleFlow(t *testing.T) {
|
|
||||||
const (
|
|
||||||
echoID = 19234
|
|
||||||
idleTimeout = time.Millisecond * 100
|
|
||||||
)
|
|
||||||
conn, err := newICMPConn(localhostIP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
flow := packet.Flow{
|
|
||||||
Src: netip.MustParseAddr("172.16.0.1"),
|
|
||||||
}
|
|
||||||
icmpFlow := newICMPFlow(conn, &flow, echoID, &noopLogger)
|
|
||||||
shutdownC := make(chan struct{})
|
|
||||||
flowErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
flowErr <- icmpFlow.serve(shutdownC, idleTimeout)
|
|
||||||
}()
|
|
||||||
|
|
||||||
require.Equal(t, errFlowInactive, <-flowErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseConnStopFlow(t *testing.T) {
|
|
||||||
const (
|
|
||||||
echoID = 19234
|
|
||||||
)
|
|
||||||
conn, err := newICMPConn(localhostIP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
flow := packet.Flow{
|
|
||||||
Src: netip.MustParseAddr("172.16.0.1"),
|
|
||||||
}
|
|
||||||
icmpFlow := newICMPFlow(conn, &flow, echoID, &noopLogger)
|
|
||||||
shutdownC := make(chan struct{})
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
err = icmpFlow.serve(shutdownC, defaultCloseAfterIdle)
|
|
||||||
require.True(t, errors.Is(err, net.ErrClosed))
|
|
||||||
}
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
//go:build darwin || linux
|
||||||
|
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
// This file extracts logic shared by Linux and Darwin implementation if ICMPProxy.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Opens a non-privileged ICMP socket on Linux and Darwin
|
||||||
|
func newICMPConn(listenIP netip.Addr) (*icmp.PacketConn, error) {
|
||||||
|
network := "udp6"
|
||||||
|
if listenIP.Is4() {
|
||||||
|
network = "udp4"
|
||||||
|
}
|
||||||
|
return icmp.ListenPacket(network, listenIP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func netipAddr(addr net.Addr) (netip.Addr, bool) {
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
return netip.AddrFromSlice(udpAddr.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowID struct {
|
||||||
|
srcIP netip.Addr
|
||||||
|
echoID int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fi *flowID) Type() string {
|
||||||
|
return "srcIP_echoID"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fi *flowID) String() string {
|
||||||
|
return fmt.Sprintf("%s:%d", fi.srcIP, fi.echoID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// icmpEchoFlow implements the packet.Funnel interface.
|
||||||
|
type icmpEchoFlow struct {
|
||||||
|
*packet.RawPacketFunnel
|
||||||
|
assignedEchoID int
|
||||||
|
originalEchoID int
|
||||||
|
// it's up to the user to ensure respEncoder is not used concurrently
|
||||||
|
respEncoder *packet.Encoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func newICMPEchoFlow(src netip.Addr, sendPipe, returnPipe packet.FunnelUniPipe, assignedEchoID, originalEchoID int, respEncoder *packet.Encoder) *icmpEchoFlow {
|
||||||
|
return &icmpEchoFlow{
|
||||||
|
RawPacketFunnel: packet.NewRawPacketFunnel(src, sendPipe, returnPipe),
|
||||||
|
assignedEchoID: assignedEchoID,
|
||||||
|
originalEchoID: originalEchoID,
|
||||||
|
respEncoder: respEncoder,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendToDst rewrites the echo ID to the one assigned to this flow
|
||||||
|
func (ief *icmpEchoFlow) sendToDst(dst netip.Addr, msg *icmp.Message) error {
|
||||||
|
originalEcho, err := getICMPEcho(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sendMsg := icmp.Message{
|
||||||
|
Type: msg.Type,
|
||||||
|
Code: msg.Code,
|
||||||
|
Body: &icmp.Echo{
|
||||||
|
ID: ief.assignedEchoID,
|
||||||
|
Seq: originalEcho.Seq,
|
||||||
|
Data: originalEcho.Data,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// For IPv4, the pseudoHeader is not used because the checksum is always calculated
|
||||||
|
var pseudoHeader []byte = nil
|
||||||
|
serializedPacket, err := sendMsg.Marshal(pseudoHeader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ief.SendToDst(dst, packet.RawPacket{Data: serializedPacket})
|
||||||
|
}
|
||||||
|
|
||||||
|
// returnToSrc rewrites the echo ID to the original echo ID from the eyeball
|
||||||
|
func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) error {
|
||||||
|
reply.echo.ID = ief.originalEchoID
|
||||||
|
reply.msg.Body = reply.echo
|
||||||
|
pk := packet.ICMP{
|
||||||
|
IP: &packet.IP{
|
||||||
|
Src: reply.from,
|
||||||
|
Dst: ief.Src,
|
||||||
|
Protocol: layers.IPProtocol(reply.msg.Type.Protocol()),
|
||||||
|
},
|
||||||
|
Message: reply.msg,
|
||||||
|
}
|
||||||
|
serializedPacket, err := ief.respEncoder.Encode(&pk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ief.ReturnToSrc(serializedPacket)
|
||||||
|
}
|
||||||
|
|
||||||
|
type echoReply struct {
|
||||||
|
from netip.Addr
|
||||||
|
msg *icmp.Message
|
||||||
|
echo *icmp.Echo
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseReply(from net.Addr, rawMsg []byte) (*echoReply, error) {
|
||||||
|
// TODO: TUN-6654 Check for IPv6
|
||||||
|
msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), rawMsg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
echo, err := getICMPEcho(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fromAddr, ok := netipAddr(from)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("cannot convert %s to netip.Addr", from)
|
||||||
|
}
|
||||||
|
return &echoReply{
|
||||||
|
from: fromAddr,
|
||||||
|
msg: msg,
|
||||||
|
echo: echo,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toICMPEchoFlow(funnel packet.Funnel) (*icmpEchoFlow, error) {
|
||||||
|
icmpFlow, ok := funnel.(*icmpEchoFlow)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("%v is not *ICMPEchoFunnel", funnel)
|
||||||
|
}
|
||||||
|
return icmpFlow, nil
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
//go:build darwin || linux
|
||||||
|
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFunnelIdleTimeout(t *testing.T) {
|
||||||
|
const (
|
||||||
|
idleTimeout = time.Second
|
||||||
|
echoID = 42573
|
||||||
|
startSeq = 8129
|
||||||
|
)
|
||||||
|
logger := zerolog.New(os.Stderr)
|
||||||
|
proxy, err := newICMPProxy(localhostIP, &logger, idleTimeout)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
proxyDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
proxy.Serve(ctx)
|
||||||
|
close(proxyDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send a packet to register the flow
|
||||||
|
pk := packet.ICMP{
|
||||||
|
IP: &packet.IP{
|
||||||
|
Src: localhostIP,
|
||||||
|
Dst: localhostIP,
|
||||||
|
Protocol: layers.IPProtocolICMPv4,
|
||||||
|
},
|
||||||
|
Message: &icmp.Message{
|
||||||
|
Type: ipv4.ICMPTypeEcho,
|
||||||
|
Code: 0,
|
||||||
|
Body: &icmp.Echo{
|
||||||
|
ID: echoID,
|
||||||
|
Seq: startSeq,
|
||||||
|
Data: []byte(t.Name()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
responder := echoFlowResponder{
|
||||||
|
decoder: packet.NewICMPDecoder(),
|
||||||
|
respChan: make(chan []byte),
|
||||||
|
}
|
||||||
|
require.NoError(t, proxy.Request(&pk, &responder))
|
||||||
|
responder.validate(t, &pk)
|
||||||
|
|
||||||
|
// Send second request, should reuse the funnel
|
||||||
|
require.NoError(t, proxy.Request(&pk, nil))
|
||||||
|
responder.validate(t, &pk)
|
||||||
|
|
||||||
|
time.Sleep(idleTimeout * 2)
|
||||||
|
newResponder := echoFlowResponder{
|
||||||
|
decoder: packet.NewICMPDecoder(),
|
||||||
|
respChan: make(chan []byte),
|
||||||
|
}
|
||||||
|
require.NoError(t, proxy.Request(&pk, &newResponder))
|
||||||
|
newResponder.validate(t, &pk)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
<-proxyDone
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
@ -143,7 +144,7 @@ type icmpProxy struct {
|
||||||
encoderPool sync.Pool
|
encoderPool sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger) (ICMPProxy, error) {
|
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) {
|
||||||
handle, _, err := IcmpCreateFile_proc.Call()
|
handle, _, err := IcmpCreateFile_proc.Call()
|
||||||
// Windows procedure calls always return non-nil error constructed from the result of GetLastError.
|
// Windows procedure calls always return non-nil error constructed from the result of GetLastError.
|
||||||
// Caller need to inspect the primary returned value
|
// Caller need to inspect the primary returned value
|
||||||
|
@ -167,7 +168,7 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error {
|
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) error {
|
||||||
if pk == nil {
|
if pk == nil {
|
||||||
return errPacketNil
|
return errPacketNil
|
||||||
}
|
}
|
||||||
|
@ -176,7 +177,7 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) er
|
||||||
ip.logger.Error().Interface("error", r).Msgf("Recover panic from sending icmp request/response, error %s", debug.Stack())
|
ip.logger.Error().Interface("error", r).Msgf("Recover panic from sending icmp request/response, error %s", debug.Stack())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
echo, err := getICMPEcho(pk)
|
echo, err := getICMPEcho(pk.Message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -193,7 +194,7 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) er
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) handleEchoResponse(request *packet.ICMP, echoReq *icmp.Echo, resp *echoResp, responder packet.FlowResponder) error {
|
func (ip *icmpProxy) handleEchoResponse(request *packet.ICMP, echoReq *icmp.Echo, resp *echoResp, responder packet.FunnelUniPipe) error {
|
||||||
var replyType icmp.Type
|
var replyType icmp.Type
|
||||||
if request.Dst.Is4() {
|
if request.Dst.Is4() {
|
||||||
replyType = ipv4.ICMPTypeEchoReply
|
replyType = ipv4.ICMPTypeEchoReply
|
||||||
|
@ -222,7 +223,7 @@ func (ip *icmpProxy) handleEchoResponse(request *packet.ICMP, echoReq *icmp.Echo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return responder.SendPacket(serializedPacket)
|
return responder.SendPacket(request.Src, serializedPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ip *icmpProxy) encodeICMPReply(pk *packet.ICMP) (packet.RawPacket, error) {
|
func (ip *icmpProxy) encodeICMPReply(pk *packet.ICMP) (packet.RawPacket, error) {
|
||||||
|
@ -262,7 +263,7 @@ func (ip *icmpProxy) icmpSendEcho(dst netip.Addr, echo *icmp.Echo) (*echoResp, e
|
||||||
}
|
}
|
||||||
replyCount, _, err := IcmpSendEcho_proc.Call(ip.handle, uintptr(inAddr), uintptr(unsafe.Pointer(&echo.Data[0])),
|
replyCount, _, err := IcmpSendEcho_proc.Call(ip.handle, uintptr(inAddr), uintptr(unsafe.Pointer(&echo.Data[0])),
|
||||||
uintptr(dataSize), noIPHeaderOption, uintptr(unsafe.Pointer(&replyBuf[0])),
|
uintptr(dataSize), noIPHeaderOption, uintptr(unsafe.Pointer(&replyBuf[0])),
|
||||||
replySize, icmpTimeoutMs)
|
replySize, icmpRequestTimeoutMs)
|
||||||
if replyCount == 0 {
|
if replyCount == 0 {
|
||||||
// status is returned in 5th to 8th byte of reply buffer
|
// status is returned in 5th to 8th byte of reply buffer
|
||||||
if status, err := unmarshalIPStatus(replyBuf[4:8]); err == nil {
|
if status, err := unmarshalIPStatus(replyBuf[4:8]); err == nil {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
|
@ -72,7 +73,7 @@ func TestParseEchoReply(t *testing.T) {
|
||||||
|
|
||||||
// TestSendEchoErrors makes sure icmpSendEcho handles error cases
|
// TestSendEchoErrors makes sure icmpSendEcho handles error cases
|
||||||
func TestSendEchoErrors(t *testing.T) {
|
func TestSendEchoErrors(t *testing.T) {
|
||||||
proxy, err := newICMPProxy(localhostIP, &noopLogger)
|
proxy, err := newICMPProxy(localhostIP, &noopLogger, time.Second)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
winProxy := proxy.(*icmpProxy)
|
winProxy := proxy.(*icmpProxy)
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultCloseAfterIdle = time.Second * 15
|
// funnelIdleTimeout controls how long to wait to close a funnel without send/return
|
||||||
|
funnelIdleTimeout = time.Second * 10
|
||||||
mtu = 1500
|
mtu = 1500
|
||||||
icmpTimeoutMs = 1000
|
// icmpRequestTimeoutMs controls how long to wait for a reply
|
||||||
|
icmpRequestTimeoutMs = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errFlowInactive = fmt.Errorf("flow is inactive")
|
|
||||||
errPacketNil = fmt.Errorf("packet is nil")
|
errPacketNil = fmt.Errorf("packet is nil")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,26 +29,17 @@ type ICMPProxy interface {
|
||||||
// Serve starts listening for responses to the requests until context is done
|
// Serve starts listening for responses to the requests until context is done
|
||||||
Serve(ctx context.Context) error
|
Serve(ctx context.Context) error
|
||||||
// Request sends an ICMP message
|
// Request sends an ICMP message
|
||||||
Request(pk *packet.ICMP, responder packet.FlowResponder) error
|
Request(pk *packet.ICMP, responder packet.FunnelUniPipe) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICMPProxy(listenIP netip.Addr, logger *zerolog.Logger) (ICMPProxy, error) {
|
func NewICMPProxy(listenIP netip.Addr, logger *zerolog.Logger) (ICMPProxy, error) {
|
||||||
return newICMPProxy(listenIP, logger)
|
return newICMPProxy(listenIP, logger, funnelIdleTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Opens a non-privileged ICMP socket on Linux and Darwin
|
func getICMPEcho(msg *icmp.Message) (*icmp.Echo, error) {
|
||||||
func newICMPConn(listenIP netip.Addr) (*icmp.PacketConn, error) {
|
echo, ok := msg.Body.(*icmp.Echo)
|
||||||
network := "udp6"
|
|
||||||
if listenIP.Is4() {
|
|
||||||
network = "udp4"
|
|
||||||
}
|
|
||||||
return icmp.ListenPacket(network, listenIP.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func getICMPEcho(pk *packet.ICMP) (*icmp.Echo, error) {
|
|
||||||
echo, ok := pk.Message.Body.(*icmp.Echo)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("expect ICMP echo, got %s", pk.Type)
|
return nil, fmt.Errorf("expect ICMP echo, got %s", msg.Type)
|
||||||
}
|
}
|
||||||
return echo, nil
|
return echo, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package ingress
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -47,27 +48,49 @@ func TestICMPProxyEcho(t *testing.T) {
|
||||||
respChan: make(chan []byte, 1),
|
respChan: make(chan []byte, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := packet.IP{
|
ips := []packet.IP{
|
||||||
|
{
|
||||||
Src: localhostIP,
|
Src: localhostIP,
|
||||||
Dst: localhostIP,
|
Dst: localhostIP,
|
||||||
Protocol: layers.IPProtocolICMPv4,
|
Protocol: layers.IPProtocolICMPv4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for i := 0; i < endSeq; i++ {
|
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if ipnet, ok := addr.(*net.IPNet); ok {
|
||||||
|
ip := ipnet.IP
|
||||||
|
if !ipnet.IP.IsLoopback() && ip.IsPrivate() && ip.To4() != nil {
|
||||||
|
localIP := netip.MustParseAddr(ipnet.IP.String())
|
||||||
|
ips = append(ips, packet.IP{
|
||||||
|
Src: localIP,
|
||||||
|
Dst: localIP,
|
||||||
|
Protocol: layers.IPProtocolICMPv4,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for seq := 0; seq < endSeq; seq++ {
|
||||||
|
for i, ip := range ips {
|
||||||
pk := packet.ICMP{
|
pk := packet.ICMP{
|
||||||
IP: &ip,
|
IP: &ip,
|
||||||
Message: &icmp.Message{
|
Message: &icmp.Message{
|
||||||
Type: ipv4.ICMPTypeEcho,
|
Type: ipv4.ICMPTypeEcho,
|
||||||
Code: 0,
|
Code: 0,
|
||||||
Body: &icmp.Echo{
|
Body: &icmp.Echo{
|
||||||
ID: echoID,
|
ID: echoID + i,
|
||||||
Seq: i,
|
Seq: seq,
|
||||||
Data: []byte(fmt.Sprintf("icmp echo seq %d", i)),
|
Data: []byte(fmt.Sprintf("icmp echo seq %d", seq)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.NoError(t, proxy.Request(&pk, &responder))
|
require.NoError(t, proxy.Request(&pk, &responder))
|
||||||
responder.validate(t, &pk)
|
responder.validate(t, &pk)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
cancel()
|
cancel()
|
||||||
<-proxyDone
|
<-proxyDone
|
||||||
}
|
}
|
||||||
|
@ -123,13 +146,18 @@ type echoFlowResponder struct {
|
||||||
respChan chan []byte
|
respChan chan []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (efr *echoFlowResponder) SendPacket(pk packet.RawPacket) error {
|
func (efr *echoFlowResponder) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
|
||||||
copiedPacket := make([]byte, len(pk.Data))
|
copiedPacket := make([]byte, len(pk.Data))
|
||||||
copy(copiedPacket, pk.Data)
|
copy(copiedPacket, pk.Data)
|
||||||
efr.respChan <- copiedPacket
|
efr.respChan <- copiedPacket
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (efr *echoFlowResponder) Close() error {
|
||||||
|
close(efr.respChan)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
|
func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
|
||||||
pk := <-efr.respChan
|
pk := <-efr.respChan
|
||||||
decoded, err := efr.decoder.Decode(packet.RawPacket{Data: pk})
|
decoded, err := efr.decoder.Decode(packet.RawPacket{Data: pk})
|
||||||
|
|
|
@ -1,94 +0,0 @@
|
||||||
package packet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrFlowNotFound = errors.New("flow not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
// FlowID represents a key type that can be used by FlowTracker
|
|
||||||
type FlowID interface {
|
|
||||||
// Type returns the name of the type that implements the FlowID
|
|
||||||
Type() string
|
|
||||||
fmt.Stringer
|
|
||||||
}
|
|
||||||
|
|
||||||
type Flow struct {
|
|
||||||
Src netip.Addr
|
|
||||||
Dst netip.Addr
|
|
||||||
Responder FlowResponder
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSameFlow(f1, f2 *Flow) bool {
|
|
||||||
if f1 == nil || f2 == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return *f1 == *f2
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlowResponder sends response packets to the flow
|
|
||||||
type FlowResponder interface {
|
|
||||||
// SendPacket returns a packet to the flow. It must not modify the packet,
|
|
||||||
// and after return it must not read the packet
|
|
||||||
SendPacket(pk RawPacket) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// FlowTracker tracks flow from the perspective of eyeball to origin
|
|
||||||
type FlowTracker struct {
|
|
||||||
lock sync.RWMutex
|
|
||||||
flows map[FlowID]*Flow
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFlowTracker() *FlowTracker {
|
|
||||||
return &FlowTracker{
|
|
||||||
flows: make(map[FlowID]*Flow),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sft *FlowTracker) Get(id FlowID) (*Flow, bool) {
|
|
||||||
sft.lock.RLock()
|
|
||||||
defer sft.lock.RUnlock()
|
|
||||||
flow, ok := sft.flows[id]
|
|
||||||
return flow, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// Registers a flow. If shouldReplace = true, replace the current flow
|
|
||||||
func (sft *FlowTracker) Register(id FlowID, flow *Flow, shouldReplace bool) (replaced bool) {
|
|
||||||
sft.lock.Lock()
|
|
||||||
defer sft.lock.Unlock()
|
|
||||||
currentFlow, ok := sft.flows[id]
|
|
||||||
if !ok {
|
|
||||||
sft.flows[id] = flow
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldReplace && !isSameFlow(currentFlow, flow) {
|
|
||||||
sft.flows[id] = flow
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unregisters a flow. If force = true, delete it even if it maps to a different flow
|
|
||||||
func (sft *FlowTracker) Unregister(id FlowID, flow *Flow, force bool) (forceDeleted bool) {
|
|
||||||
sft.lock.Lock()
|
|
||||||
defer sft.lock.Unlock()
|
|
||||||
currentFlow, ok := sft.flows[id]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if isSameFlow(currentFlow, flow) {
|
|
||||||
delete(sft.flows, id)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if force {
|
|
||||||
delete(sft.flows, id)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
package packet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrFunnelNotFound = errors.New("funnel not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Funnel is an abstraction to pipe from 1 src to 1 or more destinations
|
||||||
|
type Funnel interface {
|
||||||
|
// SendToDst sends a raw packet to a destination
|
||||||
|
SendToDst(dst netip.Addr, pk RawPacket) error
|
||||||
|
// ReturnToSrc returns a raw packet to the source
|
||||||
|
ReturnToSrc(pk RawPacket) error
|
||||||
|
// LastActive returns the last time SendToDst or ReturnToSrc is called
|
||||||
|
LastActive() time.Time
|
||||||
|
// Close closes the funnel. Further call to SendToDst or ReturnToSrc should return an error
|
||||||
|
Close() error
|
||||||
|
// Equal compares if 2 funnels are equivalent
|
||||||
|
Equal(other Funnel) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunnelUniPipe is a unidirectional pipe for sending raw packets
|
||||||
|
type FunnelUniPipe interface {
|
||||||
|
// SendPacket sends a packet to/from the funnel. It must not modify the packet,
|
||||||
|
// and after return it must not read the packet
|
||||||
|
SendPacket(dst netip.Addr, pk RawPacket) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawPacketFunnel is an implementation of Funnel that sends raw packets. It can be embedded in other structs to
|
||||||
|
// satisfy the Funnel interface.
|
||||||
|
type RawPacketFunnel struct {
|
||||||
|
Src netip.Addr
|
||||||
|
// last active unix time. Unit is seconds
|
||||||
|
lastActive int64
|
||||||
|
sendPipe FunnelUniPipe
|
||||||
|
returnPipe FunnelUniPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRawPacketFunnel(src netip.Addr, sendPipe, returnPipe FunnelUniPipe) *RawPacketFunnel {
|
||||||
|
return &RawPacketFunnel{
|
||||||
|
Src: src,
|
||||||
|
lastActive: time.Now().Unix(),
|
||||||
|
sendPipe: sendPipe,
|
||||||
|
returnPipe: returnPipe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rpf *RawPacketFunnel) SendToDst(dst netip.Addr, pk RawPacket) error {
|
||||||
|
rpf.updateLastActive()
|
||||||
|
return rpf.sendPipe.SendPacket(dst, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rpf *RawPacketFunnel) ReturnToSrc(pk RawPacket) error {
|
||||||
|
rpf.updateLastActive()
|
||||||
|
return rpf.returnPipe.SendPacket(rpf.Src, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rpf *RawPacketFunnel) updateLastActive() {
|
||||||
|
atomic.StoreInt64(&rpf.lastActive, time.Now().Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rpf *RawPacketFunnel) LastActive() time.Time {
|
||||||
|
lastActive := atomic.LoadInt64(&rpf.lastActive)
|
||||||
|
return time.Unix(lastActive, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rpf *RawPacketFunnel) Close() error {
|
||||||
|
sendPipeErr := rpf.sendPipe.Close()
|
||||||
|
returnPipeErr := rpf.returnPipe.Close()
|
||||||
|
if sendPipeErr != nil {
|
||||||
|
return sendPipeErr
|
||||||
|
}
|
||||||
|
if returnPipeErr != nil {
|
||||||
|
return returnPipeErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rpf *RawPacketFunnel) Equal(other Funnel) bool {
|
||||||
|
otherRawFunnel, ok := other.(*RawPacketFunnel)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if rpf.Src != otherRawFunnel.Src {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if rpf.sendPipe != otherRawFunnel.sendPipe {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if rpf.returnPipe != otherRawFunnel.returnPipe {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunnelID represents a key type that can be used by FunnelTracker
|
||||||
|
type FunnelID interface {
|
||||||
|
// Type returns the name of the type that implements the FunnelID
|
||||||
|
Type() string
|
||||||
|
fmt.Stringer
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunnelTracker tracks funnel from the perspective of eyeball to origin
|
||||||
|
type FunnelTracker struct {
|
||||||
|
lock sync.RWMutex
|
||||||
|
funnels map[FunnelID]Funnel
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFunnelTracker() *FunnelTracker {
|
||||||
|
return &FunnelTracker{
|
||||||
|
funnels: make(map[FunnelID]Funnel),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ft *FunnelTracker) ScheduleCleanup(ctx context.Context, idleTimeout time.Duration) {
|
||||||
|
checkIdleTicker := time.NewTicker(idleTimeout)
|
||||||
|
defer checkIdleTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-checkIdleTicker.C:
|
||||||
|
ft.cleanup(idleTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ft *FunnelTracker) cleanup(idleTimeout time.Duration) {
|
||||||
|
ft.lock.Lock()
|
||||||
|
defer ft.lock.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for id, funnel := range ft.funnels {
|
||||||
|
lastActive := funnel.LastActive()
|
||||||
|
if now.After(lastActive.Add(idleTimeout)) {
|
||||||
|
funnel.Close()
|
||||||
|
delete(ft.funnels, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ft *FunnelTracker) Get(id FunnelID) (Funnel, bool) {
|
||||||
|
ft.lock.RLock()
|
||||||
|
defer ft.lock.RUnlock()
|
||||||
|
funnel, ok := ft.funnels[id]
|
||||||
|
return funnel, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registers a funnel. It replaces the current funnel.
|
||||||
|
func (ft *FunnelTracker) Register(id FunnelID, funnel Funnel) (replaced bool) {
|
||||||
|
ft.lock.Lock()
|
||||||
|
defer ft.lock.Unlock()
|
||||||
|
currentFunnel, exists := ft.funnels[id]
|
||||||
|
if !exists {
|
||||||
|
ft.funnels[id] = funnel
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
replaced = !currentFunnel.Equal(funnel)
|
||||||
|
if replaced {
|
||||||
|
currentFunnel.Close()
|
||||||
|
}
|
||||||
|
ft.funnels[id] = funnel
|
||||||
|
return replaced
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregisters a funnel if the funnel equals to the current funnel
|
||||||
|
func (ft *FunnelTracker) Unregister(id FunnelID, funnel Funnel) (deleted bool) {
|
||||||
|
ft.lock.Lock()
|
||||||
|
defer ft.lock.Unlock()
|
||||||
|
currentFunnel, exists := ft.funnels[id]
|
||||||
|
if !exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if currentFunnel.Equal(funnel) {
|
||||||
|
currentFunnel.Close()
|
||||||
|
delete(ft.funnels, id)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
Loading…
Reference in New Issue