TUN-6744: On posix platforms, assign unique echo ID per (src, dst, echo ID)

This also refactor FunnelTracker to provide a GetOrRegister method to prevent race condition
This commit is contained in:
cthuang 2022-09-09 16:48:42 +01:00
parent e454994e3e
commit b639b6627a
7 changed files with 268 additions and 138 deletions

View File

@ -37,9 +37,9 @@ type icmpProxy struct {
// then from the beginning to lastAssignment. // then from the beginning to lastAssignment.
// ICMP echo are short lived. By the time an ID is revisited, it should have been released. // ICMP echo are short lived. By the time an ID is revisited, it should have been released.
type echoIDTracker struct { type echoIDTracker struct {
lock sync.RWMutex lock sync.Mutex
// maps the source IP to an echo ID obtained from assignment // maps the source IP, destination IP and original echo ID to a unique echo ID obtained from assignment
srcIPMapping map[netip.Addr]uint16 mapping map[flow3Tuple]uint16
// assignment tracks if an ID is assigned using index as the ID // assignment tracks if an ID is assigned using index as the ID
// The size of the array is math.MaxUint16 because echo ID is 2 bytes // The size of the array is math.MaxUint16 because echo ID is 2 bytes
assignment [math.MaxUint16]bool assignment [math.MaxUint16]bool
@ -49,20 +49,18 @@ type echoIDTracker struct {
func newEchoIDTracker() *echoIDTracker { func newEchoIDTracker() *echoIDTracker {
return &echoIDTracker{ return &echoIDTracker{
srcIPMapping: make(map[netip.Addr]uint16), mapping: make(map[flow3Tuple]uint16),
} }
} }
func (eit *echoIDTracker) get(srcIP netip.Addr) (uint16, bool) { // Get assignment or assign a new ID.
eit.lock.RLock() func (eit *echoIDTracker) getOrAssign(key flow3Tuple) (id uint16, success bool) {
defer eit.lock.RUnlock()
id, ok := eit.srcIPMapping[srcIP]
return id, ok
}
func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) {
eit.lock.Lock() eit.lock.Lock()
defer eit.lock.Unlock() defer eit.lock.Unlock()
id, exists := eit.mapping[key]
if exists {
return id, true
}
if eit.nextAssignment == math.MaxUint16 { if eit.nextAssignment == math.MaxUint16 {
eit.nextAssignment = 0 eit.nextAssignment = 0
@ -71,14 +69,14 @@ func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) {
for i, assigned := range eit.assignment[eit.nextAssignment:] { for i, assigned := range eit.assignment[eit.nextAssignment:] {
if !assigned { if !assigned {
echoID := uint16(i) + eit.nextAssignment echoID := uint16(i) + eit.nextAssignment
eit.set(srcIP, echoID) eit.set(key, echoID)
return echoID, true return echoID, true
} }
} }
for i, assigned := range eit.assignment[0:eit.nextAssignment] { for i, assigned := range eit.assignment[0:eit.nextAssignment] {
if !assigned { if !assigned {
echoID := uint16(i) echoID := uint16(i)
eit.set(srcIP, echoID) eit.set(key, echoID)
return echoID, true return echoID, true
} }
} }
@ -86,20 +84,20 @@ func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) {
} }
// Caller should hold the lock // Caller should hold the lock
func (eit *echoIDTracker) set(srcIP netip.Addr, echoID uint16) { func (eit *echoIDTracker) set(key flow3Tuple, assignedEchoID uint16) {
eit.assignment[echoID] = true eit.assignment[assignedEchoID] = true
eit.srcIPMapping[srcIP] = echoID eit.mapping[key] = assignedEchoID
eit.nextAssignment = echoID + 1 eit.nextAssignment = assignedEchoID + 1
} }
func (eit *echoIDTracker) release(srcIP netip.Addr, id uint16) bool { func (eit *echoIDTracker) release(key flow3Tuple, assigned uint16) bool {
eit.lock.Lock() eit.lock.Lock()
defer eit.lock.Unlock() defer eit.lock.Unlock()
currentID, exists := eit.srcIPMapping[srcIP] currentEchoID, exists := eit.mapping[key]
if exists && id == currentID { if exists && assigned == currentEchoID {
delete(eit.srcIPMapping, srcIP) delete(eit.mapping, key)
eit.assignment[id] = false eit.assignment[assigned] = false
return true return true
} }
return false return false
@ -134,33 +132,46 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er
if pk == nil { if pk == nil {
return errPacketNil return errPacketNil
} }
originalEcho, err := getICMPEcho(pk.Message)
if err != nil {
return err
}
echoIDTrackerKey := flow3Tuple{
srcIP: pk.Src,
dstIP: pk.Dst,
originalEchoID: originalEcho.ID,
}
// TODO: TUN-6744 assign unique flow per (src, echo ID) // TODO: TUN-6744 assign unique flow per (src, echo ID)
echoID, exists := ip.echoIDTracker.get(pk.Src) assignedEchoID, success := ip.echoIDTracker.getOrAssign(echoIDTrackerKey)
if !exists { if !success {
return fmt.Errorf("failed to assign unique echo ID")
}
newFunnelFunc := func() (packet.Funnel, error) {
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
return err return nil, err
} }
echoID, exists = ip.echoIDTracker.assign(pk.Src)
if !exists {
return fmt.Errorf("failed to assign unique echo ID")
}
funnelID := echoFunnelID(echoID)
originSender := originSender{ originSender := originSender{
conn: ip.conn, conn: ip.conn,
echoIDTracker: ip.echoIDTracker, echoIDTracker: ip.echoIDTracker,
srcIP: pk.Src, echoIDTrackerKey: echoIDTrackerKey,
echoID: echoID, assignedEchoID: assignedEchoID,
} }
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(echoID), originalEcho.ID, ip.encoder) icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(assignedEchoID), originalEcho.ID, ip.encoder)
if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced { return icmpFlow, nil
ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel")
}
return icmpFlow.sendToDst(pk.Dst, pk.Message)
} }
funnel, exists := ip.srcFunnelTracker.Get(echoFunnelID(echoID)) funnelID := echoFunnelID(assignedEchoID)
if !exists { funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc)
return packet.ErrFunnelNotFound if err != nil {
return err
}
if isNew {
ip.logger.Debug().
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Int("assignedEchoID", int(assignedEchoID)).
Msg("New flow")
} }
icmpFlow, err := toICMPEchoFlow(funnel) icmpFlow, err := toICMPEchoFlow(funnel)
if err != nil { if err != nil {
@ -199,7 +210,7 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type) ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type)
continue continue
} }
if ip.sendReply(reply); err != nil { if err := ip.sendReply(reply); err != nil {
ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply") ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply")
continue continue
} }
@ -227,7 +238,8 @@ func (ip *icmpProxy) handleFullPacket(decoder *packet.ICMPDecoder, rawPacket []b
} }
func (ip *icmpProxy) sendReply(reply *echoReply) error { func (ip *icmpProxy) sendReply(reply *echoReply) error {
funnel, ok := ip.srcFunnelTracker.Get(echoFunnelID(reply.echo.ID)) funnelID := echoFunnelID(reply.echo.ID)
funnel, ok := ip.srcFunnelTracker.Get(funnelID)
if !ok { if !ok {
return packet.ErrFunnelNotFound return packet.ErrFunnelNotFound
} }
@ -240,10 +252,10 @@ func (ip *icmpProxy) sendReply(reply *echoReply) error {
// originSender wraps icmp.PacketConn to implement packet.FunnelUniPipe interface // originSender wraps icmp.PacketConn to implement packet.FunnelUniPipe interface
type originSender struct { type originSender struct {
conn *icmp.PacketConn conn *icmp.PacketConn
echoIDTracker *echoIDTracker echoIDTracker *echoIDTracker
srcIP netip.Addr echoIDTrackerKey flow3Tuple
echoID uint16 assignedEchoID uint16
} }
func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error { func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
@ -254,6 +266,6 @@ func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
} }
func (os *originSender) Close() error { func (os *originSender) Close() error {
os.echoIDTracker.release(os.srcIP, os.echoID) os.echoIDTracker.release(os.echoIDTrackerKey, os.assignedEchoID)
return nil return nil
} }

View File

@ -12,80 +12,110 @@ import (
func TestSingleEchoIDTracker(t *testing.T) { func TestSingleEchoIDTracker(t *testing.T) {
tracker := newEchoIDTracker() tracker := newEchoIDTracker()
srcIP := netip.MustParseAddr("127.0.0.1") key := flow3Tuple{
echoID, ok := tracker.get(srcIP) srcIP: netip.MustParseAddr("172.16.0.1"),
require.False(t, ok) dstIP: netip.MustParseAddr("172.16.0.2"),
require.Equal(t, uint16(0), echoID) originalEchoID: 5182,
}
// not assigned yet, so nothing to release // not assigned yet, so nothing to release
require.False(t, tracker.release(srcIP, echoID)) require.False(t, tracker.release(key, 0))
echoID, ok = tracker.assign(srcIP) echoID, ok := tracker.getOrAssign(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, uint16(0), echoID) require.Equal(t, uint16(0), echoID)
echoID, ok = tracker.get(srcIP) // Second time should return the same echo ID
echoID, ok = tracker.getOrAssign(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, uint16(0), echoID) require.Equal(t, uint16(0), echoID)
// releasing a different ID returns false // releasing a different ID returns false
require.False(t, tracker.release(srcIP, 1999)) require.False(t, tracker.release(key, 1999))
require.True(t, tracker.release(srcIP, echoID)) require.True(t, tracker.release(key, echoID))
// releasing the second time returns false // releasing the second time returns false
require.False(t, tracker.release(srcIP, echoID)) require.False(t, tracker.release(key, echoID))
echoID, ok = tracker.get(srcIP)
require.False(t, ok)
require.Equal(t, uint16(0), echoID)
// Move to the next IP // Move to the next IP
echoID, ok = tracker.assign(srcIP) echoID, ok = tracker.getOrAssign(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, uint16(1), echoID) require.Equal(t, uint16(1), echoID)
} }
func TestFullEchoIDTracker(t *testing.T) { func TestFullEchoIDTracker(t *testing.T) {
var (
dstIP = netip.MustParseAddr("192.168.0.1")
originalEchoID = 41820
)
tracker := newEchoIDTracker() tracker := newEchoIDTracker()
firstIP := netip.MustParseAddr("172.16.0.1") firstSrcIP := netip.MustParseAddr("172.16.0.1")
srcIP := firstIP srcIP := firstSrcIP
for i := uint16(0); i < math.MaxUint16; i++ { for i := uint16(0); i < math.MaxUint16; i++ {
echoID, ok := tracker.assign(srcIP) key := flow3Tuple{
srcIP: srcIP,
dstIP: dstIP,
originalEchoID: originalEchoID,
}
echoID, ok := tracker.getOrAssign(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, i, echoID) require.Equal(t, i, echoID)
echoID, ok = tracker.get(srcIP) echoID, ok = tracker.get(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, i, echoID) require.Equal(t, i, echoID)
srcIP = srcIP.Next() srcIP = srcIP.Next()
} }
key := flow3Tuple{
srcIP: srcIP.Next(),
dstIP: dstIP,
originalEchoID: originalEchoID,
}
// All echo IDs are assigned // All echo IDs are assigned
echoID, ok := tracker.assign(srcIP.Next()) echoID, ok := tracker.getOrAssign(key)
require.False(t, ok) require.False(t, ok)
require.Equal(t, uint16(0), echoID) require.Equal(t, uint16(0), echoID)
srcIP = firstIP srcIP = firstSrcIP
for i := uint16(0); i < math.MaxUint16; i++ { for i := uint16(0); i < math.MaxUint16; i++ {
ok := tracker.release(srcIP, i) key := flow3Tuple{
srcIP: srcIP,
dstIP: dstIP,
originalEchoID: originalEchoID,
}
ok := tracker.release(key, i)
require.True(t, ok) require.True(t, ok)
echoID, ok = tracker.get(srcIP) echoID, ok = tracker.get(key)
require.False(t, ok) require.False(t, ok)
require.Equal(t, uint16(0), echoID) require.Equal(t, uint16(0), echoID)
srcIP = srcIP.Next() srcIP = srcIP.Next()
} }
// The IDs are assignable again // The IDs are assignable again
srcIP = firstIP srcIP = firstSrcIP
for i := uint16(0); i < math.MaxUint16; i++ { for i := uint16(0); i < math.MaxUint16; i++ {
echoID, ok := tracker.assign(srcIP) key := flow3Tuple{
srcIP: srcIP,
dstIP: dstIP,
originalEchoID: originalEchoID,
}
echoID, ok := tracker.getOrAssign(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, i, echoID) require.Equal(t, i, echoID)
echoID, ok = tracker.get(srcIP) echoID, ok = tracker.get(key)
require.True(t, ok) require.True(t, ok)
require.Equal(t, i, echoID) require.Equal(t, i, echoID)
srcIP = srcIP.Next() srcIP = srcIP.Next()
} }
} }
func (eit *echoIDTracker) get(key flow3Tuple) (id uint16, exist bool) {
eit.lock.Lock()
defer eit.lock.Unlock()
id, exists := eit.mapping[key]
return id, exists
}

View File

@ -57,45 +57,57 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er
if pk == nil { if pk == nil {
return errPacketNil return errPacketNil
} }
funnelID := srcIPFunnelID(pk.Src) originalEcho, err := getICMPEcho(pk.Message)
funnel, exists := ip.srcFunnelTracker.Get(funnelID) if err != nil {
if !exists { return err
originalEcho, err := getICMPEcho(pk.Message) }
if err != nil { newConnChan := make(chan *icmp.PacketConn, 1)
return err newFunnelFunc := func() (packet.Funnel, error) {
}
conn, err := newICMPConn(ip.listenIP) conn, err := newICMPConn(ip.listenIP)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to open ICMP socket") return nil, errors.Wrap(err, "failed to open ICMP socket")
} }
newConnChan <- conn
localUDPAddr, ok := conn.LocalAddr().(*net.UDPAddr) localUDPAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok { if !ok {
return fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr()) return nil, fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr())
} }
originSender := originSender{conn: conn} originSender := originSender{conn: conn}
echoID := localUDPAddr.Port echoID := localUDPAddr.Port
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, echoID, originalEcho.ID, packet.NewEncoder()) icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, echoID, originalEcho.ID, packet.NewEncoder())
if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced { return icmpFlow, nil
ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel") }
} funnelID := flow3Tuple{
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil { srcIP: pk.Src,
return errors.Wrap(err, "failed to send ICMP echo request") dstIP: pk.Dst,
} originalEchoID: originalEcho.ID,
go func() { }
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow) funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc)
if err := ip.listenResponse(icmpFlow, conn); err != nil { if err != nil {
ip.logger.Err(err). return err
Str("funnelID", funnelID.String()).
Int("echoID", echoID).
Msg("Failed to listen for ICMP echo response")
}
}()
return nil
} }
icmpFlow, err := toICMPEchoFlow(funnel) icmpFlow, err := toICMPEchoFlow(funnel)
if err != nil { if err != nil {
return err return err
} }
if isNew {
ip.logger.Debug().
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Msg("New flow")
conn := <-newConnChan
go func() {
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
if err := ip.listenResponse(icmpFlow, conn); err != nil {
ip.logger.Debug().Err(err).
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Msg("Failed to listen for ICMP echo response")
}
}()
}
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil { if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
return errors.Wrap(err, "failed to send ICMP echo request") return errors.Wrap(err, "failed to send ICMP echo request")
} }
@ -146,12 +158,11 @@ func (os *originSender) Close() error {
return os.conn.Close() return os.conn.Close()
} }
type srcIPFunnelID netip.Addr // Only linux uses flow3Tuple as FunnelID
func (ft flow3Tuple) Type() string {
func (sifd srcIPFunnelID) Type() string { return "srcIP_dstIP_echoID"
return "srcIP"
} }
func (sifd srcIPFunnelID) String() string { func (ft flow3Tuple) String() string {
return netip.Addr(sifd).String() return fmt.Sprintf("%s:%s:%d", ft.srcIP, ft.dstIP, ft.originalEchoID)
} }

View File

@ -32,17 +32,10 @@ func netipAddr(addr net.Addr) (netip.Addr, bool) {
return netip.AddrFromSlice(udpAddr.IP) return netip.AddrFromSlice(udpAddr.IP)
} }
type flowID struct { type flow3Tuple struct {
srcIP netip.Addr srcIP netip.Addr
echoID int dstIP netip.Addr
} originalEchoID int
func (fi *flowID) Type() string {
return "srcIP_echoID"
}
func (fi *flowID) String() string {
return fmt.Sprintf("%s:%d", fi.srcIP, fi.echoID)
} }
// icmpEchoFlow implements the packet.Funnel interface. // icmpEchoFlow implements the packet.Funnel interface.

View File

@ -315,23 +315,22 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, d
}, },
} }
serializedPacket, err := ip.encodeICMPReply(&pk) cachedEncoder := ip.encoderPool.Get()
// The encoded packet is a slice to of the encoder, so we shouldn't return the encoder back to the pool until
// the encoded packet is sent.
defer ip.encoderPool.Put(cachedEncoder)
encoder, ok := cachedEncoder.(*packet.Encoder)
if !ok {
return fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder)
}
serializedPacket, err := encoder.Encode(&pk)
if err != nil { if err != nil {
return err return err
} }
return responder.SendPacket(request.Src, serializedPacket) return responder.SendPacket(request.Src, serializedPacket)
} }
func (ip *icmpProxy) encodeICMPReply(pk *packet.ICMP) (packet.RawPacket, error) {
cachedEncoder := ip.encoderPool.Get()
defer ip.encoderPool.Put(cachedEncoder)
encoder, ok := cachedEncoder.(*packet.Encoder)
if !ok {
return packet.RawPacket{}, fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder)
}
return encoder.Encode(pk)
}
func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) ([]byte, error) { func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) ([]byte, error) {
if dst.Is6() { if dst.Is6() {
if ip.srcSocketAddr == nil { if ip.srcSocketAddr == nil {

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@ -97,6 +98,91 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
<-proxyDone <-proxyDone
} }
// TestConcurrentRequests makes sure icmpRouter can send concurrent requests to the same destination with different
// echo ID. This simulates concurrent ping to the same destination.
func TestConcurrentRequestsToSameDst(t *testing.T) {
const (
concurrentPings = 5
endSeq = 5
)
router, err := NewICMPRouter(&noopLogger)
require.NoError(t, err)
proxyDone := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
router.Serve(ctx)
close(proxyDone)
}()
var wg sync.WaitGroup
// icmpv4 and icmpv6 each has concurrentPings
wg.Add(concurrentPings * 2)
for i := 0; i < concurrentPings; i++ {
echoID := 38451 + i
go func() {
defer wg.Done()
responder := echoFlowResponder{
decoder: packet.NewICMPDecoder(),
respChan: make(chan []byte, 1),
}
for seq := 0; seq < endSeq; seq++ {
pk := &packet.ICMP{
IP: &packet.IP{
Src: localhostIP,
Dst: localhostIP,
Protocol: layers.IPProtocolICMPv4,
TTL: packet.DefaultTTL,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: echoID,
Seq: seq,
Data: []byte(fmt.Sprintf("icmpv4 echo id %d, seq %d", echoID, seq)),
},
},
}
require.NoError(t, router.Request(pk, &responder))
responder.validate(t, pk)
}
}()
go func() {
defer wg.Done()
responder := echoFlowResponder{
decoder: packet.NewICMPDecoder(),
respChan: make(chan []byte, 1),
}
for seq := 0; seq < endSeq; seq++ {
pk := &packet.ICMP{
IP: &packet.IP{
Src: localhostIPv6,
Dst: localhostIPv6,
Protocol: layers.IPProtocolICMPv6,
TTL: packet.DefaultTTL,
},
Message: &icmp.Message{
Type: ipv6.ICMPTypeEchoRequest,
Code: 0,
Body: &icmp.Echo{
ID: echoID,
Seq: seq,
Data: []byte(fmt.Sprintf("icmpv6 echo id %d, seq %d", echoID, seq)),
},
},
}
require.NoError(t, router.Request(pk, &responder))
responder.validate(t, pk)
}
}()
}
wg.Wait()
cancel()
<-proxyDone
}
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo // TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
func TestICMPRouterRejectNotEcho(t *testing.T) { func TestICMPRouterRejectNotEcho(t *testing.T) {
msgs := []icmp.Message{ msgs := []icmp.Message{

View File

@ -158,20 +158,19 @@ func (ft *FunnelTracker) Get(id FunnelID) (Funnel, bool) {
} }
// Registers a funnel. It replaces the current funnel. // Registers a funnel. It replaces the current funnel.
func (ft *FunnelTracker) Register(id FunnelID, funnel Funnel) (replaced bool) { func (ft *FunnelTracker) GetOrRegister(id FunnelID, newFunnelFunc func() (Funnel, error)) (funnel Funnel, new bool, err error) {
ft.lock.Lock() ft.lock.Lock()
defer ft.lock.Unlock() defer ft.lock.Unlock()
currentFunnel, exists := ft.funnels[id] currentFunnel, exists := ft.funnels[id]
if !exists { if exists {
ft.funnels[id] = funnel return currentFunnel, false, nil
return false
} }
replaced = !currentFunnel.Equal(funnel) newFunnel, err := newFunnelFunc()
if replaced { if err != nil {
currentFunnel.Close() return nil, false, err
} }
ft.funnels[id] = funnel ft.funnels[id] = newFunnel
return replaced return newFunnel, true, nil
} }
// Unregisters a funnel if the funnel equals to the current funnel // Unregisters a funnel if the funnel equals to the current funnel