Merge branch 'cloudflare:master' into master

This commit is contained in:
Areg Vrtanesyan 2025-10-08 16:59:01 +01:00 committed by GitHub
commit b8511df478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 471 additions and 272 deletions

View File

@ -185,7 +185,7 @@ fuzz:
@go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet
@go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet
@go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzSessionServe -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzSessionRead -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3
@go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3

View File

@ -14,4 +14,7 @@ spec:
lifecycle: "Active"
owner: "teams/tunnel-teams-routing"
cf:
compliance:
fedramp-high: "pending"
fedramp-moderate: "yes"
FIPS: "required"

View File

@ -11,6 +11,8 @@ import (
"github.com/rs/zerolog"
)
const writeDeadlineUDP = 200 * time.Millisecond
// OriginTCPDialer provides a TCP dial operation to a requested address.
type OriginTCPDialer interface {
DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
@ -141,6 +143,21 @@ func (d *Dialer) DialUDP(dest netip.AddrPort) (net.Conn, error) {
if err != nil {
return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err)
}
return conn, nil
return &writeDeadlineConn{
Conn: conn,
}, nil
}
// writeDeadlineConn is a wrapper around a net.Conn that sets a write deadline of 200ms.
// This is to prevent the socket from blocking on the write operation if it were to occur. However,
// we typically never expect this to occur except under high load or kernel issues.
type writeDeadlineConn struct {
net.Conn
}
func (w *writeDeadlineConn) Write(b []byte) (int, error) {
if err := w.SetWriteDeadline(time.Now().Add(writeDeadlineUDP)); err != nil {
return 0, err
}
return w.Conn.Write(b)
}

View File

