TUN-8861: Rename Session Limiter to Flow Limiter

## Summary
Session is the concept used for UDP flows. Therefore, to make
the session limiter ambiguous for both TCP and UDP, this commit
renames it to flow limiter.

Closes TUN-8861
This commit is contained in:
João "Pisco" Fernandes 2025-01-20 06:33:40 -08:00
parent 8c2eda16c1
commit 4eb0f8ce5f
23 changed files with 295 additions and 295 deletions

View File

@ -12,7 +12,7 @@ import (
pkgerrors "github.com/pkg/errors" pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
@ -107,7 +107,7 @@ func (moc *mockOriginProxy) ProxyTCP(
r *TCPRequest, r *TCPRequest,
) error { ) error {
if r.CfTraceID == "flow-rate-limited" { if r.CfTraceID == "flow-rate-limited" {
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited") return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
} }
return nil return nil

View File

@ -16,7 +16,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/net/http2" "golang.org/x/net/http2"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -336,7 +336,7 @@ func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
return false return false
} }
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) { if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited) rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
} else { } else {
rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.setResponseMetaHeader(responseMetaHeaderCfd)

View File

@ -17,7 +17,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
cfdquic "github.com/cloudflare/cloudflared/quic" cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
@ -185,7 +185,7 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
var metadata []pogs.Metadata var metadata []pogs.Metadata
// Check the type of error that was throw and add metadata that will help identify it on OTD. // Check the type of error that was throw and add metadata that will help identify it on OTD.
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) { if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey) metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
} }

View File

@ -29,7 +29,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/nettest" "golang.org/x/net/nettest"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession" "github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
@ -508,7 +508,7 @@ func TestBuildHTTPRequest(t *testing.T) {
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error { func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
if tcpRequest.Dest == "rate-limit-me" { if tcpRequest.Dest == "rate-limit-me" {
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream") return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream")
} }
_ = rwa.AckConnection("") _ = rwa.AckConnection("")
@ -828,7 +828,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
conn, conn,
index, index,
sessionManager, sessionManager,
cfdsession.NewLimiter(0), cfdflow.NewLimiter(0),
datagramMuxer, datagramMuxer,
packetRouter, packetRouter,
15 * time.Second, 15 * time.Second,

View File

@ -14,7 +14,7 @@ import (
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession" "github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
@ -46,8 +46,8 @@ type datagramV2Connection struct {
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit. // flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
sessionLimiter cfdsession.Limiter flowLimiter cfdflow.Limiter
// datagramMuxer mux/demux datagrams from quic connection // datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *cfdquic.DatagramMuxerV2 datagramMuxer *cfdquic.DatagramMuxerV2
@ -65,7 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
index uint8, index uint8,
rpcTimeout time.Duration, rpcTimeout time.Duration,
streamWriteTimeout time.Duration, streamWriteTimeout time.Duration,
sessionLimiter cfdsession.Limiter, flowLimiter cfdflow.Limiter,
logger *zerolog.Logger, logger *zerolog.Logger,
) DatagramSessionHandler { ) DatagramSessionHandler {
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
@ -77,7 +77,7 @@ func NewDatagramV2Connection(ctx context.Context,
conn: conn, conn: conn,
index: index, index: index,
sessionManager: sessionManager, sessionManager: sessionManager,
sessionLimiter: sessionLimiter, flowLimiter: flowLimiter,
datagramMuxer: datagramMuxer, datagramMuxer: datagramMuxer,
packetRouter: packetRouter, packetRouter: packetRouter,
rpcTimeout: rpcTimeout, rpcTimeout: rpcTimeout,
@ -121,7 +121,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger() log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
// Try to start a new session // Try to start a new session
if err := q.sessionLimiter.Acquire(management.UDP.String()); err != nil { if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort) log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting") err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
@ -135,7 +135,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
if err != nil { if err != nil {
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
q.sessionLimiter.Release() q.flowLimiter.Release()
return nil, err return nil, err
} }
registerSpan.SetAttributes( registerSpan.SetAttributes(
@ -148,12 +148,12 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
originProxy.Close() originProxy.Close()
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session") log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
q.sessionLimiter.Release() q.flowLimiter.Release()
return nil, err return nil, err
} }
go func() { go func() {
defer q.sessionLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method. defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
q.serveUDPSession(session, closeAfterIdleHint) q.serveUDPSession(session, closeAfterIdleHint)
}() }()

View File

@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/mocks" "github.com/cloudflare/cloudflared/mocks"
cfdsession "github.com/cloudflare/cloudflared/session"
) )
type mockQuicConnection struct { type mockQuicConnection struct {
@ -75,7 +75,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
conn := &mockQuicConnection{} conn := &mockQuicConnection{}
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
sessionLimiterMock := mocks.NewMockLimiter(ctrl) flowLimiterMock := mocks.NewMockLimiter(ctrl)
datagramConn := NewDatagramV2Connection( datagramConn := NewDatagramV2Connection(
context.Background(), context.Background(),
@ -84,13 +84,13 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
0, 0,
0*time.Second, 0*time.Second,
0*time.Second, 0*time.Second,
sessionLimiterMock, flowLimiterMock,
&log, &log,
) )
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions) flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
sessionLimiterMock.EXPECT().Release().Times(0) flowLimiterMock.EXPECT().Release().Times(0)
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "") _, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
require.ErrorIs(t, err, cfdsession.ErrTooManyActiveSessions) require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
} }

77
flow/limiter.go Normal file
View File

@ -0,0 +1,77 @@
package flow
import (
"errors"
"sync"
)
const (
unlimitedActiveFlows = 0
)
var (
ErrTooManyActiveFlows = errors.New("too many active flows")
)
type Limiter interface {
// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
// the maximum it returns ErrTooManyActiveFlows.
Acquire(flowType string) error
// Release releases a slot for a flow.
Release()
// SetLimit allows to hot swap the limit value of the limiter.
SetLimit(uint64)
}
type flowLimiter struct {
limiterLock sync.Mutex
activeFlowsCounter uint64
maxActiveFlows uint64
unlimited bool
}
func NewLimiter(maxActiveFlows uint64) Limiter {
flowLimiter := &flowLimiter{
maxActiveFlows: maxActiveFlows,
unlimited: isUnlimited(maxActiveFlows),
}
return flowLimiter
}
func (s *flowLimiter) Acquire(flowType string) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
flowRegistrationsDropped.WithLabelValues(flowType).Inc()
return ErrTooManyActiveFlows
}
s.activeFlowsCounter++
return nil
}
func (s *flowLimiter) Release() {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if s.activeFlowsCounter <= 0 {
return
}
s.activeFlowsCounter--
}
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
s.maxActiveFlows = newMaxActiveFlows
s.unlimited = isUnlimited(newMaxActiveFlows)
}
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
func isUnlimited(value uint64) bool {
return value == unlimitedActiveFlows
}

119
flow/limiter_test.go Normal file
View File

@ -0,0 +1,119 @@
package flow_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/flow"
)
func TestFlowLimiter_Unlimited(t *testing.T) {
unlimitedLimiter := flow.NewLimiter(0)
for i := 0; i < 1000; i++ {
err := unlimitedLimiter.Acquire("test")
require.NoError(t, err)
}
}
func TestFlowLimiter_Limited(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
}
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
// Acquire the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more flows fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Release the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
limiter.Release()
}
// Validate acquire 1 more flows works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Release a 10x the number of max flows
for i := uint64(0); i < 10*maxFlows; i++ {
limiter.Release()
}
// Validate it still can only acquire a value = number max flows.
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
}
func TestFlowLimiter_SetLimit(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
// Acquire the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more flows fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Set the flow limiter to support one more request
limiter.SetLimit(maxFlows + 1)
// Validate acquire 1 more flows now works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Validate acquire 1 more flows doesn't work because we already reached the limit
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Release all flows
for i := uint64(0); i < maxFlows+1; i++ {
limiter.Release()
}
// Validate 1 flow works again
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Set the flow limit to 1
limiter.SetLimit(1)
// Validate acquire 1 more flows doesn't work
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Set the flow limit to unlimited
limiter.SetLimit(0)
// Validate it can acquire a lot of flows because it is now unlimited.
for i := uint64(0); i < 10*maxFlows; i++ {
err := limiter.Acquire("shouldn't fail")
require.NoError(t, err)
}
}

View File

@ -1,4 +1,4 @@
package session package flow
import ( import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -6,17 +6,17 @@ import (
) )
const ( const (
namespace = "session" namespace = "flow"
) )
var ( var (
labels = []string{"session_type"} labels = []string{"flow_type"}
sessionRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{ flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: "client", Subsystem: "client",
Name: "registrations_rate_limited_total", Name: "registrations_rate_limited_total",
Help: "Count registrations dropped due to high number of concurrent sessions being handled", Help: "Count registrations dropped due to high number of concurrent flows being handled",
}, },
labels, labels,
) )

