TUN-8861: Rename Session Limiter to Flow Limiter
## Summary Session is the concept used for UDP flows. Therefore, to make the session limiter ambiguous for both TCP and UDP, this commit renames it to flow limiter. Closes TUN-8861
This commit is contained in:
parent
8c2eda16c1
commit
4eb0f8ce5f
|
@ -12,7 +12,7 @@ import (
|
|||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/stream"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
|
@ -107,7 +107,7 @@ func (moc *mockOriginProxy) ProxyTCP(
|
|||
r *TCPRequest,
|
||||
) error {
|
||||
if r.CfTraceID == "flow-rate-limited" {
|
||||
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
|
||||
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -336,7 +336,7 @@ func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
|
||||
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
|
||||
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
|
||||
} else {
|
||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
||||
|
|
|
@ -17,7 +17,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
|
@ -185,7 +185,7 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
|
|||
|
||||
var metadata []pogs.Metadata
|
||||
// Check the type of error that was throw and add metadata that will help identify it on OTD.
|
||||
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
|
||||
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
|
||||
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/nettest"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
|
@ -508,7 +508,7 @@ func TestBuildHTTPRequest(t *testing.T) {
|
|||
|
||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
||||
if tcpRequest.Dest == "rate-limit-me" {
|
||||
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream")
|
||||
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream")
|
||||
}
|
||||
|
||||
_ = rwa.AckConnection("")
|
||||
|
@ -828,7 +828,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
|||
conn,
|
||||
index,
|
||||
sessionManager,
|
||||
cfdsession.NewLimiter(0),
|
||||
cfdflow.NewLimiter(0),
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
15 * time.Second,
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
|
@ -46,8 +46,8 @@ type datagramV2Connection struct {
|
|||
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
sessionManager datagramsession.Manager
|
||||
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
sessionLimiter cfdsession.Limiter
|
||||
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
flowLimiter cfdflow.Limiter
|
||||
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer *cfdquic.DatagramMuxerV2
|
||||
|
@ -65,7 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
|
|||
index uint8,
|
||||
rpcTimeout time.Duration,
|
||||
streamWriteTimeout time.Duration,
|
||||
sessionLimiter cfdsession.Limiter,
|
||||
flowLimiter cfdflow.Limiter,
|
||||
logger *zerolog.Logger,
|
||||
) DatagramSessionHandler {
|
||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||
|
@ -77,7 +77,7 @@ func NewDatagramV2Connection(ctx context.Context,
|
|||
conn: conn,
|
||||
index: index,
|
||||
sessionManager: sessionManager,
|
||||
sessionLimiter: sessionLimiter,
|
||||
flowLimiter: flowLimiter,
|
||||
datagramMuxer: datagramMuxer,
|
||||
packetRouter: packetRouter,
|
||||
rpcTimeout: rpcTimeout,
|
||||
|
@ -121,7 +121,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
|||
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
||||
|
||||
// Try to start a new session
|
||||
if err := q.sessionLimiter.Acquire(management.UDP.String()); err != nil {
|
||||
if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
|
||||
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
|
||||
|
||||
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
|
||||
|
@ -135,7 +135,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
|||
if err != nil {
|
||||
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.sessionLimiter.Release()
|
||||
q.flowLimiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
registerSpan.SetAttributes(
|
||||
|
@ -148,12 +148,12 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
|||
originProxy.Close()
|
||||
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.sessionLimiter.Release()
|
||||
q.flowLimiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer q.sessionLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
|
||||
defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
|
||||
q.serveUDPSession(session, closeAfterIdleHint)
|
||||
}()
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
)
|
||||
|
||||
type mockQuicConnection struct {
|
||||
|
@ -75,7 +75,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
|
|||
log := zerolog.Nop()
|
||||
conn := &mockQuicConnection{}
|
||||
ctrl := gomock.NewController(t)
|
||||
sessionLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
flowLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
|
||||
datagramConn := NewDatagramV2Connection(
|
||||
context.Background(),
|
||||
|
@ -84,13 +84,13 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
|
|||
0,
|
||||
0*time.Second,
|
||||
0*time.Second,
|
||||
sessionLimiterMock,
|
||||
flowLimiterMock,
|
||||
&log,
|
||||
)
|
||||
|
||||
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
|
||||
sessionLimiterMock.EXPECT().Release().Times(0)
|
||||
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
|
||||
flowLimiterMock.EXPECT().Release().Times(0)
|
||||
|
||||
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
|
||||
require.ErrorIs(t, err, cfdsession.ErrTooManyActiveSessions)
|
||||
require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
package flow
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
unlimitedActiveFlows = 0
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTooManyActiveFlows = errors.New("too many active flows")
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
|
||||
// the maximum it returns ErrTooManyActiveFlows.
|
||||
Acquire(flowType string) error
|
||||
// Release releases a slot for a flow.
|
||||
Release()
|
||||
// SetLimit allows to hot swap the limit value of the limiter.
|
||||
SetLimit(uint64)
|
||||
}
|
||||
|
||||
type flowLimiter struct {
|
||||
limiterLock sync.Mutex
|
||||
activeFlowsCounter uint64
|
||||
maxActiveFlows uint64
|
||||
unlimited bool
|
||||
}
|
||||
|
||||
func NewLimiter(maxActiveFlows uint64) Limiter {
|
||||
flowLimiter := &flowLimiter{
|
||||
maxActiveFlows: maxActiveFlows,
|
||||
unlimited: isUnlimited(maxActiveFlows),
|
||||
}
|
||||
|
||||
return flowLimiter
|
||||
}
|
||||
|
||||
func (s *flowLimiter) Acquire(flowType string) error {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
|
||||
flowRegistrationsDropped.WithLabelValues(flowType).Inc()
|
||||
return ErrTooManyActiveFlows
|
||||
}
|
||||
|
||||
s.activeFlowsCounter++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *flowLimiter) Release() {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
if s.activeFlowsCounter <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.activeFlowsCounter--
|
||||
}
|
||||
|
||||
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
s.maxActiveFlows = newMaxActiveFlows
|
||||
s.unlimited = isUnlimited(newMaxActiveFlows)
|
||||
}
|
||||
|
||||
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
|
||||
func isUnlimited(value uint64) bool {
|
||||
return value == unlimitedActiveFlows
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package flow_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/flow"
|
||||
)
|
||||
|
||||
func TestFlowLimiter_Unlimited(t *testing.T) {
|
||||
unlimitedLimiter := flow.NewLimiter(0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
err := unlimitedLimiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlowLimiter_Limited(t *testing.T) {
|
||||
maxFlows := uint64(5)
|
||||
limiter := flow.NewLimiter(maxFlows)
|
||||
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
}
|
||||
|
||||
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
|
||||
maxFlows := uint64(5)
|
||||
limiter := flow.NewLimiter(maxFlows)
|
||||
|
||||
// Acquire the maximum number of flows
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Validate acquire 1 more flows fails
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Release the maximum number of flows
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate acquire 1 more flows works
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Release a 10x the number of max flows
|
||||
for i := uint64(0); i < 10*maxFlows; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate it still can only acquire a value = number max flows.
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
}
|
||||
|
||||
func TestFlowLimiter_SetLimit(t *testing.T) {
|
||||
maxFlows := uint64(5)
|
||||
limiter := flow.NewLimiter(maxFlows)
|
||||
|
||||
// Acquire the maximum number of flows
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Validate acquire 1 more flows fails
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Set the flow limiter to support one more request
|
||||
limiter.SetLimit(maxFlows + 1)
|
||||
|
||||
// Validate acquire 1 more flows now works
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate acquire 1 more flows doesn't work because we already reached the limit
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Release all flows
|
||||
for i := uint64(0); i < maxFlows+1; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate 1 flow works again
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set the flow limit to 1
|
||||
limiter.SetLimit(1)
|
||||
|
||||
// Validate acquire 1 more flows doesn't work
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Set the flow limit to unlimited
|
||||
limiter.SetLimit(0)
|
||||
|
||||
// Validate it can acquire a lot of flows because it is now unlimited.
|
||||
for i := uint64(0); i < 10*maxFlows; i++ {
|
||||
err := limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package session
|
||||
package flow
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
@ -6,17 +6,17 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
namespace = "session"
|
||||
namespace = "flow"
|
||||
)
|
||||
|
||||
var (
|
||||
labels = []string{"session_type"}
|
||||
labels = []string{"flow_type"}
|
||||
|
||||
sessionRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: "client",
|
||||
Name: "registrations_rate_limited_total",
|
||||
Help: "Count registrations dropped due to high number of concurrent sessions being handled",
|
||||
Help: "Count registrations dropped due to high number of concurrent flows being handled",
|
||||
},
|
||||
labels,
|
||||
)
|
|
@ -1,9 +1,9 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ../session/limiter.go
|
||||
// Source: ../flow/limiter.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter
|
||||
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
|
@ -40,17 +40,17 @@ func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder {
|
|||
}
|
||||
|
||||
// Acquire mocks base method.
|
||||
func (m *MockLimiter) Acquire(sessionType string) error {
|
||||
func (m *MockLimiter) Acquire(flowType string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Acquire", sessionType)
|
||||
ret := m.ctrl.Call(m, "Acquire", flowType)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Acquire indicates an expected call of Acquire.
|
||||
func (mr *MockLimiterMockRecorder) Acquire(sessionType any) *MockLimiterAcquireCall {
|
||||
func (mr *MockLimiterMockRecorder) Acquire(flowType any) *MockLimiterAcquireCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), sessionType)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), flowType)
|
||||
return &MockLimiterAcquireCall{Call: call}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
package mocks
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter"
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
|
@ -35,9 +35,9 @@ type Orchestrator struct {
|
|||
// cloudflared Configuration
|
||||
config *Config
|
||||
tags []pogs.Tag
|
||||
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
sessionLimiter cfdsession.Limiter
|
||||
log *zerolog.Logger
|
||||
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
flowLimiter cfdflow.Limiter
|
||||
log *zerolog.Logger
|
||||
|
||||
// orchestrator must not handle any more updates after shutdownC is closed
|
||||
shutdownC <-chan struct{}
|
||||
|
@ -58,7 +58,7 @@ func NewOrchestrator(ctx context.Context,
|
|||
internalRules: internalRules,
|
||||
config: config,
|
||||
tags: tags,
|
||||
sessionLimiter: cfdsession.NewLimiter(config.WarpRouting.MaxActiveFlows),
|
||||
flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows),
|
||||
log: log,
|
||||
shutdownC: ctx.Done(),
|
||||
}
|
||||
|
@ -142,10 +142,10 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
|
|||
return errors.Wrap(err, "failed to start origin")
|
||||
}
|
||||
|
||||
// Update the sessions limit since the configuration might have changed
|
||||
o.sessionLimiter.SetLimit(warpRouting.MaxActiveFlows)
|
||||
// Update the flow limit since the configuration might have changed
|
||||
o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows)
|
||||
|
||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log)
|
||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log)
|
||||
o.proxy.Store(proxy)
|
||||
o.config.Ingress = &ingressRules
|
||||
o.config.WarpRouting = warpRouting
|
||||
|
@ -217,10 +217,10 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
|
|||
return proxy, nil
|
||||
}
|
||||
|
||||
// GetSessionLimiter returns the session limiter used across cloudflared, that can be hot reload when
|
||||
// GetFlowLimiter returns the flow limiter used across cloudflared, that can be hot reload when
|
||||
// the configuration changes.
|
||||
func (o *Orchestrator) GetSessionLimiter() cfdsession.Limiter {
|
||||
return o.sessionLimiter
|
||||
func (o *Orchestrator) GetFlowLimiter() cfdflow.Limiter {
|
||||
return o.flowLimiter
|
||||
}
|
||||
|
||||
func (o *Orchestrator) waitToCloseLastProxy() {
|
||||
|
|
|
@ -14,8 +14,8 @@ import (
|
|||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
|
@ -34,11 +34,11 @@ const (
|
|||
|
||||
// Proxy represents a means to Proxy between cloudflared and the origin services.
|
||||
type Proxy struct {
|
||||
ingressRules ingress.Ingress
|
||||
warpRouting *ingress.WarpRoutingService
|
||||
tags []pogs.Tag
|
||||
sessionLimiter cfdsession.Limiter
|
||||
log *zerolog.Logger
|
||||
ingressRules ingress.Ingress
|
||||
warpRouting *ingress.WarpRoutingService
|
||||
tags []pogs.Tag
|
||||
flowLimiter cfdflow.Limiter
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
// NewOriginProxy returns a new instance of the Proxy struct.
|
||||
|
@ -46,15 +46,15 @@ func NewOriginProxy(
|
|||
ingressRules ingress.Ingress,
|
||||
warpRouting ingress.WarpRoutingConfig,
|
||||
tags []pogs.Tag,
|
||||
sessionLimiter cfdsession.Limiter,
|
||||
flowLimiter cfdflow.Limiter,
|
||||
writeTimeout time.Duration,
|
||||
log *zerolog.Logger,
|
||||
) *Proxy {
|
||||
proxy := &Proxy{
|
||||
ingressRules: ingressRules,
|
||||
tags: tags,
|
||||
sessionLimiter: sessionLimiter,
|
||||
log: log,
|
||||
ingressRules: ingressRules,
|
||||
tags: tags,
|
||||
flowLimiter: flowLimiter,
|
||||
log: log,
|
||||
}
|
||||
|
||||
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
|
||||
|
@ -160,12 +160,12 @@ func (p *Proxy) ProxyTCP(
|
|||
|
||||
logger := newTCPLogger(p.log, req)
|
||||
|
||||
// Try to start a new session
|
||||
if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil {
|
||||
logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy")
|
||||
return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting")
|
||||
// Try to start a new flow
|
||||
if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil {
|
||||
logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy")
|
||||
return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting")
|
||||
}
|
||||
defer p.sessionLimiter.Release()
|
||||
defer p.flowLimiter.Release()
|
||||
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
|
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
|
|||
|
||||
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||
t.Run("testProxySSE", testProxySSE(proxy))
|
||||
|
@ -368,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
||||
|
||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||
|
||||
for _, test := range tests {
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
|
@ -416,7 +416,7 @@ func TestProxyError(t *testing.T) {
|
|||
|
||||
log := zerolog.Nop()
|
||||
|
||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
|
@ -484,8 +484,8 @@ func TestConnections(t *testing.T) {
|
|||
// requestheaders to be sent in the call to proxy.Proxy
|
||||
requestHeaders http.Header
|
||||
|
||||
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
|
||||
sessionLimiterResponse error
|
||||
// flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
|
||||
flowLimiterResponse error
|
||||
}
|
||||
|
||||
type want struct {
|
||||
|
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
|
|||
requestHeaders: map[string][]string{
|
||||
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||
},
|
||||
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
|
||||
flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
|
||||
},
|
||||
want: want{
|
||||
message: []byte{},
|
||||
|
@ -695,14 +695,14 @@ func TestConnections(t *testing.T) {
|
|||
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
||||
_ = ingressRule.StartOrigins(logger, ctx.Done())
|
||||
|
||||
// Mock session limiter
|
||||
// Mock flow limiter
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
sessionLimiter := mocks.NewMockLimiter(ctrl)
|
||||
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
|
||||
sessionLimiter.EXPECT().Release().AnyTimes()
|
||||
flowLimiter := mocks.NewMockLimiter(ctrl)
|
||||
flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
|
||||
flowLimiter.EXPECT().Release().AnyTimes()
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
|
||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
|
||||
proxy.warpRouting = test.args.warpRoutingService
|
||||
|
||||
dest := ln.Addr().String()
|
||||
|
|
|
@ -284,8 +284,8 @@ const (
|
|||
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
|
||||
// Session registration was unable to bind to a local UDP socket.
|
||||
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
|
||||
// Session registration failed due to the number of session being higher than the limit.
|
||||
ResponseTooManyActiveSessions SessionRegistrationResp = 0x03
|
||||
// Session registration failed due to the number of flows being higher than the limit.
|
||||
ResponseTooManyActiveFlows SessionRegistrationResp = 0x03
|
||||
// Session registration failed with an unexpected error but provided a message.
|
||||
ResponseErrorWithMsg SessionRegistrationResp = 0xff
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -20,7 +20,7 @@ var (
|
|||
ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
|
||||
// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
|
||||
ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
|
||||
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active sessions.
|
||||
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active flows.
|
||||
ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
|
||||
)
|
||||
|
||||
|
@ -44,12 +44,12 @@ type sessionManager struct {
|
|||
sessions map[RequestID]Session
|
||||
mutex sync.RWMutex
|
||||
originDialer DialUDP
|
||||
limiter cfdsession.Limiter
|
||||
limiter cfdflow.Limiter
|
||||
metrics Metrics
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdsession.Limiter) SessionManager {
|
||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager {
|
||||
return &sessionManager{
|
||||
sessions: make(map[RequestID]Session),
|
||||
originDialer: originDialer,
|
||||
|
|
|
@ -13,14 +13,14 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
)
|
||||
|
||||
func TestRegisterSession(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0))
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
|
||||
|
||||
request := v3.UDPSessionRegistrationDatagram{
|
||||
RequestID: testRequestID,
|
||||
|
@ -76,7 +76,7 @@ func TestRegisterSession(t *testing.T) {
|
|||
|
||||
func TestGetSession_Empty(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0))
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
|
||||
|
||||
_, err := manager.GetSession(testRequestID)
|
||||
if !errors.Is(err, v3.ErrSessionNotFound) {
|
||||
|
@ -88,12 +88,12 @@ func TestRegisterSessionRateLimit(t *testing.T) {
|
|||
log := zerolog.Nop()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
sessionLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
flowLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
|
||||
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
|
||||
sessionLimiterMock.EXPECT().Release().Times(0)
|
||||
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
|
||||
flowLimiterMock.EXPECT().Release().Times(0)
|
||||
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, sessionLimiterMock)
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock)
|
||||
|
||||
request := v3.UDPSessionRegistrationDatagram{
|
||||
RequestID: testRequestID,
|
||||
|
|
|
@ -351,7 +351,7 @@ func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, log
|
|||
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
|
||||
c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
|
||||
|
||||
rateLimitResponse := ResponseTooManyActiveSessions
|
||||
rateLimitResponse := ResponseTooManyActiveFlows
|
||||
err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)
|
||||
|
|
|
@ -20,10 +20,10 @@ import (
|
|||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
)
|
||||
|
||||
type noopEyeball struct {
|
||||
|
@ -88,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
|
|||
|
||||
func TestDatagramConn_New(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
if conn == nil {
|
||||
t.Fatal("expected valid connection")
|
||||
}
|
||||
|
@ -97,7 +97,7 @@ func TestDatagramConn_New(t *testing.T) {
|
|||
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
payload := []byte{0xef, 0xef}
|
||||
err := conn.SendUDPSessionDatagram(payload)
|
||||
|
@ -112,7 +112,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
|||
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
|
||||
require.NoError(t, err)
|
||||
|
@ -134,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
|||
func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
@ -150,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
quic.ctx = ctx
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.Serve(context.Background())
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
|
@ -161,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
|||
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := &mockQuicConnReadError{err: net.ErrClosed}
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.Serve(context.Background())
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
|
@ -198,7 +198,7 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
require.EqualValues(t, testRequestID, resp.RequestID)
|
||||
require.EqualValues(t, v3.ResponseTooManyActiveSessions, resp.ResponseType)
|
||||
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
unlimitedActiveSessions = 0
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTooManyActiveSessions = errors.New("too many active sessions")
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
// Acquire tries to acquire a free slot for a session, if the value of sessions is already above
|
||||
// the maximum it returns ErrTooManyActiveSessions.
|
||||
Acquire(sessionType string) error
|
||||
// Release releases a slot for a session.
|
||||
Release()
|
||||
// SetLimit allows to hot swap the limit value of the limiter.
|
||||
SetLimit(uint64)
|
||||
}
|
||||
|
||||
type sessionLimiter struct {
|
||||
limiterLock sync.Mutex
|
||||
activeSessionsCounter uint64
|
||||
maxActiveSessions uint64
|
||||
unlimited bool
|
||||
}
|
||||
|
||||
func NewLimiter(maxActiveSessions uint64) Limiter {
|
||||
sessionLimiter := &sessionLimiter{
|
||||
maxActiveSessions: maxActiveSessions,
|
||||
unlimited: isUnlimited(maxActiveSessions),
|
||||
}
|
||||
|
||||
return sessionLimiter
|
||||
}
|
||||
|
||||
func (s *sessionLimiter) Acquire(sessionType string) error {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
if !s.unlimited && s.activeSessionsCounter >= s.maxActiveSessions {
|
||||
sessionRegistrationsDropped.WithLabelValues(sessionType).Inc()
|
||||
return ErrTooManyActiveSessions
|
||||
}
|
||||
|
||||
s.activeSessionsCounter++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sessionLimiter) Release() {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
if s.activeSessionsCounter <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.activeSessionsCounter--
|
||||
}
|
||||
|
||||
func (s *sessionLimiter) SetLimit(newMaxActiveSessions uint64) {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
s.maxActiveSessions = newMaxActiveSessions
|
||||
s.unlimited = isUnlimited(newMaxActiveSessions)
|
||||
}
|
||||
|
||||
// isUnlimited checks if the value received matches the configuration for the unlimited session limiter.
|
||||
func isUnlimited(value uint64) bool {
|
||||
return value == unlimitedActiveSessions
|
||||
}
|
|
@ -1,119 +0,0 @@
|
|||
package session_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/session"
|
||||
)
|
||||
|
||||
func TestSessionLimiter_Unlimited(t *testing.T) {
|
||||
unlimitedLimiter := session.NewLimiter(0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
err := unlimitedLimiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLimiter_Limited(t *testing.T) {
|
||||
maxSessions := uint64(5)
|
||||
limiter := session.NewLimiter(maxSessions)
|
||||
|
||||
for i := uint64(0); i < maxSessions; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
||||
}
|
||||
|
||||
func TestSessionLimiter_AcquireAndReleaseSession(t *testing.T) {
|
||||
maxSessions := uint64(5)
|
||||
limiter := session.NewLimiter(maxSessions)
|
||||
|
||||
// Acquire the maximum number of sessions
|
||||
for i := uint64(0); i < maxSessions; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Validate acquire 1 more sessions fails
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
||||
|
||||
// Release the maximum number of sessions
|
||||
for i := uint64(0); i < maxSessions; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate acquire 1 more sessions works
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Release a 10x the number of max sessions
|
||||
for i := uint64(0); i < 10*maxSessions; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate it still can only acquire a value = number max sessions.
|
||||
for i := uint64(0); i < maxSessions; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
||||
}
|
||||
|
||||
func TestSessionLimiter_SetLimit(t *testing.T) {
|
||||
maxSessions := uint64(5)
|
||||
limiter := session.NewLimiter(maxSessions)
|
||||
|
||||
// Acquire the maximum number of sessions
|
||||
for i := uint64(0); i < maxSessions; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Validate acquire 1 more sessions fails
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
||||
|
||||
// Set the session limiter to support one more request
|
||||
limiter.SetLimit(maxSessions + 1)
|
||||
|
||||
// Validate acquire 1 more sessions now works
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate acquire 1 more sessions doesn't work because we already reached the limit
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
||||
|
||||
// Release all sessions
|
||||
for i := uint64(0); i < maxSessions+1; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate 1 session works again
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set the session limit to 1
|
||||
limiter.SetLimit(1)
|
||||
|
||||
// Validate acquire 1 more sessions doesn't work
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
||||
|
||||
// Set the session limit to unlimited
|
||||
limiter.SetLimit(0)
|
||||
|
||||
// Validate it can acquire a lot of sessions because it is now unlimited.
|
||||
for i := uint64(0); i < 10*maxSessions; i++ {
|
||||
err := limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
|
@ -78,7 +78,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
|||
edgeBindAddr := config.EdgeBindAddr
|
||||
|
||||
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
|
||||
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetSessionLimiter())
|
||||
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
|
||||
|
||||
edgeTunnelServer := EdgeTunnelServer{
|
||||
config: config,
|
||||
|
|
|
@ -617,7 +617,7 @@ func (e *EdgeTunnelServer) serveQUIC(
|
|||
connIndex,
|
||||
e.config.RPCTimeout,
|
||||
e.config.WriteStreamTimeout,
|
||||
e.orchestrator.GetSessionLimiter(),
|
||||
e.orchestrator.GetFlowLimiter(),
|
||||
connLogger.Logger(),
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue