TUN-8822: Prevent concurrent usage of ICMPDecoder

## Summary
Some description...

Closes TUN-8822
This commit is contained in:
Gonçalo Garcia 2024-12-19 07:19:36 -08:00
parent 9bc6cbd06d
commit c6901551e7
2 changed files with 107 additions and 6 deletions

View File

@ -65,12 +65,11 @@ type datagramConn struct {
icmpRouter ingress.ICMPRouter
metrics Metrics
logger *zerolog.Logger
datagrams chan []byte
readErrors chan error
datagrams chan []byte
readErrors chan error
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
icmpDecoder *packet.ICMPDecoder
icmpDecoderPool sync.Pool
}
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
return packet.NewEncoder()
},
},
icmpDecoder: packet.NewICMPDecoder(),
icmpDecoderPool: sync.Pool{
New: func() any {
return packet.NewICMPDecoder()
},
},
}
}
@ -367,7 +370,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
// Decode the provided ICMPDatagram as an ICMP packet
rawPacket := packet.RawPacket{Data: datagram.Payload}
icmp, err := c.icmpDecoder.Decode(rawPacket)
cachedDecoder := c.icmpDecoderPool.Get()
defer c.icmpDecoderPool.Put(cachedDecoder)
decoder, ok := cachedDecoder.(*packet.ICMPDecoder)
if !ok {
c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet")
return
}
icmp, err := decoder.Decode(rawPacket)
if err != nil {
c.logger.Err(err).Msgf("unable to marshal icmp packet")
return

View File

@ -4,13 +4,17 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"golang.org/x/net/icmp"
@ -304,6 +308,91 @@ func TestDatagramConnServe(t *testing.T) {
assertContextClosed(t, ctx, done, cancel)
}
// This test exists because decoding multiple packets in parallel with the same decoder
// instances causes inteference resulting in multiple different raw packets being decoded
// as the same decoded packet.
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log)
// Setup the muxer
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(errors.New("other error"))
done := make(chan error, 1)
go func() {
done <- conn.Serve(ctx)
}()
packetCount := 100
packets := make([]*packet.ICMP, 100)
ipTemplate := "10.0.0.%d"
for i := 1; i <= packetCount; i++ {
packets[i-1] = &packet.ICMP{
IP: &packet.IP{
Src: netip.MustParseAddr("192.168.1.1"),
Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)),
Protocol: layers.IPProtocolICMPv4,
TTL: 20,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: 25821,
Seq: 58129,
Data: []byte("test"),
},
},
}
}
wg := sync.WaitGroup{}
var receivedPackets []*packet.ICMP
go func() {
for ctx.Err() == nil {
select {
case icmpPacket := <-router.recv:
receivedPackets = append(receivedPackets, icmpPacket)
wg.Done()
}
}
}()
for _, p := range packets {
// We increment here but only decrement when receiving the packet
wg.Add(1)
go func() {
datagram := newICMPDatagram(p)
quic.send <- datagram
}()
}
wg.Wait()
// If there were duplicates then we won't have the same number of IPs
packetSet := make(map[netip.Addr]*packet.ICMP, 0)
for _, p := range receivedPackets {
packetSet[p.Dst] = p
}
assert.Equal(t, len(packetSet), len(packets))
// Sort the slice by last byte of IP address (the one we increment for each destination)
// and then check that we have one match for each packet sent
sort.Slice(receivedPackets, func(i, j int) bool {
return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3]
})
for i, p := range receivedPackets {
assert.Equal(t, p.Dst, packets[i].Dst)
}
// Cancel the muxer Serve context and make sure it closes with the expected error
assertContextClosed(t, ctx, done, cancel)
}
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()