TUN-8822: Prevent concurrent usage of ICMPDecoder
## Summary Some description... Closes TUN-8822
This commit is contained in:
parent
9bc6cbd06d
commit
c6901551e7
|
@ -65,12 +65,11 @@ type datagramConn struct {
|
|||
icmpRouter ingress.ICMPRouter
|
||||
metrics Metrics
|
||||
logger *zerolog.Logger
|
||||
|
||||
datagrams chan []byte
|
||||
readErrors chan error
|
||||
datagrams chan []byte
|
||||
readErrors chan error
|
||||
|
||||
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
|
||||
icmpDecoder *packet.ICMPDecoder
|
||||
icmpDecoderPool sync.Pool
|
||||
}
|
||||
|
||||
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
|
||||
|
@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
|
|||
return packet.NewEncoder()
|
||||
},
|
||||
},
|
||||
icmpDecoder: packet.NewICMPDecoder(),
|
||||
icmpDecoderPool: sync.Pool{
|
||||
New: func() any {
|
||||
return packet.NewICMPDecoder()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -367,7 +370,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
|
|||
|
||||
// Decode the provided ICMPDatagram as an ICMP packet
|
||||
rawPacket := packet.RawPacket{Data: datagram.Payload}
|
||||
icmp, err := c.icmpDecoder.Decode(rawPacket)
|
||||
cachedDecoder := c.icmpDecoderPool.Get()
|
||||
defer c.icmpDecoderPool.Put(cachedDecoder)
|
||||
decoder, ok := cachedDecoder.(*packet.ICMPDecoder)
|
||||
if !ok {
|
||||
c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet")
|
||||
return
|
||||
}
|
||||
|
||||
icmp, err := decoder.Decode(rawPacket)
|
||||
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to marshal icmp packet")
|
||||
return
|
||||
|
|
|
@ -4,13 +4,17 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/icmp"
|
||||
|
@ -304,6 +308,91 @@ func TestDatagramConnServe(t *testing.T) {
|
|||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
// This test exists because decoding multiple packets in parallel with the same decoder
|
||||
// instances causes inteference resulting in multiple different raw packets being decoded
|
||||
// as the same decoded packet.
|
||||
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
router := newMockICMPRouter()
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log)
|
||||
|
||||
// Setup the muxer
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(errors.New("other error"))
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
packetCount := 100
|
||||
packets := make([]*packet.ICMP, 100)
|
||||
ipTemplate := "10.0.0.%d"
|
||||
for i := 1; i <= packetCount; i++ {
|
||||
packets[i-1] = &packet.ICMP{
|
||||
IP: &packet.IP{
|
||||
Src: netip.MustParseAddr("192.168.1.1"),
|
||||
Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)),
|
||||
Protocol: layers.IPProtocolICMPv4,
|
||||
TTL: 20,
|
||||
},
|
||||
Message: &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &icmp.Echo{
|
||||
ID: 25821,
|
||||
Seq: 58129,
|
||||
Data: []byte("test"),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var receivedPackets []*packet.ICMP
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
select {
|
||||
case icmpPacket := <-router.recv:
|
||||
receivedPackets = append(receivedPackets, icmpPacket)
|
||||
wg.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, p := range packets {
|
||||
// We increment here but only decrement when receiving the packet
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
datagram := newICMPDatagram(p)
|
||||
quic.send <- datagram
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If there were duplicates then we won't have the same number of IPs
|
||||
packetSet := make(map[netip.Addr]*packet.ICMP, 0)
|
||||
for _, p := range receivedPackets {
|
||||
packetSet[p.Dst] = p
|
||||
}
|
||||
assert.Equal(t, len(packetSet), len(packets))
|
||||
|
||||
// Sort the slice by last byte of IP address (the one we increment for each destination)
|
||||
// and then check that we have one match for each packet sent
|
||||
sort.Slice(receivedPackets, func(i, j int) bool {
|
||||
return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3]
|
||||
})
|
||||
for i, p := range receivedPackets {
|
||||
assert.Equal(t, p.Dst, packets[i].Dst)
|
||||
}
|
||||
|
||||
// Cancel the muxer Serve context and make sure it closes with the expected error
|
||||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
|
|
Loading…
Reference in New Issue