Merge branch 'cloudflare:master' into master

This commit is contained in:
Areg Vrtanesyan 2025-10-08 16:59:01 +01:00 committed by GitHub
commit b8511df478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 471 additions and 272 deletions

View File

@ -36,7 +36,7 @@ ifdef PACKAGE_MANAGER
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)" VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
endif endif
ifdef CONTAINER_BUILD ifdef CONTAINER_BUILD
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual" VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
endif endif
@ -122,7 +122,7 @@ ifneq ($(TARGET_ARM), )
ARM_COMMAND := GOARM=$(TARGET_ARM) ARM_COMMAND := GOARM=$(TARGET_ARM)
endif endif
ifeq ($(TARGET_ARM), 7) ifeq ($(TARGET_ARM), 7)
PACKAGE_ARCH := armhf PACKAGE_ARCH := armhf
else else
PACKAGE_ARCH := $(TARGET_ARCH) PACKAGE_ARCH := $(TARGET_ARCH)
@ -185,7 +185,7 @@ fuzz:
@go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet @go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet
@go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet @go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet
@go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzSessionServe -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzSessionRead -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3

View File

@ -14,4 +14,7 @@ spec:
lifecycle: "Active" lifecycle: "Active"
owner: "teams/tunnel-teams-routing" owner: "teams/tunnel-teams-routing"
cf: cf:
compliance:
fedramp-high: "pending"
fedramp-moderate: "yes"
FIPS: "required" FIPS: "required"

View File

@ -11,6 +11,8 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
const writeDeadlineUDP = 200 * time.Millisecond
// OriginTCPDialer provides a TCP dial operation to a requested address. // OriginTCPDialer provides a TCP dial operation to a requested address.
type OriginTCPDialer interface { type OriginTCPDialer interface {
DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
@ -141,6 +143,21 @@ func (d *Dialer) DialUDP(dest netip.AddrPort) (net.Conn, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err) return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err)
} }
return &writeDeadlineConn{
return conn, nil Conn: conn,
}, nil
}
// writeDeadlineConn is a wrapper around a net.Conn that sets a write deadline of 200ms.
// This is to prevent the socket from blocking on the write operation if it were to occur. However,
// we typically never expect this to occur except under high load or kernel issues.
type writeDeadlineConn struct {
net.Conn
}
func (w *writeDeadlineConn) Write(b []byte) (int, error) {
if err := w.SetWriteDeadline(time.Now().Add(writeDeadlineUDP)); err != nil {
return 0, err
}
return w.Conn.Write(b)
} }

View File

