diff --git a/Makefile b/Makefile index d92a8707..7eaccec7 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ ifdef PACKAGE_MANAGER VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)" endif -ifdef CONTAINER_BUILD +ifdef CONTAINER_BUILD VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual" endif @@ -119,7 +119,7 @@ ifneq ($(TARGET_ARM), ) ARM_COMMAND := GOARM=$(TARGET_ARM) endif -ifeq ($(TARGET_ARM), 7) +ifeq ($(TARGET_ARM), 7) PACKAGE_ARCH := armhf else PACKAGE_ARCH := $(TARGET_ARCH) @@ -182,7 +182,7 @@ fuzz: @go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet @go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet @go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3 - @go test -fuzz=FuzzSessionServe -fuzztime=600s ./quic/v3 + @go test -fuzz=FuzzSessionRead -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3 @go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3 diff --git a/quic/v3/datagram_test.go b/quic/v3/datagram_test.go index 834c4ae4..2eeddb83 100644 --- a/quic/v3/datagram_test.go +++ b/quic/v3/datagram_test.go @@ -1,6 +1,7 @@ package v3_test import ( + "crypto/rand" "encoding/binary" "errors" "net/netip" @@ -14,12 +15,18 @@ import ( func makePayload(size int) []byte { payload := make([]byte, size) - for i := range len(payload) { - payload[i] = 0xfc - } + _, _ = rand.Read(payload) return payload } +func makePayloads(size int, count int) [][]byte { + payloads := make([][]byte, count) + for i := range payloads { + payloads[i] = makePayload(size) + } + return payloads +} + func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { payload := makePayload(1280) tests := []*v3.UDPSessionRegistrationDatagram{ diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go index 6a614814..0d0411d1 100644 --- a/quic/v3/muxer.go +++ b/quic/v3/muxer.go @@ -17,6 +17,9 @@ const ( // Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader. // This has worked previously well for datagramv2, so we will start with this as well demuxChanCapacity = 16 + // This provides a small buffer for the PacketRouter to poll ICMP packets from the QUIC connection + // before writing them to the origin. + icmpDatagramChanCapacity = 128 logSrcKey = "src" logDstKey = "dst" @@ -59,14 +62,15 @@ type QuicConnection interface { } type datagramConn struct { - conn QuicConnection - index uint8 - sessionManager SessionManager - icmpRouter ingress.ICMPRouter - metrics Metrics - logger *zerolog.Logger - datagrams chan []byte - readErrors chan error + conn QuicConnection + index uint8 + sessionManager SessionManager + icmpRouter ingress.ICMPRouter + metrics Metrics + logger *zerolog.Logger + datagrams chan []byte + icmpDatagramChan chan *ICMPDatagram + readErrors chan error icmpEncoderPool sync.Pool // a pool of *packet.Encoder icmpDecoderPool sync.Pool @@ -75,14 +79,15 @@ type datagramConn struct { func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn { log := logger.With().Uint8("datagramVersion", 3).Logger() return &datagramConn{ - conn: conn, - index: index, - sessionManager: sessionManager, - icmpRouter: icmpRouter, - metrics: metrics, - logger: &log, - datagrams: make(chan []byte, demuxChanCapacity), - readErrors: make(chan error, 2), + conn: conn, + index: index, + sessionManager: sessionManager, + icmpRouter: icmpRouter, + metrics: metrics, + logger: &log, + datagrams: make(chan []byte, demuxChanCapacity), + icmpDatagramChan: make(chan *ICMPDatagram, icmpDatagramChanCapacity), + readErrors: make(chan error, 2), icmpEncoderPool: sync.Pool{ New: func() any { return packet.NewEncoder() @@ -168,6 +173,9 @@ func (c *datagramConn) Serve(ctx context.Context) error { readCtx, cancel := context.WithCancel(connCtx) defer cancel() go c.pollDatagrams(readCtx) + // Processing ICMP datagrams also monitors the reader context since the ICMP datagrams from the reader are the input + // for the routine. + go c.processICMPDatagrams(readCtx) for { // We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur. var datagram []byte @@ -181,58 +189,59 @@ func (c *datagramConn) Serve(ctx context.Context) error { // Monitor for any hard errors from reading the connection case err := <-c.readErrors: return err - // Otherwise, wait and dequeue datagrams as they come in + // Wait and dequeue datagrams as they come in case d := <-c.datagrams: datagram = d } // Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated. - go func() { - typ, err := ParseDatagramType(datagram) + typ, err := ParseDatagramType(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ) + continue + } + switch typ { + case UDPSessionRegistrationType: + reg := &UDPSessionRegistrationDatagram{} + err := reg.UnmarshalBinary(datagram) if err != nil { - c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ) - return + c.logger.Err(err).Msgf("unable to unmarshal session registration datagram") + continue } - switch typ { - case UDPSessionRegistrationType: - reg := &UDPSessionRegistrationDatagram{} - err := reg.UnmarshalBinary(datagram) - if err != nil { - c.logger.Err(err).Msgf("unable to unmarshal session registration datagram") - return - } - logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger() - // We bind the new session to the quic connection context instead of cloudflared context to allow for the - // quic connection to close and close only the sessions bound to it. Closing of cloudflared will also - // initiate the close of the quic connection, so we don't have to worry about the application context - // in the scope of a session. - c.handleSessionRegistrationDatagram(connCtx, reg, &logger) - case UDPSessionPayloadType: - payload := &UDPSessionPayloadDatagram{} - err := payload.UnmarshalBinary(datagram) - if err != nil { - c.logger.Err(err).Msgf("unable to unmarshal session payload datagram") - return - } - logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger() - c.handleSessionPayloadDatagram(payload, &logger) - case ICMPType: - packet := &ICMPDatagram{} - err := packet.UnmarshalBinary(datagram) - if err != nil { - c.logger.Err(err).Msgf("unable to unmarshal icmp datagram") - return - } - c.handleICMPPacket(packet) - case UDPSessionRegistrationResponseType: - // cloudflared should never expect to receive UDP session responses as it will not initiate new - // sessions towards the edge. - c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType) - return - default: - c.logger.Error().Msgf("unknown datagram type received: %d", typ) + logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger() + // We bind the new session to the quic connection context instead of cloudflared context to allow for the + // quic connection to close and close only the sessions bound to it. Closing of cloudflared will also + // initiate the close of the quic connection, so we don't have to worry about the application context + // in the scope of a session. + // + // Additionally, we spin out the registration into a separate go routine to handle the Serve'ing of the + // session in a separate routine from the demuxer. + go c.handleSessionRegistrationDatagram(connCtx, reg, &logger) + case UDPSessionPayloadType: + payload := &UDPSessionPayloadDatagram{} + err := payload.UnmarshalBinary(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to unmarshal session payload datagram") + continue } - }() + logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger() + c.handleSessionPayloadDatagram(payload, &logger) + case ICMPType: + packet := &ICMPDatagram{} + err := packet.UnmarshalBinary(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to unmarshal icmp datagram") + continue + } + c.handleICMPPacket(packet) + case UDPSessionRegistrationResponseType: + // cloudflared should never expect to receive UDP session responses as it will not initiate new + // sessions towards the edge. + c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType) + continue + default: + c.logger.Error().Msgf("unknown datagram type received: %d", typ) + } } } @@ -243,24 +252,21 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da Str(logDstKey, datagram.Dest.String()). Logger() session, err := c.sessionManager.RegisterSession(datagram, c) - switch err { - case nil: - // Continue as normal - case ErrSessionAlreadyRegistered: - // Session is already registered and likely the response got lost - c.handleSessionAlreadyRegistered(datagram.RequestID, &log) - return - case ErrSessionBoundToOtherConn: - // Session is already registered but to a different connection - c.handleSessionMigration(datagram.RequestID, &log) - return - case ErrSessionRegistrationRateLimited: - // There are too many concurrent sessions so we return an error to force a retry later - c.handleSessionRegistrationRateLimited(datagram, &log) - return - default: - log.Err(err).Msg("flow registration failure") - c.handleSessionRegistrationFailure(datagram.RequestID, &log) + if err != nil { + switch err { + case ErrSessionAlreadyRegistered: + // Session is already registered and likely the response got lost + c.handleSessionAlreadyRegistered(datagram.RequestID, &log) + case ErrSessionBoundToOtherConn: + // Session is already registered but to a different connection + c.handleSessionMigration(datagram.RequestID, &log) + case ErrSessionRegistrationRateLimited: + // There are too many concurrent sessions so we return an error to force a retry later + c.handleSessionRegistrationRateLimited(datagram, &log) + default: + log.Err(err).Msg("flow registration failure") + c.handleSessionRegistrationFailure(datagram.RequestID, &log) + } return } log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger() @@ -365,21 +371,42 @@ func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadD logger.Err(err).Msgf("unable to find flow") return } - // We ignore the bytes written to the socket because any partial write must return an error. - _, err = s.Write(datagram.Payload) - if err != nil { - logger.Err(err).Msgf("unable to write payload for the flow") - return - } + s.Write(datagram.Payload) } -// Handles incoming ICMP datagrams. +// Handles incoming ICMP datagrams into a serialized channel to be handled by a single consumer. func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) { if c.icmpRouter == nil { // ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets return } + select { + case c.icmpDatagramChan <- datagram: + default: + // If the ICMP datagram channel is full, drop any additional incoming. + c.logger.Warn().Msg("failed to write icmp packet to origin: dropped") + } +} +// Consumes from the ICMP datagram channel to write out the ICMP requests to an origin. +func (c *datagramConn) processICMPDatagrams(ctx context.Context) { + if c.icmpRouter == nil { + // ICMPRouter is disabled so we ignore all incoming ICMP packets + return + } + + for { + select { + // If the provided context is closed we want to exit the write loop + case <-ctx.Done(): + return + case datagram := <-c.icmpDatagramChan: + c.writeICMPPacket(datagram) + } + } +} + +func (c *datagramConn) writeICMPPacket(datagram *ICMPDatagram) { // Decode the provided ICMPDatagram as an ICMP packet rawPacket := packet.RawPacket{Data: datagram.Payload} cachedDecoder := c.icmpDecoderPool.Get() diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index 729abd3c..fab59328 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/fortytw2/leaktest" "github.com/google/gopacket/layers" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -92,7 +93,7 @@ func TestDatagramConn_New(t *testing.T) { DefaultDialer: testDefaultDialer, TCPWriteTimeout: 0, }, &log) - conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(newMockQuicConn(t.Context()), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) if conn == nil { t.Fatal("expected valid connection") } @@ -104,7 +105,9 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { DefaultDialer: testDefaultDialer, TCPWriteTimeout: 0, }, &log) - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) payload := []byte{0xef, 0xef} @@ -123,7 +126,9 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { DefaultDialer: testDefaultDialer, TCPWriteTimeout: 0, }, &log) - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) @@ -149,7 +154,9 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) { DefaultDialer: testDefaultDialer, TCPWriteTimeout: 0, }, &log) - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) @@ -166,7 +173,9 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { DefaultDialer: testDefaultDialer, TCPWriteTimeout: 0, }, &log) - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() quic.ctx = ctx @@ -195,15 +204,17 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) { log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) sessionManager := &mockSessionManager{ expectedRegErr: v3.ErrSessionRegistrationRateLimited, } conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) // Setup the muxer - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(context.Canceled) done := make(chan error, 1) go func() { done <- conn.Serve(ctx) @@ -223,9 +234,12 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) { require.EqualValues(t, testRequestID, resp.RequestID) require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType) + + assertContextClosed(t, ctx, done, cancel) } func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { + defer leaktest.Check(t)() for _, test := range []struct { name string input []byte @@ -250,7 +264,9 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { t.Run(test.name, func(t *testing.T) { logOutput := new(LockedBuffer) log := zerolog.New(logOutput) - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) quic.send <- test.input conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log) @@ -289,8 +305,11 @@ func (b *LockedBuffer) String() string { } func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) expectedErr := errors.New("unable to register session") sessionManager := mockSessionManager{expectedRegErr: expectedErr} conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) @@ -324,8 +343,11 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { } func TestDatagramConnServe(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) session := newMockSession() sessionManager := mockSessionManager{session: &session} conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) @@ -372,8 +394,11 @@ func TestDatagramConnServe(t *testing.T) { // instances causes inteference resulting in multiple different raw packets being decoded // as the same decoded packet. func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) session := newMockSession() sessionManager := mockSessionManager{session: &session} router := newMockICMPRouter() @@ -413,10 +438,14 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) { wg := sync.WaitGroup{} var receivedPackets []*packet.ICMP go func() { - for ctx.Err() == nil { - icmpPacket := <-router.recv - receivedPackets = append(receivedPackets, icmpPacket) - wg.Done() + for { + select { + case <-ctx.Done(): + return + case icmpPacket := <-router.recv: + receivedPackets = append(receivedPackets, icmpPacket) + wg.Done() + } } }() @@ -452,8 +481,11 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) { } func TestDatagramConnServe_RegisterTwice(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) session := newMockSession() sessionManager := mockSessionManager{session: &session} conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) @@ -514,12 +546,17 @@ func TestDatagramConnServe_RegisterTwice(t *testing.T) { } func TestDatagramConnServe_MigrateConnection(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) session := newMockSession() sessionManager := mockSessionManager{session: &session} conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) - quic2 := newMockQuicConn() + conn2Ctx, conn2Cancel := context.WithCancelCause(t.Context()) + defer conn2Cancel(context.Canceled) + quic2 := newMockQuicConn(conn2Ctx) conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log) // Setup the muxer @@ -597,8 +634,11 @@ func TestDatagramConnServe_MigrateConnection(t *testing.T) { } func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) // mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound} conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) @@ -624,9 +664,12 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { assertContextClosed(t, ctx, done, cancel) } -func TestDatagramConnServe_Payload(t *testing.T) { +func TestDatagramConnServe_Payloads(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) session := newMockSession() sessionManager := mockSessionManager{session: &session} conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) @@ -639,15 +682,26 @@ func TestDatagramConnServe_Payload(t *testing.T) { done <- conn.Serve(ctx) }() - // Send new session registration - expectedPayload := []byte{0xef, 0xef} - datagram := newSessionPayloadDatagram(testRequestID, expectedPayload) - quic.send <- datagram + // Send session payloads + expectedPayloads := makePayloads(256, 16) + go func() { + for _, payload := range expectedPayloads { + datagram := newSessionPayloadDatagram(testRequestID, payload) + quic.send <- datagram + } + }() - // Session should receive the payload - payload := <-session.recv - if !slices.Equal(expectedPayload, payload) { - t.Fatalf("expected session receieve the payload sent via the muxer") + // Session should receive the payloads (in-order) + for i, payload := range expectedPayloads { + select { + case recv := <-session.recv: + if !slices.Equal(recv, payload) { + t.Fatalf("expected session receieve the payload[%d] sent via the muxer: (%x) (%x)", i, recv[:16], payload[:16]) + } + case err := <-ctx.Done(): + // we expect the payload to return before the context to cancel on the session + t.Fatal(err) + } } // Cancel the muxer Serve context and make sure it closes with the expected error @@ -655,8 +709,11 @@ func TestDatagramConnServe_Payload(t *testing.T) { } func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) router := newMockICMPRouter() conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) @@ -701,8 +758,11 @@ func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) { } func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - quic := newMockQuicConn() + connCtx, connCancel := context.WithCancelCause(t.Context()) + defer connCancel(context.Canceled) + quic := newMockQuicConn(connCtx) router := newMockICMPRouter() conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) @@ -821,9 +881,9 @@ type mockQuicConn struct { recv chan []byte } -func newMockQuicConn() *mockQuicConn { +func newMockQuicConn(ctx context.Context) *mockQuicConn { return &mockQuicConn{ - ctx: context.Background(), + ctx: ctx, send: make(chan []byte, 1), recv: make(chan []byte, 1), } @@ -841,7 +901,12 @@ func (m *mockQuicConn) SendDatagram(payload []byte) error { } func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) { - return <-m.send, nil + select { + case <-m.ctx.Done(): + return nil, m.ctx.Err() + case b := <-m.send: + return b, nil + } } type mockQuicConnReadError struct { @@ -905,11 +970,10 @@ func (m *mockSession) Serve(ctx context.Context) error { return v3.SessionCloseErr } -func (m *mockSession) Write(payload []byte) (n int, err error) { +func (m *mockSession) Write(payload []byte) { b := make([]byte, len(payload)) copy(b, payload) m.recv <- b - return len(b), nil } func (m *mockSession) Close() error { diff --git a/quic/v3/session.go b/quic/v3/session.go index 82f71b0d..74a34542 100644 --- a/quic/v3/session.go +++ b/quic/v3/session.go @@ -22,6 +22,11 @@ const ( // this value (maxDatagramPayloadLen). maxOriginUDPPacketSize = 1500 + // The maximum amount of datagrams a session will queue up before it begins dropping datagrams. + // This channel buffer is small because we assume that the dedicated writer to the origin is typically + // fast enought to keep the channel empty. + writeChanCapacity = 16 + logFlowID = "flowID" logPacketSizeKey = "packetSize" ) @@ -49,7 +54,7 @@ func newSessionIdleErr(timeout time.Duration) error { } type Session interface { - io.WriteCloser + io.Closer ID() RequestID ConnectionID() uint8 RemoteAddr() net.Addr @@ -58,6 +63,7 @@ type Session interface { Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger) // Serve starts the event loop for processing UDP packets Serve(ctx context.Context) error + Write(payload []byte) } type session struct { @@ -67,12 +73,18 @@ type session struct { originAddr net.Addr localAddr net.Addr eyeball atomic.Pointer[DatagramConn] + writeChan chan []byte // activeAtChan is used to communicate the last read/write time activeAtChan chan time.Time - closeChan chan error - contextChan chan context.Context - metrics Metrics - log *zerolog.Logger + errChan chan error + // The close channel signal only exists for the write loop because the read loop is always waiting on a read + // from the UDP socket to the origin. To close the read loop we close the socket. + // Additionally, we can't close the writeChan to indicate that writes are complete because the producer (edge) + // side may still be trying to write to this session. + closeWrite chan struct{} + contextChan chan context.Context + metrics Metrics + log *zerolog.Logger // A special close function that we wrap with sync.Once to make sure it is only called once closeFn func() error @@ -89,10 +101,12 @@ func NewSession( log *zerolog.Logger, ) Session { logger := log.With().Str(logFlowID, id.String()).Logger() - // closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to - // write to the channel without blocking since there is only ever one value read from the closeChan by the + writeChan := make(chan []byte, writeChanCapacity) + // errChan has three slots to allow for all writers (the closeFn, the read loop and the write loop) to + // write to the channel without blocking since there is only ever one value read from the errChan by the // waitForCloseCondition. - closeChan := make(chan error, 2) + errChan := make(chan error, 3) + closeWrite := make(chan struct{}) session := &session{ id: id, closeAfterIdle: closeAfterIdle, @@ -100,10 +114,12 @@ func NewSession( originAddr: originAddr, localAddr: localAddr, eyeball: atomic.Pointer[DatagramConn]{}, + writeChan: writeChan, // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // drop instead of blocking because last active time only needs to be an approximation activeAtChan: make(chan time.Time, 1), - closeChan: closeChan, + errChan: errChan, + closeWrite: closeWrite, // contextChan is an unbounded channel to help enforce one active migration of a session at a time. contextChan: make(chan context.Context), metrics: metrics, @@ -111,9 +127,12 @@ func NewSession( closeFn: sync.OnceValue(func() error { // We don't want to block on sending to the close channel if it is already full select { - case closeChan <- SessionCloseErr: + case errChan <- SessionCloseErr: default: } + // Indicate to the write loop that the session is now closed + close(closeWrite) + // Close the socket directly to unblock the read loop and cause it to also end return origin.Close() }), } @@ -154,66 +173,107 @@ func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zer } func (s *session) Serve(ctx context.Context) error { - go func() { - // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 - // This makes it safe to share readBuffer between iterations - readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{} - // To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with - // the required datagram header information. We can reuse this buffer for this session since the header is the - // same for the each read. - _ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen]) - for { - // Read from the origin UDP socket - n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:]) - if err != nil { - if errors.Is(err, io.EOF) || - errors.Is(err, io.ErrUnexpectedEOF) { - s.log.Debug().Msgf("flow (origin) connection closed: %v", err) - } - s.closeChan <- err - return - } - if n < 0 { - s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped") - continue - } - if n > maxDatagramPayloadLen { - connectionIndex := s.ConnectionID() - s.metrics.PayloadTooLarge(connectionIndex) - s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped") - continue - } - // We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point - // of lock contention, as a migration can only happen during startup of a session before traffic flow. - eyeball := *(s.eyeball.Load()) - // Sending a packet to the session does block on the [quic.Connection], however, this is okay because it - // will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge. - err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n]) - if err != nil { - s.closeChan <- err - return - } - // Mark the session as active since we proxied a valid packet from the origin. - s.markActive() - } - }() + go s.writeLoop() + go s.readLoop() return s.waitForCloseCondition(ctx, s.closeAfterIdle) } -func (s *session) Write(payload []byte) (n int, err error) { - n, err = s.origin.Write(payload) - if err != nil { - s.log.Err(err).Msg("failed to write payload to flow (remote)") - return n, err +// Read datagrams from the origin and write them to the connection. +func (s *session) readLoop() { + // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 + // This makes it safe to share readBuffer between iterations + readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{} + // To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with + // the required datagram header information. We can reuse this buffer for this session since the header is the + // same for the each read. + _ = MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen]) + for { + // Read from the origin UDP socket + n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:]) + if err != nil { + if isConnectionClosed(err) { + s.log.Debug().Msgf("flow (read) connection closed: %v", err) + } + s.closeSession(err) + return + } + if n < 0 { + s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped") + continue + } + if n > maxDatagramPayloadLen { + connectionIndex := s.ConnectionID() + s.metrics.PayloadTooLarge(connectionIndex) + s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped") + continue + } + // We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point + // of lock contention, as a migration can only happen during startup of a session before traffic flow. + eyeball := *(s.eyeball.Load()) + // Sending a packet to the session does block on the [quic.Connection], however, this is okay because it + // will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge. + err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n]) + if err != nil { + s.closeSession(err) + return + } + // Mark the session as active since we proxied a valid packet from the origin. + s.markActive() } - // Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer - if n < len(payload) { - s.log.Err(io.ErrShortWrite).Msg("failed to write the full payload to flow (remote)") - return n, io.ErrShortWrite +} + +func (s *session) Write(payload []byte) { + select { + case s.writeChan <- payload: + default: + s.log.Error().Msg("failed to write flow payload to origin: dropped") + } +} + +// Read datagrams from the write channel to the origin. +func (s *session) writeLoop() { + for { + select { + case <-s.closeWrite: + // When the closeWrite channel is closed, we will no longer write to the origin and end this + // goroutine since the session is now closed. + return + case payload := <-s.writeChan: + n, err := s.origin.Write(payload) + if err != nil { + if isConnectionClosed(err) { + s.log.Debug().Msgf("flow (write) connection closed: %v", err) + } + s.log.Err(err).Msg("failed to write flow payload to origin") + s.closeSession(err) + // If we fail to write to the origin socket, we need to end the writer and close the session + return + } + // Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer + if n < len(payload) { + s.log.Err(io.ErrShortWrite).Msg("failed to write the full flow payload to origin") + continue + } + // Mark the session as active since we successfully proxied a packet to the origin. + s.markActive() + } + } +} + +func isConnectionClosed(err error) bool { + return errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) +} + +// Send an error to the error channel to report that an error has either happened on the tunnel or origin side of the +// proxied connection. +func (s *session) closeSession(err error) { + select { + case s.errChan <- err: + default: + // In the case that the errChan is already full, we will skip over it and return as to not block + // the caller because we should start cleaning up the session. + s.log.Warn().Msg("error channel was full") } - // Mark the session as active since we proxied a packet to the origin. - s.markActive() - return n, err } // ResetIdleTimer will restart the current idle timer. @@ -240,7 +300,8 @@ func (s *session) Close() error { func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { connCtx := ctx - // Closing the session at the end cancels read so Serve() can return + // Closing the session at the end cancels read so Serve() can return, additionally, it closes the + // closeWrite channel which indicates to the write loop to return. defer s.Close() if closeAfterIdle == 0 { // Provided that the default caller doesn't specify one @@ -260,7 +321,9 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time // still be active on the existing connection. connCtx = newContext continue - case reason := <-s.closeChan: + case reason := <-s.errChan: + // Any error returned here is from the read or write loops indicating that it can no longer process datagrams + // and as such the session needs to close. return reason case <-checkIdleTimer.C: // The check idle timer will only return after an idle period since the last active diff --git a/quic/v3/session_fuzz_test.go b/quic/v3/session_fuzz_test.go index 0e4952c0..a8f35dc9 100644 --- a/quic/v3/session_fuzz_test.go +++ b/quic/v3/session_fuzz_test.go @@ -4,20 +4,24 @@ import ( "testing" ) -// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin. +// FuzzSessionWrite verifies that we don't run into any panics when writing a single variable sized payload to the origin. func FuzzSessionWrite(f *testing.F) { - f.Fuzz(func(t *testing.T, b []byte) { - testSessionWrite(t, b) - }) -} - -// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin. -func FuzzSessionServe(f *testing.F) { f.Fuzz(func(t *testing.T, b []byte) { // The origin transport read is bound to 1280 bytes if len(b) > 1280 { b = b[:1280] } - testSessionServe_Origin(t, b) + testSessionWrite(t, [][]byte{b}) + }) +} + +// FuzzSessionRead verifies that we don't run into any panics when reading a single variable sized payload from the origin. +func FuzzSessionRead(f *testing.F) { + f.Fuzz(func(t *testing.T, b []byte) { + // The origin transport read is bound to 1280 bytes + if len(b) > 1280 { + b = b[:1280] + } + testSessionRead(t, [][]byte{b}) }) } diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go index ce058ee4..c15a0318 100644 --- a/quic/v3/session_test.go +++ b/quic/v3/session_test.go @@ -31,60 +31,61 @@ func TestSessionNew(t *testing.T) { } } -func testSessionWrite(t *testing.T, payload []byte) { +func testSessionWrite(t *testing.T, payloads [][]byte) { log := zerolog.Nop() origin, server := net.Pipe() defer origin.Close() defer server.Close() - // Start origin server read - serverRead := make(chan []byte, 1) + // Start origin server reads + serverRead := make(chan []byte, len(payloads)) go func() { - read := make([]byte, 1500) - _, _ = server.Read(read[:]) - serverRead <- read + for range len(payloads) { + buf := make([]byte, 1500) + _, _ = server.Read(buf[:]) + serverRead <- buf + } + close(serverRead) }() - // Create session and write to origin + + // Create a session session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) - n, err := session.Write(payload) defer session.Close() - if err != nil { - t.Fatal(err) - } - if n != len(payload) { - t.Fatal("unable to write the whole payload") + // Start the Serve to begin the writeLoop + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(context.Canceled) + done := make(chan error) + go func() { + done <- session.Serve(ctx) + }() + // Write the payloads to the session + for _, payload := range payloads { + session.Write(payload) } - read := <-serverRead - if !slices.Equal(payload, read[:len(payload)]) { - t.Fatal("payload provided from origin and read value are not the same") + // Read from the origin to ensure the payloads were received (in-order) + for i, payload := range payloads { + read := <-serverRead + if !slices.Equal(payload, read[:len(payload)]) { + t.Fatalf("payload[%d] provided from origin and read value are not the same (%x) and (%x)", i, payload[:16], read[:16]) + } + } + _, more := <-serverRead + if more { + t.Fatalf("expected the session to have all of the origin payloads received: %d", len(serverRead)) + } + + assertContextClosed(t, ctx, done, cancel) +} + +func TestSessionWrite(t *testing.T) { + defer leaktest.Check(t)() + for i := range 1280 { + payloads := makePayloads(i, 16) + testSessionWrite(t, payloads) } } -func TestSessionWrite_Max(t *testing.T) { - defer leaktest.Check(t)() - payload := makePayload(1280) - testSessionWrite(t, payload) -} - -func TestSessionWrite_Min(t *testing.T) { - defer leaktest.Check(t)() - payload := makePayload(0) - testSessionWrite(t, payload) -} - -func TestSessionServe_OriginMax(t *testing.T) { - defer leaktest.Check(t)() - payload := makePayload(1280) - testSessionServe_Origin(t, payload) -} - -func TestSessionServe_OriginMin(t *testing.T) { - defer leaktest.Check(t)() - payload := makePayload(0) - testSessionServe_Origin(t, payload) -} - -func testSessionServe_Origin(t *testing.T, payload []byte) { +func testSessionRead(t *testing.T, payloads [][]byte) { log := zerolog.Nop() origin, server := net.Pipe() defer origin.Close() @@ -100,37 +101,42 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { done <- session.Serve(ctx) }() - // Write from the origin server - _, err := server.Write(payload) - if err != nil { - t.Fatal(err) - } - - select { - case data := <-eyeball.recvData: - // check received data matches provided from origin - expectedData := makePayload(1500) - _ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) - copy(expectedData[17:], payload) - if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { - t.Fatal("expected datagram did not equal expected") + // Write from the origin server to the eyeball + go func() { + for _, payload := range payloads { + _, _ = server.Write(payload) + } + }() + + // Read from the eyeball to ensure the payloads were received (in-order) + for i, payload := range payloads { + select { + case data := <-eyeball.recvData: + // check received data matches provided from origin + expectedData := makePayload(1500) + _ = v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) + copy(expectedData[17:], payload) + if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { + t.Fatalf("expected datagram[%d] did not equal expected", i) + } + case err := <-ctx.Done(): + // we expect the payload to return before the context to cancel on the session + t.Fatal(err) } - cancel(errExpectedContextCanceled) - case err := <-ctx.Done(): - // we expect the payload to return before the context to cancel on the session - t.Fatal(err) } - err = <-done - if !errors.Is(err, context.Canceled) { - t.Fatal(err) - } - if !errors.Is(context.Cause(ctx), errExpectedContextCanceled) { - t.Fatal(err) + assertContextClosed(t, ctx, done, cancel) +} + +func TestSessionRead(t *testing.T) { + defer leaktest.Check(t)() + for i := range 1280 { + payloads := makePayloads(i, 16) + testSessionRead(t, payloads) } } -func TestSessionServe_OriginTooLarge(t *testing.T) { +func TestSessionRead_OriginTooLarge(t *testing.T) { defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() @@ -317,6 +323,8 @@ func TestSessionServe_IdleTimeout(t *testing.T) { closeAfterIdle := 2 * time.Second session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) err := session.Serve(t.Context()) + + // Session should idle timeout if no reads or writes occur if !errors.Is(err, v3.SessionIdleErr{}) { t.Fatal(err) }