TUN-8861: Rename Session Limiter to Flow Limiter
## Summary Session is the concept used for UDP flows. Therefore, to make the session limiter ambiguous for both TCP and UDP, this commit renames it to flow limiter. Closes TUN-8861
This commit is contained in:
parent
8c2eda16c1
commit
4eb0f8ce5f
|
@ -12,7 +12,7 @@ import (
|
||||||
pkgerrors "github.com/pkg/errors"
|
pkgerrors "github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/stream"
|
"github.com/cloudflare/cloudflared/stream"
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
|
@ -107,7 +107,7 @@ func (moc *mockOriginProxy) ProxyTCP(
|
||||||
r *TCPRequest,
|
r *TCPRequest,
|
||||||
) error {
|
) error {
|
||||||
if r.CfTraceID == "flow-rate-limited" {
|
if r.CfTraceID == "flow-rate-limited" {
|
||||||
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
|
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -336,7 +336,7 @@ func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
|
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
|
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
|
||||||
} else {
|
} else {
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
||||||
|
|
|
@ -17,7 +17,7 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
|
@ -185,7 +185,7 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
|
||||||
|
|
||||||
var metadata []pogs.Metadata
|
var metadata []pogs.Metadata
|
||||||
// Check the type of error that was throw and add metadata that will help identify it on OTD.
|
// Check the type of error that was throw and add metadata that will help identify it on OTD.
|
||||||
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
|
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
|
||||||
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
|
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/net/nettest"
|
"golang.org/x/net/nettest"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/datagramsession"
|
"github.com/cloudflare/cloudflared/datagramsession"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
|
@ -508,7 +508,7 @@ func TestBuildHTTPRequest(t *testing.T) {
|
||||||
|
|
||||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
||||||
if tcpRequest.Dest == "rate-limit-me" {
|
if tcpRequest.Dest == "rate-limit-me" {
|
||||||
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream")
|
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = rwa.AckConnection("")
|
_ = rwa.AckConnection("")
|
||||||
|
@ -828,7 +828,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
||||||
conn,
|
conn,
|
||||||
index,
|
index,
|
||||||
sessionManager,
|
sessionManager,
|
||||||
cfdsession.NewLimiter(0),
|
cfdflow.NewLimiter(0),
|
||||||
datagramMuxer,
|
datagramMuxer,
|
||||||
packetRouter,
|
packetRouter,
|
||||||
15 * time.Second,
|
15 * time.Second,
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/datagramsession"
|
"github.com/cloudflare/cloudflared/datagramsession"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
|
@ -46,8 +46,8 @@ type datagramV2Connection struct {
|
||||||
|
|
||||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||||
sessionManager datagramsession.Manager
|
sessionManager datagramsession.Manager
|
||||||
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||||
sessionLimiter cfdsession.Limiter
|
flowLimiter cfdflow.Limiter
|
||||||
|
|
||||||
// datagramMuxer mux/demux datagrams from quic connection
|
// datagramMuxer mux/demux datagrams from quic connection
|
||||||
datagramMuxer *cfdquic.DatagramMuxerV2
|
datagramMuxer *cfdquic.DatagramMuxerV2
|
||||||
|
@ -65,7 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
|
||||||
index uint8,
|
index uint8,
|
||||||
rpcTimeout time.Duration,
|
rpcTimeout time.Duration,
|
||||||
streamWriteTimeout time.Duration,
|
streamWriteTimeout time.Duration,
|
||||||
sessionLimiter cfdsession.Limiter,
|
flowLimiter cfdflow.Limiter,
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
) DatagramSessionHandler {
|
) DatagramSessionHandler {
|
||||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||||
|
@ -77,7 +77,7 @@ func NewDatagramV2Connection(ctx context.Context,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
index: index,
|
index: index,
|
||||||
sessionManager: sessionManager,
|
sessionManager: sessionManager,
|
||||||
sessionLimiter: sessionLimiter,
|
flowLimiter: flowLimiter,
|
||||||
datagramMuxer: datagramMuxer,
|
datagramMuxer: datagramMuxer,
|
||||||
packetRouter: packetRouter,
|
packetRouter: packetRouter,
|
||||||
rpcTimeout: rpcTimeout,
|
rpcTimeout: rpcTimeout,
|
||||||
|
@ -121,7 +121,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
||||||
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
||||||
|
|
||||||
// Try to start a new session
|
// Try to start a new session
|
||||||
if err := q.sessionLimiter.Acquire(management.UDP.String()); err != nil {
|
if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
|
||||||
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
|
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
|
||||||
|
|
||||||
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
|
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
|
||||||
|
@ -135,7 +135,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||||
tracing.EndWithErrorStatus(registerSpan, err)
|
tracing.EndWithErrorStatus(registerSpan, err)
|
||||||
q.sessionLimiter.Release()
|
q.flowLimiter.Release()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
registerSpan.SetAttributes(
|
registerSpan.SetAttributes(
|
||||||
|
@ -148,12 +148,12 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
||||||
originProxy.Close()
|
originProxy.Close()
|
||||||
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
||||||
tracing.EndWithErrorStatus(registerSpan, err)
|
tracing.EndWithErrorStatus(registerSpan, err)
|
||||||
q.sessionLimiter.Release()
|
q.flowLimiter.Release()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer q.sessionLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
|
defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
|
||||||
q.serveUDPSession(session, closeAfterIdleHint)
|
q.serveUDPSession(session, closeAfterIdleHint)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
|
|
||||||
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
"github.com/cloudflare/cloudflared/mocks"
|
"github.com/cloudflare/cloudflared/mocks"
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockQuicConnection struct {
|
type mockQuicConnection struct {
|
||||||
|
@ -75,7 +75,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
conn := &mockQuicConnection{}
|
conn := &mockQuicConnection{}
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
sessionLimiterMock := mocks.NewMockLimiter(ctrl)
|
flowLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||||
|
|
||||||
datagramConn := NewDatagramV2Connection(
|
datagramConn := NewDatagramV2Connection(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
|
@ -84,13 +84,13 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
|
||||||
0,
|
0,
|
||||||
0*time.Second,
|
0*time.Second,
|
||||||
0*time.Second,
|
0*time.Second,
|
||||||
sessionLimiterMock,
|
flowLimiterMock,
|
||||||
&log,
|
&log,
|
||||||
)
|
)
|
||||||
|
|
||||||
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
|
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
|
||||||
sessionLimiterMock.EXPECT().Release().Times(0)
|
flowLimiterMock.EXPECT().Release().Times(0)
|
||||||
|
|
||||||
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
|
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
|
||||||
require.ErrorIs(t, err, cfdsession.ErrTooManyActiveSessions)
|
require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
package flow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
unlimitedActiveFlows = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTooManyActiveFlows = errors.New("too many active flows")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Limiter interface {
|
||||||
|
// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
|
||||||
|
// the maximum it returns ErrTooManyActiveFlows.
|
||||||
|
Acquire(flowType string) error
|
||||||
|
// Release releases a slot for a flow.
|
||||||
|
Release()
|
||||||
|
// SetLimit allows to hot swap the limit value of the limiter.
|
||||||
|
SetLimit(uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowLimiter struct {
|
||||||
|
limiterLock sync.Mutex
|
||||||
|
activeFlowsCounter uint64
|
||||||
|
maxActiveFlows uint64
|
||||||
|
unlimited bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimiter(maxActiveFlows uint64) Limiter {
|
||||||
|
flowLimiter := &flowLimiter{
|
||||||
|
maxActiveFlows: maxActiveFlows,
|
||||||
|
unlimited: isUnlimited(maxActiveFlows),
|
||||||
|
}
|
||||||
|
|
||||||
|
return flowLimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *flowLimiter) Acquire(flowType string) error {
|
||||||
|
s.limiterLock.Lock()
|
||||||
|
defer s.limiterLock.Unlock()
|
||||||
|
|
||||||
|
if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
|
||||||
|
flowRegistrationsDropped.WithLabelValues(flowType).Inc()
|
||||||
|
return ErrTooManyActiveFlows
|
||||||
|
}
|
||||||
|
|
||||||
|
s.activeFlowsCounter++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *flowLimiter) Release() {
|
||||||
|
s.limiterLock.Lock()
|
||||||
|
defer s.limiterLock.Unlock()
|
||||||
|
|
||||||
|
if s.activeFlowsCounter <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.activeFlowsCounter--
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
|
||||||
|
s.limiterLock.Lock()
|
||||||
|
defer s.limiterLock.Unlock()
|
||||||
|
|
||||||
|
s.maxActiveFlows = newMaxActiveFlows
|
||||||
|
s.unlimited = isUnlimited(newMaxActiveFlows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
|
||||||
|
func isUnlimited(value uint64) bool {
|
||||||
|
return value == unlimitedActiveFlows
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
package flow_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/flow"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlowLimiter_Unlimited(t *testing.T) {
|
||||||
|
unlimitedLimiter := flow.NewLimiter(0)
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
err := unlimitedLimiter.Acquire("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlowLimiter_Limited(t *testing.T) {
|
||||||
|
maxFlows := uint64(5)
|
||||||
|
limiter := flow.NewLimiter(maxFlows)
|
||||||
|
|
||||||
|
for i := uint64(0); i < maxFlows; i++ {
|
||||||
|
err := limiter.Acquire("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := limiter.Acquire("should fail")
|
||||||
|
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
|
||||||
|
maxFlows := uint64(5)
|
||||||
|
limiter := flow.NewLimiter(maxFlows)
|
||||||
|
|
||||||
|
// Acquire the maximum number of flows
|
||||||
|
for i := uint64(0); i < maxFlows; i++ {
|
||||||
|
err := limiter.Acquire("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate acquire 1 more flows fails
|
||||||
|
err := limiter.Acquire("should fail")
|
||||||
|
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||||
|
|
||||||
|
// Release the maximum number of flows
|
||||||
|
for i := uint64(0); i < maxFlows; i++ {
|
||||||
|
limiter.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate acquire 1 more flows works
|
||||||
|
err = limiter.Acquire("shouldn't fail")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Release a 10x the number of max flows
|
||||||
|
for i := uint64(0); i < 10*maxFlows; i++ {
|
||||||
|
limiter.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate it still can only acquire a value = number max flows.
|
||||||
|
for i := uint64(0); i < maxFlows; i++ {
|
||||||
|
err := limiter.Acquire("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
err = limiter.Acquire("should fail")
|
||||||
|
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlowLimiter_SetLimit(t *testing.T) {
|
||||||
|
maxFlows := uint64(5)
|
||||||
|
limiter := flow.NewLimiter(maxFlows)
|
||||||
|
|
||||||
|
// Acquire the maximum number of flows
|
||||||
|
for i := uint64(0); i < maxFlows; i++ {
|
||||||
|
err := limiter.Acquire("test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate acquire 1 more flows fails
|
||||||
|
err := limiter.Acquire("should fail")
|
||||||
|
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||||
|
|
||||||
|
// Set the flow limiter to support one more request
|
||||||
|
limiter.SetLimit(maxFlows + 1)
|
||||||
|
|
||||||
|
// Validate acquire 1 more flows now works
|
||||||
|
err = limiter.Acquire("shouldn't fail")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Validate acquire 1 more flows doesn't work because we already reached the limit
|
||||||
|
err = limiter.Acquire("should fail")
|
||||||
|
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||||
|
|
||||||
|
// Release all flows
|
||||||
|
for i := uint64(0); i < maxFlows+1; i++ {
|
||||||
|
limiter.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 1 flow works again
|
||||||
|
err = limiter.Acquire("shouldn't fail")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Set the flow limit to 1
|
||||||
|
limiter.SetLimit(1)
|
||||||
|
|
||||||
|
// Validate acquire 1 more flows doesn't work
|
||||||
|
err = limiter.Acquire("should fail")
|
||||||
|
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||||
|
|
||||||
|
// Set the flow limit to unlimited
|
||||||
|
limiter.SetLimit(0)
|
||||||
|
|
||||||
|
// Validate it can acquire a lot of flows because it is now unlimited.
|
||||||
|
for i := uint64(0); i < 10*maxFlows; i++ {
|
||||||
|
err := limiter.Acquire("shouldn't fail")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package session
|
package flow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
@ -6,17 +6,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
namespace = "session"
|
namespace = "flow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
labels = []string{"session_type"}
|
labels = []string{"flow_type"}
|
||||||
|
|
||||||
sessionRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
|
flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
Namespace: namespace,
|
Namespace: namespace,
|
||||||
Subsystem: "client",
|
Subsystem: "client",
|
||||||
Name: "registrations_rate_limited_total",
|
Name: "registrations_rate_limited_total",
|
||||||
Help: "Count registrations dropped due to high number of concurrent sessions being handled",
|
Help: "Count registrations dropped due to high number of concurrent flows being handled",
|
||||||
},
|
},
|
||||||
labels,
|
labels,
|
||||||
)
|
)
|
|
@ -1,9 +1,9 @@
|
||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: ../session/limiter.go
|
// Source: ../flow/limiter.go
|
||||||
//
|
//
|
||||||
// Generated by this command:
|
// Generated by this command:
|
||||||
//
|
//
|
||||||
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter
|
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter
|
||||||
//
|
//
|
||||||
|
|
||||||
// Package mocks is a generated GoMock package.
|
// Package mocks is a generated GoMock package.
|
||||||
|
@ -40,17 +40,17 @@ func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Acquire mocks base method.
|
// Acquire mocks base method.
|
||||||
func (m *MockLimiter) Acquire(sessionType string) error {
|
func (m *MockLimiter) Acquire(flowType string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Acquire", sessionType)
|
ret := m.ctrl.Call(m, "Acquire", flowType)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Acquire indicates an expected call of Acquire.
|
// Acquire indicates an expected call of Acquire.
|
||||||
func (mr *MockLimiterMockRecorder) Acquire(sessionType any) *MockLimiterAcquireCall {
|
func (mr *MockLimiterMockRecorder) Acquire(flowType any) *MockLimiterAcquireCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), sessionType)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), flowType)
|
||||||
return &MockLimiterAcquireCall{Call: call}
|
return &MockLimiterAcquireCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter"
|
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
@ -35,8 +35,8 @@ type Orchestrator struct {
|
||||||
// cloudflared Configuration
|
// cloudflared Configuration
|
||||||
config *Config
|
config *Config
|
||||||
tags []pogs.Tag
|
tags []pogs.Tag
|
||||||
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||||
sessionLimiter cfdsession.Limiter
|
flowLimiter cfdflow.Limiter
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
|
|
||||||
// orchestrator must not handle any more updates after shutdownC is closed
|
// orchestrator must not handle any more updates after shutdownC is closed
|
||||||
|
@ -58,7 +58,7 @@ func NewOrchestrator(ctx context.Context,
|
||||||
internalRules: internalRules,
|
internalRules: internalRules,
|
||||||
config: config,
|
config: config,
|
||||||
tags: tags,
|
tags: tags,
|
||||||
sessionLimiter: cfdsession.NewLimiter(config.WarpRouting.MaxActiveFlows),
|
flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows),
|
||||||
log: log,
|
log: log,
|
||||||
shutdownC: ctx.Done(),
|
shutdownC: ctx.Done(),
|
||||||
}
|
}
|
||||||
|
@ -142,10 +142,10 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
|
||||||
return errors.Wrap(err, "failed to start origin")
|
return errors.Wrap(err, "failed to start origin")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the sessions limit since the configuration might have changed
|
// Update the flow limit since the configuration might have changed
|
||||||
o.sessionLimiter.SetLimit(warpRouting.MaxActiveFlows)
|
o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows)
|
||||||
|
|
||||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log)
|
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log)
|
||||||
o.proxy.Store(proxy)
|
o.proxy.Store(proxy)
|
||||||
o.config.Ingress = &ingressRules
|
o.config.Ingress = &ingressRules
|
||||||
o.config.WarpRouting = warpRouting
|
o.config.WarpRouting = warpRouting
|
||||||
|
@ -217,10 +217,10 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
|
||||||
return proxy, nil
|
return proxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionLimiter returns the session limiter used across cloudflared, that can be hot reload when
|
// GetFlowLimiter returns the flow limiter used across cloudflared, that can be hot reload when
|
||||||
// the configuration changes.
|
// the configuration changes.
|
||||||
func (o *Orchestrator) GetSessionLimiter() cfdsession.Limiter {
|
func (o *Orchestrator) GetFlowLimiter() cfdflow.Limiter {
|
||||||
return o.sessionLimiter
|
return o.flowLimiter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Orchestrator) waitToCloseLastProxy() {
|
func (o *Orchestrator) waitToCloseLastProxy() {
|
||||||
|
|
|
@ -14,8 +14,8 @@ import (
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
|
||||||
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
"github.com/cloudflare/cloudflared/management"
|
"github.com/cloudflare/cloudflared/management"
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/cfio"
|
"github.com/cloudflare/cloudflared/cfio"
|
||||||
|
@ -37,7 +37,7 @@ type Proxy struct {
|
||||||
ingressRules ingress.Ingress
|
ingressRules ingress.Ingress
|
||||||
warpRouting *ingress.WarpRoutingService
|
warpRouting *ingress.WarpRoutingService
|
||||||
tags []pogs.Tag
|
tags []pogs.Tag
|
||||||
sessionLimiter cfdsession.Limiter
|
flowLimiter cfdflow.Limiter
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,14 +46,14 @@ func NewOriginProxy(
|
||||||
ingressRules ingress.Ingress,
|
ingressRules ingress.Ingress,
|
||||||
warpRouting ingress.WarpRoutingConfig,
|
warpRouting ingress.WarpRoutingConfig,
|
||||||
tags []pogs.Tag,
|
tags []pogs.Tag,
|
||||||
sessionLimiter cfdsession.Limiter,
|
flowLimiter cfdflow.Limiter,
|
||||||
writeTimeout time.Duration,
|
writeTimeout time.Duration,
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
) *Proxy {
|
) *Proxy {
|
||||||
proxy := &Proxy{
|
proxy := &Proxy{
|
||||||
ingressRules: ingressRules,
|
ingressRules: ingressRules,
|
||||||
tags: tags,
|
tags: tags,
|
||||||
sessionLimiter: sessionLimiter,
|
flowLimiter: flowLimiter,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,12 +160,12 @@ func (p *Proxy) ProxyTCP(
|
||||||
|
|
||||||
logger := newTCPLogger(p.log, req)
|
logger := newTCPLogger(p.log, req)
|
||||||
|
|
||||||
// Try to start a new session
|
// Try to start a new flow
|
||||||
if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil {
|
if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil {
|
||||||
logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy")
|
logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy")
|
||||||
return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting")
|
return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting")
|
||||||
}
|
}
|
||||||
defer p.sessionLimiter.Release()
|
defer p.flowLimiter.Release()
|
||||||
|
|
||||||
serveCtx, cancel := context.WithCancel(ctx)
|
serveCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
|
@ -26,7 +26,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/mocks"
|
"github.com/cloudflare/cloudflared/mocks"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cfio"
|
"github.com/cloudflare/cloudflared/cfio"
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
|
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||||
|
|
||||||
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||||
t.Run("testProxySSE", testProxySSE(proxy))
|
t.Run("testProxySSE", testProxySSE(proxy))
|
||||||
|
@ -368,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
|
@ -416,7 +416,7 @@ func TestProxyError(t *testing.T) {
|
||||||
|
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
|
||||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||||
|
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
|
@ -484,8 +484,8 @@ func TestConnections(t *testing.T) {
|
||||||
// requestheaders to be sent in the call to proxy.Proxy
|
// requestheaders to be sent in the call to proxy.Proxy
|
||||||
requestHeaders http.Header
|
requestHeaders http.Header
|
||||||
|
|
||||||
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
|
// flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
|
||||||
sessionLimiterResponse error
|
flowLimiterResponse error
|
||||||
}
|
}
|
||||||
|
|
||||||
type want struct {
|
type want struct {
|
||||||
|
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
|
||||||
requestHeaders: map[string][]string{
|
requestHeaders: map[string][]string{
|
||||||
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||||
},
|
},
|
||||||
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
|
flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
|
||||||
},
|
},
|
||||||
want: want{
|
want: want{
|
||||||
message: []byte{},
|
message: []byte{},
|
||||||
|
@ -695,14 +695,14 @@ func TestConnections(t *testing.T) {
|
||||||
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
||||||
_ = ingressRule.StartOrigins(logger, ctx.Done())
|
_ = ingressRule.StartOrigins(logger, ctx.Done())
|
||||||
|
|
||||||
// Mock session limiter
|
// Mock flow limiter
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
sessionLimiter := mocks.NewMockLimiter(ctrl)
|
flowLimiter := mocks.NewMockLimiter(ctrl)
|
||||||
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
|
flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
|
||||||
sessionLimiter.EXPECT().Release().AnyTimes()
|
flowLimiter.EXPECT().Release().AnyTimes()
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
|
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
|
||||||
proxy.warpRouting = test.args.warpRoutingService
|
proxy.warpRouting = test.args.warpRoutingService
|
||||||
|
|
||||||
dest := ln.Addr().String()
|
dest := ln.Addr().String()
|
||||||
|
|
|
@ -284,8 +284,8 @@ const (
|
||||||
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
|
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
|
||||||
// Session registration was unable to bind to a local UDP socket.
|
// Session registration was unable to bind to a local UDP socket.
|
||||||
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
|
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
|
||||||
// Session registration failed due to the number of session being higher than the limit.
|
// Session registration failed due to the number of flows being higher than the limit.
|
||||||
ResponseTooManyActiveSessions SessionRegistrationResp = 0x03
|
ResponseTooManyActiveFlows SessionRegistrationResp = 0x03
|
||||||
// Session registration failed with an unexpected error but provided a message.
|
// Session registration failed with an unexpected error but provided a message.
|
||||||
ResponseErrorWithMsg SessionRegistrationResp = 0xff
|
ResponseErrorWithMsg SessionRegistrationResp = 0xff
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/management"
|
"github.com/cloudflare/cloudflared/management"
|
||||||
|
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -20,7 +20,7 @@ var (
|
||||||
ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
|
ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
|
||||||
// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
|
// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
|
||||||
ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
|
ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
|
||||||
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active sessions.
|
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active flows.
|
||||||
ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
|
ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,12 +44,12 @@ type sessionManager struct {
|
||||||
sessions map[RequestID]Session
|
sessions map[RequestID]Session
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
originDialer DialUDP
|
originDialer DialUDP
|
||||||
limiter cfdsession.Limiter
|
limiter cfdflow.Limiter
|
||||||
metrics Metrics
|
metrics Metrics
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdsession.Limiter) SessionManager {
|
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager {
|
||||||
return &sessionManager{
|
return &sessionManager{
|
||||||
sessions: make(map[RequestID]Session),
|
sessions: make(map[RequestID]Session),
|
||||||
originDialer: originDialer,
|
originDialer: originDialer,
|
||||||
|
|
|
@ -13,14 +13,14 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/mocks"
|
"github.com/cloudflare/cloudflared/mocks"
|
||||||
|
|
||||||
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRegisterSession(t *testing.T) {
|
func TestRegisterSession(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0))
|
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
|
||||||
|
|
||||||
request := v3.UDPSessionRegistrationDatagram{
|
request := v3.UDPSessionRegistrationDatagram{
|
||||||
RequestID: testRequestID,
|
RequestID: testRequestID,
|
||||||
|
@ -76,7 +76,7 @@ func TestRegisterSession(t *testing.T) {
|
||||||
|
|
||||||
func TestGetSession_Empty(t *testing.T) {
|
func TestGetSession_Empty(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0))
|
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
|
||||||
|
|
||||||
_, err := manager.GetSession(testRequestID)
|
_, err := manager.GetSession(testRequestID)
|
||||||
if !errors.Is(err, v3.ErrSessionNotFound) {
|
if !errors.Is(err, v3.ErrSessionNotFound) {
|
||||||
|
@ -88,12 +88,12 @@ func TestRegisterSessionRateLimit(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
|
|
||||||
sessionLimiterMock := mocks.NewMockLimiter(ctrl)
|
flowLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||||
|
|
||||||
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
|
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
|
||||||
sessionLimiterMock.EXPECT().Release().Times(0)
|
flowLimiterMock.EXPECT().Release().Times(0)
|
||||||
|
|
||||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, sessionLimiterMock)
|
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock)
|
||||||
|
|
||||||
request := v3.UDPSessionRegistrationDatagram{
|
request := v3.UDPSessionRegistrationDatagram{
|
||||||
RequestID: testRequestID,
|
RequestID: testRequestID,
|
||||||
|
|
|
@ -351,7 +351,7 @@ func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, log
|
||||||
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
|
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
|
||||||
c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
|
c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
|
||||||
|
|
||||||
rateLimitResponse := ResponseTooManyActiveSessions
|
rateLimitResponse := ResponseTooManyActiveFlows
|
||||||
err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
|
err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)
|
logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)
|
||||||
|
|
|
@ -20,10 +20,10 @@ import (
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
|
|
||||||
|
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/packet"
|
"github.com/cloudflare/cloudflared/packet"
|
||||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type noopEyeball struct {
|
type noopEyeball struct {
|
||||||
|
@ -88,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
|
||||||
|
|
||||||
func TestDatagramConn_New(t *testing.T) {
|
func TestDatagramConn_New(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
t.Fatal("expected valid connection")
|
t.Fatal("expected valid connection")
|
||||||
}
|
}
|
||||||
|
@ -97,7 +97,7 @@ func TestDatagramConn_New(t *testing.T) {
|
||||||
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
quic := newMockQuicConn()
|
quic := newMockQuicConn()
|
||||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||||
|
|
||||||
payload := []byte{0xef, 0xef}
|
payload := []byte{0xef, 0xef}
|
||||||
err := conn.SendUDPSessionDatagram(payload)
|
err := conn.SendUDPSessionDatagram(payload)
|
||||||
|
@ -112,7 +112,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
||||||
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
quic := newMockQuicConn()
|
quic := newMockQuicConn()
|
||||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||||
|
|
||||||
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
|
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -134,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
||||||
func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
|
func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
quic := newMockQuicConn()
|
quic := newMockQuicConn()
|
||||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -150,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
quic.ctx = ctx
|
quic.ctx = ctx
|
||||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||||
|
|
||||||
err := conn.Serve(context.Background())
|
err := conn.Serve(context.Background())
|
||||||
if !errors.Is(err, context.DeadlineExceeded) {
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
@ -161,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
||||||
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
|
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
quic := &mockQuicConnReadError{err: net.ErrClosed}
|
quic := &mockQuicConnReadError{err: net.ErrClosed}
|
||||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||||
|
|
||||||
err := conn.Serve(context.Background())
|
err := conn.Serve(context.Background())
|
||||||
if !errors.Is(err, net.ErrClosed) {
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
@ -198,7 +198,7 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
require.EqualValues(t, testRequestID, resp.RequestID)
|
require.EqualValues(t, testRequestID, resp.RequestID)
|
||||||
require.EqualValues(t, v3.ResponseTooManyActiveSessions, resp.ResponseType)
|
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
|
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
|
||||||
|
|
|
@ -1,77 +0,0 @@
|
||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
unlimitedActiveSessions = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrTooManyActiveSessions = errors.New("too many active sessions")
|
|
||||||
)
|
|
||||||
|
|
||||||
type Limiter interface {
|
|
||||||
// Acquire tries to acquire a free slot for a session, if the value of sessions is already above
|
|
||||||
// the maximum it returns ErrTooManyActiveSessions.
|
|
||||||
Acquire(sessionType string) error
|
|
||||||
// Release releases a slot for a session.
|
|
||||||
Release()
|
|
||||||
// SetLimit allows to hot swap the limit value of the limiter.
|
|
||||||
SetLimit(uint64)
|
|
||||||
}
|
|
||||||
|
|
||||||
type sessionLimiter struct {
|
|
||||||
limiterLock sync.Mutex
|
|
||||||
activeSessionsCounter uint64
|
|
||||||
maxActiveSessions uint64
|
|
||||||
unlimited bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLimiter(maxActiveSessions uint64) Limiter {
|
|
||||||
sessionLimiter := &sessionLimiter{
|
|
||||||
maxActiveSessions: maxActiveSessions,
|
|
||||||
unlimited: isUnlimited(maxActiveSessions),
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessionLimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionLimiter) Acquire(sessionType string) error {
|
|
||||||
s.limiterLock.Lock()
|
|
||||||
defer s.limiterLock.Unlock()
|
|
||||||
|
|
||||||
if !s.unlimited && s.activeSessionsCounter >= s.maxActiveSessions {
|
|
||||||
sessionRegistrationsDropped.WithLabelValues(sessionType).Inc()
|
|
||||||
return ErrTooManyActiveSessions
|
|
||||||
}
|
|
||||||
|
|
||||||
s.activeSessionsCounter++
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionLimiter) Release() {
|
|
||||||
s.limiterLock.Lock()
|
|
||||||
defer s.limiterLock.Unlock()
|
|
||||||
|
|
||||||
if s.activeSessionsCounter <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.activeSessionsCounter--
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionLimiter) SetLimit(newMaxActiveSessions uint64) {
|
|
||||||
s.limiterLock.Lock()
|
|
||||||
defer s.limiterLock.Unlock()
|
|
||||||
|
|
||||||
s.maxActiveSessions = newMaxActiveSessions
|
|
||||||
s.unlimited = isUnlimited(newMaxActiveSessions)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUnlimited checks if the value received matches the configuration for the unlimited session limiter.
|
|
||||||
func isUnlimited(value uint64) bool {
|
|
||||||
return value == unlimitedActiveSessions
|
|
||||||
}
|
|
|
@ -1,119 +0,0 @@
|
||||||
package session_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/session"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSessionLimiter_Unlimited(t *testing.T) {
|
|
||||||
unlimitedLimiter := session.NewLimiter(0)
|
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
err := unlimitedLimiter.Acquire("test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionLimiter_Limited(t *testing.T) {
|
|
||||||
maxSessions := uint64(5)
|
|
||||||
limiter := session.NewLimiter(maxSessions)
|
|
||||||
|
|
||||||
for i := uint64(0); i < maxSessions; i++ {
|
|
||||||
err := limiter.Acquire("test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := limiter.Acquire("should fail")
|
|
||||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionLimiter_AcquireAndReleaseSession(t *testing.T) {
|
|
||||||
maxSessions := uint64(5)
|
|
||||||
limiter := session.NewLimiter(maxSessions)
|
|
||||||
|
|
||||||
// Acquire the maximum number of sessions
|
|
||||||
for i := uint64(0); i < maxSessions; i++ {
|
|
||||||
err := limiter.Acquire("test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate acquire 1 more sessions fails
|
|
||||||
err := limiter.Acquire("should fail")
|
|
||||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
|
||||||
|
|
||||||
// Release the maximum number of sessions
|
|
||||||
for i := uint64(0); i < maxSessions; i++ {
|
|
||||||
limiter.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate acquire 1 more sessions works
|
|
||||||
err = limiter.Acquire("shouldn't fail")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Release a 10x the number of max sessions
|
|
||||||
for i := uint64(0); i < 10*maxSessions; i++ {
|
|
||||||
limiter.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate it still can only acquire a value = number max sessions.
|
|
||||||
for i := uint64(0); i < maxSessions; i++ {
|
|
||||||
err := limiter.Acquire("test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
err = limiter.Acquire("should fail")
|
|
||||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionLimiter_SetLimit(t *testing.T) {
|
|
||||||
maxSessions := uint64(5)
|
|
||||||
limiter := session.NewLimiter(maxSessions)
|
|
||||||
|
|
||||||
// Acquire the maximum number of sessions
|
|
||||||
for i := uint64(0); i < maxSessions; i++ {
|
|
||||||
err := limiter.Acquire("test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate acquire 1 more sessions fails
|
|
||||||
err := limiter.Acquire("should fail")
|
|
||||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
|
||||||
|
|
||||||
// Set the session limiter to support one more request
|
|
||||||
limiter.SetLimit(maxSessions + 1)
|
|
||||||
|
|
||||||
// Validate acquire 1 more sessions now works
|
|
||||||
err = limiter.Acquire("shouldn't fail")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Validate acquire 1 more sessions doesn't work because we already reached the limit
|
|
||||||
err = limiter.Acquire("should fail")
|
|
||||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
|
||||||
|
|
||||||
// Release all sessions
|
|
||||||
for i := uint64(0); i < maxSessions+1; i++ {
|
|
||||||
limiter.Release()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate 1 session works again
|
|
||||||
err = limiter.Acquire("shouldn't fail")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Set the session limit to 1
|
|
||||||
limiter.SetLimit(1)
|
|
||||||
|
|
||||||
// Validate acquire 1 more sessions doesn't work
|
|
||||||
err = limiter.Acquire("should fail")
|
|
||||||
require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
|
|
||||||
|
|
||||||
// Set the session limit to unlimited
|
|
||||||
limiter.SetLimit(0)
|
|
||||||
|
|
||||||
// Validate it can acquire a lot of sessions because it is now unlimited.
|
|
||||||
for i := uint64(0); i < 10*maxSessions; i++ {
|
|
||||||
err := limiter.Acquire("shouldn't fail")
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -78,7 +78,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
||||||
edgeBindAddr := config.EdgeBindAddr
|
edgeBindAddr := config.EdgeBindAddr
|
||||||
|
|
||||||
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
|
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
|
||||||
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetSessionLimiter())
|
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
|
||||||
|
|
||||||
edgeTunnelServer := EdgeTunnelServer{
|
edgeTunnelServer := EdgeTunnelServer{
|
||||||
config: config,
|
config: config,
|
||||||
|
|
|
@ -617,7 +617,7 @@ func (e *EdgeTunnelServer) serveQUIC(
|
||||||
connIndex,
|
connIndex,
|
||||||
e.config.RPCTimeout,
|
e.config.RPCTimeout,
|
||||||
e.config.WriteStreamTimeout,
|
e.config.WriteStreamTimeout,
|
||||||
e.orchestrator.GetSessionLimiter(),
|
e.orchestrator.GetFlowLimiter(),
|
||||||
connLogger.Logger(),
|
connLogger.Logger(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue