TUN-9322: Add metric for unsupported RPC commands for datagram v3

Additionally adds support for the connection index as a label for the
datagram v3 specific tunnel metrics.

Closes TUN-9322
This commit is contained in:
Devin Carr 2025-05-13 16:11:09 +00:00
parent ce27840573
commit 02705c44b2
10 changed files with 133 additions and 96 deletions

View File

@ -2,11 +2,11 @@ package connection
import ( import (
"context" "context"
"fmt"
"net" "net"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -16,10 +16,17 @@ import (
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
var (
ErrUnsupportedRPCUDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC")
ErrUnsupportedRPCUDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC")
)
type datagramV3Connection struct { type datagramV3Connection struct {
conn quic.Connection conn quic.Connection
index uint8
// datagramMuxer mux/demux datagrams from quic connection // datagramMuxer mux/demux datagrams from quic connection
datagramMuxer cfdquic.DatagramConn datagramMuxer cfdquic.DatagramConn
metrics cfdquic.Metrics
logger *zerolog.Logger logger *zerolog.Logger
} }
@ -40,7 +47,9 @@ func NewDatagramV3Connection(ctx context.Context,
return &datagramV3Connection{ return &datagramV3Connection{
conn, conn,
index,
datagramMuxer, datagramMuxer,
metrics,
logger, logger,
} }
} }
@ -50,9 +59,11 @@ func (d *datagramV3Connection) Serve(ctx context.Context) error {
} }
func (d *datagramV3Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) { func (d *datagramV3Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
return nil, fmt.Errorf("datagram v3 does not support RegisterUdpSession RPC") d.metrics.UnsupportedRemoteCommand(d.index, "register_udp_session")
return nil, ErrUnsupportedRPCUDPRegistration
} }
func (d *datagramV3Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { func (d *datagramV3Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
return fmt.Errorf("datagram v3 does not support UnregisterUdpSession RPC") d.metrics.UnsupportedRemoteCommand(d.index, "unregister_udp_session")
return ErrUnsupportedRPCUDPUnregistration
} }

View File

@ -84,7 +84,7 @@ func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
// Closing dstConn cancels read so dstToTransport routine in Serve() can return // Closing dstConn cancels read so dstToTransport routine in Serve() can return
defer s.dstConn.Close() defer s.dstConn.Close()
if closeAfterIdle == 0 { if closeAfterIdle == 0 {
// provide deafult is caller doesn't specify one // provide default is caller doesn't specify one
closeAfterIdle = defaultCloseIdleAfter closeAfterIdle = defaultCloseIdleAfter
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -54,22 +55,22 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
closedByRemote, err := session.Serve(ctx, closeAfterIdle) closedByRemote, err := session.Serve(ctx, closeAfterIdle)
switch closeBy { switch closeBy {
case closeByContext: case closeByContext:
require.Equal(t, context.Canceled, err) assert.Equal(t, context.Canceled, err)
require.False(t, closedByRemote) assert.False(t, closedByRemote)
case closeByCallingClose: case closeByCallingClose:
require.Equal(t, localCloseReason, err) assert.Equal(t, localCloseReason, err)
require.Equal(t, localCloseReason.byRemote, closedByRemote) assert.Equal(t, localCloseReason.byRemote, closedByRemote)
case closeByTimeout: case closeByTimeout:
require.Equal(t, SessionIdleErr(closeAfterIdle), err) assert.Equal(t, SessionIdleErr(closeAfterIdle), err)
require.False(t, closedByRemote) assert.False(t, closedByRemote)
} }
close(sessionDone) close(sessionDone)
}() }()
go func() { go func() {
n, err := session.transportToDst(payload) n, err := session.transportToDst(payload)
require.NoError(t, err) assert.NoError(t, err)
require.Equal(t, len(payload), n) assert.Equal(t, len(payload), n)
}() }()
readBuffer := make([]byte, len(payload)+1) readBuffer := make([]byte, len(payload)+1)
@ -84,6 +85,8 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
cancel() cancel()
case closeByCallingClose: case closeByCallingClose:
session.close(localCloseReason) session.close(localCloseReason)
default:
// ignore
} }
<-sessionDone <-sessionDone
@ -128,7 +131,7 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
session.Serve(ctx, closeAfterIdle) _, _ = session.Serve(ctx, closeAfterIdle)
if time.Now().Before(startTime.Add(activeTime)) { if time.Now().Before(startTime.Add(activeTime)) {
return fmt.Errorf("session closed while it's still active") return fmt.Errorf("session closed while it's still active")
} }

View File

@ -12,10 +12,13 @@ import (
const ( const (
namespace = "quic" namespace = "quic"
ConnectionIndexMetricLabel = "conn_index"
frameTypeMetricLabel = "frame_type"
packetTypeMetricLabel = "packet_type"
reasonMetricLabel = "reason"
) )
var ( var (
clientConnLabels = []string{"conn_index"}
clientMetrics = struct { clientMetrics = struct {
totalConnections prometheus.Counter totalConnections prometheus.Counter
closedConnections prometheus.Counter closedConnections prometheus.Counter
@ -35,7 +38,7 @@ var (
congestionState *prometheus.GaugeVec congestionState *prometheus.GaugeVec
}{ }{
totalConnections: prometheus.NewCounter( totalConnections: prometheus.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "total_connections", Name: "total_connections",
@ -43,7 +46,7 @@ var (
}, },
), ),
closedConnections: prometheus.NewCounter( closedConnections: prometheus.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "closed_connections", Name: "closed_connections",
@ -57,70 +60,70 @@ var (
Name: "max_udp_payload", Name: "max_udp_payload",
Help: "Maximum UDP payload size in bytes for a QUIC packet", Help: "Maximum UDP payload size in bytes for a QUIC packet",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
sentFrames: prometheus.NewCounterVec( sentFrames: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "sent_frames", Name: "sent_frames",
Help: "Number of frames that have been sent through a connection", Help: "Number of frames that have been sent through a connection",
}, },
append(clientConnLabels, "frame_type"), []string{ConnectionIndexMetricLabel, frameTypeMetricLabel},
), ),
sentBytes: prometheus.NewCounterVec( sentBytes: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "sent_bytes", Name: "sent_bytes",
Help: "Number of bytes that have been sent through a connection", Help: "Number of bytes that have been sent through a connection",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
receivedFrames: prometheus.NewCounterVec( receivedFrames: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "received_frames", Name: "received_frames",
Help: "Number of frames that have been received through a connection", Help: "Number of frames that have been received through a connection",
}, },
append(clientConnLabels, "frame_type"), []string{ConnectionIndexMetricLabel, frameTypeMetricLabel},
), ),
receivedBytes: prometheus.NewCounterVec( receivedBytes: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "receive_bytes", Name: "receive_bytes",
Help: "Number of bytes that have been received through a connection", Help: "Number of bytes that have been received through a connection",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
bufferedPackets: prometheus.NewCounterVec( bufferedPackets: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "buffered_packets", Name: "buffered_packets",
Help: "Number of bytes that have been buffered on a connection", Help: "Number of bytes that have been buffered on a connection",
}, },
append(clientConnLabels, "packet_type"), []string{ConnectionIndexMetricLabel, packetTypeMetricLabel},
), ),
droppedPackets: prometheus.NewCounterVec( droppedPackets: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "dropped_packets", Name: "dropped_packets",
Help: "Number of bytes that have been dropped on a connection", Help: "Number of bytes that have been dropped on a connection",
}, },
append(clientConnLabels, "packet_type", "reason"), []string{ConnectionIndexMetricLabel, packetTypeMetricLabel, reasonMetricLabel},
), ),
lostPackets: prometheus.NewCounterVec( lostPackets: prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "lost_packets", Name: "lost_packets",
Help: "Number of packets that have been lost from a connection", Help: "Number of packets that have been lost from a connection",
}, },
append(clientConnLabels, "reason"), []string{ConnectionIndexMetricLabel, reasonMetricLabel},
), ),
minRTT: prometheus.NewGaugeVec( minRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -129,7 +132,7 @@ var (
Name: "min_rtt", Name: "min_rtt",
Help: "Lowest RTT measured on a connection in millisec", Help: "Lowest RTT measured on a connection in millisec",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
latestRTT: prometheus.NewGaugeVec( latestRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -138,7 +141,7 @@ var (
Name: "latest_rtt", Name: "latest_rtt",
Help: "Latest RTT measured on a connection", Help: "Latest RTT measured on a connection",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
smoothedRTT: prometheus.NewGaugeVec( smoothedRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -147,7 +150,7 @@ var (
Name: "smoothed_rtt", Name: "smoothed_rtt",
Help: "Calculated smoothed RTT measured on a connection in millisec", Help: "Calculated smoothed RTT measured on a connection in millisec",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
mtu: prometheus.NewGaugeVec( mtu: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -156,7 +159,7 @@ var (
Name: "mtu", Name: "mtu",
Help: "Current maximum transmission unit (MTU) of a connection", Help: "Current maximum transmission unit (MTU) of a connection",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
congestionWindow: prometheus.NewGaugeVec( congestionWindow: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -165,7 +168,7 @@ var (
Name: "congestion_window", Name: "congestion_window",
Help: "Current congestion window size", Help: "Current congestion window size",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
congestionState: prometheus.NewGaugeVec( congestionState: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
@ -174,13 +177,13 @@ var (
Name: "congestion_state", Name: "congestion_state",
Help: "Current congestion control state. See https://pkg.go.dev/github.com/quic-go/quic-go@v0.45.0/logging#CongestionState for what each value maps to", Help: "Current congestion control state. See https://pkg.go.dev/github.com/quic-go/quic-go@v0.45.0/logging#CongestionState for what each value maps to",
}, },
clientConnLabels, []string{ConnectionIndexMetricLabel},
), ),
} }
registerClient = sync.Once{} registerClient = sync.Once{}
packetTooBigDropped = prometheus.NewCounter(prometheus.CounterOpts{ packetTooBigDropped = prometheus.NewCounter(prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "packet_too_big_dropped", Name: "packet_too_big_dropped",

View File

@ -2,82 +2,98 @@ package v3
import ( import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/quic"
) )
const ( const (
namespace = "cloudflared" namespace = "cloudflared"
subsystem = "udp" subsystem = "udp"
commandMetricLabel = "command"
) )
type Metrics interface { type Metrics interface {
IncrementFlows() IncrementFlows(connIndex uint8)
DecrementFlows() DecrementFlows(connIndex uint8)
PayloadTooLarge() PayloadTooLarge(connIndex uint8)
RetryFlowResponse() RetryFlowResponse(connIndex uint8)
MigrateFlow() MigrateFlow(connIndex uint8)
UnsupportedRemoteCommand(connIndex uint8, command string)
} }
type metrics struct { type metrics struct {
activeUDPFlows prometheus.Gauge activeUDPFlows *prometheus.GaugeVec
totalUDPFlows prometheus.Counter totalUDPFlows *prometheus.CounterVec
payloadTooLarge prometheus.Counter payloadTooLarge *prometheus.CounterVec
retryFlowResponses prometheus.Counter retryFlowResponses *prometheus.CounterVec
migratedFlows prometheus.Counter migratedFlows *prometheus.CounterVec
unsupportedRemoteCommands *prometheus.CounterVec
} }
func (m *metrics) IncrementFlows() { func (m *metrics) IncrementFlows(connIndex uint8) {
m.totalUDPFlows.Inc() m.totalUDPFlows.WithLabelValues(string(connIndex)).Inc()
m.activeUDPFlows.Inc() m.activeUDPFlows.WithLabelValues(string(connIndex)).Inc()
} }
func (m *metrics) DecrementFlows() { func (m *metrics) DecrementFlows(connIndex uint8) {
m.activeUDPFlows.Dec() m.activeUDPFlows.WithLabelValues(string(connIndex)).Dec()
} }
func (m *metrics) PayloadTooLarge() { func (m *metrics) PayloadTooLarge(connIndex uint8) {
m.payloadTooLarge.Inc() m.payloadTooLarge.WithLabelValues(string(connIndex)).Inc()
} }
func (m *metrics) RetryFlowResponse() { func (m *metrics) RetryFlowResponse(connIndex uint8) {
m.retryFlowResponses.Inc() m.retryFlowResponses.WithLabelValues(string(connIndex)).Inc()
} }
func (m *metrics) MigrateFlow() { func (m *metrics) MigrateFlow(connIndex uint8) {
m.migratedFlows.Inc() m.migratedFlows.WithLabelValues(string(connIndex)).Inc()
}
func (m *metrics) UnsupportedRemoteCommand(connIndex uint8, command string) {
m.unsupportedRemoteCommands.WithLabelValues(string(connIndex), command).Inc()
} }
func NewMetrics(registerer prometheus.Registerer) Metrics { func NewMetrics(registerer prometheus.Registerer) Metrics {
m := &metrics{ m := &metrics{
activeUDPFlows: prometheus.NewGauge(prometheus.GaugeOpts{ activeUDPFlows: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: subsystem, Subsystem: subsystem,
Name: "active_flows", Name: "active_flows",
Help: "Concurrent count of UDP flows that are being proxied to any origin", Help: "Concurrent count of UDP flows that are being proxied to any origin",
}), }, []string{quic.ConnectionIndexMetricLabel}),
totalUDPFlows: prometheus.NewCounter(prometheus.CounterOpts{ totalUDPFlows: prometheus.NewCounterVec(prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: subsystem, Subsystem: subsystem,
Name: "total_flows", Name: "total_flows",
Help: "Total count of UDP flows that have been proxied to any origin", Help: "Total count of UDP flows that have been proxied to any origin",
}), }, []string{quic.ConnectionIndexMetricLabel}),
payloadTooLarge: prometheus.NewCounter(prometheus.CounterOpts{ payloadTooLarge: prometheus.NewCounterVec(prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: subsystem, Subsystem: subsystem,
Name: "payload_too_large", Name: "payload_too_large",
Help: "Total count of UDP flows that have had origin payloads that are too large to proxy", Help: "Total count of UDP flows that have had origin payloads that are too large to proxy",
}), }, []string{quic.ConnectionIndexMetricLabel}),
retryFlowResponses: prometheus.NewCounter(prometheus.CounterOpts{ retryFlowResponses: prometheus.NewCounterVec(prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: subsystem, Subsystem: subsystem,
Name: "retry_flow_responses", Name: "retry_flow_responses",
Help: "Total count of UDP flows that have had to send their registration response more than once", Help: "Total count of UDP flows that have had to send their registration response more than once",
}), }, []string{quic.ConnectionIndexMetricLabel}),
migratedFlows: prometheus.NewCounter(prometheus.CounterOpts{ migratedFlows: prometheus.NewCounterVec(prometheus.CounterOpts{ //nolint:promlinter
Namespace: namespace, Namespace: namespace,
Subsystem: subsystem, Subsystem: subsystem,
Name: "migrated_flows", Name: "migrated_flows",
Help: "Total count of UDP flows have been migrated across local connections", Help: "Total count of UDP flows have been migrated across local connections",
}), }, []string{quic.ConnectionIndexMetricLabel}),
unsupportedRemoteCommands: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "unsupported_remote_command_total",
Help: "Total count of unsupported remote RPC commands for the ",
}, []string{quic.ConnectionIndexMetricLabel, commandMetricLabel}),
} }
registerer.MustRegister( registerer.MustRegister(
m.activeUDPFlows, m.activeUDPFlows,
@ -85,6 +101,7 @@ func NewMetrics(registerer prometheus.Registerer) Metrics {
m.payloadTooLarge, m.payloadTooLarge,
m.retryFlowResponses, m.retryFlowResponses,
m.migratedFlows, m.migratedFlows,
m.unsupportedRemoteCommands,
) )
return m return m
} }

View File

@ -2,8 +2,9 @@ package v3_test
type noopMetrics struct{} type noopMetrics struct{}
func (noopMetrics) IncrementFlows() {} func (noopMetrics) IncrementFlows(connIndex uint8) {}
func (noopMetrics) DecrementFlows() {} func (noopMetrics) DecrementFlows(connIndex uint8) {}
func (noopMetrics) PayloadTooLarge() {} func (noopMetrics) PayloadTooLarge(connIndex uint8) {}
func (noopMetrics) RetryFlowResponse() {} func (noopMetrics) RetryFlowResponse(connIndex uint8) {}
func (noopMetrics) MigrateFlow() {} func (noopMetrics) MigrateFlow(connIndex uint8) {}
func (noopMetrics) UnsupportedRemoteCommand(connIndex uint8, command string) {}

View File

@ -264,10 +264,10 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
return return
} }
log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger() log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger()
c.metrics.IncrementFlows() c.metrics.IncrementFlows(c.index)
// Make sure to eventually remove the session from the session manager when the session is closed // Make sure to eventually remove the session from the session manager when the session is closed
defer c.sessionManager.UnregisterSession(session.ID()) defer c.sessionManager.UnregisterSession(session.ID())
defer c.metrics.DecrementFlows() defer c.metrics.DecrementFlows(c.index)
// Respond that we are able to process the new session // Respond that we are able to process the new session
err = c.SendUDPSessionResponse(datagram.RequestID, ResponseOk) err = c.SendUDPSessionResponse(datagram.RequestID, ResponseOk)
@ -315,7 +315,7 @@ func (c *datagramConn) handleSessionAlreadyRegistered(requestID RequestID, logge
// The session is already running in another routine so we want to restart the idle timeout since no proxied // The session is already running in another routine so we want to restart the idle timeout since no proxied
// packets have come down yet. // packets have come down yet.
session.ResetIdleTimer() session.ResetIdleTimer()
c.metrics.RetryFlowResponse() c.metrics.RetryFlowResponse(c.index)
logger.Debug().Msgf("flow registration response retry") logger.Debug().Msgf("flow registration response retry")
} }

