diff --git a/ingress/icmp_darwin.go b/ingress/icmp_darwin.go index 73ac4de1..78d0a87f 100644 --- a/ingress/icmp_darwin.go +++ b/ingress/icmp_darwin.go @@ -37,9 +37,9 @@ type icmpProxy struct { // then from the beginning to lastAssignment. // ICMP echo are short lived. By the time an ID is revisited, it should have been released. type echoIDTracker struct { - lock sync.RWMutex - // maps the source IP to an echo ID obtained from assignment - srcIPMapping map[netip.Addr]uint16 + lock sync.Mutex + // maps the source IP, destination IP and original echo ID to a unique echo ID obtained from assignment + mapping map[flow3Tuple]uint16 // assignment tracks if an ID is assigned using index as the ID // The size of the array is math.MaxUint16 because echo ID is 2 bytes assignment [math.MaxUint16]bool @@ -49,20 +49,18 @@ type echoIDTracker struct { func newEchoIDTracker() *echoIDTracker { return &echoIDTracker{ - srcIPMapping: make(map[netip.Addr]uint16), + mapping: make(map[flow3Tuple]uint16), } } -func (eit *echoIDTracker) get(srcIP netip.Addr) (uint16, bool) { - eit.lock.RLock() - defer eit.lock.RUnlock() - id, ok := eit.srcIPMapping[srcIP] - return id, ok -} - -func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) { +// Get assignment or assign a new ID. +func (eit *echoIDTracker) getOrAssign(key flow3Tuple) (id uint16, success bool) { eit.lock.Lock() defer eit.lock.Unlock() + id, exists := eit.mapping[key] + if exists { + return id, true + } if eit.nextAssignment == math.MaxUint16 { eit.nextAssignment = 0 @@ -71,14 +69,14 @@ func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) { for i, assigned := range eit.assignment[eit.nextAssignment:] { if !assigned { echoID := uint16(i) + eit.nextAssignment - eit.set(srcIP, echoID) + eit.set(key, echoID) return echoID, true } } for i, assigned := range eit.assignment[0:eit.nextAssignment] { if !assigned { echoID := uint16(i) - eit.set(srcIP, echoID) + eit.set(key, echoID) return echoID, true } } @@ -86,20 +84,20 @@ func (eit *echoIDTracker) assign(srcIP netip.Addr) (uint16, bool) { } // Caller should hold the lock -func (eit *echoIDTracker) set(srcIP netip.Addr, echoID uint16) { - eit.assignment[echoID] = true - eit.srcIPMapping[srcIP] = echoID - eit.nextAssignment = echoID + 1 +func (eit *echoIDTracker) set(key flow3Tuple, assignedEchoID uint16) { + eit.assignment[assignedEchoID] = true + eit.mapping[key] = assignedEchoID + eit.nextAssignment = assignedEchoID + 1 } -func (eit *echoIDTracker) release(srcIP netip.Addr, id uint16) bool { +func (eit *echoIDTracker) release(key flow3Tuple, assigned uint16) bool { eit.lock.Lock() defer eit.lock.Unlock() - currentID, exists := eit.srcIPMapping[srcIP] - if exists && id == currentID { - delete(eit.srcIPMapping, srcIP) - eit.assignment[id] = false + currentEchoID, exists := eit.mapping[key] + if exists && assigned == currentEchoID { + delete(eit.mapping, key) + eit.assignment[assigned] = false return true } return false @@ -134,33 +132,46 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er if pk == nil { return errPacketNil } + originalEcho, err := getICMPEcho(pk.Message) + if err != nil { + return err + } + echoIDTrackerKey := flow3Tuple{ + srcIP: pk.Src, + dstIP: pk.Dst, + originalEchoID: originalEcho.ID, + } // TODO: TUN-6744 assign unique flow per (src, echo ID) - echoID, exists := ip.echoIDTracker.get(pk.Src) - if !exists { + assignedEchoID, success := ip.echoIDTracker.getOrAssign(echoIDTrackerKey) + if !success { + return fmt.Errorf("failed to assign unique echo ID") + } + newFunnelFunc := func() (packet.Funnel, error) { originalEcho, err := getICMPEcho(pk.Message) if err != nil { - return err + return nil, err } - echoID, exists = ip.echoIDTracker.assign(pk.Src) - if !exists { - return fmt.Errorf("failed to assign unique echo ID") - } - funnelID := echoFunnelID(echoID) originSender := originSender{ - conn: ip.conn, - echoIDTracker: ip.echoIDTracker, - srcIP: pk.Src, - echoID: echoID, + conn: ip.conn, + echoIDTracker: ip.echoIDTracker, + echoIDTrackerKey: echoIDTrackerKey, + assignedEchoID: assignedEchoID, } - icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(echoID), originalEcho.ID, ip.encoder) - if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced { - ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel") - } - return icmpFlow.sendToDst(pk.Dst, pk.Message) + icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, int(assignedEchoID), originalEcho.ID, ip.encoder) + return icmpFlow, nil } - funnel, exists := ip.srcFunnelTracker.Get(echoFunnelID(echoID)) - if !exists { - return packet.ErrFunnelNotFound + funnelID := echoFunnelID(assignedEchoID) + funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc) + if err != nil { + return err + } + if isNew { + ip.logger.Debug(). + Str("src", pk.Src.String()). + Str("dst", pk.Dst.String()). + Int("originalEchoID", originalEcho.ID). + Int("assignedEchoID", int(assignedEchoID)). + Msg("New flow") } icmpFlow, err := toICMPEchoFlow(funnel) if err != nil { @@ -199,7 +210,7 @@ func (ip *icmpProxy) Serve(ctx context.Context) error { ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type) continue } - if ip.sendReply(reply); err != nil { + if err := ip.sendReply(reply); err != nil { ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply") continue } @@ -227,7 +238,8 @@ func (ip *icmpProxy) handleFullPacket(decoder *packet.ICMPDecoder, rawPacket []b } func (ip *icmpProxy) sendReply(reply *echoReply) error { - funnel, ok := ip.srcFunnelTracker.Get(echoFunnelID(reply.echo.ID)) + funnelID := echoFunnelID(reply.echo.ID) + funnel, ok := ip.srcFunnelTracker.Get(funnelID) if !ok { return packet.ErrFunnelNotFound } @@ -240,10 +252,10 @@ func (ip *icmpProxy) sendReply(reply *echoReply) error { // originSender wraps icmp.PacketConn to implement packet.FunnelUniPipe interface type originSender struct { - conn *icmp.PacketConn - echoIDTracker *echoIDTracker - srcIP netip.Addr - echoID uint16 + conn *icmp.PacketConn + echoIDTracker *echoIDTracker + echoIDTrackerKey flow3Tuple + assignedEchoID uint16 } func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error { @@ -254,6 +266,6 @@ func (os *originSender) SendPacket(dst netip.Addr, pk packet.RawPacket) error { } func (os *originSender) Close() error { - os.echoIDTracker.release(os.srcIP, os.echoID) + os.echoIDTracker.release(os.echoIDTrackerKey, os.assignedEchoID) return nil } diff --git a/ingress/icmp_darwin_test.go b/ingress/icmp_darwin_test.go index 8baccb58..90b92710 100644 --- a/ingress/icmp_darwin_test.go +++ b/ingress/icmp_darwin_test.go @@ -12,80 +12,110 @@ import ( func TestSingleEchoIDTracker(t *testing.T) { tracker := newEchoIDTracker() - srcIP := netip.MustParseAddr("127.0.0.1") - echoID, ok := tracker.get(srcIP) - require.False(t, ok) - require.Equal(t, uint16(0), echoID) + key := flow3Tuple{ + srcIP: netip.MustParseAddr("172.16.0.1"), + dstIP: netip.MustParseAddr("172.16.0.2"), + originalEchoID: 5182, + } // not assigned yet, so nothing to release - require.False(t, tracker.release(srcIP, echoID)) + require.False(t, tracker.release(key, 0)) - echoID, ok = tracker.assign(srcIP) + echoID, ok := tracker.getOrAssign(key) require.True(t, ok) require.Equal(t, uint16(0), echoID) - echoID, ok = tracker.get(srcIP) + // Second time should return the same echo ID + echoID, ok = tracker.getOrAssign(key) require.True(t, ok) require.Equal(t, uint16(0), echoID) // releasing a different ID returns false - require.False(t, tracker.release(srcIP, 1999)) - require.True(t, tracker.release(srcIP, echoID)) + require.False(t, tracker.release(key, 1999)) + require.True(t, tracker.release(key, echoID)) // releasing the second time returns false - require.False(t, tracker.release(srcIP, echoID)) - - echoID, ok = tracker.get(srcIP) - require.False(t, ok) - require.Equal(t, uint16(0), echoID) + require.False(t, tracker.release(key, echoID)) // Move to the next IP - echoID, ok = tracker.assign(srcIP) + echoID, ok = tracker.getOrAssign(key) require.True(t, ok) require.Equal(t, uint16(1), echoID) } func TestFullEchoIDTracker(t *testing.T) { + var ( + dstIP = netip.MustParseAddr("192.168.0.1") + originalEchoID = 41820 + ) tracker := newEchoIDTracker() - firstIP := netip.MustParseAddr("172.16.0.1") - srcIP := firstIP + firstSrcIP := netip.MustParseAddr("172.16.0.1") + srcIP := firstSrcIP for i := uint16(0); i < math.MaxUint16; i++ { - echoID, ok := tracker.assign(srcIP) + key := flow3Tuple{ + srcIP: srcIP, + dstIP: dstIP, + originalEchoID: originalEchoID, + } + echoID, ok := tracker.getOrAssign(key) require.True(t, ok) require.Equal(t, i, echoID) - echoID, ok = tracker.get(srcIP) + echoID, ok = tracker.get(key) require.True(t, ok) require.Equal(t, i, echoID) + srcIP = srcIP.Next() } + key := flow3Tuple{ + srcIP: srcIP.Next(), + dstIP: dstIP, + originalEchoID: originalEchoID, + } // All echo IDs are assigned - echoID, ok := tracker.assign(srcIP.Next()) + echoID, ok := tracker.getOrAssign(key) require.False(t, ok) require.Equal(t, uint16(0), echoID) - srcIP = firstIP + srcIP = firstSrcIP for i := uint16(0); i < math.MaxUint16; i++ { - ok := tracker.release(srcIP, i) + key := flow3Tuple{ + srcIP: srcIP, + dstIP: dstIP, + originalEchoID: originalEchoID, + } + ok := tracker.release(key, i) require.True(t, ok) - echoID, ok = tracker.get(srcIP) + echoID, ok = tracker.get(key) require.False(t, ok) require.Equal(t, uint16(0), echoID) srcIP = srcIP.Next() } // The IDs are assignable again - srcIP = firstIP + srcIP = firstSrcIP for i := uint16(0); i < math.MaxUint16; i++ { - echoID, ok := tracker.assign(srcIP) + key := flow3Tuple{ + srcIP: srcIP, + dstIP: dstIP, + originalEchoID: originalEchoID, + } + echoID, ok := tracker.getOrAssign(key) require.True(t, ok) require.Equal(t, i, echoID) - echoID, ok = tracker.get(srcIP) + echoID, ok = tracker.get(key) require.True(t, ok) require.Equal(t, i, echoID) srcIP = srcIP.Next() } } + +func (eit *echoIDTracker) get(key flow3Tuple) (id uint16, exist bool) { + eit.lock.Lock() + defer eit.lock.Unlock() + id, exists := eit.mapping[key] + return id, exists +} diff --git a/ingress/icmp_linux.go b/ingress/icmp_linux.go index 3b7ae828..5f122e49 100644 --- a/ingress/icmp_linux.go +++ b/ingress/icmp_linux.go @@ -57,45 +57,57 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er if pk == nil { return errPacketNil } - funnelID := srcIPFunnelID(pk.Src) - funnel, exists := ip.srcFunnelTracker.Get(funnelID) - if !exists { - originalEcho, err := getICMPEcho(pk.Message) - if err != nil { - return err - } + originalEcho, err := getICMPEcho(pk.Message) + if err != nil { + return err + } + newConnChan := make(chan *icmp.PacketConn, 1) + newFunnelFunc := func() (packet.Funnel, error) { conn, err := newICMPConn(ip.listenIP) if err != nil { - return errors.Wrap(err, "failed to open ICMP socket") + return nil, errors.Wrap(err, "failed to open ICMP socket") } + newConnChan <- conn localUDPAddr, ok := conn.LocalAddr().(*net.UDPAddr) if !ok { - return fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr()) + return nil, fmt.Errorf("ICMP listener address %s is not net.UDPAddr", conn.LocalAddr()) } originSender := originSender{conn: conn} echoID := localUDPAddr.Port icmpFlow := newICMPEchoFlow(pk.Src, &originSender, responder, echoID, originalEcho.ID, packet.NewEncoder()) - if replaced := ip.srcFunnelTracker.Register(funnelID, icmpFlow); replaced { - ip.logger.Info().Str("src", pk.Src.String()).Msg("Replaced funnel") - } - if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil { - return errors.Wrap(err, "failed to send ICMP echo request") - } - go func() { - defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow) - if err := ip.listenResponse(icmpFlow, conn); err != nil { - ip.logger.Err(err). - Str("funnelID", funnelID.String()). - Int("echoID", echoID). - Msg("Failed to listen for ICMP echo response") - } - }() - return nil + return icmpFlow, nil + } + funnelID := flow3Tuple{ + srcIP: pk.Src, + dstIP: pk.Dst, + originalEchoID: originalEcho.ID, + } + funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc) + if err != nil { + return err } icmpFlow, err := toICMPEchoFlow(funnel) if err != nil { return err } + if isNew { + ip.logger.Debug(). + Str("src", pk.Src.String()). + Str("dst", pk.Dst.String()). + Int("originalEchoID", originalEcho.ID). + Msg("New flow") + conn := <-newConnChan + go func() { + defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow) + if err := ip.listenResponse(icmpFlow, conn); err != nil { + ip.logger.Debug().Err(err). + Str("src", pk.Src.String()). + Str("dst", pk.Dst.String()). + Int("originalEchoID", originalEcho.ID). + Msg("Failed to listen for ICMP echo response") + } + }() + } if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil { return errors.Wrap(err, "failed to send ICMP echo request") } @@ -146,12 +158,11 @@ func (os *originSender) Close() error { return os.conn.Close() } -type srcIPFunnelID netip.Addr - -func (sifd srcIPFunnelID) Type() string { - return "srcIP" +// Only linux uses flow3Tuple as FunnelID +func (ft flow3Tuple) Type() string { + return "srcIP_dstIP_echoID" } -func (sifd srcIPFunnelID) String() string { - return netip.Addr(sifd).String() +func (ft flow3Tuple) String() string { + return fmt.Sprintf("%s:%s:%d", ft.srcIP, ft.dstIP, ft.originalEchoID) } diff --git a/ingress/icmp_posix.go b/ingress/icmp_posix.go index a2a0a5b9..5c3c62b3 100644 --- a/ingress/icmp_posix.go +++ b/ingress/icmp_posix.go @@ -32,17 +32,10 @@ func netipAddr(addr net.Addr) (netip.Addr, bool) { return netip.AddrFromSlice(udpAddr.IP) } -type flowID struct { - srcIP netip.Addr - echoID int -} - -func (fi *flowID) Type() string { - return "srcIP_echoID" -} - -func (fi *flowID) String() string { - return fmt.Sprintf("%s:%d", fi.srcIP, fi.echoID) +type flow3Tuple struct { + srcIP netip.Addr + dstIP netip.Addr + originalEchoID int } // icmpEchoFlow implements the packet.Funnel interface. diff --git a/ingress/icmp_windows.go b/ingress/icmp_windows.go index 4170d5fc..8ad313b5 100644 --- a/ingress/icmp_windows.go +++ b/ingress/icmp_windows.go @@ -315,23 +315,22 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, d }, } - serializedPacket, err := ip.encodeICMPReply(&pk) + cachedEncoder := ip.encoderPool.Get() + // The encoded packet is a slice to of the encoder, so we shouldn't return the encoder back to the pool until + // the encoded packet is sent. + defer ip.encoderPool.Put(cachedEncoder) + encoder, ok := cachedEncoder.(*packet.Encoder) + if !ok { + return fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder) + } + + serializedPacket, err := encoder.Encode(&pk) if err != nil { return err } return responder.SendPacket(request.Src, serializedPacket) } -func (ip *icmpProxy) encodeICMPReply(pk *packet.ICMP) (packet.RawPacket, error) { - cachedEncoder := ip.encoderPool.Get() - defer ip.encoderPool.Put(cachedEncoder) - encoder, ok := cachedEncoder.(*packet.Encoder) - if !ok { - return packet.RawPacket{}, fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder) - } - return encoder.Encode(pk) -} - func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) ([]byte, error) { if dst.Is6() { if ip.srcSocketAddr == nil { diff --git a/ingress/origin_icmp_proxy_test.go b/ingress/origin_icmp_proxy_test.go index 5684a608..b961b7d4 100644 --- a/ingress/origin_icmp_proxy_test.go +++ b/ingress/origin_icmp_proxy_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync" "testing" "github.com/google/gopacket/layers" @@ -97,6 +98,91 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) { <-proxyDone } +// TestConcurrentRequests makes sure icmpRouter can send concurrent requests to the same destination with different +// echo ID. This simulates concurrent ping to the same destination. +func TestConcurrentRequestsToSameDst(t *testing.T) { + const ( + concurrentPings = 5 + endSeq = 5 + ) + + router, err := NewICMPRouter(&noopLogger) + require.NoError(t, err) + + proxyDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + router.Serve(ctx) + close(proxyDone) + }() + + var wg sync.WaitGroup + // icmpv4 and icmpv6 each has concurrentPings + wg.Add(concurrentPings * 2) + for i := 0; i < concurrentPings; i++ { + echoID := 38451 + i + go func() { + defer wg.Done() + responder := echoFlowResponder{ + decoder: packet.NewICMPDecoder(), + respChan: make(chan []byte, 1), + } + for seq := 0; seq < endSeq; seq++ { + pk := &packet.ICMP{ + IP: &packet.IP{ + Src: localhostIP, + Dst: localhostIP, + Protocol: layers.IPProtocolICMPv4, + TTL: packet.DefaultTTL, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: echoID, + Seq: seq, + Data: []byte(fmt.Sprintf("icmpv4 echo id %d, seq %d", echoID, seq)), + }, + }, + } + require.NoError(t, router.Request(pk, &responder)) + responder.validate(t, pk) + } + }() + go func() { + defer wg.Done() + responder := echoFlowResponder{ + decoder: packet.NewICMPDecoder(), + respChan: make(chan []byte, 1), + } + for seq := 0; seq < endSeq; seq++ { + pk := &packet.ICMP{ + IP: &packet.IP{ + Src: localhostIPv6, + Dst: localhostIPv6, + Protocol: layers.IPProtocolICMPv6, + TTL: packet.DefaultTTL, + }, + Message: &icmp.Message{ + Type: ipv6.ICMPTypeEchoRequest, + Code: 0, + Body: &icmp.Echo{ + ID: echoID, + Seq: seq, + Data: []byte(fmt.Sprintf("icmpv6 echo id %d, seq %d", echoID, seq)), + }, + }, + } + require.NoError(t, router.Request(pk, &responder)) + responder.validate(t, pk) + } + }() + } + wg.Wait() + cancel() + <-proxyDone +} + // TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo func TestICMPRouterRejectNotEcho(t *testing.T) { msgs := []icmp.Message{ diff --git a/packet/funnel.go b/packet/funnel.go index 0cb1667a..f0124070 100644 --- a/packet/funnel.go +++ b/packet/funnel.go @@ -158,20 +158,19 @@ func (ft *FunnelTracker) Get(id FunnelID) (Funnel, bool) { } // Registers a funnel. It replaces the current funnel. -func (ft *FunnelTracker) Register(id FunnelID, funnel Funnel) (replaced bool) { +func (ft *FunnelTracker) GetOrRegister(id FunnelID, newFunnelFunc func() (Funnel, error)) (funnel Funnel, new bool, err error) { ft.lock.Lock() defer ft.lock.Unlock() currentFunnel, exists := ft.funnels[id] - if !exists { - ft.funnels[id] = funnel - return false + if exists { + return currentFunnel, false, nil } - replaced = !currentFunnel.Equal(funnel) - if replaced { - currentFunnel.Close() + newFunnel, err := newFunnelFunc() + if err != nil { + return nil, false, err } - ft.funnels[id] = funnel - return replaced + ft.funnels[id] = newFunnel + return newFunnel, true, nil } // Unregisters a funnel if the funnel equals to the current funnel