@ -1,6 +1,7 @@
package v3_test package v3_test
import ( import (
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net/netip" "net/netip"
@ -14,12 +15,18 @@ import (
func makePayload(size int) []byte { func makePayload(size int) []byte {
payload := make([]byte, size) payload := make([]byte, size)
for i := range len(payload) { _, _ = rand.Read(payload)
payload[i] = 0xfc
}
return payload return payload
} }
func makePayloads(size int, count int) [][]byte {
payloads := make([][]byte, count)
for i := range payloads {
payloads[i] = makePayload(size)
}
return payloads
}
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
payload := makePayload(1280) payload := makePayload(1280)
tests := []*v3.UDPSessionRegistrationDatagram{ tests := []*v3.UDPSessionRegistrationDatagram{

View File

@ -17,6 +17,9 @@ const (
// Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader. // Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader.
// This has worked previously well for datagramv2, so we will start with this as well // This has worked previously well for datagramv2, so we will start with this as well
demuxChanCapacity = 16 demuxChanCapacity = 16
// This provides a small buffer for the PacketRouter to poll ICMP packets from the QUIC connection
// before writing them to the origin.
icmpDatagramChanCapacity = 128
logSrcKey = "src" logSrcKey = "src"
logDstKey = "dst" logDstKey = "dst"
@ -59,14 +62,15 @@ type QuicConnection interface {
} }
type datagramConn struct { type datagramConn struct {
conn QuicConnection conn QuicConnection
index uint8 index uint8
sessionManager SessionManager sessionManager SessionManager
icmpRouter ingress.ICMPRouter icmpRouter ingress.ICMPRouter
metrics Metrics metrics Metrics
logger *zerolog.Logger logger *zerolog.Logger
datagrams chan []byte datagrams chan []byte
readErrors chan error icmpDatagramChan chan *ICMPDatagram
readErrors chan error
icmpEncoderPool sync.Pool // a pool of *packet.Encoder icmpEncoderPool sync.Pool // a pool of *packet.Encoder
icmpDecoderPool sync.Pool icmpDecoderPool sync.Pool
@ -75,14 +79,15 @@ type datagramConn struct {
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn { func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
log := logger.With().Uint8("datagramVersion", 3).Logger() log := logger.With().Uint8("datagramVersion", 3).Logger()
return &datagramConn{ return &datagramConn{
conn: conn, conn: conn,
index: index, index: index,
sessionManager: sessionManager, sessionManager: sessionManager,
icmpRouter: icmpRouter, icmpRouter: icmpRouter,
metrics: metrics, metrics: metrics,
logger: &log, logger: &log,
datagrams: make(chan []byte, demuxChanCapacity), datagrams: make(chan []byte, demuxChanCapacity),
readErrors: make(chan error, 2), icmpDatagramChan: make(chan *ICMPDatagram, icmpDatagramChanCapacity),
readErrors: make(chan error, 2),
icmpEncoderPool: sync.Pool{ icmpEncoderPool: sync.Pool{
New: func() any { New: func() any {
return packet.NewEncoder() return packet.NewEncoder()
@ -168,6 +173,9 @@ func (c *datagramConn) Serve(ctx context.Context) error {
readCtx, cancel := context.WithCancel(connCtx) readCtx, cancel := context.WithCancel(connCtx)
defer cancel() defer cancel()
go c.pollDatagrams(readCtx) go c.pollDatagrams(readCtx)
// Processing ICMP datagrams also monitors the reader context since the ICMP datagrams from the reader are the input
// for the routine.
go c.processICMPDatagrams(readCtx)
for { for {
// We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur. // We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur.
var datagram []byte var datagram []byte
@ -181,58 +189,59 @@ func (c *datagramConn) Serve(ctx context.Context) error {
// Monitor for any hard errors from reading the connection // Monitor for any hard errors from reading the connection
case err := <-c.readErrors: case err := <-c.readErrors:
return err return err
// Otherwise, wait and dequeue datagrams as they come in // Wait and dequeue datagrams as they come in
case d := <-c.datagrams: case d := <-c.datagrams:
datagram = d datagram = d
} }
// Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated. // Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated.
go func() { typ, err := ParseDatagramType(datagram)
typ, err := ParseDatagramType(datagram) if err != nil {
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ)
continue
}
switch typ {
case UDPSessionRegistrationType:
reg := &UDPSessionRegistrationDatagram{}
err := reg.UnmarshalBinary(datagram)
if err != nil { if err != nil {
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ) c.logger.Err(err).Msgf("unable to unmarshal session registration datagram")
return continue
} }
switch typ { logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger()
case UDPSessionRegistrationType: // We bind the new session to the quic connection context instead of cloudflared context to allow for the
reg := &UDPSessionRegistrationDatagram{} // quic connection to close and close only the sessions bound to it. Closing of cloudflared will also
err := reg.UnmarshalBinary(datagram) // initiate the close of the quic connection, so we don't have to worry about the application context
if err != nil { // in the scope of a session.
c.logger.Err(err).Msgf("unable to unmarshal session registration datagram") //
return // Additionally, we spin out the registration into a separate go routine to handle the Serve'ing of the
} // session in a separate routine from the demuxer.
logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger() go c.handleSessionRegistrationDatagram(connCtx, reg, &logger)
// We bind the new session to the quic connection context instead of cloudflared context to allow for the case UDPSessionPayloadType:
// quic connection to close and close only the sessions bound to it. Closing of cloudflared will also payload := &UDPSessionPayloadDatagram{}
// initiate the close of the quic connection, so we don't have to worry about the application context err := payload.UnmarshalBinary(datagram)
// in the scope of a session. if err != nil {
c.handleSessionRegistrationDatagram(connCtx, reg, &logger) c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
case UDPSessionPayloadType: continue
payload := &UDPSessionPayloadDatagram{}
err := payload.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
return
}
logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
c.handleSessionPayloadDatagram(payload, &logger)
case ICMPType:
packet := &ICMPDatagram{}
err := packet.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
return
}
c.handleICMPPacket(packet)
case UDPSessionRegistrationResponseType:
// cloudflared should never expect to receive UDP session responses as it will not initiate new
// sessions towards the edge.
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
return
default:
c.logger.Error().Msgf("unknown datagram type received: %d", typ)
} }
}() logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
c.handleSessionPayloadDatagram(payload, &logger)
case ICMPType:
packet := &ICMPDatagram{}
err := packet.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
continue
}
c.handleICMPPacket(packet)
case UDPSessionRegistrationResponseType:
// cloudflared should never expect to receive UDP session responses as it will not initiate new
// sessions towards the edge.
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
continue
default:
c.logger.Error().Msgf("unknown datagram type received: %d", typ)
}
} }
} }
@ -243,24 +252,21 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
Str(logDstKey, datagram.Dest.String()). Str(logDstKey, datagram.Dest.String()).
Logger() Logger()
session, err := c.sessionManager.RegisterSession(datagram, c) session, err := c.sessionManager.RegisterSession(datagram, c)
switch err { if err != nil {
case nil: switch err {
// Continue as normal case ErrSessionAlreadyRegistered:
case ErrSessionAlreadyRegistered: // Session is already registered and likely the response got lost
// Session is already registered and likely the response got lost c.handleSessionAlreadyRegistered(datagram.RequestID, &log)
c.handleSessionAlreadyRegistered(datagram.RequestID, &log) case ErrSessionBoundToOtherConn:
return // Session is already registered but to a different connection
case ErrSessionBoundToOtherConn: c.handleSessionMigration(datagram.RequestID, &log)
// Session is already registered but to a different connection case ErrSessionRegistrationRateLimited:
c.handleSessionMigration(datagram.RequestID, &log) // There are too many concurrent sessions so we return an error to force a retry later
return c.handleSessionRegistrationRateLimited(datagram, &log)
case ErrSessionRegistrationRateLimited: default:
// There are too many concurrent sessions so we return an error to force a retry later log.Err(err).Msg("flow registration failure")
c.handleSessionRegistrationRateLimited(datagram, &log) c.handleSessionRegistrationFailure(datagram.RequestID, &log)
return }
default:
log.Err(err).Msg("flow registration failure")
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
return return
} }
log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger() log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger()
@ -365,21 +371,42 @@ func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadD
logger.Err(err).Msgf("unable to find flow") logger.Err(err).Msgf("unable to find flow")
return return
} }
// We ignore the bytes written to the socket because any partial write must return an error. s.Write(datagram.Payload)
_, err = s.Write(datagram.Payload)
if err != nil {
logger.Err(err).Msgf("unable to write payload for the flow")
return
}
} }
// Handles incoming ICMP datagrams. // Handles incoming ICMP datagrams into a serialized channel to be handled by a single consumer.
func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) { func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
if c.icmpRouter == nil { if c.icmpRouter == nil {
// ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets // ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets
return return
} }
select {
case c.icmpDatagramChan <- datagram:
default:
// If the ICMP datagram channel is full, drop any additional incoming.
c.logger.Warn().Msg("failed to write icmp packet to origin: dropped")
}
}
// Consumes from the ICMP datagram channel to write out the ICMP requests to an origin.
func (c *datagramConn) processICMPDatagrams(ctx context.Context) {
if c.icmpRouter == nil {
// ICMPRouter is disabled so we ignore all incoming ICMP packets
return
}
for {
select {
// If the provided context is closed we want to exit the write loop
case <-ctx.Done():
return
case datagram := <-c.icmpDatagramChan:
c.writeICMPPacket(datagram)
}
}
}
func (c *datagramConn) writeICMPPacket(datagram *ICMPDatagram) {
// Decode the provided ICMPDatagram as an ICMP packet // Decode the provided ICMPDatagram as an ICMP packet
rawPacket := packet.RawPacket{Data: datagram.Payload} rawPacket := packet.RawPacket{Data: datagram.Payload}
cachedDecoder := c.icmpDecoderPool.Get() cachedDecoder := c.icmpDecoderPool.Get()

View File

@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -92,7 +93,7 @@ func TestDatagramConn_New(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(newMockQuicConn(t.Context()), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
if conn == nil { if conn == nil {
t.Fatal("expected valid connection") t.Fatal("expected valid connection")
} }
@ -104,7 +105,9 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
payload := []byte{0xef, 0xef} payload := []byte{0xef, 0xef}
@ -123,7 +126,9 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
@ -149,7 +154,9 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
@ -166,7 +173,9 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
DefaultDialer: testDefaultDialer, DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0, TCPWriteTimeout: 0,
}, &log) }, &log)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
defer cancel() defer cancel()
quic.ctx = ctx quic.ctx = ctx
@ -195,15 +204,17 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) { func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
sessionManager := &mockSessionManager{ sessionManager := &mockSessionManager{
expectedRegErr: v3.ErrSessionRegistrationRateLimited, expectedRegErr: v3.ErrSessionRegistrationRateLimited,
} }
conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
// Setup the muxer // Setup the muxer
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancelCause(t.Context())
defer cancel() defer cancel(context.Canceled)
done := make(chan error, 1) done := make(chan error, 1)
go func() { go func() {
done <- conn.Serve(ctx) done <- conn.Serve(ctx)
@ -223,9 +234,12 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
require.EqualValues(t, testRequestID, resp.RequestID) require.EqualValues(t, testRequestID, resp.RequestID)
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType) require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
assertContextClosed(t, ctx, done, cancel)
} }
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
defer leaktest.Check(t)()
for _, test := range []struct { for _, test := range []struct {
name string name string
input []byte input []byte
@ -250,7 +264,9 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
logOutput := new(LockedBuffer) logOutput := new(LockedBuffer)
log := zerolog.New(logOutput) log := zerolog.New(logOutput)
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
quic.send <- test.input quic.send <- test.input
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -289,8 +305,11 @@ func (b *LockedBuffer) String() string {
} }
func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
expectedErr := errors.New("unable to register session") expectedErr := errors.New("unable to register session")
sessionManager := mockSessionManager{expectedRegErr: expectedErr} sessionManager := mockSessionManager{expectedRegErr: expectedErr}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -324,8 +343,11 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
} }
func TestDatagramConnServe(t *testing.T) { func TestDatagramConnServe(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -372,8 +394,11 @@ func TestDatagramConnServe(t *testing.T) {
// instances causes inteference resulting in multiple different raw packets being decoded // instances causes inteference resulting in multiple different raw packets being decoded
// as the same decoded packet. // as the same decoded packet.
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) { func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
router := newMockICMPRouter() router := newMockICMPRouter()
@ -413,10 +438,14 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
var receivedPackets []*packet.ICMP var receivedPackets []*packet.ICMP
go func() { go func() {
for ctx.Err() == nil { for {
icmpPacket := <-router.recv select {
receivedPackets = append(receivedPackets, icmpPacket) case <-ctx.Done():
wg.Done() return
case icmpPacket := <-router.recv:
receivedPackets = append(receivedPackets, icmpPacket)
wg.Done()
}
} }
}() }()
@ -452,8 +481,11 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
} }
func TestDatagramConnServe_RegisterTwice(t *testing.T) { func TestDatagramConnServe_RegisterTwice(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -514,12 +546,17 @@ func TestDatagramConnServe_RegisterTwice(t *testing.T) {
} }
func TestDatagramConnServe_MigrateConnection(t *testing.T) { func TestDatagramConnServe_MigrateConnection(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
quic2 := newMockQuicConn() conn2Ctx, conn2Cancel := context.WithCancelCause(t.Context())
defer conn2Cancel(context.Canceled)
quic2 := newMockQuicConn(conn2Ctx)
conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log) conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log)
// Setup the muxer // Setup the muxer
@ -597,8 +634,11 @@ func TestDatagramConnServe_MigrateConnection(t *testing.T) {
} }
func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
// mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer // mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer
sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound} sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -624,9 +664,12 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
assertContextClosed(t, ctx, done, cancel) assertContextClosed(t, ctx, done, cancel)
} }
func TestDatagramConnServe_Payload(t *testing.T) { func TestDatagramConnServe_Payloads(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession() session := newMockSession()
sessionManager := mockSessionManager{session: &session} sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -639,15 +682,26 @@ func TestDatagramConnServe_Payload(t *testing.T) {
done <- conn.Serve(ctx) done <- conn.Serve(ctx)
}() }()
// Send new session registration // Send session payloads
expectedPayload := []byte{0xef, 0xef} expectedPayloads := makePayloads(256, 16)
datagram := newSessionPayloadDatagram(testRequestID, expectedPayload) go func() {
quic.send <- datagram for _, payload := range expectedPayloads {
datagram := newSessionPayloadDatagram(testRequestID, payload)
quic.send <- datagram
}
}()
// Session should receive the payload // Session should receive the payloads (in-order)
payload := <-session.recv for i, payload := range expectedPayloads {
if !slices.Equal(expectedPayload, payload) { select {
t.Fatalf("expected session receieve the payload sent via the muxer") case recv := <-session.recv:
if !slices.Equal(recv, payload) {
t.Fatalf("expected session receieve the payload[%d] sent via the muxer: (%x) (%x)", i, recv[:16], payload[:16])
}
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
} }
// Cancel the muxer Serve context and make sure it closes with the expected error // Cancel the muxer Serve context and make sure it closes with the expected error
@ -655,8 +709,11 @@ func TestDatagramConnServe_Payload(t *testing.T) {
} }
func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) { func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
router := newMockICMPRouter() router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
@ -701,8 +758,11 @@ func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
} }
func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) { func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
router := newMockICMPRouter() router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
@ -821,9 +881,9 @@ type mockQuicConn struct {
recv chan []byte recv chan []byte
} }
func newMockQuicConn() *mockQuicConn { func newMockQuicConn(ctx context.Context) *mockQuicConn {
return &mockQuicConn{ return &mockQuicConn{
ctx: context.Background(), ctx: ctx,
send: make(chan []byte, 1), send: make(chan []byte, 1),
recv: make(chan []byte, 1), recv: make(chan []byte, 1),
} }
@ -841,7 +901,12 @@ func (m *mockQuicConn) SendDatagram(payload []byte) error {
} }
func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) { func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) {
return <-m.send, nil select {
case <-m.ctx.Done():
return nil, m.ctx.Err()
case b := <-m.send:
return b, nil
}
} }
type mockQuicConnReadError struct { type mockQuicConnReadError struct {
@ -905,11 +970,10 @@ func (m *mockSession) Serve(ctx context.Context) error {
return v3.SessionCloseErr return v3.SessionCloseErr
} }
func (m *mockSession) Write(payload []byte) (n int, err error) { func (m *mockSession) Write(payload []byte) {
b := make([]byte, len(payload)) b := make([]byte, len(payload))
copy(b, payload) copy(b, payload)
m.recv <- b m.recv <- b
return len(b), nil
} }
func (m *mockSession) Close() error { func (m *mockSession) Close() error {

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -22,6 +23,11 @@ const (
// this value (maxDatagramPayloadLen). // this value (maxDatagramPayloadLen).
maxOriginUDPPacketSize = 1500 maxOriginUDPPacketSize = 1500
// The maximum amount of datagrams a session will queue up before it begins dropping datagrams.
// This channel buffer is small because we assume that the dedicated writer to the origin is typically
// fast enought to keep the channel empty.
writeChanCapacity = 16
logFlowID = "flowID" logFlowID = "flowID"
logPacketSizeKey = "packetSize" logPacketSizeKey = "packetSize"
) )
@ -49,7 +55,7 @@ func newSessionIdleErr(timeout time.Duration) error {
} }
type Session interface { type Session interface {
io.WriteCloser io.Closer
ID() RequestID ID() RequestID
ConnectionID() uint8 ConnectionID() uint8
RemoteAddr() net.Addr RemoteAddr() net.Addr
@ -58,6 +64,7 @@ type Session interface {
Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger) Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger)
// Serve starts the event loop for processing UDP packets // Serve starts the event loop for processing UDP packets
Serve(ctx context.Context) error Serve(ctx context.Context) error
Write(payload []byte)
} }
type session struct { type session struct {
@ -67,12 +74,18 @@ type session struct {
originAddr net.Addr originAddr net.Addr
localAddr net.Addr localAddr net.Addr
eyeball atomic.Pointer[DatagramConn] eyeball atomic.Pointer[DatagramConn]
writeChan chan []byte
// activeAtChan is used to communicate the last read/write time // activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time activeAtChan chan time.Time
closeChan chan error errChan chan error
contextChan chan context.Context // The close channel signal only exists for the write loop because the read loop is always waiting on a read
metrics Metrics // from the UDP socket to the origin. To close the read loop we close the socket.
log *zerolog.Logger // Additionally, we can't close the writeChan to indicate that writes are complete because the producer (edge)
// side may still be trying to write to this session.
closeWrite chan struct{}
contextChan chan context.Context
metrics Metrics
log *zerolog.Logger
// A special close function that we wrap with sync.Once to make sure it is only called once // A special close function that we wrap with sync.Once to make sure it is only called once
closeFn func() error closeFn func() error
@ -89,10 +102,12 @@ func NewSession(
log *zerolog.Logger, log *zerolog.Logger,
) Session { ) Session {
logger := log.With().Str(logFlowID, id.String()).Logger() logger := log.With().Str(logFlowID, id.String()).Logger()
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to writeChan := make(chan []byte, writeChanCapacity)
// write to the channel without blocking since there is only ever one value read from the closeChan by the // errChan has three slots to allow for all writers (the closeFn, the read loop and the write loop) to
// write to the channel without blocking since there is only ever one value read from the errChan by the
// waitForCloseCondition. // waitForCloseCondition.
closeChan := make(chan error, 2) errChan := make(chan error, 3)
closeWrite := make(chan struct{})
session := &session{ session := &session{
id: id, id: id,
closeAfterIdle: closeAfterIdle, closeAfterIdle: closeAfterIdle,
@ -100,10 +115,12 @@ func NewSession(
originAddr: originAddr, originAddr: originAddr,
localAddr: localAddr, localAddr: localAddr,
eyeball: atomic.Pointer[DatagramConn]{}, eyeball: atomic.Pointer[DatagramConn]{},
writeChan: writeChan,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation // drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1), activeAtChan: make(chan time.Time, 1),
closeChan: closeChan, errChan: errChan,
closeWrite: closeWrite,
// contextChan is an unbounded channel to help enforce one active migration of a session at a time. // contextChan is an unbounded channel to help enforce one active migration of a session at a time.
contextChan: make(chan context.Context), contextChan: make(chan context.Context),
metrics: metrics, metrics: metrics,
@ -111,9 +128,12 @@ func NewSession(
closeFn: sync.OnceValue(func() error { closeFn: sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full // We don't want to block on sending to the close channel if it is already full
select { select {
case closeChan <- SessionCloseErr: case errChan <- SessionCloseErr:
default: default:
} }
// Indicate to the write loop that the session is now closed
close(closeWrite)
// Close the socket directly to unblock the read loop and cause it to also end
return origin.Close() return origin.Close()
}), }),
} }
@ -154,66 +174,112 @@ func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zer
} }
func (s *session) Serve(ctx context.Context) error { func (s *session) Serve(ctx context.Context) error {
go func() { go s.writeLoop()
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 go s.readLoop()
// This makes it safe to share readBuffer between iterations
readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
// the required datagram header information. We can reuse this buffer for this session since the header is the
// same for the each read.
_ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
for {
// Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
if err != nil {
if errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) {
s.log.Debug().Msgf("flow (origin) connection closed: %v", err)
}
s.closeChan <- err
return
}
if n < 0 {
s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped")
continue
}
if n > maxDatagramPayloadLen {
connectionIndex := s.ConnectionID()
s.metrics.PayloadTooLarge(connectionIndex)
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
continue
}
// We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point
// of lock contention, as a migration can only happen during startup of a session before traffic flow.
eyeball := *(s.eyeball.Load())
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
if err != nil {
s.closeChan <- err
return
}
// Mark the session as active since we proxied a valid packet from the origin.
s.markActive()
}
}()
return s.waitForCloseCondition(ctx, s.closeAfterIdle) return s.waitForCloseCondition(ctx, s.closeAfterIdle)
} }
func (s *session) Write(payload []byte) (n int, err error) { // Read datagrams from the origin and write them to the connection.
n, err = s.origin.Write(payload) func (s *session) readLoop() {
if err != nil { // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
s.log.Err(err).Msg("failed to write payload to flow (remote)") // This makes it safe to share readBuffer between iterations
return n, err readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
// the required datagram header information. We can reuse this buffer for this session since the header is the
// same for the each read.
_ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
for {
// Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
if err != nil {
if isConnectionClosed(err) {
s.log.Debug().Msgf("flow (read) connection closed: %v", err)
}
s.closeSession(err)
return
}
if n < 0 {
s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped")
continue
}
if n > maxDatagramPayloadLen {
connectionIndex := s.ConnectionID()
s.metrics.PayloadTooLarge(connectionIndex)
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
continue
}
// We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point
// of lock contention, as a migration can only happen during startup of a session before traffic flow.
eyeball := *(s.eyeball.Load())
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
if err != nil {
s.closeSession(err)
return
}
// Mark the session as active since we proxied a valid packet from the origin.
s.markActive()
} }
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer }
if n < len(payload) {
s.log.Err(io.ErrShortWrite).Msg("failed to write the full payload to flow (remote)") func (s *session) Write(payload []byte) {
return n, io.ErrShortWrite select {
case s.writeChan <- payload:
default:
s.log.Error().Msg("failed to write flow payload to origin: dropped")
}
}
// Read datagrams from the write channel to the origin.
func (s *session) writeLoop() {
for {
select {
case <-s.closeWrite:
// When the closeWrite channel is closed, we will no longer write to the origin and end this
// goroutine since the session is now closed.
return
case payload := <-s.writeChan:
n, err := s.origin.Write(payload)
if err != nil {
// Check if this is a write deadline exceeded to the connection
if errors.Is(err, os.ErrDeadlineExceeded) {
s.log.Warn().Err(err).Msg("flow (write) deadline exceeded: dropping packet")
continue
}
if isConnectionClosed(err) {
s.log.Debug().Msgf("flow (write) connection closed: %v", err)
}
s.log.Err(err).Msg("failed to write flow payload to origin")
s.closeSession(err)
// If we fail to write to the origin socket, we need to end the writer and close the session
return
}
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
if n < len(payload) {
s.log.Err(io.ErrShortWrite).Msg("failed to write the full flow payload to origin")
continue
}
// Mark the session as active since we successfully proxied a packet to the origin.
s.markActive()
}
}
}
func isConnectionClosed(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
// Send an error to the error channel to report that an error has either happened on the tunnel or origin side of the
// proxied connection.
func (s *session) closeSession(err error) {
select {
case s.errChan <- err:
default:
// In the case that the errChan is already full, we will skip over it and return as to not block
// the caller because we should start cleaning up the session.
s.log.Warn().Msg("error channel was full")
} }
// Mark the session as active since we proxied a packet to the origin.
s.markActive()
return n, err
} }
// ResetIdleTimer will restart the current idle timer. // ResetIdleTimer will restart the current idle timer.
@ -240,7 +306,8 @@ func (s *session) Close() error {
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
connCtx := ctx connCtx := ctx
// Closing the session at the end cancels read so Serve() can return // Closing the session at the end cancels read so Serve() can return, additionally, it closes the
// closeWrite channel which indicates to the write loop to return.
defer s.Close() defer s.Close()
if closeAfterIdle == 0 { if closeAfterIdle == 0 {
// Provided that the default caller doesn't specify one // Provided that the default caller doesn't specify one
@ -260,7 +327,9 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
// still be active on the existing connection. // still be active on the existing connection.
connCtx = newContext connCtx = newContext
continue continue
case reason := <-s.closeChan: case reason := <-s.errChan:
// Any error returned here is from the read or write loops indicating that it can no longer process datagrams
// and as such the session needs to close.
return reason return reason
case <-checkIdleTimer.C: case <-checkIdleTimer.C:
// The check idle timer will only return after an idle period since the last active // The check idle timer will only return after an idle period since the last active

View File

@ -4,20 +4,24 @@ import (
"testing" "testing"
) )
// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin. // FuzzSessionWrite verifies that we don't run into any panics when writing a single variable sized payload to the origin.
func FuzzSessionWrite(f *testing.F) { func FuzzSessionWrite(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
testSessionWrite(t, b)
})
}
// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin.
func FuzzSessionServe(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) { f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes // The origin transport read is bound to 1280 bytes
if len(b) > 1280 { if len(b) > 1280 {
b = b[:1280] b = b[:1280]
} }
testSessionServe_Origin(t, b) testSessionWrite(t, [][]byte{b})
})
}
// FuzzSessionRead verifies that we don't run into any panics when reading a single variable sized payload from the origin.
func FuzzSessionRead(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes
if len(b) > 1280 {
b = b[:1280]
}
testSessionRead(t, [][]byte{b})
}) })
} }