View File

@ -781,12 +781,12 @@ func newICMPDatagram(pk *packet.ICMP) []byte {
// Cancel the provided context and make sure it closes with the expected cancellation error // Cancel the provided context and make sure it closes with the expected cancellation error
func assertContextClosed(t *testing.T, ctx context.Context, done <-chan error, cancel context.CancelCauseFunc) { func assertContextClosed(t *testing.T, ctx context.Context, done <-chan error, cancel context.CancelCauseFunc) {
cancel(expectedContextCanceled) cancel(errExpectedContextCanceled)
err := <-done err := <-done
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
t.Fatal(err) t.Fatal(err)
} }
if !errors.Is(context.Cause(ctx), expectedContextCanceled) { if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -27,11 +27,11 @@ const (
) )
// SessionCloseErr indicates that the session's Close method was called. // SessionCloseErr indicates that the session's Close method was called.
var SessionCloseErr error = errors.New("flow was closed directly") var SessionCloseErr error = errors.New("flow was closed directly") //nolint:errname
// SessionIdleErr is returned when the session was closed because there was no communication // SessionIdleErr is returned when the session was closed because there was no communication
// in either direction over the session for the timeout period. // in either direction over the session for the timeout period.
type SessionIdleErr struct { type SessionIdleErr struct { //nolint:errname
timeout time.Duration timeout time.Duration
} }
@ -149,7 +149,8 @@ func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zer
} }
// The session is already running so we want to restart the idle timeout since no proxied packets have come down yet. // The session is already running so we want to restart the idle timeout since no proxied packets have come down yet.
s.markActive() s.markActive()
s.metrics.MigrateFlow() connectionIndex := eyeball.ID()
s.metrics.MigrateFlow(connectionIndex)
} }
func (s *session) Serve(ctx context.Context) error { func (s *session) Serve(ctx context.Context) error {
@ -160,7 +161,7 @@ func (s *session) Serve(ctx context.Context) error {
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with // To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
// the required datagram header information. We can reuse this buffer for this session since the header is the // the required datagram header information. We can reuse this buffer for this session since the header is the
// same for the each read. // same for the each read.
MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen]) _ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen])
for { for {
// Read from the origin UDP socket // Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:]) n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:])
@ -177,7 +178,8 @@ func (s *session) Serve(ctx context.Context) error {
continue continue
} }
if n > maxDatagramPayloadLen { if n > maxDatagramPayloadLen {
s.metrics.PayloadTooLarge() connectionIndex := s.ConnectionID()
s.metrics.PayloadTooLarge(connectionIndex)
s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped") s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped")
continue continue
} }
@ -241,7 +243,7 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time
// Closing the session at the end cancels read so Serve() can return // Closing the session at the end cancels read so Serve() can return
defer s.Close() defer s.Close()
if closeAfterIdle == 0 { if closeAfterIdle == 0 {
// provide deafult is caller doesn't specify one // Provided that the default caller doesn't specify one
closeAfterIdle = defaultCloseIdleAfter closeAfterIdle = defaultCloseIdleAfter
} }