@ -1,6 +1,7 @@
package v3_test
import (
"crypto/rand"
"encoding/binary"
"errors"
"net/netip"
@ -14,12 +15,18 @@ import (
func makePayload(size int) []byte {
payload := make([]byte, size)
for i := range len(payload) {
payload[i] = 0xfc
}
_, _ = rand.Read(payload)
return payload
}
func makePayloads(size int, count int) [][]byte {
payloads := make([][]byte, count)
for i := range payloads {
payloads[i] = makePayload(size)
}
return payloads
}
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
payload := makePayload(1280)
tests := []*v3.UDPSessionRegistrationDatagram{

View File

@ -17,6 +17,9 @@ const (
// Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader.
// This has worked previously well for datagramv2, so we will start with this as well
demuxChanCapacity = 16
// This provides a small buffer for the PacketRouter to poll ICMP packets from the QUIC connection
// before writing them to the origin.
icmpDatagramChanCapacity = 128
logSrcKey = "src"
logDstKey = "dst"
@ -59,14 +62,15 @@ type QuicConnection interface {
}
type datagramConn struct {
conn QuicConnection
index uint8
sessionManager SessionManager
icmpRouter ingress.ICMPRouter
metrics Metrics
logger *zerolog.Logger
datagrams chan []byte
readErrors chan error
conn QuicConnection
index uint8
sessionManager SessionManager
icmpRouter ingress.ICMPRouter
metrics Metrics
logger *zerolog.Logger
datagrams chan []byte
icmpDatagramChan chan *ICMPDatagram
readErrors chan error
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
icmpDecoderPool sync.Pool
@ -75,14 +79,15 @@ type datagramConn struct {
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
log := logger.With().Uint8("datagramVersion", 3).Logger()
return &datagramConn{
conn: conn,
index: index,
sessionManager: sessionManager,
icmpRouter: icmpRouter,
metrics: metrics,
logger: &log,
datagrams: make(chan []byte, demuxChanCapacity),
readErrors: make(chan error, 2),
conn: conn,
index: index,
sessionManager: sessionManager,
icmpRouter: icmpRouter,
metrics: metrics,
logger: &log,
datagrams: make(chan []byte, demuxChanCapacity),
icmpDatagramChan: make(chan *ICMPDatagram, icmpDatagramChanCapacity),
readErrors: make(chan error, 2),
icmpEncoderPool: sync.Pool{
New: func() any {
return packet.NewEncoder()
@ -168,6 +173,9 @@ func (c *datagramConn) Serve(ctx context.Context) error {
readCtx, cancel := context.WithCancel(connCtx)
defer cancel()
go c.pollDatagrams(readCtx)
// Processing ICMP datagrams also monitors the reader context since the ICMP datagrams from the reader are the input
// for the routine.
go c.processICMPDatagrams(readCtx)
for {
// We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur.
var datagram []byte
@ -181,58 +189,59 @@ func (c *datagramConn) Serve(ctx context.Context) error {
// Monitor for any hard errors from reading the connection
case err := <-c.readErrors:
return err
// Otherwise, wait and dequeue datagrams as they come in
// Wait and dequeue datagrams as they come in
case d := <-c.datagrams:
datagram = d
}
// Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated.
go func() {
typ, err := ParseDatagramType(datagram)
typ, err := ParseDatagramType(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ)
continue
}
switch typ {
case UDPSessionRegistrationType:
reg := &UDPSessionRegistrationDatagram{}
err := reg.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ)
return
c.logger.Err(err).Msgf("unable to unmarshal session registration datagram")
continue
}
switch typ {
case UDPSessionRegistrationType:
reg := &UDPSessionRegistrationDatagram{}
err := reg.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal session registration datagram")
return
}
logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger()
// We bind the new session to the quic connection context instead of cloudflared context to allow for the
// quic connection to close and close only the sessions bound to it. Closing of cloudflared will also
// initiate the close of the quic connection, so we don't have to worry about the application context
// in the scope of a session.
c.handleSessionRegistrationDatagram(connCtx, reg, &logger)
case UDPSessionPayloadType:
payload := &UDPSessionPayloadDatagram{}
err := payload.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
return
}
logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
c.handleSessionPayloadDatagram(payload, &logger)
case ICMPType:
packet := &ICMPDatagram{}
err := packet.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
return
}
c.handleICMPPacket(packet)
case UDPSessionRegistrationResponseType:
// cloudflared should never expect to receive UDP session responses as it will not initiate new
// sessions towards the edge.
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
return
default:
c.logger.Error().Msgf("unknown datagram type received: %d", typ)
logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger()
// We bind the new session to the quic connection context instead of cloudflared context to allow for the
// quic connection to close and close only the sessions bound to it. Closing of cloudflared will also
// initiate the close of the quic connection, so we don't have to worry about the application context
// in the scope of a session.
//
// Additionally, we spin out the registration into a separate go routine to handle the Serve'ing of the
// session in a separate routine from the demuxer.
go c.handleSessionRegistrationDatagram(connCtx, reg, &logger)
case UDPSessionPayloadType:
payload := &UDPSessionPayloadDatagram{}
err := payload.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
continue
}
}()
logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
c.handleSessionPayloadDatagram(payload, &logger)
case ICMPType:
packet := &ICMPDatagram{}
err := packet.UnmarshalBinary(datagram)
if err != nil {
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
continue
}
c.handleICMPPacket(packet)
case UDPSessionRegistrationResponseType:
// cloudflared should never expect to receive UDP session responses as it will not initiate new
// sessions towards the edge.
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
continue
default:
c.logger.Error().Msgf("unknown datagram type received: %d", typ)
}
}
}
@ -243,24 +252,21 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
Str(logDstKey, datagram.Dest.String()).
Logger()
session, err := c.sessionManager.RegisterSession(datagram, c)
switch err {
case nil:
// Continue as normal
case ErrSessionAlreadyRegistered:
// Session is already registered and likely the response got lost
c.handleSessionAlreadyRegistered(datagram.RequestID, &log)
return
case ErrSessionBoundToOtherConn:
// Session is already registered but to a different connection
c.handleSessionMigration(datagram.RequestID, &log)
return
case ErrSessionRegistrationRateLimited:
// There are too many concurrent sessions so we return an error to force a retry later
c.handleSessionRegistrationRateLimited(datagram, &log)
return
default:
log.Err(err).Msg("flow registration failure")
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
if err != nil {
switch err {
case ErrSessionAlreadyRegistered:
// Session is already registered and likely the response got lost
c.handleSessionAlreadyRegistered(datagram.RequestID, &log)
case ErrSessionBoundToOtherConn:
// Session is already registered but to a different connection
c.handleSessionMigration(datagram.RequestID, &log)
case ErrSessionRegistrationRateLimited:
// There are too many concurrent sessions so we return an error to force a retry later
c.handleSessionRegistrationRateLimited(datagram, &log)
default:
log.Err(err).Msg("flow registration failure")
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
}
return
}
log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger()
@ -365,21 +371,42 @@ func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadD
logger.Err(err).Msgf("unable to find flow")
return
}
// We ignore the bytes written to the socket because any partial write must return an error.
_, err = s.Write(datagram.Payload)
if err != nil {
logger.Err(err).Msgf("unable to write payload for the flow")
return
}
s.Write(datagram.Payload)
}
// Handles incoming ICMP datagrams.
// Handles incoming ICMP datagrams into a serialized channel to be handled by a single consumer.
func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
if c.icmpRouter == nil {
// ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets
return
}
select {
case c.icmpDatagramChan <- datagram:
default:
// If the ICMP datagram channel is full, drop any additional incoming.
c.logger.Warn().Msg("failed to write icmp packet to origin: dropped")
}
}
// Consumes from the ICMP datagram channel to write out the ICMP requests to an origin.
func (c *datagramConn) processICMPDatagrams(ctx context.Context) {
if c.icmpRouter == nil {
// ICMPRouter is disabled so we ignore all incoming ICMP packets
return
}
for {
select {
// If the provided context is closed we want to exit the write loop
case <-ctx.Done():
return
case datagram := <-c.icmpDatagramChan:
c.writeICMPPacket(datagram)
}
}
}
func (c *datagramConn) writeICMPPacket(datagram *ICMPDatagram) {
// Decode the provided ICMPDatagram as an ICMP packet
rawPacket := packet.RawPacket{Data: datagram.Payload}
cachedDecoder := c.icmpDecoderPool.Get()

View File

@ -13,6 +13,7 @@ import (
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
@ -92,7 +93,7 @@ func TestDatagramConn_New(t *testing.T) {
DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0,
}, &log)
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(newMockQuicConn(t.Context()), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
if conn == nil {
t.Fatal("expected valid connection")
}
@ -104,7 +105,9 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0,
}, &log)
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
payload := []byte{0xef, 0xef}
@ -123,7 +126,9 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0,
}, &log)
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
@ -149,7 +154,9 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0,
}, &log)
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
@ -166,7 +173,9 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
DefaultDialer: testDefaultDialer,
TCPWriteTimeout: 0,
}, &log)
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
defer cancel()
quic.ctx = ctx
@ -195,15 +204,17 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
sessionManager := &mockSessionManager{
expectedRegErr: v3.ErrSessionRegistrationRateLimited,
}
conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
// Setup the muxer
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(context.Canceled)
done := make(chan error, 1)
go func() {
done <- conn.Serve(ctx)
@ -223,9 +234,12 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
require.EqualValues(t, testRequestID, resp.RequestID)
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
assertContextClosed(t, ctx, done, cancel)
}
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
defer leaktest.Check(t)()
for _, test := range []struct {
name string
input []byte
@ -250,7 +264,9 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
logOutput := new(LockedBuffer)
log := zerolog.New(logOutput)
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
quic.send <- test.input
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -289,8 +305,11 @@ func (b *LockedBuffer) String() string {
}
func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
expectedErr := errors.New("unable to register session")
sessionManager := mockSessionManager{expectedRegErr: expectedErr}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -324,8 +343,11 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
}
func TestDatagramConnServe(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -372,8 +394,11 @@ func TestDatagramConnServe(t *testing.T) {
// instances causes inteference resulting in multiple different raw packets being decoded
// as the same decoded packet.
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
router := newMockICMPRouter()
@ -413,10 +438,14 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
wg := sync.WaitGroup{}
var receivedPackets []*packet.ICMP
go func() {
for ctx.Err() == nil {
icmpPacket := <-router.recv
receivedPackets = append(receivedPackets, icmpPacket)
wg.Done()
for {
select {
case <-ctx.Done():
return
case icmpPacket := <-router.recv:
receivedPackets = append(receivedPackets, icmpPacket)
wg.Done()
}
}
}()
@ -452,8 +481,11 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
}
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -514,12 +546,17 @@ func TestDatagramConnServe_RegisterTwice(t *testing.T) {
}
func TestDatagramConnServe_MigrateConnection(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
quic2 := newMockQuicConn()
conn2Ctx, conn2Cancel := context.WithCancelCause(t.Context())
defer conn2Cancel(context.Canceled)
quic2 := newMockQuicConn(conn2Ctx)
conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log)
// Setup the muxer
@ -597,8 +634,11 @@ func TestDatagramConnServe_MigrateConnection(t *testing.T) {
}
func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
// mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer
sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -624,9 +664,12 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
assertContextClosed(t, ctx, done, cancel)
}
func TestDatagramConnServe_Payload(t *testing.T) {
func TestDatagramConnServe_Payloads(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
session := newMockSession()
sessionManager := mockSessionManager{session: &session}
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
@ -639,15 +682,26 @@ func TestDatagramConnServe_Payload(t *testing.T) {
done <- conn.Serve(ctx)
}()
// Send new session registration
expectedPayload := []byte{0xef, 0xef}
datagram := newSessionPayloadDatagram(testRequestID, expectedPayload)
quic.send <- datagram
// Send session payloads
expectedPayloads := makePayloads(256, 16)
go func() {
for _, payload := range expectedPayloads {
datagram := newSessionPayloadDatagram(testRequestID, payload)
quic.send <- datagram
}
}()
// Session should receive the payload
payload := <-session.recv
if !slices.Equal(expectedPayload, payload) {
t.Fatalf("expected session receieve the payload sent via the muxer")
// Session should receive the payloads (in-order)
for i, payload := range expectedPayloads {
select {
case recv := <-session.recv:
if !slices.Equal(recv, payload) {
t.Fatalf("expected session receieve the payload[%d] sent via the muxer: (%x) (%x)", i, recv[:16], payload[:16])
}
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
}
// Cancel the muxer Serve context and make sure it closes with the expected error
@ -655,8 +709,11 @@ func TestDatagramConnServe_Payload(t *testing.T) {
}
func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
@ -701,8 +758,11 @@ func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
}
func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
quic := newMockQuicConn()
connCtx, connCancel := context.WithCancelCause(t.Context())
defer connCancel(context.Canceled)
quic := newMockQuicConn(connCtx)
router := newMockICMPRouter()
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
@ -821,9 +881,9 @@ type mockQuicConn struct {
recv chan []byte
}
func newMockQuicConn() *mockQuicConn {
func newMockQuicConn(ctx context.Context) *mockQuicConn {
return &mockQuicConn{
ctx: context.Background(),
ctx: ctx,
send: make(chan []byte, 1),
recv: make(chan []byte, 1),
}
@ -841,7 +901,12 @@ func (m *mockQuicConn) SendDatagram(payload []byte) error {
}
func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) {
return <-m.send, nil
select {
case <-m.ctx.Done():
return nil, m.ctx.Err()
case b := <-m.send:
return b, nil
}
}
type mockQuicConnReadError struct {
@ -905,11 +970,10 @@ func (m *mockSession) Serve(ctx context.Context) error {
return v3.SessionCloseErr
}
func (m *mockSession) Write(payload []byte) (n int, err error) {
func (m *mockSession) Write(payload []byte) {
b := make([]byte, len(payload))
copy(b, payload)
m.recv <- b
return len(b), nil
}
func (m *mockSession) Close() error {

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"os"
"sync"
"sync/atomic"
"time"
@ -22,6 +23,11 @@ const (
// this value (maxDatagramPayloadLen).
maxOriginUDPPacketSize = 1500
// The maximum amount of datagrams a session will queue up before it begins dropping datagrams.
// This channel buffer is small because we assume that the dedicated writer to the origin is typically
// fast enought to keep the channel empty.
writeChanCapacity = 16
logFlowID = "flowID"
logPacketSizeKey = "packetSize"
)
@ -49,7 +55,7 @@ func newSessionIdleErr(timeout time.Duration) error {
}
type Session interface {
io.WriteCloser
io.Closer
ID() RequestID
ConnectionID() uint8
RemoteAddr() net.Addr
@ -58,6 +64,7 @@ type Session interface {
Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger)
// Serve starts the event loop for processing UDP packets
Serve(ctx context.Context) error
Write(payload []byte)
}
type session struct {
@ -67,12 +74,18 @@ type session struct {
originAddr net.Addr
localAddr net.Addr
eyeball atomic.Pointer[DatagramConn]
writeChan chan []byte
// activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time
closeChan chan error
contextChan chan context.Context
metrics Metrics
log *zerolog.Logger
errChan chan error
// The close channel signal only exists for the write loop because the read loop is always waiting on a read
// from the UDP socket to the origin. To close the read loop we close the socket.
// Additionally, we can't close the writeChan to indicate that writes are complete because the producer (edge)
// side may still be trying to write to this session.
closeWrite chan struct{}
contextChan chan context.Context
metrics Metrics
log *zerolog.Logger
// A special close function that we wrap with sync.Once to make sure it is only called once
closeFn func() error
@ -89,10 +102,12 @@ func NewSession(
log *zerolog.Logger,
) Session {
logger := log.With().Str(logFlowID, id.String()).Logger()
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
// write to the channel without blocking since there is only ever one value read from the closeChan by the
writeChan := make(chan []byte, writeChanCapacity)
// errChan has three slots to allow for all writers (the closeFn, the read loop and the write loop) to
// write to the channel without blocking since there is only ever one value read from the errChan by the
// waitForCloseCondition.
closeChan := make(chan error, 2)
errChan := make(chan error, 3)
closeWrite := make(chan struct{})
session := &session{
id: id,
closeAfterIdle: closeAfterIdle,
@ -100,10 +115,12 @@ func NewSession(
originAddr: originAddr,
localAddr: localAddr,
eyeball: atomic.Pointer[DatagramConn]{},
writeChan: writeChan,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1),
closeChan: closeChan,
errChan: errChan,
closeWrite: closeWrite,
// contextChan is an unbounded channel to help enforce one active migration of a session at a time.
contextChan: make(chan context.Context),
metrics: metrics,
@ -111,9 +128,12 @@ func NewSession(
closeFn: sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full
select {
case closeChan <- SessionCloseErr:
case errChan <- SessionCloseErr:
default:
}
// Indicate to the write loop that the session is now closed
close(closeWrite)
// Close the socket directly to unblock the read loop and cause it to also end
return origin.Close()
}),
}
@ -154,66 +174,112 @@ func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zer
}
func (s *session) Serve(ctx context.Context) error {
go func() {
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations
readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
// the required datagram header information. We can reuse this buffer for this session since the header is the
// same for the each read.
_ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
for {
// Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
if err != nil {
if errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) {
s.log.Debug().Msgf("flow (origin) connection closed: %v", err)
}
s.closeChan <- err
return
}
if n < 0 {
s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped")
continue
}
if n > maxDatagramPayloadLen {
connectionIndex := s.ConnectionID()
s.metrics.PayloadTooLarge(connectionIndex)
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
continue
}
// We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point
// of lock contention, as a migration can only happen during startup of a session before traffic flow.
eyeball := *(s.eyeball.Load())
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
if err != nil {
s.closeChan <- err
return
}
// Mark the session as active since we proxied a valid packet from the origin.
s.markActive()
}
}()
go s.writeLoop()
go s.readLoop()
return s.waitForCloseCondition(ctx, s.closeAfterIdle)
}
func (s *session) Write(payload []byte) (n int, err error) {
n, err = s.origin.Write(payload)
if err != nil {
s.log.Err(err).Msg("failed to write payload to flow (remote)")
return n, err
// Read datagrams from the origin and write them to the connection.
func (s *session) readLoop() {
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations
readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
// the required datagram header information. We can reuse this buffer for this session since the header is the
// same for the each read.
_ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
for {
// Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
if err != nil {
if isConnectionClosed(err) {
s.log.Debug().Msgf("flow (read) connection closed: %v", err)
}
s.closeSession(err)
return
}
if n < 0 {
s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped")
continue
}
if n > maxDatagramPayloadLen {
connectionIndex := s.ConnectionID()
s.metrics.PayloadTooLarge(connectionIndex)
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
continue
}
// We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point
// of lock contention, as a migration can only happen during startup of a session before traffic flow.
eyeball := *(s.eyeball.Load())
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
if err != nil {
s.closeSession(err)
return
}
// Mark the session as active since we proxied a valid packet from the origin.
s.markActive()
}
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
if n < len(payload) {
s.log.Err(io.ErrShortWrite).Msg("failed to write the full payload to flow (remote)")
return n, io.ErrShortWrite
}
func (s *session) Write(payload []byte) {
select {
case s.writeChan <- payload:
default:
s.log.Error().Msg("failed to write flow payload to origin: dropped")
}
}
// Read datagrams from the write channel to the origin.
func (s *session) writeLoop() {
for {
select {
case <-s.closeWrite:
// When the closeWrite channel is closed, we will no longer write to the origin and end this
// goroutine since the session is now closed.
return
case payload := <-s.writeChan:
n, err := s.origin.Write(payload)
if err != nil {
// Check if this is a write deadline exceeded to the connection
if errors.Is(err, os.ErrDeadlineExceeded) {
s.log.Warn().Err(err).Msg("flow (write) deadline exceeded: dropping packet")
continue
}
if isConnectionClosed(err) {
s.log.Debug().Msgf("flow (write) connection closed: %v", err)
}
s.log.Err(err).Msg("failed to write flow payload to origin")
s.closeSession(err)
// If we fail to write to the origin socket, we need to end the writer and close the session
return
}
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
if n < len(payload) {
s.log.Err(io.ErrShortWrite).Msg("failed to write the full flow payload to origin")
continue
}
// Mark the session as active since we successfully proxied a packet to the origin.
s.markActive()
}
}
}
func isConnectionClosed(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
}
// Send an error to the error channel to report that an error has either happened on the tunnel or origin side of the
// proxied connection.
func (s *session) closeSession(err error) {
select {
case s.errChan <- err:
default:
// In the case that the errChan is already full, we will skip over it and return as to not block
// the caller because we should start cleaning up the session.
s.log.Warn().Msg("error channel was full")
}
// Mark the session as active since we proxied a packet to the origin.
s.markActive()
return n, err
}
// ResetIdleTimer will restart the current idle timer.
@ -240,7 +306,8 @@ func (s *session) Close() error {
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
connCtx := ctx
// Closing the session at the end cancels read so Serve() can return
// Closing the session at the end cancels read so Serve() can return, additionally, it closes the
// closeWrite channel which indicates to the write loop to return.
defer s.Close()
if closeAfterIdle == 0 {
// Provided that the default caller doesn't specify one
@ -260,7 +327,9 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
// still be active on the existing connection.
connCtx = newContext
continue
case reason := <-s.closeChan:
case reason := <-s.errChan:
// Any error returned here is from the read or write loops indicating that it can no longer process datagrams
// and as such the session needs to close.
return reason
case <-checkIdleTimer.C:
// The check idle timer will only return after an idle period since the last active

View File

@ -4,20 +4,24 @@ import (
"testing"
)
// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin.
// FuzzSessionWrite verifies that we don't run into any panics when writing a single variable sized payload to the origin.
func FuzzSessionWrite(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
testSessionWrite(t, b)
})
}
// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin.
func FuzzSessionServe(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes
if len(b) > 1280 {
b = b[:1280]
}
testSessionServe_Origin(t, b)
testSessionWrite(t, [][]byte{b})
})
}
// FuzzSessionRead verifies that we don't run into any panics when reading a single variable sized payload from the origin.
func FuzzSessionRead(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes
if len(b) > 1280 {
b = b[:1280]
}
testSessionRead(t, [][]byte{b})
})
}

View File

@ -31,60 +31,61 @@ func TestSessionNew(t *testing.T) {
}
}
func testSessionWrite(t *testing.T, payload []byte) {
func testSessionWrite(t *testing.T, payloads [][]byte) {
log := zerolog.Nop()
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
// Start origin server read
serverRead := make(chan []byte, 1)
// Start origin server reads
serverRead := make(chan []byte, len(payloads))
go func() {
read := make([]byte, 1500)
_, _ = server.Read(read[:])
serverRead <- read
for range len(payloads) {
buf := make([]byte, 1500)
_, _ = server.Read(buf[:])
serverRead <- buf
}
close(serverRead)
}()
// Create session and write to origin
// Create a session
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
n, err := session.Write(payload)
defer session.Close()
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("unable to write the whole payload")
// Start the Serve to begin the writeLoop
ctx, cancel := context.WithCancelCause(t.Context())
defer cancel(context.Canceled)
done := make(chan error)
go func() {
done <- session.Serve(ctx)
}()
// Write the payloads to the session
for _, payload := range payloads {
session.Write(payload)
}
read := <-serverRead
if !slices.Equal(payload, read[:len(payload)]) {
t.Fatal("payload provided from origin and read value are not the same")
// Read from the origin to ensure the payloads were received (in-order)
for i, payload := range payloads {
read := <-serverRead
if !slices.Equal(payload, read[:len(payload)]) {
t.Fatalf("payload[%d] provided from origin and read value are not the same (%x) and (%x)", i, payload[:16], read[:16])
}
}
_, more := <-serverRead
if more {
t.Fatalf("expected the session to have all of the origin payloads received: %d", len(serverRead))
}
assertContextClosed(t, ctx, done, cancel)
}
func TestSessionWrite(t *testing.T) {
defer leaktest.Check(t)()
for i := range 1280 {
payloads := makePayloads(i, 16)
testSessionWrite(t, payloads)
}
}
func TestSessionWrite_Max(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionWrite(t, payload)
}
func TestSessionWrite_Min(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionWrite(t, payload)
}
func TestSessionServe_OriginMax(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280)
testSessionServe_Origin(t, payload)
}
func TestSessionServe_OriginMin(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0)
testSessionServe_Origin(t, payload)
}
func testSessionServe_Origin(t *testing.T, payload []byte) {
func testSessionRead(t *testing.T, payloads [][]byte) {
log := zerolog.Nop()
origin, server := net.Pipe()
defer origin.Close()
@ -100,37 +101,42 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
done <- session.Serve(ctx)
}()
// Write from the origin server
_, err := server.Write(payload)
if err != nil {
t.Fatal(err)
}
select {
case data := <-eyeball.recvData:
// check received data matches provided from origin
expectedData := makePayload(1500)
_ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatal("expected datagram did not equal expected")
// Write from the origin server to the eyeball
go func() {
for _, payload := range payloads {
_, _ = server.Write(payload)
}
}()
// Read from the eyeball to ensure the payloads were received (in-order)
for i, payload := range payloads {
select {
case data := <-eyeball.recvData:
// check received data matches provided from origin
expectedData := makePayload(1500)
_ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatalf("expected datagram[%d] did not equal expected", i)
}
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
cancel(errExpectedContextCanceled)
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
err = <-done
if !errors.Is(err, context.Canceled) {
t.Fatal(err)
}
if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) {
t.Fatal(err)
assertContextClosed(t, ctx, done, cancel)
}
func TestSessionRead(t *testing.T) {
defer leaktest.Check(t)()
for i := range 1280 {
payloads := makePayloads(i, 16)
testSessionRead(t, payloads)
}
}
func TestSessionServe_OriginTooLarge(t *testing.T) {
func TestSessionRead_OriginTooLarge(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop()
eyeball := newMockEyeball()
@ -317,6 +323,8 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Serve(t.Context())
// Session should idle timeout if no reads or writes occur
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err)
}