TUN-6744: On posix platforms, assign unique echo ID per (src, dst, echo ID)

This also refactor FunnelTracker to provide a GetOrRegister method to prevent race condition
This commit is contained in:
cthuang 2022-09-09 16:48:42 +01:00
parent e454994e3e
commit b639b6627a
7 changed files with 268 additions and 138 deletions

View File

@ -37,9 +37,9 @@ type icmpProxy struct {
// then from the beginning to lastAssignment.
// ICMP echo are short lived. By the time an ID is revisited, it should have been released.
type echoIDTracker struct {
lock sync.RWMutex
// maps the source IP to an echo ID obtained from assignment
srcIPMapping map[netip.Addr]uint16
lock sync.Mutex
// maps the source IP, destination IP and original echo ID to a unique echo ID obtained from assignment
mapping map[flow3Tuple]uint16
// assignment tracks if an ID is assigned using index as the ID
// The size of the array is math.MaxUint16 because echo ID is 2 bytes
assignment [math.MaxUint16]bool
@ -49,20 +49,18 @@ type echoIDTracker struct {
func newEchoIDTracker() *echoIDTracker {
return &echoIDTracker{
srcIPMapping: make(map[netip.Addr]uint16),
mapping: make(map[flow3Tuple]uint16),
}
}
func (eit *echoIDTracker) get(srcIP netip.Addr) (uint16, bool) {
eit.lock.RLock()
defer eit.lock.RUnlock()
id, ok := eit.srcIPMapping[srcIP]
return id, ok
}
func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) {
// Get assignment or assign a new ID.
func (eit *echoIDTracker) getOrAssign(key flow3Tuple) (id uint16, success bool) {
eit.lock.Lock()
defer eit.lock.Unlock()
id, exists := eit.mapping[key]
if exists {
return id, true
}
if eit.nextAssignment == math.MaxUint16 {
eit.nextAssignment = 0
@ -71,14 +69,14 @@ func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) {
for i, assigned := range eit.assignment[eit.nextAssignment:] {
if !assigned {
echoID := uint16(i) + eit.nextAssignment
eit.set(srcIP, echoID)
eit.set(key, echoID)
return echoID, true
}
}
for i, assigned := range eit.assignment[0:eit.nextAssignment] {
if !assigned {
echoID := uint16(i)
eit.set(srcIP, echoID)
eit.set(key, echoID)
return echoID, true
}
}
@ -86,20 +84,20 @@ func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) {
}
// Caller should hold the lock
func (eit *echoIDTracker) set(srcIP netip.Addr, echoID uint16) {
eit.assignment[echoID] = true
eit.srcIPMapping[srcIP] = echoID
eit.nextAssignment = echoID + 1
func (eit *echoIDTracker) set(key flow3Tuple, assignedEchoID uint16) {
eit.assignment[assignedEchoID] = true
eit.mapping[key] = assignedEchoID
eit.nextAssignment = assignedEchoID + 1
}
func (eit *echoIDTracker) release(srcIP netip.Addr, id uint16) bool {
func (eit *echoIDTracker) release(key flow3Tuple, assigned uint16) bool {
eit.lock.Lock()
defer eit.lock.Unlock()
currentID, exists := eit.srcIPMapping[srcIP]
if exists && id == currentID {
delete(eit.srcIPMapping, srcIP)
eit.assignment[id] = false
currentEchoID, exists := eit.mapping[key]
if exists && assigned == currentEchoID {
delete(eit.mapping, key)
eit.assignment[assigned] = false
return true
}
return false
@ -134,33 +132,46 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er
if pk == nil {
return errPacketNil
}
// TODO: TUN-6744 assign unique flow per (src, echo ID)
echoID, exists := ip.echoIDTracker.get(pk.Src)
if !exists {
originalEcho, err := getICMPEcho(pk.Message)
if err != nil {
return err
}
echoID, exists = ip.echoIDTracker.assign(pk.Src)
if !exists {
echoIDTrackerKey := flow3Tuple{
srcIP: pk.Src,
dstIP: pk.Dst,
originalEchoID: originalEcho.ID,
}
// TODO: TUN-6744 assign unique flow per (src, echo ID)
assignedEchoID, success := ip.echoIDTracker.getOrAssign(echoIDTrackerKey)
if !success {
return fmt.Errorf("failed to assign unique echo ID")
}
funnelID := echoFunnelID(echoID)
newFunnelFunc := func() (packet.Funnel, error) {
originalEcho, err := getICMPEcho(pk.Message)
if err != nil {
return nil, err
}
originSender := originSender{
conn: ip.conn,
echoIDTracker: ip.echoIDTracker,
srcIP: pk.Src,
echoID: echoID,
echoIDTrackerKey: echoIDTrackerKey,
assignedEchoID: assignedEchoID,
}
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(echoID), originalEcho.ID, ip.encoder)
if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced {
ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel")
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(assignedEchoID), originalEcho.ID, ip.encoder)
return icmpFlow, nil
}
return icmpFlow.sendToDst(pk.Dst, pk.Message)
funnelID := echoFunnelID(assignedEchoID)
funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc)
if err != nil {
return err
}
funnel, exists := ip.srcFunnelTracker.Get(echoFunnelID(echoID))
if !exists {
return packet.ErrFunnelNotFound
if isNew {
ip.logger.Debug().
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Int("assignedEchoID", int(assignedEchoID)).
Msg("New flow")
}
icmpFlow, err := toICMPEchoFlow(funnel)
if err != nil {
@ -199,7 +210,7 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type)
continue
}
if ip.sendReply(reply); err != nil {
if err := ip.sendReply(reply); err != nil {
ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply")
continue
}
@ -227,7 +238,8 @@ func (ip *icmpProxy) handleFullPacket(decoder *packet.ICMPDecoder, rawPacket []b
}
func (ip *icmpProxy) sendReply(reply *echoReply) error {
funnel, ok := ip.srcFunnelTracker.Get(echoFunnelID(reply.echo.ID))
funnelID := echoFunnelID(reply.echo.ID)
funnel, ok := ip.srcFunnelTracker.Get(funnelID)
if !ok {
return packet.ErrFunnelNotFound
}
@ -242,8 +254,8 @@ func (ip *icmpProxy) sendReply(reply *echoReply) error {
type originSender struct {
conn *icmp.PacketConn
echoIDTracker *echoIDTracker
srcIP netip.Addr
echoID uint16
echoIDTrackerKey flow3Tuple
assignedEchoID uint16
}
func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
@ -254,6 +266,6 @@ func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
}
func (os *originSender) Close() error {
os.echoIDTracker.release(os.srcIP, os.echoID)
os.echoIDTracker.release(os.echoIDTrackerKey, os.assignedEchoID)
return nil
}

View File

@ -12,80 +12,110 @@ import (
func TestSingleEchoIDTracker(t *testing.T) {
tracker := newEchoIDTracker()
srcIP := netip.MustParseAddr("127.0.0.1")
echoID, ok := tracker.get(srcIP)
require.False(t, ok)
require.Equal(t, uint16(0), echoID)
key := flow3Tuple{
srcIP: netip.MustParseAddr("172.16.0.1"),
dstIP: netip.MustParseAddr("172.16.0.2"),
originalEchoID: 5182,
}
// not assigned yet, so nothing to release
require.False(t, tracker.release(srcIP, echoID))
require.False(t, tracker.release(key, 0))
echoID, ok = tracker.assign(srcIP)
echoID, ok := tracker.getOrAssign(key)
require.True(t, ok)
require.Equal(t, uint16(0), echoID)
echoID, ok = tracker.get(srcIP)
// Second time should return the same echo ID
echoID, ok = tracker.getOrAssign(key)
require.True(t, ok)
require.Equal(t, uint16(0), echoID)
// releasing a different ID returns false
require.False(t, tracker.release(srcIP, 1999))
require.True(t, tracker.release(srcIP, echoID))
require.False(t, tracker.release(key, 1999))
require.True(t, tracker.release(key, echoID))
// releasing the second time returns false
require.False(t, tracker.release(srcIP, echoID))
echoID, ok = tracker.get(srcIP)
require.False(t, ok)
require.Equal(t, uint16(0), echoID)
require.False(t, tracker.release(key, echoID))
// Move to the next IP
echoID, ok = tracker.assign(srcIP)
echoID, ok = tracker.getOrAssign(key)
require.True(t, ok)
require.Equal(t, uint16(1), echoID)
}
func TestFullEchoIDTracker(t *testing.T) {
var (
dstIP = netip.MustParseAddr("192.168.0.1")
originalEchoID = 41820
)
tracker := newEchoIDTracker()
firstIP := netip.MustParseAddr("172.16.0.1")
srcIP := firstIP
firstSrcIP := netip.MustParseAddr("172.16.0.1")
srcIP := firstSrcIP
for i := uint16(0); i < math.MaxUint16; i++ {
echoID, ok := tracker.assign(srcIP)
key := flow3Tuple{
srcIP: srcIP,
dstIP: dstIP,
originalEchoID: originalEchoID,
}
echoID, ok := tracker.getOrAssign(key)
require.True(t, ok)
require.Equal(t, i, echoID)
echoID, ok = tracker.get(srcIP)
echoID, ok = tracker.get(key)
require.True(t, ok)
require.Equal(t, i, echoID)
srcIP = srcIP.Next()
}
key := flow3Tuple{
srcIP: srcIP.Next(),
dstIP: dstIP,
originalEchoID: originalEchoID,
}
// All echo IDs are assigned
echoID, ok := tracker.assign(srcIP.Next())
echoID, ok := tracker.getOrAssign(key)
require.False(t, ok)
require.Equal(t, uint16(0), echoID)
srcIP = firstIP
srcIP = firstSrcIP
for i := uint16(0); i < math.MaxUint16; i++ {
ok := tracker.release(srcIP, i)
key := flow3Tuple{
srcIP: srcIP,
dstIP: dstIP,
originalEchoID: originalEchoID,
}
ok := tracker.release(key, i)
require.True(t, ok)
echoID, ok = tracker.get(srcIP)
echoID, ok = tracker.get(key)
require.False(t, ok)
require.Equal(t, uint16(0), echoID)
srcIP = srcIP.Next()
}
// The IDs are assignable again
srcIP = firstIP
srcIP = firstSrcIP
for i := uint16(0); i < math.MaxUint16; i++ {
echoID, ok := tracker.assign(srcIP)
key := flow3Tuple{
srcIP: srcIP,
dstIP: dstIP,
originalEchoID: originalEchoID,
}
echoID, ok := tracker.getOrAssign(key)
require.True(t, ok)
require.Equal(t, i, echoID)
echoID, ok = tracker.get(srcIP)
echoID, ok = tracker.get(key)
require.True(t, ok)
require.Equal(t, i, echoID)
srcIP = srcIP.Next()
}
}
func (eit *echoIDTracker) get(key flow3Tuple) (id uint16, exist bool) {
eit.lock.Lock()
defer eit.lock.Unlock()
id, exists := eit.mapping[key]
return id, exists
}

View File

@ -57,45 +57,57 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er
if pk == nil {
return errPacketNil
}
funnelID := srcIPFunnelID(pk.Src)
funnel, exists := ip.srcFunnelTracker.Get(funnelID)
if !exists {
originalEcho, err := getICMPEcho(pk.Message)
if err != nil {
return err
}
newConnChan := make(chan *icmp.PacketConn, 1)
newFunnelFunc := func() (packet.Funnel, error) {
conn, err := newICMPConn(ip.listenIP)
if err != nil {
return errors.Wrap(err, "failed to open ICMP socket")
return nil, errors.Wrap(err, "failed to open ICMP socket")
}
newConnChan <- conn
localUDPAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
return fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr())
return nil, fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr())
}
originSender := originSender{conn: conn}
echoID := localUDPAddr.Port
icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, echoID, originalEcho.ID, packet.NewEncoder())
if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced {
ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel")
return icmpFlow, nil
}
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
return errors.Wrap(err, "failed to send ICMP echo request")
funnelID := flow3Tuple{
srcIP: pk.Src,
dstIP: pk.Dst,
originalEchoID: originalEcho.ID,
}
go func() {
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
if err := ip.listenResponse(icmpFlow, conn); err != nil {
ip.logger.Err(err).
Str("funnelID", funnelID.String()).
Int("echoID", echoID).
Msg("Failed to listen for ICMP echo response")
}
}()
return nil
funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc)
if err != nil {
return err
}
icmpFlow, err := toICMPEchoFlow(funnel)
if err != nil {
return err
}
if isNew {
ip.logger.Debug().
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Msg("New flow")
conn := <-newConnChan
go func() {
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
if err := ip.listenResponse(icmpFlow, conn); err != nil {
ip.logger.Debug().Err(err).
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Msg("Failed to listen for ICMP echo response")
}
}()
}
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
return errors.Wrap(err, "failed to send ICMP echo request")
}
@ -146,12 +158,11 @@ func (os *originSender) Close() error {
return os.conn.Close()
}
type srcIPFunnelID netip.Addr
func (sifd srcIPFunnelID) Type() string {
return "srcIP"
// Only linux uses flow3Tuple as FunnelID
func (ft flow3Tuple) Type() string {
return "srcIP_dstIP_echoID"
}
func (sifd srcIPFunnelID) String() string {
return netip.Addr(sifd).String()
func (ft flow3Tuple) String() string {
return fmt.Sprintf("%s:%s:%d", ft.srcIP, ft.dstIP, ft.originalEchoID)
}

View File

@ -32,17 +32,10 @@ func netipAddr(addr net.Addr) (netip.Addr, bool) {
return netip.AddrFromSlice(udpAddr.IP)
}
type flowID struct {
type flow3Tuple struct {
srcIP netip.Addr
echoID int
}
func (fi *flowID) Type() string {
return "srcIP_echoID"
}
func (fi *flowID) String() string {
return fmt.Sprintf("%s:%d", fi.srcIP, fi.echoID)
dstIP netip.Addr
originalEchoID int
}
// icmpEchoFlow implements the packet.Funnel interface.

View File

@ -315,23 +315,22 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, d
},
}
serializedPacket, err := ip.encodeICMPReply(&pk)
cachedEncoder := ip.encoderPool.Get()
// The encoded packet is a slice to of the encoder, so we shouldn't return the encoder back to the pool until
// the encoded packet is sent.
defer ip.encoderPool.Put(cachedEncoder)
encoder, ok := cachedEncoder.(*packet.Encoder)
if !ok {
return fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder)
}
serializedPacket, err := encoder.Encode(&pk)
if err != nil {
return err
}
return responder.SendPacket(request.Src, serializedPacket)
}
func (ip *icmpProxy) encodeICMPReply(pk *packet.ICMP) (packet.RawPacket, error) {
cachedEncoder := ip.encoderPool.Get()
defer ip.encoderPool.Put(cachedEncoder)
encoder, ok := cachedEncoder.(*packet.Encoder)
if !ok {
return packet.RawPacket{}, fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder)
}
return encoder.Encode(pk)
}
func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) ([]byte, error) {
if dst.Is6() {
if ip.srcSocketAddr == nil {

View File

@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync"
"testing"
"github.com/google/gopacket/layers"
@ -97,6 +98,91 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
<-proxyDone
}
// TestConcurrentRequests makes sure icmpRouter can send concurrent requests to the same destination with different
// echo ID. This simulates concurrent ping to the same destination.
func TestConcurrentRequestsToSameDst(t *testing.T) {
const (
concurrentPings = 5
endSeq = 5
)
router, err := NewICMPRouter(&noopLogger)
require.NoError(t, err)
proxyDone := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
router.Serve(ctx)
close(proxyDone)
}()
var wg sync.WaitGroup
// icmpv4 and icmpv6 each has concurrentPings
wg.Add(concurrentPings * 2)
for i := 0; i < concurrentPings; i++ {
echoID := 38451 + i
go func() {
defer wg.Done()
responder := echoFlowResponder{
decoder: packet.NewICMPDecoder(),
respChan: make(chan []byte, 1),
}
for seq := 0; seq < endSeq; seq++ {
pk := &packet.ICMP{
IP: &packet.IP{
Src: localhostIP,
Dst: localhostIP,
Protocol: layers.IPProtocolICMPv4,
TTL: packet.DefaultTTL,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: echoID,
Seq: seq,
Data: []byte(fmt.Sprintf("icmpv4 echo id %d, seq %d", echoID, seq)),
},
},
}
require.NoError(t, router.Request(pk, &responder))
responder.validate(t, pk)
}
}()
go func() {
defer wg.Done()
responder := echoFlowResponder{
decoder: packet.NewICMPDecoder(),
respChan: make(chan []byte, 1),
}
for seq := 0; seq < endSeq; seq++ {
pk := &packet.ICMP{
IP: &packet.IP{
Src: localhostIPv6,
Dst: localhostIPv6,
Protocol: layers.IPProtocolICMPv6,
TTL: packet.DefaultTTL,
},
Message: &icmp.Message{
Type: ipv6.ICMPTypeEchoRequest,
Code: 0,
Body: &icmp.Echo{
ID: echoID,
Seq: seq,
Data: []byte(fmt.Sprintf("icmpv6 echo id %d, seq %d", echoID, seq)),
},
},
}
require.NoError(t, router.Request(pk, &responder))
responder.validate(t, pk)
}
}()
}
wg.Wait()
cancel()
<-proxyDone
}
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
func TestICMPRouterRejectNotEcho(t *testing.T) {
msgs := []icmp.Message{

View File

@ -158,20 +158,19 @@ func (ft *FunnelTracker) Get(id FunnelID) (Funnel, bool) {
}
// Registers a funnel. It replaces the current funnel.
func (ft *FunnelTracker) Register(id FunnelID, funnel Funnel) (replaced bool) {
func (ft *FunnelTracker) GetOrRegister(id FunnelID, newFunnelFunc func() (Funnel, error)) (funnel Funnel, new bool, err error) {
ft.lock.Lock()
defer ft.lock.Unlock()
currentFunnel, exists := ft.funnels[id]
if !exists {
ft.funnels[id] = funnel
return false
if exists {
return currentFunnel, false, nil
}
replaced = !currentFunnel.Equal(funnel)
if replaced {
currentFunnel.Close()
newFunnel, err := newFunnelFunc()
if err != nil {
return nil, false, err
}
ft.funnels[id] = funnel
return replaced
ft.funnels[id] = newFunnel
return newFunnel, true, nil
}
// Unregisters a funnel if the funnel equals to the current funnel