View File

@ -17,7 +17,7 @@ import (
) )
var ( var (
expectedContextCanceled = errors.New("expected context canceled") errExpectedContextCanceled = errors.New("expected context canceled")
testOriginAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0")) testOriginAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0"))
testLocalAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0")) testLocalAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0"))
@ -40,7 +40,7 @@ func testSessionWrite(t *testing.T, payload []byte) {
serverRead := make(chan []byte, 1) serverRead := make(chan []byte, 1)
go func() { go func() {
read := make([]byte, 1500) read := make([]byte, 1500)
server.Read(read[:]) _, _ = server.Read(read[:])
serverRead <- read serverRead <- read
}() }()
// Create session and write to origin // Create session and write to origin
@ -110,12 +110,12 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
case data := <-eyeball.recvData: case data := <-eyeball.recvData:
// check received data matches provided from origin // check received data matches provided from origin
expectedData := makePayload(1500) expectedData := makePayload(1500)
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) _ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload) copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatal("expected datagram did not equal expected") t.Fatal("expected datagram did not equal expected")
} }
cancel(expectedContextCanceled) cancel(errExpectedContextCanceled)
case err := <-ctx.Done(): case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session // we expect the payload to return before the context to cancel on the session
t.Fatal(err) t.Fatal(err)
@ -125,7 +125,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
t.Fatal(err) t.Fatal(err)
} }
if !errors.Is(context.Cause(ctx), expectedContextCanceled) { if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) {
t.Fatal(err) t.Fatal(err)
} }
} }
@ -198,7 +198,7 @@ func TestSessionServe_Migrate(t *testing.T) {
// Origin sends data // Origin sends data
payload2 := []byte{0xde} payload2 := []byte{0xde}
pipe1.Write(payload2) _, _ = pipe1.Write(payload2)
// Expect write to eyeball2 // Expect write to eyeball2
data := <-eyeball2.recvData data := <-eyeball2.recvData
@ -249,13 +249,13 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
t.Fatalf("expected session to still be running") t.Fatalf("expected session to still be running")
default: default:
} }
if context.Cause(eyeball1Ctx) != contextCancelErr { if !errors.Is(context.Cause(eyeball1Ctx), contextCancelErr) {
t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx)) t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx))
} }
// Origin sends data // Origin sends data
payload2 := []byte{0xde} payload2 := []byte{0xde}
pipe1.Write(payload2) _, _ = pipe1.Write(payload2)
// Expect write to eyeball2 // Expect write to eyeball2
data := <-eyeball2.recvData data := <-eyeball2.recvData