TUN-9882: Add buffers for UDP and ICMP datagrams in datagram v3

Instead of creating a go routine to process each incoming datagram from the tunnel, a single consumer (the demuxer) will
process each of the datagrams in serial.

Registration datagrams will still be spun out into separate go routines since they are responsible for managing the
lifetime of the session once started via the `Serve` method.

UDP payload datagrams will be handled in separate channels to allow for parallel writing inside of the scope of a
session via a new write loop. This channel will have a small buffer to help unblock the demuxer from dequeueing other
datagrams.

ICMP datagrams will be funneled into a single channel across all possible origins with a single consumer to write to
their respective destinations.

Each of these changes is to prevent datagram reordering from occurring when dequeuing from the tunnel connection. By
establishing a single demuxer that serializes the writes per session, each session will be able to write sequentially,
but in parallel to their respective origins.

Closes TUN-9882
This commit is contained in:
Devin Carr 2025-10-07 16:14:01 -07:00
parent fff1fc7390
commit 1fb466941a
7 changed files with 443 additions and 270 deletions

View File

@ -182,7 +182,7 @@ fuzz:
@go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet @go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet
@go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet @go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet
@go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzSessionServe -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzSessionRead -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3

View File

@ -1,6 +1,7 @@
package v3_test package v3_test
import ( import (
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net/netip" "net/netip"
@ -14,12 +15,18 @@ import (
func makePayload(size int) []byte { func makePayload(size int) []byte {
payload := make([]byte, size) payload := make([]byte, size)
for i := range len(payload) { _, _ = rand.Read(payload)
payload[i] = 0xfc
}
return payload return payload
} }
func makePayloads(size int, count int) [][]byte {
payloads := make([][]byte, count)
for i := range payloads {
payloads[i] = makePayload(size)
}
return payloads
}
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
payload := makePayload(1280) payload := makePayload(1280)
tests := []*v3.UDPSessionRegistrationDatagram{ tests := []*v3.UDPSessionRegistrationDatagram{

View File

@ -17,6 +17,9 @@ const (
// Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader. // Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader.
// This has worked previously well for datagramv2, so we will start with this as well // This has worked previously well for datagramv2, so we will start with this as well
demuxChanCapacity = 16 demuxChanCapacity = 16
// This provides a small buffer for the PacketRouter to poll ICMP packets from the QUIC connection
// before writing them to the origin.
icmpDatagramChanCapacity = 128
logSrcKey = "src" logSrcKey = "src"
logDstKey = "dst" logDstKey = "dst"
@ -66,6 +69,7 @@ type datagramConn struct {
metrics Metrics metrics Metrics
logger *zerolog.Logger logger *zerolog.Logger
datagrams chan []byte datagrams chan []byte
icmpDatagramChan chan *ICMPDatagram
readErrors chan error readErrors chan error
icmpEncoderPool sync.Pool // a pool of *packet.Encoder icmpEncoderPool sync.Pool // a pool of *packet.Encoder
@ -82,6 +86,7 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
metrics: metrics, metrics: metrics,
logger: &log, logger: &log,
datagrams: make(chan []byte, demuxChanCapacity), datagrams: make(chan []byte, demuxChanCapacity),
icmpDatagramChan: make(chan *ICMPDatagram, icmpDatagramChanCapacity),
readErrors: make(chan error, 2), readErrors: make(chan error, 2),
icmpEncoderPool: sync.Pool{ icmpEncoderPool: sync.Pool{
New: func() any { New: func() any {
@ -168,6 +173,9 @@ func (c *datagramConn) Serve(ctx context.Context) error {
readCtx, cancel := context.WithCancel(connCtx) readCtx, cancel := context.WithCancel(connCtx)
defer cancel() defer cancel()
go c.pollDatagrams(readCtx) go c.pollDatagrams(readCtx)
// Processing ICMP datagrams also monitors the reader context since the ICMP datagrams from the reader are the input
// for the routine.
go c.processICMPDatagrams(readCtx)
for { for {
// We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur. // We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur.
var datagram []byte var datagram []byte
@ -181,17 +189,16 @@ func (c *datagramConn) Serve(ctx context.Context) error {
// Monitor for any hard errors from reading the connection // Monitor for any hard errors from reading the connection
case err := <-c.readErrors: case err := <-c.readErrors:
return err return err
// Otherwise, wait and dequeue datagrams as they come in // Wait and dequeue datagrams as they come in
case d := <-c.datagrams: case d := <-c.datagrams:
datagram = d datagram = d
} }
// Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated. // Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated.
go func() {
typ, err := ParseDatagramType(datagram) typ, err := ParseDatagramType(datagram)
if err != nil { if err != nil {
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ) c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ)
return continue
} }
switch typ { switch typ {
case UDPSessionRegistrationType: case UDPSessionRegistrationType:
@ -199,20 +206,23 @@ func (c *datagramConn) Serve(ctx context.Context) error {
err := reg.UnmarshalBinary(datagram) err := reg.UnmarshalBinary(datagram)
if err != nil { if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal session registration datagram") c.logger.Err(err).Msgf("unable to unmarshal session registration datagram")
return continue
} }
logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger() logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger()
// We bind the new session to the quic connection context instead of cloudflared context to allow for the // We bind the new session to the quic connection context instead of cloudflared context to allow for the
// quic connection to close and close only the sessions bound to it. Closing of cloudflared will also // quic connection to close and close only the sessions bound to it. Closing of cloudflared will also
// initiate the close of the quic connection, so we don't have to worry about the application context // initiate the close of the quic connection, so we don't have to worry about the application context
// in the scope of a session. // in the scope of a session.
c.handleSessionRegistrationDatagram(connCtx, reg, &logger) //
// Additionally, we spin out the registration into a separate go routine to handle the Serve'ing of the
// session in a separate routine from the demuxer.
go c.handleSessionRegistrationDatagram(connCtx, reg, &logger)
case UDPSessionPayloadType: case UDPSessionPayloadType:
payload := &UDPSessionPayloadDatagram{} payload := &UDPSessionPayloadDatagram{}
err := payload.UnmarshalBinary(datagram) err := payload.UnmarshalBinary(datagram)
if err != nil { if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal session payload datagram") c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
return continue
} }
logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger() logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
c.handleSessionPayloadDatagram(payload, &logger) c.handleSessionPayloadDatagram(payload, &logger)
@ -221,18 +231,17 @@ func (c *datagramConn) Serve(ctx context.Context) error {
err := packet.UnmarshalBinary(datagram) err := packet.UnmarshalBinary(datagram)
if err != nil { if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram") c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
return continue
} }
c.handleICMPPacket(packet) c.handleICMPPacket(packet)
case UDPSessionRegistrationResponseType: case UDPSessionRegistrationResponseType:
// cloudflared should never expect to receive UDP session responses as it will not initiate new // cloudflared should never expect to receive UDP session responses as it will not initiate new
// sessions towards the edge. // sessions towards the edge.
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType) c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
return continue
default: default:
c.logger.Error().Msgf("unknown datagram type received: %d", typ) c.logger.Error().Msgf("unknown datagram type received: %d", typ)
} }
}()
} }
} }
@ -243,24 +252,21 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
Str(logDstKey, datagram.Dest.String()). Str(logDstKey, datagram.Dest.String()).
Logger() Logger()
session, err := c.sessionManager.RegisterSession(datagram, c) session, err := c.sessionManager.RegisterSession(datagram, c)
if err != nil {
switch err { switch err {
case nil:
// Continue as normal
case ErrSessionAlreadyRegistered: case ErrSessionAlreadyRegistered:
// Session is already registered and likely the response got lost // Session is already registered and likely the response got lost
c.handleSessionAlreadyRegistered(datagram.RequestID, &log) c.handleSessionAlreadyRegistered(datagram.RequestID, &log)
return
case ErrSessionBoundToOtherConn: case ErrSessionBoundToOtherConn:
// Session is already registered but to a different connection // Session is already registered but to a different connection
c.handleSessionMigration(datagram.RequestID, &log) c.handleSessionMigration(datagram.RequestID, &log)
return
case ErrSessionRegistrationRateLimited: case ErrSessionRegistrationRateLimited:
// There are too many concurrent sessions so we return an error to force a retry later // There are too many concurrent sessions so we return an error to force a retry later
c.handleSessionRegistrationRateLimited(datagram, &log) c.handleSessionRegistrationRateLimited(datagram, &log)
return
default: default:
log.Err(err).Msg("flow registration failure") log.Err(err).Msg("flow registration failure")
c.handleSessionRegistrationFailure(datagram.RequestID, &log) c.handleSessionRegistrationFailure(datagram.RequestID, &log)
}
return return
} }
log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger() log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger()
@ -365,21 +371,42 @@ func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadD
logger.Err(err).Msgf("unable to find flow") logger.Err(err).Msgf("unable to find flow")
return return
} }
// We ignore the bytes written to the socket because any partial write must return an error. s.Write(datagram.Payload)
_, err = s.Write(datagram.Payload)
if err != nil {
logger.Err(err).Msgf("unable to write payload for the flow")
return
}
} }
// Handles incoming ICMP datagrams. // Handles incoming ICMP datagrams into a serialized channel to be handled by a single consumer.
func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) { func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
if c.icmpRouter == nil { if c.icmpRouter == nil {
// ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets // ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets
return return
} }
select {
case c.icmpDatagramChan <- datagram:
default:
// If the ICMP datagram channel is full, drop any additional incoming.
c.logger.Warn().Msg("failed to write icmp packet to origin: dropped")
}
}
// Consumes from the ICMP datagram channel to write out the ICMP requests to an origin.
func (c *datagramConn) processICMPDatagrams(ctx context.Context) {
if c.icmpRouter == nil {
// ICMPRouter is disabled so we ignore all incoming ICMP packets
return
}
for {
select {
// If the provided context is closed we want to exit the write loop
case <-ctx.Done():
return
case datagram := <-c.icmpDatagramChan:
c.writeICMPPacket(datagram)
}
}
}
func (c *datagramConn) writeICMPPacket(datagram *ICMPDatagram) {
// Decode the provided ICMPDatagram as an ICMP packet // Decode the provided ICMPDatagram as an ICMP packet
rawPacket := packet.RawPacket{Data: datagram.Payload} rawPacket := packet.RawPacket{Data: datagram.Payload}
cachedDecoder := c.icmpDecoderPool.Get() cachedDecoder := c.icmpDecoderPool.Get()

View File

@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -92,7 +93,7 @@ func TestDatagramConn_New(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(newMockQuicConn(t.Context()), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
if conn == nil { if conn == nil {
t.Fatal("expected valid connection") t.Fatal("expected valid connection")
} }
@ -104,7 +105,9 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
payload := []byte{0xef, 0xef} payload := []byte{0xef, 0xef}
@ -123,7 +126,9 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
@ -149,7 +154,9 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
@ -166,7 +173,9 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
defer cancel() defer cancel()
quic.ctx = ctx quic.ctx = ctx
@ -195,15 +204,17 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) { func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
sessionManager := &mockSessionManager{ sessionManager := &mockSessionManager{
expectedRegErr: v3.ErrSessionRegistrationRateLimited, expectedRegErr: v3.ErrSessionRegistrationRateLimited,
} }
conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
// Setup the muxer // Setup the muxer
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancelCause(t.Context())
defer cancel() defer cancel(context.Canceled)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- conn.Serve(ctx) done <- conn.Serve(ctx)
@ -223,9 +234,12 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
require.EqualValues(t, testRequestID, resp.RequestID) require.EqualValues(t, testRequestID, resp.RequestID)
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType) require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
assertContextClosed(t, ctx, done, cancel)
} }
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
defer leaktest.Check(t)()
for _, test := range []struct { for _, test := range []struct {
name string name string
input []byte input []byte
@ -250,7 +264,9 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
logOutput := new(LockedBuffer) logOutput := new(LockedBuffer)
log := zerolog.New(logOutput) log := zerolog.New(logOutput)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
quic.send <- test.input quic.send <- test.input
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -289,8 +305,11 @@ func (b *LockedBuffer) String() string {
} }
func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
expectedErr := errors.New("unable to register session") expectedErr := errors.New("unable to register session")
sessionManager := mockSessionManager{expectedRegErr: expectedErr} sessionManager := mockSessionManager{expectedRegErr: expectedErr}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -324,8 +343,11 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
} }
func TestDatagramConnServe(t *testing.T) { func TestDatagramConnServe(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -372,8 +394,11 @@ func TestDatagramConnServe(t *testing.T) {
// instances causes inteference resulting in multiple different raw packets being decoded // instances causes inteference resulting in multiple different raw packets being decoded
// as the same decoded packet. // as the same decoded packet.
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) { func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
router := newMockICMPRouter() router := newMockICMPRouter()
@ -413,11 +438,15 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
var receivedPackets []*packet.ICMP var receivedPackets []*packet.ICMP
go func() { go func() {
for ctx.Err() == nil { for {
icmpPacket := <-router.recv select {
case <-ctx.Done():
return
case icmpPacket := <-router.recv:
receivedPackets = append(receivedPackets, icmpPacket) receivedPackets = append(receivedPackets, icmpPacket)
wg.Done() wg.Done()
} }
}
}() }()
for _, p := range packets { for _, p := range packets {
@ -452,8 +481,11 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
} }
func TestDatagramConnServe_RegisterTwice(t *testing.T) { func TestDatagramConnServe_RegisterTwice(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -514,12 +546,17 @@ func TestDatagramConnServe_RegisterTwice(t *testing.T) {
} }
func TestDatagramConnServe_MigrateConnection(t *testing.T) { func TestDatagramConnServe_MigrateConnection(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
quic2 := newMockQuicConn() conn2Ctx, conn2Cancel := context.WithCancelCause(t.Context())
defer conn2Cancel(context.Canceled)
quic2 := newMockQuicConn(conn2Ctx)
conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log) conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log)
// Setup the muxer // Setup the muxer
@ -597,8 +634,11 @@ func TestDatagramConnServe_MigrateConnection(t *testing.T) {
} }
func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
// mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer // mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer
sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound} sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -624,9 +664,12 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
assertContextClosed(t, ctx, done, cancel) assertContextClosed(t, ctx, done, cancel)
} }
func TestDatagramConnServe_Payload(t *testing.T) { func TestDatagramConnServe_Payloads(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -639,15 +682,26 @@ func TestDatagramConnServe_Payload(t *testing.T) {
done <- conn.Serve(ctx) done <- conn.Serve(ctx)
}() }()
// Send new session registration // Send session payloads
expectedPayload := []byte{0xef, 0xef} expectedPayloads := makePayloads(256, 16)
datagram := newSessionPayloadDatagram(testRequestID, expectedPayload) go func() {
for _, payload := range expectedPayloads {
datagram := newSessionPayloadDatagram(testRequestID, payload)
quic.send <- datagram quic.send <- datagram
}
}()
// Session should receive the payload // Session should receive the payloads (in-order)
payload := <-session.recv for i, payload := range expectedPayloads {
if !slices.Equal(expectedPayload, payload) { select {
t.Fatalf("expected session receieve the payload sent via the muxer") case recv := <-session.recv:
if !slices.Equal(recv, payload) {
t.Fatalf("expected session receieve the payload[%d] sent via the muxer: (%x) (%x)", i, recv[:16], payload[:16])
}
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
} }
// Cancel the muxer Serve context and make sure it closes with the expected error // Cancel the muxer Serve context and make sure it closes with the expected error
@ -655,8 +709,11 @@ func TestDatagramConnServe_Payload(t *testing.T) {
} }
func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) { func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
router := newMockICMPRouter() router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
@ -701,8 +758,11 @@ func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
} }
func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) { func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
router := newMockICMPRouter() router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
@ -821,9 +881,9 @@ type mockQuicConn struct {
recv chan []byte recv chan []byte
} }
func newMockQuicConn() *mockQuicConn { func newMockQuicConn(ctx context.Context) *mockQuicConn {
return &mockQuicConn{ return &mockQuicConn{
ctx: context.Background(), ctx: ctx,
send: make(chan []byte, 1), send: make(chan []byte, 1),
recv: make(chan []byte, 1), recv: make(chan []byte, 1),
} }
@ -841,7 +901,12 @@ func (m *mockQuicConn) SendDatagram(payload []byte) error {
} }
func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) { func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) {
return <-m.send, nil select {
case <-m.ctx.Done():
return nil, m.ctx.Err()
case b := <-m.send:
return b, nil
}
} }
type mockQuicConnReadError struct { type mockQuicConnReadError struct {
@ -905,11 +970,10 @@ func (m *mockSession) Serve(ctx context.Context) error {
return v3.SessionCloseErr return v3.SessionCloseErr
} }
func (m *mockSession) Write(payload []byte) (n int, err error) { func (m *mockSession) Write(payload []byte) {
b := make([]byte, len(payload)) b := make([]byte, len(payload))
copy(b, payload) copy(b, payload)
m.recv <- b m.recv <- b
return len(b), nil
} }
func (m *mockSession) Close() error { func (m *mockSession) Close() error {

View File

@ -22,6 +22,11 @@ const (
// this value (maxDatagramPayloadLen). // this value (maxDatagramPayloadLen).
maxOriginUDPPacketSize = 1500 maxOriginUDPPacketSize = 1500
// The maximum amount of datagrams a session will queue up before it begins dropping datagrams.
// This channel buffer is small because we assume that the dedicated writer to the origin is typically
// fast enought to keep the channel empty.
writeChanCapacity = 16
logFlowID = "flowID" logFlowID = "flowID"
logPacketSizeKey = "packetSize" logPacketSizeKey = "packetSize"
) )
@ -49,7 +54,7 @@ func newSessionIdleErr(timeout time.Duration) error {
} }
type Session interface { type Session interface {
io.WriteCloser io.Closer
ID() RequestID ID() RequestID
ConnectionID() uint8 ConnectionID() uint8
RemoteAddr() net.Addr RemoteAddr() net.Addr
@ -58,6 +63,7 @@ type Session interface {
Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger) Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger)
// Serve starts the event loop for processing UDP packets // Serve starts the event loop for processing UDP packets
Serve(ctx context.Context) error Serve(ctx context.Context) error
Write(payload []byte)
} }
type session struct { type session struct {
@ -67,9 +73,15 @@ type session struct {
originAddr net.Addr originAddr net.Addr
localAddr net.Addr localAddr net.Addr
eyeball atomic.Pointer[DatagramConn] eyeball atomic.Pointer[DatagramConn]
writeChan chan []byte
// activeAtChan is used to communicate the last read/write time // activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time activeAtChan chan time.Time
closeChan chan error errChan chan error
// The close channel signal only exists for the write loop because the read loop is always waiting on a read
// from the UDP socket to the origin. To close the read loop we close the socket.
// Additionally, we can't close the writeChan to indicate that writes are complete because the producer (edge)
// side may still be trying to write to this session.
closeWrite chan struct{}
contextChan chan context.Context contextChan chan context.Context
metrics Metrics metrics Metrics
log *zerolog.Logger log *zerolog.Logger
@ -89,10 +101,12 @@ func NewSession(
log *zerolog.Logger, log *zerolog.Logger,
) Session { ) Session {
logger := log.With().Str(logFlowID, id.String()).Logger() logger := log.With().Str(logFlowID, id.String()).Logger()
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to writeChan := make(chan []byte, writeChanCapacity)
// write to the channel without blocking since there is only ever one value read from the closeChan by the // errChan has three slots to allow for all writers (the closeFn, the read loop and the write loop) to
// write to the channel without blocking since there is only ever one value read from the errChan by the
// waitForCloseCondition. // waitForCloseCondition.
closeChan := make(chan error, 2) errChan := make(chan error, 3)
closeWrite := make(chan struct{})
session := &session{ session := &session{
id: id, id: id,
closeAfterIdle: closeAfterIdle, closeAfterIdle: closeAfterIdle,
@ -100,10 +114,12 @@ func NewSession(
originAddr: originAddr, originAddr: originAddr,
localAddr: localAddr, localAddr: localAddr,
eyeball: atomic.Pointer[DatagramConn]{}, eyeball: atomic.Pointer[DatagramConn]{},
writeChan: writeChan,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation // drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1), activeAtChan: make(chan time.Time, 1),
closeChan: closeChan, errChan: errChan,
closeWrite: closeWrite,
// contextChan is an unbounded channel to help enforce one active migration of a session at a time. // contextChan is an unbounded channel to help enforce one active migration of a session at a time.
contextChan: make(chan context.Context), contextChan: make(chan context.Context),
metrics: metrics, metrics: metrics,
@ -111,9 +127,12 @@ func NewSession(
closeFn: sync.OnceValue(func() error { closeFn: sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full // We don't want to block on sending to the close channel if it is already full
select { select {
case closeChan <- SessionCloseErr: case errChan <- SessionCloseErr:
default: default:
} }
// Indicate to the write loop that the session is now closed
close(closeWrite)
// Close the socket directly to unblock the read loop and cause it to also end
return origin.Close() return origin.Close()
}), }),
} }
@ -154,7 +173,13 @@ func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zer
} }
func (s *session) Serve(ctx context.Context) error { func (s *session) Serve(ctx context.Context) error {
go func() { go s.writeLoop()
go s.readLoop()
return s.waitForCloseCondition(ctx, s.closeAfterIdle)
}
// Read datagrams from the origin and write them to the connection.
func (s *session) readLoop() {
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations // This makes it safe to share readBuffer between iterations
readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{} readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
@ -166,11 +191,10 @@ func (s *session) Serve(ctx context.Context) error {
// Read from the origin UDP socket // Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:]) n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
if err != nil { if err != nil {
if errors.Is(err, io.EOF) || if isConnectionClosed(err) {
errors.Is(err, io.ErrUnexpectedEOF) { s.log.Debug().Msgf("flow (read) connection closed: %v", err)
s.log.Debug().Msgf("flow (origin) connection closed: %v", err)
} }
s.closeChan <- err s.closeSession(err)
return return
} }
if n < 0 { if n < 0 {
@ -190,30 +214,66 @@ func (s *session) Serve(ctx context.Context) error {
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge. // will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n]) err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
if err != nil { if err != nil {
s.closeChan <- err s.closeSession(err)
return return
} }
// Mark the session as active since we proxied a valid packet from the origin. // Mark the session as active since we proxied a valid packet from the origin.
s.markActive() s.markActive()
} }
}()
return s.waitForCloseCondition(ctx, s.closeAfterIdle)
} }
func (s *session) Write(payload []byte) (n int, err error) { func (s *session) Write(payload []byte) {
n, err = s.origin.Write(payload) select {
case s.writeChan <- payload:
default:
s.log.Error().Msg("failed to write flow payload to origin: dropped")
}
}
// Read datagrams from the write channel to the origin.
func (s *session) writeLoop() {
for {
select {
case <-s.closeWrite:
// When the closeWrite channel is closed, we will no longer write to the origin and end this
// goroutine since the session is now closed.
return
case payload := <-s.writeChan:
n, err := s.origin.Write(payload)
if err != nil { if err != nil {
s.log.Err(err).Msg("failed to write payload to flow (remote)") if isConnectionClosed(err) {
return n, err s.log.Debug().Msgf("flow (write) connection closed: %v", err)
}
s.log.Err(err).Msg("failed to write flow payload to origin")
s.closeSession(err)
// If we fail to write to the origin socket, we need to end the writer and close the session
return
} }
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer // Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
if n < len(payload) { if n < len(payload) {
s.log.Err(io.ErrShortWrite).Msg("failed to write the full payload to flow (remote)") s.log.Err(io.ErrShortWrite).Msg("failed to write the full flow payload to origin")
return n, io.ErrShortWrite continue
} }
// Mark the session as active since we proxied a packet to the origin. // Mark the session as active since we successfully proxied a packet to the origin.
s.markActive() s.markActive()
return n, err }
}
}
func isConnectionClosed(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
// Send an error to the error channel to report that an error has either happened on the tunnel or origin side of the
// proxied connection.
func (s *session) closeSession(err error) {
select {
case s.errChan <- err:
default:
// In the case that the errChan is already full, we will skip over it and return as to not block
// the caller because we should start cleaning up the session.
s.log.Warn().Msg("error channel was full")
}
} }
// ResetIdleTimer will restart the current idle timer. // ResetIdleTimer will restart the current idle timer.
@ -240,7 +300,8 @@ func (s *session) Close() error {
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
connCtx := ctx connCtx := ctx
// Closing the session at the end cancels read so Serve() can return // Closing the session at the end cancels read so Serve() can return, additionally, it closes the
// closeWrite channel which indicates to the write loop to return.
defer s.Close() defer s.Close()
if closeAfterIdle == 0 { if closeAfterIdle == 0 {
// Provided that the default caller doesn't specify one // Provided that the default caller doesn't specify one
@ -260,7 +321,9 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
// still be active on the existing connection. // still be active on the existing connection.
connCtx = newContext connCtx = newContext
continue continue
case reason := <-s.closeChan: case reason := <-s.errChan:
// Any error returned here is from the read or write loops indicating that it can no longer process datagrams
// and as such the session needs to close.
return reason return reason
case <-checkIdleTimer.C: case <-checkIdleTimer.C:
// The check idle timer will only return after an idle period since the last active // The check idle timer will only return after an idle period since the last active

View File

@ -4,20 +4,24 @@ import (
"testing" "testing"
) )
// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin. // FuzzSessionWrite verifies that we don't run into any panics when writing a single variable sized payload to the origin.
func FuzzSessionWrite(f *testing.F) { func FuzzSessionWrite(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
testSessionWrite(t, b)
})
}
// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin.
func FuzzSessionServe(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) { f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes // The origin transport read is bound to 1280 bytes
if len(b) > 1280 { if len(b) > 1280 {
b = b[:1280] b = b[:1280]
} }
testSessionServe_Origin(t, b) testSessionWrite(t, [][]byte{b})
})
}
// FuzzSessionRead verifies that we don't run into any panics when reading a single variable sized payload from the origin.
func FuzzSessionRead(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes
if len(b) > 1280 {
b = b[:1280]
}
testSessionRead(t, [][]byte{b})
}) })
} }