View File

@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: ../session/limiter.go // Source: ../flow/limiter.go
// //
// Generated by this command: // Generated by this command:
// //
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter // mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter
// //
// Package mocks is a generated GoMock package. // Package mocks is a generated GoMock package.
@ -40,17 +40,17 @@ func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder {
} }
// Acquire mocks base method. // Acquire mocks base method.
func (m *MockLimiter) Acquire(sessionType string) error { func (m *MockLimiter) Acquire(flowType string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Acquire", sessionType) ret := m.ctrl.Call(m, "Acquire", flowType)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Acquire indicates an expected call of Acquire. // Acquire indicates an expected call of Acquire.
func (mr *MockLimiterMockRecorder) Acquire(sessionType any) *MockLimiterAcquireCall { func (mr *MockLimiterMockRecorder) Acquire(flowType any) *MockLimiterAcquireCall {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), sessionType) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), flowType)
return &MockLimiterAcquireCall{Call: call} return &MockLimiterAcquireCall{Call: call}
} }

View File

@ -2,4 +2,4 @@
package mocks package mocks
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter" //go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"

View File

@ -10,7 +10,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
@ -35,9 +35,9 @@ type Orchestrator struct {
// cloudflared Configuration // cloudflared Configuration
config *Config config *Config
tags []pogs.Tag tags []pogs.Tag
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit. // flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
sessionLimiter cfdsession.Limiter flowLimiter cfdflow.Limiter
log *zerolog.Logger log *zerolog.Logger
// orchestrator must not handle any more updates after shutdownC is closed // orchestrator must not handle any more updates after shutdownC is closed
shutdownC <-chan struct{} shutdownC <-chan struct{}
@ -58,7 +58,7 @@ func NewOrchestrator(ctx context.Context,
internalRules: internalRules, internalRules: internalRules,
config: config, config: config,
tags: tags, tags: tags,
sessionLimiter: cfdsession.NewLimiter(config.WarpRouting.MaxActiveFlows), flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows),
log: log, log: log,
shutdownC: ctx.Done(), shutdownC: ctx.Done(),
} }
@ -142,10 +142,10 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
return errors.Wrap(err, "failed to start origin") return errors.Wrap(err, "failed to start origin")
} }
// Update the sessions limit since the configuration might have changed // Update the flow limit since the configuration might have changed
o.sessionLimiter.SetLimit(warpRouting.MaxActiveFlows) o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows)
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log) proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log)
o.proxy.Store(proxy) o.proxy.Store(proxy)
o.config.Ingress = &ingressRules o.config.Ingress = &ingressRules
o.config.WarpRouting = warpRouting o.config.WarpRouting = warpRouting
@ -217,10 +217,10 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
return proxy, nil return proxy, nil
} }
// GetSessionLimiter returns the session limiter used across cloudflared, that can be hot reload when // GetFlowLimiter returns the flow limiter used across cloudflared, that can be hot reload when
// the configuration changes. // the configuration changes.
func (o *Orchestrator) GetSessionLimiter() cfdsession.Limiter { func (o *Orchestrator) GetFlowLimiter() cfdflow.Limiter {
return o.sessionLimiter return o.flowLimiter
} }
func (o *Orchestrator) waitToCloseLastProxy() { func (o *Orchestrator) waitToCloseLastProxy() {

View File

@ -14,8 +14,8 @@ import (
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
cfdsession "github.com/cloudflare/cloudflared/session"
"github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cfio" "github.com/cloudflare/cloudflared/cfio"
@ -34,11 +34,11 @@ const (
// Proxy represents a means to Proxy between cloudflared and the origin services. // Proxy represents a means to Proxy between cloudflared and the origin services.
type Proxy struct { type Proxy struct {
ingressRules ingress.Ingress ingressRules ingress.Ingress
warpRouting *ingress.WarpRoutingService warpRouting *ingress.WarpRoutingService
tags []pogs.Tag tags []pogs.Tag
sessionLimiter cfdsession.Limiter flowLimiter cfdflow.Limiter
log *zerolog.Logger log *zerolog.Logger
} }
// NewOriginProxy returns a new instance of the Proxy struct. // NewOriginProxy returns a new instance of the Proxy struct.
@ -46,15 +46,15 @@ func NewOriginProxy(
ingressRules ingress.Ingress, ingressRules ingress.Ingress,
warpRouting ingress.WarpRoutingConfig, warpRouting ingress.WarpRoutingConfig,
tags []pogs.Tag, tags []pogs.Tag,
sessionLimiter cfdsession.Limiter, flowLimiter cfdflow.Limiter,
writeTimeout time.Duration, writeTimeout time.Duration,
log *zerolog.Logger, log *zerolog.Logger,
) *Proxy { ) *Proxy {
proxy := &Proxy{ proxy := &Proxy{
ingressRules: ingressRules, ingressRules: ingressRules,
tags: tags, tags: tags,
sessionLimiter: sessionLimiter, flowLimiter: flowLimiter,
log: log, log: log,
} }
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout) proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
@ -160,12 +160,12 @@ func (p *Proxy) ProxyTCP(
logger := newTCPLogger(p.log, req) logger := newTCPLogger(p.log, req)
// Try to start a new session // Try to start a new flow
if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil { if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil {
logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy") logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy")
return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting") return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting")
} }
defer p.sessionLimiter.Release() defer p.flowLimiter.Release()
serveCtx, cancel := context.WithCancel(ctx) serveCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()

View File

@ -26,7 +26,7 @@ import (
"github.com/cloudflare/cloudflared/mocks" "github.com/cloudflare/cloudflared/mocks"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/cfio" "github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log) proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSE", testProxySSE(proxy))
@ -368,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log) proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
for _, test := range tests { for _, test := range tests {
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
@ -416,7 +416,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log) proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -484,8 +484,8 @@ func TestConnections(t *testing.T) {
// requestheaders to be sent in the call to proxy.Proxy // requestheaders to be sent in the call to proxy.Proxy
requestHeaders http.Header requestHeaders http.Header
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call // flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
sessionLimiterResponse error flowLimiterResponse error
} }
type want struct { type want struct {
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
requestHeaders: map[string][]string{ requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
}, },
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions, flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
}, },
want: want{ want: want{
message: []byte{}, message: []byte{},
@ -695,14 +695,14 @@ func TestConnections(t *testing.T) {
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
_ = ingressRule.StartOrigins(logger, ctx.Done()) _ = ingressRule.StartOrigins(logger, ctx.Done())
// Mock session limiter // Mock flow limiter
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
sessionLimiter := mocks.NewMockLimiter(ctrl) flowLimiter := mocks.NewMockLimiter(ctrl)
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse) flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
sessionLimiter.EXPECT().Release().AnyTimes() flowLimiter.EXPECT().Release().AnyTimes()
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger) proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String() dest := ln.Addr().String()

View File

@ -284,8 +284,8 @@ const (
ResponseDestinationUnreachable SessionRegistrationResp = 0x01 ResponseDestinationUnreachable SessionRegistrationResp = 0x01
// Session registration was unable to bind to a local UDP socket. // Session registration was unable to bind to a local UDP socket.
ResponseUnableToBindSocket SessionRegistrationResp = 0x02 ResponseUnableToBindSocket SessionRegistrationResp = 0x02
// Session registration failed due to the number of session being higher than the limit. // Session registration failed due to the number of flows being higher than the limit.
ResponseTooManyActiveSessions SessionRegistrationResp = 0x03 ResponseTooManyActiveFlows SessionRegistrationResp = 0x03
// Session registration failed with an unexpected error but provided a message. // Session registration failed with an unexpected error but provided a message.
ResponseErrorWithMsg SessionRegistrationResp = 0xff ResponseErrorWithMsg SessionRegistrationResp = 0xff
) )

View File

@ -10,7 +10,7 @@ import (
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
cfdsession "github.com/cloudflare/cloudflared/session" cfdflow "github.com/cloudflare/cloudflared/flow"
) )
var ( var (
@ -20,7 +20,7 @@ var (
ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection") ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection. // ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection") ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active sessions. // ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active flows.
ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited") ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
) )
@ -44,12 +44,12 @@ type sessionManager struct {
sessions map[RequestID]Session sessions map[RequestID]Session
mutex sync.RWMutex mutex sync.RWMutex
originDialer DialUDP originDialer DialUDP
limiter cfdsession.Limiter limiter cfdflow.Limiter
metrics Metrics metrics Metrics
log *zerolog.Logger log *zerolog.Logger
} }
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdsession.Limiter) SessionManager { func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager {
return &sessionManager{ return &sessionManager{
sessions: make(map[RequestID]Session), sessions: make(map[RequestID]Session),
originDialer: originDialer, originDialer: originDialer,

View File

@ -13,14 +13,14 @@ import (
"github.com/cloudflare/cloudflared/mocks" "github.com/cloudflare/cloudflared/mocks"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
v3 "github.com/cloudflare/cloudflared/quic/v3" v3 "github.com/cloudflare/cloudflared/quic/v3"
cfdsession "github.com/cloudflare/cloudflared/session"
) )
func TestRegisterSession(t *testing.T) { func TestRegisterSession(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)) manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
request := v3.UDPSessionRegistrationDatagram{ request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID, RequestID: testRequestID,
@ -76,7 +76,7 @@ func TestRegisterSession(t *testing.T) {
func TestGetSession_Empty(t *testing.T) { func TestGetSession_Empty(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)) manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
_, err := manager.GetSession(testRequestID) _, err := manager.GetSession(testRequestID)
if !errors.Is(err, v3.ErrSessionNotFound) { if !errors.Is(err, v3.ErrSessionNotFound) {
@ -88,12 +88,12 @@ func TestRegisterSessionRateLimit(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
sessionLimiterMock := mocks.NewMockLimiter(ctrl) flowLimiterMock := mocks.NewMockLimiter(ctrl)
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions) flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
sessionLimiterMock.EXPECT().Release().Times(0) flowLimiterMock.EXPECT().Release().Times(0)
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, sessionLimiterMock) manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock)
request := v3.UDPSessionRegistrationDatagram{ request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID, RequestID: testRequestID,

View File

@ -351,7 +351,7 @@ func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, log
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) { func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy") c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
rateLimitResponse := ResponseTooManyActiveSessions rateLimitResponse := ResponseTooManyActiveFlows
err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse) err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
if err != nil { if err != nil {
logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse) logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)

View File

@ -20,10 +20,10 @@ import (
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/packet"
v3 "github.com/cloudflare/cloudflared/quic/v3" v3 "github.com/cloudflare/cloudflared/quic/v3"
cfdsession "github.com/cloudflare/cloudflared/session"
) )
type noopEyeball struct { type noopEyeball struct {
@ -88,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
func TestDatagramConn_New(t *testing.T) { func TestDatagramConn_New(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
if conn == nil { if conn == nil {
t.Fatal("expected valid connection") t.Fatal("expected valid connection")
} }
@ -97,7 +97,7 @@ func TestDatagramConn_New(t *testing.T) {
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
payload := []byte{0xef, 0xef} payload := []byte{0xef, 0xef}
err := conn.SendUDPSessionDatagram(payload) err := conn.SendUDPSessionDatagram(payload)
@ -112,7 +112,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
require.NoError(t, err) require.NoError(t, err)
@ -134,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
func TestDatagramConnServe_ApplicationClosed(t *testing.T) { func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := newMockQuicConn() quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel() defer cancel()
@ -150,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel() defer cancel()
quic.ctx = ctx quic.ctx = ctx
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.Serve(context.Background()) err := conn.Serve(context.Background())
if !errors.Is(err, context.DeadlineExceeded) { if !errors.Is(err, context.DeadlineExceeded) {
@ -161,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
quic := &mockQuicConnReadError{err: net.ErrClosed} quic := &mockQuicConnReadError{err: net.ErrClosed}
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.Serve(context.Background()) err := conn.Serve(context.Background())
if !errors.Is(err, net.ErrClosed) { if !errors.Is(err, net.ErrClosed) {
@ -198,7 +198,7 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
} }
require.EqualValues(t, testRequestID, resp.RequestID) require.EqualValues(t, testRequestID, resp.RequestID)
require.EqualValues(t, v3.ResponseTooManyActiveSessions, resp.ResponseType) require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
} }
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {

View File

@ -1,77 +0,0 @@
package session
import (
"errors"
"sync"
)
const (
unlimitedActiveSessions = 0
)
var (
ErrTooManyActiveSessions = errors.New("too many active sessions")
)
type Limiter interface {
// Acquire tries to acquire a free slot for a session, if the value of sessions is already above
// the maximum it returns ErrTooManyActiveSessions.
Acquire(sessionType string) error
// Release releases a slot for a session.
Release()
// SetLimit allows to hot swap the limit value of the limiter.
SetLimit(uint64)
}
type sessionLimiter struct {
limiterLock sync.Mutex
activeSessionsCounter uint64
maxActiveSessions uint64
unlimited bool
}
func NewLimiter(maxActiveSessions uint64) Limiter {
sessionLimiter := &sessionLimiter{
maxActiveSessions: maxActiveSessions,
unlimited: isUnlimited(maxActiveSessions),
}
return sessionLimiter
}
func (s *sessionLimiter) Acquire(sessionType string) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if !s.unlimited && s.activeSessionsCounter >= s.maxActiveSessions {
sessionRegistrationsDropped.WithLabelValues(sessionType).Inc()
return ErrTooManyActiveSessions
}
s.activeSessionsCounter++
return nil
}
func (s *sessionLimiter) Release() {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if s.activeSessionsCounter <= 0 {
return
}
s.activeSessionsCounter--
}
func (s *sessionLimiter) SetLimit(newMaxActiveSessions uint64) {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
s.maxActiveSessions = newMaxActiveSessions
s.unlimited = isUnlimited(newMaxActiveSessions)
}
// isUnlimited checks if the value received matches the configuration for the unlimited session limiter.
func isUnlimited(value uint64) bool {
return value == unlimitedActiveSessions
}

View File

@ -1,119 +0,0 @@
package session_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/session"
)
func TestSessionLimiter_Unlimited(t *testing.T) {
unlimitedLimiter := session.NewLimiter(0)
for i := 0; i < 1000; i++ {
err := unlimitedLimiter.Acquire("test")
require.NoError(t, err)
}
}
func TestSessionLimiter_Limited(t *testing.T) {
maxSessions := uint64(5)
limiter := session.NewLimiter(maxSessions)
for i := uint64(0); i < maxSessions; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
}
func TestSessionLimiter_AcquireAndReleaseSession(t *testing.T) {
maxSessions := uint64(5)
limiter := session.NewLimiter(maxSessions)
// Acquire the maximum number of sessions
for i := uint64(0); i < maxSessions; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more sessions fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
// Release the maximum number of sessions
for i := uint64(0); i < maxSessions; i++ {
limiter.Release()
}
// Validate acquire 1 more sessions works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Release a 10x the number of max sessions
for i := uint64(0); i < 10*maxSessions; i++ {
limiter.Release()
}
// Validate it still can only acquire a value = number max sessions.
for i := uint64(0); i < maxSessions; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
}
func TestSessionLimiter_SetLimit(t *testing.T) {
maxSessions := uint64(5)
limiter := session.NewLimiter(maxSessions)
// Acquire the maximum number of sessions
for i := uint64(0); i < maxSessions; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more sessions fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
// Set the session limiter to support one more request
limiter.SetLimit(maxSessions + 1)
// Validate acquire 1 more sessions now works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Validate acquire 1 more sessions doesn't work because we already reached the limit
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
// Release all sessions
for i := uint64(0); i < maxSessions+1; i++ {
limiter.Release()
}
// Validate 1 session works again
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Set the session limit to 1
limiter.SetLimit(1)
// Validate acquire 1 more sessions doesn't work
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
// Set the session limit to unlimited
limiter.SetLimit(0)
// Validate it can acquire a lot of sessions because it is now unlimited.
for i := uint64(0); i < 10*maxSessions; i++ {
err := limiter.Acquire("shouldn't fail")
require.NoError(t, err)
}
}

View File

@ -78,7 +78,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
edgeBindAddr := config.EdgeBindAddr edgeBindAddr := config.EdgeBindAddr
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer) datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetSessionLimiter()) sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
edgeTunnelServer := EdgeTunnelServer{ edgeTunnelServer := EdgeTunnelServer{
config: config, config: config,

View File

@ -617,7 +617,7 @@ func (e *EdgeTunnelServer) serveQUIC(
connIndex, connIndex,
e.config.RPCTimeout, e.config.RPCTimeout,
e.config.WriteStreamTimeout, e.config.WriteStreamTimeout,
e.orchestrator.GetSessionLimiter(), e.orchestrator.GetFlowLimiter(),
connLogger.Logger(), connLogger.Logger(),
) )
} }