TUN-8822: Prevent concurrent usage of ICMPDecoder
## Summary Some description... Closes TUN-8822
This commit is contained in:
parent
9bc6cbd06d
commit
c6901551e7
|
@ -65,12 +65,11 @@ type datagramConn struct {
|
||||||
icmpRouter ingress.ICMPRouter
|
icmpRouter ingress.ICMPRouter
|
||||||
metrics Metrics
|
metrics Metrics
|
||||||
logger *zerolog.Logger
|
logger *zerolog.Logger
|
||||||
|
datagrams chan []byte
|
||||||
datagrams chan []byte
|
readErrors chan error
|
||||||
readErrors chan error
|
|
||||||
|
|
||||||
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
|
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
|
||||||
icmpDecoder *packet.ICMPDecoder
|
icmpDecoderPool sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
|
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
|
||||||
|
@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
|
||||||
return packet.NewEncoder()
|
return packet.NewEncoder()
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
icmpDecoder: packet.NewICMPDecoder(),
|
icmpDecoderPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return packet.NewICMPDecoder()
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,7 +370,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
|
||||||
|
|
||||||
// Decode the provided ICMPDatagram as an ICMP packet
|
// Decode the provided ICMPDatagram as an ICMP packet
|
||||||
rawPacket := packet.RawPacket{Data: datagram.Payload}
|
rawPacket := packet.RawPacket{Data: datagram.Payload}
|
||||||
icmp, err := c.icmpDecoder.Decode(rawPacket)
|
cachedDecoder := c.icmpDecoderPool.Get()
|
||||||
|
defer c.icmpDecoderPool.Put(cachedDecoder)
|
||||||
|
decoder, ok := cachedDecoder.(*packet.ICMPDecoder)
|
||||||
|
if !ok {
|
||||||
|
c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
icmp, err := decoder.Decode(rawPacket)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logger.Err(err).Msgf("unable to marshal icmp packet")
|
c.logger.Err(err).Msgf("unable to marshal icmp packet")
|
||||||
return
|
return
|
||||||
|
|
|
@ -4,13 +4,17 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
|
@ -304,6 +308,91 @@ func TestDatagramConnServe(t *testing.T) {
|
||||||
assertContextClosed(t, ctx, done, cancel)
|
assertContextClosed(t, ctx, done, cancel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test exists because decoding multiple packets in parallel with the same decoder
|
||||||
|
// instances causes inteference resulting in multiple different raw packets being decoded
|
||||||
|
// as the same decoded packet.
|
||||||
|
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
|
||||||
|
log := zerolog.Nop()
|
||||||
|
quic := newMockQuicConn()
|
||||||
|
session := newMockSession()
|
||||||
|
sessionManager := mockSessionManager{session: &session}
|
||||||
|
router := newMockICMPRouter()
|
||||||
|
conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log)
|
||||||
|
|
||||||
|
// Setup the muxer
|
||||||
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
defer cancel(errors.New("other error"))
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
packetCount := 100
|
||||||
|
packets := make([]*packet.ICMP, 100)
|
||||||
|
ipTemplate := "10.0.0.%d"
|
||||||
|
for i := 1; i <= packetCount; i++ {
|
||||||
|
packets[i-1] = &packet.ICMP{
|
||||||
|
IP: &packet.IP{
|
||||||
|
Src: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)),
|
||||||
|
Protocol: layers.IPProtocolICMPv4,
|
||||||
|
TTL: 20,
|
||||||
|
},
|
||||||
|
Message: &icmp.Message{
|
||||||
|
Type: ipv4.ICMPTypeEcho,
|
||||||
|
Code: 0,
|
||||||
|
Body: &icmp.Echo{
|
||||||
|
ID: 25821,
|
||||||
|
Seq: 58129,
|
||||||
|
Data: []byte("test"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
var receivedPackets []*packet.ICMP
|
||||||
|
go func() {
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
select {
|
||||||
|
case icmpPacket := <-router.recv:
|
||||||
|
receivedPackets = append(receivedPackets, icmpPacket)
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, p := range packets {
|
||||||
|
// We increment here but only decrement when receiving the packet
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
datagram := newICMPDatagram(p)
|
||||||
|
quic.send <- datagram
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// If there were duplicates then we won't have the same number of IPs
|
||||||
|
packetSet := make(map[netip.Addr]*packet.ICMP, 0)
|
||||||
|
for _, p := range receivedPackets {
|
||||||
|
packetSet[p.Dst] = p
|
||||||
|
}
|
||||||
|
assert.Equal(t, len(packetSet), len(packets))
|
||||||
|
|
||||||
|
// Sort the slice by last byte of IP address (the one we increment for each destination)
|
||||||
|
// and then check that we have one match for each packet sent
|
||||||
|
sort.Slice(receivedPackets, func(i, j int) bool {
|
||||||
|
return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3]
|
||||||
|
})
|
||||||
|
for i, p := range receivedPackets {
|
||||||
|
assert.Equal(t, p.Dst, packets[i].Dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel the muxer Serve context and make sure it closes with the expected error
|
||||||
|
assertContextClosed(t, ctx, done, cancel)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
|
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
quic := newMockQuicConn()
|
quic := newMockQuicConn()
|
||||||
|
|
Loading…
Reference in New Issue