View File

@ -31,60 +31,61 @@ func TestSessionNew(t *testing.T) {
} }
} }
func testSessionWrite(t *testing.T, payload []byte) { func testSessionWrite(t *testing.T, payloads [][]byte) {
log := zerolog.Nop() log := zerolog.Nop()
origin, server := net.Pipe() origin, server := net.Pipe()
defer origin.Close() defer origin.Close()
defer server.Close() defer server.Close()
// Start origin server read // Start origin server reads
serverRead := make(chan []byte, 1) serverRead := make(chan []byte, len(payloads))
go func() { go func() {
read := make([]byte, 1500) for range len(payloads) {
_, _ = server.Read(read[:]) buf := make([]byte, 1500)
serverRead <- read _, _ = server.Read(buf[:])
}() serverRead <- buf
// Create session and write to origin
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
n, err := session.Write(payload)
defer session.Close()
if err != nil {
t.Fatal(err)
} }
if n != len(payload) { close(serverRead)
t.Fatal("unable to write the whole payload") }()
// Create a session
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
defer session.Close()
// Start the Serve to begin the writeLoop
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(context.Canceled)
done := make(chan error)
go func() {
done <- session.Serve(ctx)
}()
// Write the payloads to the session
for _, payload := range payloads {
session.Write(payload)
} }
// Read from the origin to ensure the payloads were received (in-order)
for i, payload := range payloads {
read := <-serverRead read := <-serverRead
if !slices.Equal(payload, read[:len(payload)]) { if !slices.Equal(payload, read[:len(payload)]) {
t.Fatal("payload provided from origin and read value are not the same") t.Fatalf("payload[%d] provided from origin and read value are not the same (%x) and (%x)", i, payload[:16], read[:16])
} }
} }
_, more := <-serverRead
if more {
t.Fatalf("expected the session to have all of the origin payloads received: %d", len(serverRead))
}
func TestSessionWrite_Max(t *testing.T) { assertContextClosed(t, ctx, done, cancel)
}
func TestSessionWrite(t *testing.T) {
defer leaktest.Check(t)() defer leaktest.Check(t)()
payload := makePayload(1280) for i := range 1280 {
testSessionWrite(t, payload) payloads := makePayloads(i, 16)
testSessionWrite(t, payloads)
}
} }
func TestSessionWrite_Min(t *testing.T) { func testSessionRead(t *testing.T, payloads [][]byte) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionWrite(t, payload)
}
func TestSessionServe_OriginMax(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionServe_Origin(t, payload)
}
func TestSessionServe_OriginMin(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionServe_Origin(t, payload)
}
func testSessionServe_Origin(t *testing.T, payload []byte) {
log := zerolog.Nop() log := zerolog.Nop()
origin, server := net.Pipe() origin, server := net.Pipe()
defer origin.Close() defer origin.Close()
@ -100,12 +101,15 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
done <- session.Serve(ctx) done <- session.Serve(ctx)
}() }()
// Write from the origin server // Write from the origin server to the eyeball
_, err := server.Write(payload) go func() {
if err != nil { for _, payload := range payloads {
t.Fatal(err) _, _ = server.Write(payload)
} }
}()
// Read from the eyeball to ensure the payloads were received (in-order)
for i, payload := range payloads {
select { select {
case data := <-eyeball.recvData: case data := <-eyeball.recvData:
// check received data matches provided from origin // check received data matches provided from origin
@ -113,24 +117,26 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
_ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) _ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload) copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatal("expected datagram did not equal expected") t.Fatalf("expected datagram[%d] did not equal expected", i)
} }
cancel(errExpectedContextCanceled)
case err := <-ctx.Done(): case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session // we expect the payload to return before the context to cancel on the session
t.Fatal(err) t.Fatal(err)
} }
err = <-done
if !errors.Is(err, context.Canceled) {
t.Fatal(err)
} }
if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) {
t.Fatal(err) assertContextClosed(t, ctx, done, cancel)
}
func TestSessionRead(t *testing.T) {
defer leaktest.Check(t)()
for i := range 1280 {
payloads := makePayloads(i, 16)
testSessionRead(t, payloads)
} }
} }
func TestSessionServe_OriginTooLarge(t *testing.T) { func TestSessionRead_OriginTooLarge(t *testing.T) {
defer leaktest.Check(t)() defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
eyeball := newMockEyeball() eyeball := newMockEyeball()
@ -317,6 +323,8 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
closeAfterIdle := 2 * time.Second closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Serve(t.Context()) err := session.Serve(t.Context())
// Session should idle timeout if no reads or writes occur
if !errors.Is(err, v3.SessionIdleErr{}) { if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err) t.Fatal(err)
} }