View File

@ -31,60 +31,61 @@ func TestSessionNew(t *testing.T) {
} }
} }
func testSessionWrite(t *testing.T, payload []byte) { func testSessionWrite(t *testing.T, payloads [][]byte) {
log := zerolog.Nop() log := zerolog.Nop()
origin, server := net.Pipe() origin, server := net.Pipe()
defer origin.Close() defer origin.Close()
defer server.Close() defer server.Close()
// Start origin server read // Start origin server reads
serverRead := make(chan []byte, 1) serverRead := make(chan []byte, len(payloads))
go func() { go func() {
read := make([]byte, 1500) for range len(payloads) {
_, _ = server.Read(read[:]) buf := make([]byte, 1500)
serverRead <- read _, _ = server.Read(buf[:])
serverRead <- buf
}
close(serverRead)
}() }()
// Create session and write to origin
// Create a session
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
n, err := session.Write(payload)
defer session.Close() defer session.Close()
if err != nil { // Start the Serve to begin the writeLoop
t.Fatal(err) ctx, cancel := context.WithCancelCause(t.Context())
} defer cancel(context.Canceled)
if n != len(payload) { done := make(chan error)
t.Fatal("unable to write the whole payload") go func() {
done <- session.Serve(ctx)
}()
// Write the payloads to the session
for _, payload := range payloads {
session.Write(payload)
} }
read := <-serverRead // Read from the origin to ensure the payloads were received (in-order)
if !slices.Equal(payload, read[:len(payload)]) { for i, payload := range payloads {
t.Fatal("payload provided from origin and read value are not the same") read := <-serverRead
if !slices.Equal(payload, read[:len(payload)]) {
t.Fatalf("payload[%d] provided from origin and read value are not the same (%x) and (%x)", i, payload[:16], read[:16])
}
}
_, more := <-serverRead
if more {
t.Fatalf("expected the session to have all of the origin payloads received: %d", len(serverRead))
}
assertContextClosed(t, ctx, done, cancel)
}
func TestSessionWrite(t *testing.T) {
defer leaktest.Check(t)()
for i := range 1280 {
payloads := makePayloads(i, 16)
testSessionWrite(t, payloads)
} }
} }
func TestSessionWrite_Max(t *testing.T) { func testSessionRead(t *testing.T, payloads [][]byte) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionWrite(t, payload)
}
func TestSessionWrite_Min(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionWrite(t, payload)
}
func TestSessionServe_OriginMax(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionServe_Origin(t, payload)
}
func TestSessionServe_OriginMin(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionServe_Origin(t, payload)
}
func testSessionServe_Origin(t *testing.T, payload []byte) {
log := zerolog.Nop() log := zerolog.Nop()
origin, server := net.Pipe() origin, server := net.Pipe()
defer origin.Close() defer origin.Close()
@ -100,37 +101,42 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
done <- session.Serve(ctx) done <- session.Serve(ctx)
}() }()
// Write from the origin server // Write from the origin server to the eyeball
_, err := server.Write(payload) go func() {
if err != nil { for _, payload := range payloads {
t.Fatal(err) _, _ = server.Write(payload)
} }
}()
select {
case data := <-eyeball.recvData: // Read from the eyeball to ensure the payloads were received (in-order)
// check received data matches provided from origin for i, payload := range payloads {
expectedData := makePayload(1500) select {
_ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) case data := <-eyeball.recvData:
copy(expectedData[17:], payload) // check received data matches provided from origin
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { expectedData := makePayload(1500)
t.Fatal("expected datagram did not equal expected") _ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatalf("expected datagram[%d] did not equal expected", i)
}
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
} }
cancel(errExpectedContextCanceled)
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
} }
err = <-done assertContextClosed(t, ctx, done, cancel)
if !errors.Is(err, context.Canceled) { }
t.Fatal(err)
} func TestSessionRead(t *testing.T) {
if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) { defer leaktest.Check(t)()
t.Fatal(err) for i := range 1280 {
payloads := makePayloads(i, 16)
testSessionRead(t, payloads)
} }
} }
func TestSessionServe_OriginTooLarge(t *testing.T) { func TestSessionRead_OriginTooLarge(t *testing.T) {
defer leaktest.Check(t)() defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
eyeball := newMockEyeball() eyeball := newMockEyeball()
@ -317,6 +323,8 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
closeAfterIdle := 2 * time.Second closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Serve(t.Context()) err := session.Serve(t.Context())
// Session should idle timeout if no reads or writes occur
if !errors.Is(err, v3.SessionIdleErr{}) { if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err) t.Fatal(err)
} }