Merge branch 'cloudflare:master' into master
This commit is contained in:
commit
b8511df478
6
Makefile
6
Makefile
|
|
@ -36,7 +36,7 @@ ifdef PACKAGE_MANAGER
|
|||
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
|
||||
endif
|
||||
|
||||
ifdef CONTAINER_BUILD
|
||||
ifdef CONTAINER_BUILD
|
||||
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
|
||||
endif
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ ifneq ($(TARGET_ARM), )
|
|||
ARM_COMMAND := GOARM=$(TARGET_ARM)
|
||||
endif
|
||||
|
||||
ifeq ($(TARGET_ARM), 7)
|
||||
ifeq ($(TARGET_ARM), 7)
|
||||
PACKAGE_ARCH := armhf
|
||||
else
|
||||
PACKAGE_ARCH := $(TARGET_ARCH)
|
||||
|
|
@ -185,7 +185,7 @@ fuzz:
|
|||
@go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet
|
||||
@go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet
|
||||
@go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzSessionServe -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzSessionRead -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3
|
||||
|
|
|
|||
|
|
@ -14,4 +14,7 @@ spec:
|
|||
lifecycle: "Active"
|
||||
owner: "teams/tunnel-teams-routing"
|
||||
cf:
|
||||
compliance:
|
||||
fedramp-high: "pending"
|
||||
fedramp-moderate: "yes"
|
||||
FIPS: "required"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const writeDeadlineUDP = 200 * time.Millisecond
|
||||
|
||||
// OriginTCPDialer provides a TCP dial operation to a requested address.
|
||||
type OriginTCPDialer interface {
|
||||
DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error)
|
||||
|
|
@ -141,6 +143,21 @@ func (d *Dialer) DialUDP(dest netip.AddrPort) (net.Conn, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
return &writeDeadlineConn{
|
||||
Conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// writeDeadlineConn is a wrapper around a net.Conn that sets a write deadline of 200ms.
|
||||
// This is to prevent the socket from blocking on the write operation if it were to occur. However,
|
||||
// we typically never expect this to occur except under high load or kernel issues.
|
||||
type writeDeadlineConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (w *writeDeadlineConn) Write(b []byte) (int, error) {
|
||||
if err := w.SetWriteDeadline(time.Now().Add(writeDeadlineUDP)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.Conn.Write(b)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package v3_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
|
@ -14,12 +15,18 @@ import (
|
|||
|
||||
func makePayload(size int) []byte {
|
||||
payload := make([]byte, size)
|
||||
for i := range len(payload) {
|
||||
payload[i] = 0xfc
|
||||
}
|
||||
_, _ = rand.Read(payload)
|
||||
return payload
|
||||
}
|
||||
|
||||
func makePayloads(size int, count int) [][]byte {
|
||||
payloads := make([][]byte, count)
|
||||
for i := range payloads {
|
||||
payloads[i] = makePayload(size)
|
||||
}
|
||||
return payloads
|
||||
}
|
||||
|
||||
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
|
||||
payload := makePayload(1280)
|
||||
tests := []*v3.UDPSessionRegistrationDatagram{
|
||||
|
|
|
|||
197
quic/v3/muxer.go
197
quic/v3/muxer.go
|
|
@ -17,6 +17,9 @@ const (
|
|||
// Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader.
|
||||
// This has worked previously well for datagramv2, so we will start with this as well
|
||||
demuxChanCapacity = 16
|
||||
// This provides a small buffer for the PacketRouter to poll ICMP packets from the QUIC connection
|
||||
// before writing them to the origin.
|
||||
icmpDatagramChanCapacity = 128
|
||||
|
||||
logSrcKey = "src"
|
||||
logDstKey = "dst"
|
||||
|
|
@ -59,14 +62,15 @@ type QuicConnection interface {
|
|||
}
|
||||
|
||||
type datagramConn struct {
|
||||
conn QuicConnection
|
||||
index uint8
|
||||
sessionManager SessionManager
|
||||
icmpRouter ingress.ICMPRouter
|
||||
metrics Metrics
|
||||
logger *zerolog.Logger
|
||||
datagrams chan []byte
|
||||
readErrors chan error
|
||||
conn QuicConnection
|
||||
index uint8
|
||||
sessionManager SessionManager
|
||||
icmpRouter ingress.ICMPRouter
|
||||
metrics Metrics
|
||||
logger *zerolog.Logger
|
||||
datagrams chan []byte
|
||||
icmpDatagramChan chan *ICMPDatagram
|
||||
readErrors chan error
|
||||
|
||||
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
|
||||
icmpDecoderPool sync.Pool
|
||||
|
|
@ -75,14 +79,15 @@ type datagramConn struct {
|
|||
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
|
||||
log := logger.With().Uint8("datagramVersion", 3).Logger()
|
||||
return &datagramConn{
|
||||
conn: conn,
|
||||
index: index,
|
||||
sessionManager: sessionManager,
|
||||
icmpRouter: icmpRouter,
|
||||
metrics: metrics,
|
||||
logger: &log,
|
||||
datagrams: make(chan []byte, demuxChanCapacity),
|
||||
readErrors: make(chan error, 2),
|
||||
conn: conn,
|
||||
index: index,
|
||||
sessionManager: sessionManager,
|
||||
icmpRouter: icmpRouter,
|
||||
metrics: metrics,
|
||||
logger: &log,
|
||||
datagrams: make(chan []byte, demuxChanCapacity),
|
||||
icmpDatagramChan: make(chan *ICMPDatagram, icmpDatagramChanCapacity),
|
||||
readErrors: make(chan error, 2),
|
||||
icmpEncoderPool: sync.Pool{
|
||||
New: func() any {
|
||||
return packet.NewEncoder()
|
||||
|
|
@ -168,6 +173,9 @@ func (c *datagramConn) Serve(ctx context.Context) error {
|
|||
readCtx, cancel := context.WithCancel(connCtx)
|
||||
defer cancel()
|
||||
go c.pollDatagrams(readCtx)
|
||||
// Processing ICMP datagrams also monitors the reader context since the ICMP datagrams from the reader are the input
|
||||
// for the routine.
|
||||
go c.processICMPDatagrams(readCtx)
|
||||
for {
|
||||
// We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur.
|
||||
var datagram []byte
|
||||
|
|
@ -181,58 +189,59 @@ func (c *datagramConn) Serve(ctx context.Context) error {
|
|||
// Monitor for any hard errors from reading the connection
|
||||
case err := <-c.readErrors:
|
||||
return err
|
||||
// Otherwise, wait and dequeue datagrams as they come in
|
||||
// Wait and dequeue datagrams as they come in
|
||||
case d := <-c.datagrams:
|
||||
datagram = d
|
||||
}
|
||||
|
||||
// Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated.
|
||||
go func() {
|
||||
typ, err := ParseDatagramType(datagram)
|
||||
typ, err := ParseDatagramType(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ)
|
||||
continue
|
||||
}
|
||||
switch typ {
|
||||
case UDPSessionRegistrationType:
|
||||
reg := &UDPSessionRegistrationDatagram{}
|
||||
err := reg.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ)
|
||||
return
|
||||
c.logger.Err(err).Msgf("unable to unmarshal session registration datagram")
|
||||
continue
|
||||
}
|
||||
switch typ {
|
||||
case UDPSessionRegistrationType:
|
||||
reg := &UDPSessionRegistrationDatagram{}
|
||||
err := reg.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to unmarshal session registration datagram")
|
||||
return
|
||||
}
|
||||
logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger()
|
||||
// We bind the new session to the quic connection context instead of cloudflared context to allow for the
|
||||
// quic connection to close and close only the sessions bound to it. Closing of cloudflared will also
|
||||
// initiate the close of the quic connection, so we don't have to worry about the application context
|
||||
// in the scope of a session.
|
||||
c.handleSessionRegistrationDatagram(connCtx, reg, &logger)
|
||||
case UDPSessionPayloadType:
|
||||
payload := &UDPSessionPayloadDatagram{}
|
||||
err := payload.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
|
||||
return
|
||||
}
|
||||
logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
|
||||
c.handleSessionPayloadDatagram(payload, &logger)
|
||||
case ICMPType:
|
||||
packet := &ICMPDatagram{}
|
||||
err := packet.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
|
||||
return
|
||||
}
|
||||
c.handleICMPPacket(packet)
|
||||
case UDPSessionRegistrationResponseType:
|
||||
// cloudflared should never expect to receive UDP session responses as it will not initiate new
|
||||
// sessions towards the edge.
|
||||
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
|
||||
return
|
||||
default:
|
||||
c.logger.Error().Msgf("unknown datagram type received: %d", typ)
|
||||
logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger()
|
||||
// We bind the new session to the quic connection context instead of cloudflared context to allow for the
|
||||
// quic connection to close and close only the sessions bound to it. Closing of cloudflared will also
|
||||
// initiate the close of the quic connection, so we don't have to worry about the application context
|
||||
// in the scope of a session.
|
||||
//
|
||||
// Additionally, we spin out the registration into a separate go routine to handle the Serve'ing of the
|
||||
// session in a separate routine from the demuxer.
|
||||
go c.handleSessionRegistrationDatagram(connCtx, reg, &logger)
|
||||
case UDPSessionPayloadType:
|
||||
payload := &UDPSessionPayloadDatagram{}
|
||||
err := payload.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to unmarshal session payload datagram")
|
||||
continue
|
||||
}
|
||||
}()
|
||||
logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger()
|
||||
c.handleSessionPayloadDatagram(payload, &logger)
|
||||
case ICMPType:
|
||||
packet := &ICMPDatagram{}
|
||||
err := packet.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to unmarshal icmp datagram")
|
||||
continue
|
||||
}
|
||||
c.handleICMPPacket(packet)
|
||||
case UDPSessionRegistrationResponseType:
|
||||
// cloudflared should never expect to receive UDP session responses as it will not initiate new
|
||||
// sessions towards the edge.
|
||||
c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType)
|
||||
continue
|
||||
default:
|
||||
c.logger.Error().Msgf("unknown datagram type received: %d", typ)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -243,24 +252,21 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
|
|||
Str(logDstKey, datagram.Dest.String()).
|
||||
Logger()
|
||||
session, err := c.sessionManager.RegisterSession(datagram, c)
|
||||
switch err {
|
||||
case nil:
|
||||
// Continue as normal
|
||||
case ErrSessionAlreadyRegistered:
|
||||
// Session is already registered and likely the response got lost
|
||||
c.handleSessionAlreadyRegistered(datagram.RequestID, &log)
|
||||
return
|
||||
case ErrSessionBoundToOtherConn:
|
||||
// Session is already registered but to a different connection
|
||||
c.handleSessionMigration(datagram.RequestID, &log)
|
||||
return
|
||||
case ErrSessionRegistrationRateLimited:
|
||||
// There are too many concurrent sessions so we return an error to force a retry later
|
||||
c.handleSessionRegistrationRateLimited(datagram, &log)
|
||||
return
|
||||
default:
|
||||
log.Err(err).Msg("flow registration failure")
|
||||
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrSessionAlreadyRegistered:
|
||||
// Session is already registered and likely the response got lost
|
||||
c.handleSessionAlreadyRegistered(datagram.RequestID, &log)
|
||||
case ErrSessionBoundToOtherConn:
|
||||
// Session is already registered but to a different connection
|
||||
c.handleSessionMigration(datagram.RequestID, &log)
|
||||
case ErrSessionRegistrationRateLimited:
|
||||
// There are too many concurrent sessions so we return an error to force a retry later
|
||||
c.handleSessionRegistrationRateLimited(datagram, &log)
|
||||
default:
|
||||
log.Err(err).Msg("flow registration failure")
|
||||
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
|
||||
}
|
||||
return
|
||||
}
|
||||
log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger()
|
||||
|
|
@ -365,21 +371,42 @@ func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadD
|
|||
logger.Err(err).Msgf("unable to find flow")
|
||||
return
|
||||
}
|
||||
// We ignore the bytes written to the socket because any partial write must return an error.
|
||||
_, err = s.Write(datagram.Payload)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("unable to write payload for the flow")
|
||||
return
|
||||
}
|
||||
s.Write(datagram.Payload)
|
||||
}
|
||||
|
||||
// Handles incoming ICMP datagrams.
|
||||
// Handles incoming ICMP datagrams into a serialized channel to be handled by a single consumer.
|
||||
func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
|
||||
if c.icmpRouter == nil {
|
||||
// ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets
|
||||
return
|
||||
}
|
||||
select {
|
||||
case c.icmpDatagramChan <- datagram:
|
||||
default:
|
||||
// If the ICMP datagram channel is full, drop any additional incoming.
|
||||
c.logger.Warn().Msg("failed to write icmp packet to origin: dropped")
|
||||
}
|
||||
}
|
||||
|
||||
// Consumes from the ICMP datagram channel to write out the ICMP requests to an origin.
|
||||
func (c *datagramConn) processICMPDatagrams(ctx context.Context) {
|
||||
if c.icmpRouter == nil {
|
||||
// ICMPRouter is disabled so we ignore all incoming ICMP packets
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
// If the provided context is closed we want to exit the write loop
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case datagram := <-c.icmpDatagramChan:
|
||||
c.writeICMPPacket(datagram)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *datagramConn) writeICMPPacket(datagram *ICMPDatagram) {
|
||||
// Decode the provided ICMPDatagram as an ICMP packet
|
||||
rawPacket := packet.RawPacket{Data: datagram.Payload}
|
||||
cachedDecoder := c.icmpDecoderPool.Get()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fortytw2/leaktest"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -92,7 +93,7 @@ func TestDatagramConn_New(t *testing.T) {
|
|||
DefaultDialer: testDefaultDialer,
|
||||
TCPWriteTimeout: 0,
|
||||
}, &log)
|
||||
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(newMockQuicConn(t.Context()), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
if conn == nil {
|
||||
t.Fatal("expected valid connection")
|
||||
}
|
||||
|
|
@ -104,7 +105,9 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
|||
DefaultDialer: testDefaultDialer,
|
||||
TCPWriteTimeout: 0,
|
||||
}, &log)
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
payload := []byte{0xef, 0xef}
|
||||
|
|
@ -123,7 +126,9 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
|||
DefaultDialer: testDefaultDialer,
|
||||
TCPWriteTimeout: 0,
|
||||
}, &log)
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
|
||||
|
|
@ -149,7 +154,9 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
|
|||
DefaultDialer: testDefaultDialer,
|
||||
TCPWriteTimeout: 0,
|
||||
}, &log)
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
|
||||
|
|
@ -166,7 +173,9 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
|||
DefaultDialer: testDefaultDialer,
|
||||
TCPWriteTimeout: 0,
|
||||
}, &log)
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second)
|
||||
defer cancel()
|
||||
quic.ctx = ctx
|
||||
|
|
@ -195,15 +204,17 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
|
|||
|
||||
func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
sessionManager := &mockSessionManager{
|
||||
expectedRegErr: v3.ErrSessionRegistrationRateLimited,
|
||||
}
|
||||
conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
// Setup the muxer
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(context.Canceled)
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- conn.Serve(ctx)
|
||||
|
|
@ -223,9 +234,12 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
|
|||
|
||||
require.EqualValues(t, testRequestID, resp.RequestID)
|
||||
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
|
||||
|
||||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
input []byte
|
||||
|
|
@ -250,7 +264,9 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
|
|||
t.Run(test.name, func(t *testing.T) {
|
||||
logOutput := new(LockedBuffer)
|
||||
log := zerolog.New(logOutput)
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
quic.send <- test.input
|
||||
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
|
|
@ -289,8 +305,11 @@ func (b *LockedBuffer) String() string {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
expectedErr := errors.New("unable to register session")
|
||||
sessionManager := mockSessionManager{expectedRegErr: expectedErr}
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
|
@ -324,8 +343,11 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
|
@ -372,8 +394,11 @@ func TestDatagramConnServe(t *testing.T) {
|
|||
// instances causes inteference resulting in multiple different raw packets being decoded
|
||||
// as the same decoded packet.
|
||||
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
router := newMockICMPRouter()
|
||||
|
|
@ -413,10 +438,14 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
|
|||
wg := sync.WaitGroup{}
|
||||
var receivedPackets []*packet.ICMP
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
icmpPacket := <-router.recv
|
||||
receivedPackets = append(receivedPackets, icmpPacket)
|
||||
wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case icmpPacket := <-router.recv:
|
||||
receivedPackets = append(receivedPackets, icmpPacket)
|
||||
wg.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -452,8 +481,11 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
|
@ -514,12 +546,17 @@ func TestDatagramConnServe_RegisterTwice(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe_MigrateConnection(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
quic2 := newMockQuicConn()
|
||||
conn2Ctx, conn2Cancel := context.WithCancelCause(t.Context())
|
||||
defer conn2Cancel(context.Canceled)
|
||||
quic2 := newMockQuicConn(conn2Ctx)
|
||||
conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log)
|
||||
|
||||
// Setup the muxer
|
||||
|
|
@ -597,8 +634,11 @@ func TestDatagramConnServe_MigrateConnection(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
// mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer
|
||||
sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound}
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
|
@ -624,9 +664,12 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) {
|
|||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_Payload(t *testing.T) {
|
||||
func TestDatagramConnServe_Payloads(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
|
@ -639,15 +682,26 @@ func TestDatagramConnServe_Payload(t *testing.T) {
|
|||
done <- conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
// Send new session registration
|
||||
expectedPayload := []byte{0xef, 0xef}
|
||||
datagram := newSessionPayloadDatagram(testRequestID, expectedPayload)
|
||||
quic.send <- datagram
|
||||
// Send session payloads
|
||||
expectedPayloads := makePayloads(256, 16)
|
||||
go func() {
|
||||
for _, payload := range expectedPayloads {
|
||||
datagram := newSessionPayloadDatagram(testRequestID, payload)
|
||||
quic.send <- datagram
|
||||
}
|
||||
}()
|
||||
|
||||
// Session should receive the payload
|
||||
payload := <-session.recv
|
||||
if !slices.Equal(expectedPayload, payload) {
|
||||
t.Fatalf("expected session receieve the payload sent via the muxer")
|
||||
// Session should receive the payloads (in-order)
|
||||
for i, payload := range expectedPayloads {
|
||||
select {
|
||||
case recv := <-session.recv:
|
||||
if !slices.Equal(recv, payload) {
|
||||
t.Fatalf("expected session receieve the payload[%d] sent via the muxer: (%x) (%x)", i, recv[:16], payload[:16])
|
||||
}
|
||||
case err := <-ctx.Done():
|
||||
// we expect the payload to return before the context to cancel on the session
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel the muxer Serve context and make sure it closes with the expected error
|
||||
|
|
@ -655,8 +709,11 @@ func TestDatagramConnServe_Payload(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
router := newMockICMPRouter()
|
||||
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
|
||||
|
||||
|
|
@ -701,8 +758,11 @@ func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
connCtx, connCancel := context.WithCancelCause(t.Context())
|
||||
defer connCancel(context.Canceled)
|
||||
quic := newMockQuicConn(connCtx)
|
||||
router := newMockICMPRouter()
|
||||
conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log)
|
||||
|
||||
|
|
@ -821,9 +881,9 @@ type mockQuicConn struct {
|
|||
recv chan []byte
|
||||
}
|
||||
|
||||
func newMockQuicConn() *mockQuicConn {
|
||||
func newMockQuicConn(ctx context.Context) *mockQuicConn {
|
||||
return &mockQuicConn{
|
||||
ctx: context.Background(),
|
||||
ctx: ctx,
|
||||
send: make(chan []byte, 1),
|
||||
recv: make(chan []byte, 1),
|
||||
}
|
||||
|
|
@ -841,7 +901,12 @@ func (m *mockQuicConn) SendDatagram(payload []byte) error {
|
|||
}
|
||||
|
||||
func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||
return <-m.send, nil
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return nil, m.ctx.Err()
|
||||
case b := <-m.send:
|
||||
return b, nil
|
||||
}
|
||||
}
|
||||
|
||||
type mockQuicConnReadError struct {
|
||||
|
|
@ -905,11 +970,10 @@ func (m *mockSession) Serve(ctx context.Context) error {
|
|||
return v3.SessionCloseErr
|
||||
}
|
||||
|
||||
func (m *mockSession) Write(payload []byte) (n int, err error) {
|
||||
func (m *mockSession) Write(payload []byte) {
|
||||
b := make([]byte, len(payload))
|
||||
copy(b, payload)
|
||||
m.recv <- b
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *mockSession) Close() error {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
|
@ -22,6 +23,11 @@ const (
|
|||
// this value (maxDatagramPayloadLen).
|
||||
maxOriginUDPPacketSize = 1500
|
||||
|
||||
// The maximum amount of datagrams a session will queue up before it begins dropping datagrams.
|
||||
// This channel buffer is small because we assume that the dedicated writer to the origin is typically
|
||||
// fast enought to keep the channel empty.
|
||||
writeChanCapacity = 16
|
||||
|
||||
logFlowID = "flowID"
|
||||
logPacketSizeKey = "packetSize"
|
||||
)
|
||||
|
|
@ -49,7 +55,7 @@ func newSessionIdleErr(timeout time.Duration) error {
|
|||
}
|
||||
|
||||
type Session interface {
|
||||
io.WriteCloser
|
||||
io.Closer
|
||||
ID() RequestID
|
||||
ConnectionID() uint8
|
||||
RemoteAddr() net.Addr
|
||||
|
|
@ -58,6 +64,7 @@ type Session interface {
|
|||
Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger)
|
||||
// Serve starts the event loop for processing UDP packets
|
||||
Serve(ctx context.Context) error
|
||||
Write(payload []byte)
|
||||
}
|
||||
|
||||
type session struct {
|
||||
|
|
@ -67,12 +74,18 @@ type session struct {
|
|||
originAddr net.Addr
|
||||
localAddr net.Addr
|
||||
eyeball atomic.Pointer[DatagramConn]
|
||||
writeChan chan []byte
|
||||
// activeAtChan is used to communicate the last read/write time
|
||||
activeAtChan chan time.Time
|
||||
closeChan chan error
|
||||
contextChan chan context.Context
|
||||
metrics Metrics
|
||||
log *zerolog.Logger
|
||||
errChan chan error
|
||||
// The close channel signal only exists for the write loop because the read loop is always waiting on a read
|
||||
// from the UDP socket to the origin. To close the read loop we close the socket.
|
||||
// Additionally, we can't close the writeChan to indicate that writes are complete because the producer (edge)
|
||||
// side may still be trying to write to this session.
|
||||
closeWrite chan struct{}
|
||||
contextChan chan context.Context
|
||||
metrics Metrics
|
||||
log *zerolog.Logger
|
||||
|
||||
// A special close function that we wrap with sync.Once to make sure it is only called once
|
||||
closeFn func() error
|
||||
|
|
@ -89,10 +102,12 @@ func NewSession(
|
|||
log *zerolog.Logger,
|
||||
) Session {
|
||||
logger := log.With().Str(logFlowID, id.String()).Logger()
|
||||
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
|
||||
// write to the channel without blocking since there is only ever one value read from the closeChan by the
|
||||
writeChan := make(chan []byte, writeChanCapacity)
|
||||
// errChan has three slots to allow for all writers (the closeFn, the read loop and the write loop) to
|
||||
// write to the channel without blocking since there is only ever one value read from the errChan by the
|
||||
// waitForCloseCondition.
|
||||
closeChan := make(chan error, 2)
|
||||
errChan := make(chan error, 3)
|
||||
closeWrite := make(chan struct{})
|
||||
session := &session{
|
||||
id: id,
|
||||
closeAfterIdle: closeAfterIdle,
|
||||
|
|
@ -100,10 +115,12 @@ func NewSession(
|
|||
originAddr: originAddr,
|
||||
localAddr: localAddr,
|
||||
eyeball: atomic.Pointer[DatagramConn]{},
|
||||
writeChan: writeChan,
|
||||
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
||||
// drop instead of blocking because last active time only needs to be an approximation
|
||||
activeAtChan: make(chan time.Time, 1),
|
||||
closeChan: closeChan,
|
||||
errChan: errChan,
|
||||
closeWrite: closeWrite,
|
||||
// contextChan is an unbounded channel to help enforce one active migration of a session at a time.
|
||||
contextChan: make(chan context.Context),
|
||||
metrics: metrics,
|
||||
|
|
@ -111,9 +128,12 @@ func NewSession(
|
|||
closeFn: sync.OnceValue(func() error {
|
||||
// We don't want to block on sending to the close channel if it is already full
|
||||
select {
|
||||
case closeChan <- SessionCloseErr:
|
||||
case errChan <- SessionCloseErr:
|
||||
default:
|
||||
}
|
||||
// Indicate to the write loop that the session is now closed
|
||||
close(closeWrite)
|
||||
// Close the socket directly to unblock the read loop and cause it to also end
|
||||
return origin.Close()
|
||||
}),
|
||||
}
|
||||
|
|
@ -154,66 +174,112 @@ func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zer
|
|||
}
|
||||
|
||||
func (s *session) Serve(ctx context.Context) error {
|
||||
go func() {
|
||||
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||
// This makes it safe to share readBuffer between iterations
|
||||
readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
|
||||
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
|
||||
// the required datagram header information. We can reuse this buffer for this session since the header is the
|
||||
// same for the each read.
|
||||
_ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
|
||||
for {
|
||||
// Read from the origin UDP socket
|
||||
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
s.log.Debug().Msgf("flow (origin) connection closed: %v", err)
|
||||
}
|
||||
s.closeChan <- err
|
||||
return
|
||||
}
|
||||
if n < 0 {
|
||||
s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped")
|
||||
continue
|
||||
}
|
||||
if n > maxDatagramPayloadLen {
|
||||
connectionIndex := s.ConnectionID()
|
||||
s.metrics.PayloadTooLarge(connectionIndex)
|
||||
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
|
||||
continue
|
||||
}
|
||||
// We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point
|
||||
// of lock contention, as a migration can only happen during startup of a session before traffic flow.
|
||||
eyeball := *(s.eyeball.Load())
|
||||
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
|
||||
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
|
||||
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
|
||||
if err != nil {
|
||||
s.closeChan <- err
|
||||
return
|
||||
}
|
||||
// Mark the session as active since we proxied a valid packet from the origin.
|
||||
s.markActive()
|
||||
}
|
||||
}()
|
||||
go s.writeLoop()
|
||||
go s.readLoop()
|
||||
return s.waitForCloseCondition(ctx, s.closeAfterIdle)
|
||||
}
|
||||
|
||||
func (s *session) Write(payload []byte) (n int, err error) {
|
||||
n, err = s.origin.Write(payload)
|
||||
if err != nil {
|
||||
s.log.Err(err).Msg("failed to write payload to flow (remote)")
|
||||
return n, err
|
||||
// Read datagrams from the origin and write them to the connection.
|
||||
func (s *session) readLoop() {
|
||||
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||
// This makes it safe to share readBuffer between iterations
|
||||
readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{}
|
||||
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
|
||||
// the required datagram header information. We can reuse this buffer for this session since the header is the
|
||||
// same for the each read.
|
||||
_ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
|
||||
for {
|
||||
// Read from the origin UDP socket
|
||||
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
|
||||
if err != nil {
|
||||
if isConnectionClosed(err) {
|
||||
s.log.Debug().Msgf("flow (read) connection closed: %v", err)
|
||||
}
|
||||
s.closeSession(err)
|
||||
return
|
||||
}
|
||||
if n < 0 {
|
||||
s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped")
|
||||
continue
|
||||
}
|
||||
if n > maxDatagramPayloadLen {
|
||||
connectionIndex := s.ConnectionID()
|
||||
s.metrics.PayloadTooLarge(connectionIndex)
|
||||
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
|
||||
continue
|
||||
}
|
||||
// We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point
|
||||
// of lock contention, as a migration can only happen during startup of a session before traffic flow.
|
||||
eyeball := *(s.eyeball.Load())
|
||||
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
|
||||
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
|
||||
err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n])
|
||||
if err != nil {
|
||||
s.closeSession(err)
|
||||
return
|
||||
}
|
||||
// Mark the session as active since we proxied a valid packet from the origin.
|
||||
s.markActive()
|
||||
}
|
||||
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
|
||||
if n < len(payload) {
|
||||
s.log.Err(io.ErrShortWrite).Msg("failed to write the full payload to flow (remote)")
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
|
||||
func (s *session) Write(payload []byte) {
|
||||
select {
|
||||
case s.writeChan <- payload:
|
||||
default:
|
||||
s.log.Error().Msg("failed to write flow payload to origin: dropped")
|
||||
}
|
||||
}
|
||||
|
||||
// Read datagrams from the write channel to the origin.
|
||||
func (s *session) writeLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.closeWrite:
|
||||
// When the closeWrite channel is closed, we will no longer write to the origin and end this
|
||||
// goroutine since the session is now closed.
|
||||
return
|
||||
case payload := <-s.writeChan:
|
||||
n, err := s.origin.Write(payload)
|
||||
if err != nil {
|
||||
// Check if this is a write deadline exceeded to the connection
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
s.log.Warn().Err(err).Msg("flow (write) deadline exceeded: dropping packet")
|
||||
continue
|
||||
}
|
||||
if isConnectionClosed(err) {
|
||||
s.log.Debug().Msgf("flow (write) connection closed: %v", err)
|
||||
}
|
||||
s.log.Err(err).Msg("failed to write flow payload to origin")
|
||||
s.closeSession(err)
|
||||
// If we fail to write to the origin socket, we need to end the writer and close the session
|
||||
return
|
||||
}
|
||||
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
|
||||
if n < len(payload) {
|
||||
s.log.Err(io.ErrShortWrite).Msg("failed to write the full flow payload to origin")
|
||||
continue
|
||||
}
|
||||
// Mark the session as active since we successfully proxied a packet to the origin.
|
||||
s.markActive()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isConnectionClosed(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
// Send an error to the error channel to report that an error has either happened on the tunnel or origin side of the
|
||||
// proxied connection.
|
||||
func (s *session) closeSession(err error) {
|
||||
select {
|
||||
case s.errChan <- err:
|
||||
default:
|
||||
// In the case that the errChan is already full, we will skip over it and return as to not block
|
||||
// the caller because we should start cleaning up the session.
|
||||
s.log.Warn().Msg("error channel was full")
|
||||
}
|
||||
// Mark the session as active since we proxied a packet to the origin.
|
||||
s.markActive()
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ResetIdleTimer will restart the current idle timer.
|
||||
|
|
@ -240,7 +306,8 @@ func (s *session) Close() error {
|
|||
|
||||
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
|
||||
connCtx := ctx
|
||||
// Closing the session at the end cancels read so Serve() can return
|
||||
// Closing the session at the end cancels read so Serve() can return, additionally, it closes the
|
||||
// closeWrite channel which indicates to the write loop to return.
|
||||
defer s.Close()
|
||||
if closeAfterIdle == 0 {
|
||||
// Provided that the default caller doesn't specify one
|
||||
|
|
@ -260,7 +327,9 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
|
|||
// still be active on the existing connection.
|
||||
connCtx = newContext
|
||||
continue
|
||||
case reason := <-s.closeChan:
|
||||
case reason := <-s.errChan:
|
||||
// Any error returned here is from the read or write loops indicating that it can no longer process datagrams
|
||||
// and as such the session needs to close.
|
||||
return reason
|
||||
case <-checkIdleTimer.C:
|
||||
// The check idle timer will only return after an idle period since the last active
|
||||
|
|
|
|||
|
|
@ -4,20 +4,24 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin.
|
||||
// FuzzSessionWrite verifies that we don't run into any panics when writing a single variable sized payload to the origin.
|
||||
func FuzzSessionWrite(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
testSessionWrite(t, b)
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin.
|
||||
func FuzzSessionServe(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
// The origin transport read is bound to 1280 bytes
|
||||
if len(b) > 1280 {
|
||||
b = b[:1280]
|
||||
}
|
||||
testSessionServe_Origin(t, b)
|
||||
testSessionWrite(t, [][]byte{b})
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzSessionRead verifies that we don't run into any panics when reading a single variable sized payload from the origin.
|
||||
func FuzzSessionRead(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
// The origin transport read is bound to 1280 bytes
|
||||
if len(b) > 1280 {
|
||||
b = b[:1280]
|
||||
}
|
||||
testSessionRead(t, [][]byte{b})
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,60 +31,61 @@ func TestSessionNew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testSessionWrite(t *testing.T, payload []byte) {
|
||||
func testSessionWrite(t *testing.T, payloads [][]byte) {
|
||||
log := zerolog.Nop()
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
// Start origin server read
|
||||
serverRead := make(chan []byte, 1)
|
||||
// Start origin server reads
|
||||
serverRead := make(chan []byte, len(payloads))
|
||||
go func() {
|
||||
read := make([]byte, 1500)
|
||||
_, _ = server.Read(read[:])
|
||||
serverRead <- read
|
||||
for range len(payloads) {
|
||||
buf := make([]byte, 1500)
|
||||
_, _ = server.Read(buf[:])
|
||||
serverRead <- buf
|
||||
}
|
||||
close(serverRead)
|
||||
}()
|
||||
// Create session and write to origin
|
||||
|
||||
// Create a session
|
||||
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
n, err := session.Write(payload)
|
||||
defer session.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatal("unable to write the whole payload")
|
||||
// Start the Serve to begin the writeLoop
|
||||
ctx, cancel := context.WithCancelCause(t.Context())
|
||||
defer cancel(context.Canceled)
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- session.Serve(ctx)
|
||||
}()
|
||||
// Write the payloads to the session
|
||||
for _, payload := range payloads {
|
||||
session.Write(payload)
|
||||
}
|
||||
|
||||
read := <-serverRead
|
||||
if !slices.Equal(payload, read[:len(payload)]) {
|
||||
t.Fatal("payload provided from origin and read value are not the same")
|
||||
// Read from the origin to ensure the payloads were received (in-order)
|
||||
for i, payload := range payloads {
|
||||
read := <-serverRead
|
||||
if !slices.Equal(payload, read[:len(payload)]) {
|
||||
t.Fatalf("payload[%d] provided from origin and read value are not the same (%x) and (%x)", i, payload[:16], read[:16])
|
||||
}
|
||||
}
|
||||
_, more := <-serverRead
|
||||
if more {
|
||||
t.Fatalf("expected the session to have all of the origin payloads received: %d", len(serverRead))
|
||||
}
|
||||
|
||||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
func TestSessionWrite(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
for i := range 1280 {
|
||||
payloads := makePayloads(i, 16)
|
||||
testSessionWrite(t, payloads)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionWrite_Max(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(1280)
|
||||
testSessionWrite(t, payload)
|
||||
}
|
||||
|
||||
func TestSessionWrite_Min(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(0)
|
||||
testSessionWrite(t, payload)
|
||||
}
|
||||
|
||||
func TestSessionServe_OriginMax(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(1280)
|
||||
testSessionServe_Origin(t, payload)
|
||||
}
|
||||
|
||||
func TestSessionServe_OriginMin(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(0)
|
||||
testSessionServe_Origin(t, payload)
|
||||
}
|
||||
|
||||
func testSessionServe_Origin(t *testing.T, payload []byte) {
|
||||
func testSessionRead(t *testing.T, payloads [][]byte) {
|
||||
log := zerolog.Nop()
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
|
|
@ -100,37 +101,42 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
|||
done <- session.Serve(ctx)
|
||||
}()
|
||||
|
||||
// Write from the origin server
|
||||
_, err := server.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-eyeball.recvData:
|
||||
// check received data matches provided from origin
|
||||
expectedData := makePayload(1500)
|
||||
_ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
|
||||
copy(expectedData[17:], payload)
|
||||
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
|
||||
t.Fatal("expected datagram did not equal expected")
|
||||
// Write from the origin server to the eyeball
|
||||
go func() {
|
||||
for _, payload := range payloads {
|
||||
_, _ = server.Write(payload)
|
||||
}
|
||||
}()
|
||||
|
||||
// Read from the eyeball to ensure the payloads were received (in-order)
|
||||
for i, payload := range payloads {
|
||||
select {
|
||||
case data := <-eyeball.recvData:
|
||||
// check received data matches provided from origin
|
||||
expectedData := makePayload(1500)
|
||||
_ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
|
||||
copy(expectedData[17:], payload)
|
||||
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
|
||||
t.Fatalf("expected datagram[%d] did not equal expected", i)
|
||||
}
|
||||
case err := <-ctx.Done():
|
||||
// we expect the payload to return before the context to cancel on the session
|
||||
t.Fatal(err)
|
||||
}
|
||||
cancel(errExpectedContextCanceled)
|
||||
case err := <-ctx.Done():
|
||||
// we expect the payload to return before the context to cancel on the session
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = <-done
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) {
|
||||
t.Fatal(err)
|
||||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
func TestSessionRead(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
for i := range 1280 {
|
||||
payloads := makePayloads(i, 16)
|
||||
testSessionRead(t, payloads)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionServe_OriginTooLarge(t *testing.T) {
|
||||
func TestSessionRead_OriginTooLarge(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
eyeball := newMockEyeball()
|
||||
|
|
@ -317,6 +323,8 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
|
|||
closeAfterIdle := 2 * time.Second
|
||||
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
err := session.Serve(t.Context())
|
||||
|
||||
// Session should idle timeout if no reads or writes occur
|
||||
if !errors.Is(err, v3.SessionIdleErr{}) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue