TUN-8861: Rename Session Limiter to Flow Limiter
## Summary Session is the concept used for UDP flows. Therefore, to make the session limiter ambiguous for both TCP and UDP, this commit renames it to flow limiter. Closes TUN-8861
This commit is contained in:
		
							parent
							
								
									8c2eda16c1
								
							
						
					
					
						commit
						4eb0f8ce5f
					
				| 
						 | 
				
			
			@ -12,7 +12,7 @@ import (
 | 
			
		|||
	pkgerrors "github.com/pkg/errors"
 | 
			
		||||
	"github.com/rs/zerolog"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/stream"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tracing"
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +107,7 @@ func (moc *mockOriginProxy) ProxyTCP(
 | 
			
		|||
	r *TCPRequest,
 | 
			
		||||
) error {
 | 
			
		||||
	if r.CfTraceID == "flow-rate-limited" {
 | 
			
		||||
		return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
 | 
			
		||||
		return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@ import (
 | 
			
		|||
	"github.com/rs/zerolog"
 | 
			
		||||
	"golang.org/x/net/http2"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tracing"
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
| 
						 | 
				
			
			@ -336,7 +336,7 @@ func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
 | 
			
		|||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
 | 
			
		||||
	if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
 | 
			
		||||
		rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
 | 
			
		||||
	} else {
 | 
			
		||||
		rp.setResponseMetaHeader(responseMetaHeaderCfd)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,7 +17,7 @@ import (
 | 
			
		|||
	"github.com/rs/zerolog"
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	cfdquic "github.com/cloudflare/cloudflared/quic"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tracing"
 | 
			
		||||
| 
						 | 
				
			
			@ -185,7 +185,7 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
 | 
			
		|||
 | 
			
		||||
		var metadata []pogs.Metadata
 | 
			
		||||
		// Check the type of error that was throw and add metadata that will help identify it on OTD.
 | 
			
		||||
		if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
 | 
			
		||||
		if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
 | 
			
		||||
			metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,7 +29,7 @@ import (
 | 
			
		|||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"golang.org/x/net/nettest"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/datagramsession"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/ingress"
 | 
			
		||||
| 
						 | 
				
			
			@ -508,7 +508,7 @@ func TestBuildHTTPRequest(t *testing.T) {
 | 
			
		|||
 | 
			
		||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
 | 
			
		||||
	if tcpRequest.Dest == "rate-limit-me" {
 | 
			
		||||
		return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream")
 | 
			
		||||
		return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ = rwa.AckConnection("")
 | 
			
		||||
| 
						 | 
				
			
			@ -828,7 +828,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
 | 
			
		|||
		conn,
 | 
			
		||||
		index,
 | 
			
		||||
		sessionManager,
 | 
			
		||||
		cfdsession.NewLimiter(0),
 | 
			
		||||
		cfdflow.NewLimiter(0),
 | 
			
		||||
		datagramMuxer,
 | 
			
		||||
		packetRouter,
 | 
			
		||||
		15 * time.Second,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@ import (
 | 
			
		|||
	"go.opentelemetry.io/otel/trace"
 | 
			
		||||
	"golang.org/x/sync/errgroup"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/datagramsession"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/ingress"
 | 
			
		||||
| 
						 | 
				
			
			@ -46,8 +46,8 @@ type datagramV2Connection struct {
 | 
			
		|||
 | 
			
		||||
	// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
 | 
			
		||||
	sessionManager datagramsession.Manager
 | 
			
		||||
	// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
 | 
			
		||||
	sessionLimiter cfdsession.Limiter
 | 
			
		||||
	// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
 | 
			
		||||
	flowLimiter cfdflow.Limiter
 | 
			
		||||
 | 
			
		||||
	// datagramMuxer mux/demux datagrams from quic connection
 | 
			
		||||
	datagramMuxer *cfdquic.DatagramMuxerV2
 | 
			
		||||
| 
						 | 
				
			
			@ -65,7 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
 | 
			
		|||
	index uint8,
 | 
			
		||||
	rpcTimeout time.Duration,
 | 
			
		||||
	streamWriteTimeout time.Duration,
 | 
			
		||||
	sessionLimiter cfdsession.Limiter,
 | 
			
		||||
	flowLimiter cfdflow.Limiter,
 | 
			
		||||
	logger *zerolog.Logger,
 | 
			
		||||
) DatagramSessionHandler {
 | 
			
		||||
	sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
 | 
			
		||||
| 
						 | 
				
			
			@ -77,7 +77,7 @@ func NewDatagramV2Connection(ctx context.Context,
 | 
			
		|||
		conn:               conn,
 | 
			
		||||
		index:              index,
 | 
			
		||||
		sessionManager:     sessionManager,
 | 
			
		||||
		sessionLimiter:     sessionLimiter,
 | 
			
		||||
		flowLimiter:        flowLimiter,
 | 
			
		||||
		datagramMuxer:      datagramMuxer,
 | 
			
		||||
		packetRouter:       packetRouter,
 | 
			
		||||
		rpcTimeout:         rpcTimeout,
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +121,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
 | 
			
		|||
	log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
 | 
			
		||||
 | 
			
		||||
	// Try to start a new session
 | 
			
		||||
	if err := q.sessionLimiter.Acquire(management.UDP.String()); err != nil {
 | 
			
		||||
	if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
 | 
			
		||||
		log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
 | 
			
		||||
 | 
			
		||||
		err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
 | 
			
		||||
| 
						 | 
				
			
			@ -135,7 +135,7 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
 | 
			
		||||
		tracing.EndWithErrorStatus(registerSpan, err)
 | 
			
		||||
		q.sessionLimiter.Release()
 | 
			
		||||
		q.flowLimiter.Release()
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	registerSpan.SetAttributes(
 | 
			
		||||
| 
						 | 
				
			
			@ -148,12 +148,12 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
 | 
			
		|||
		originProxy.Close()
 | 
			
		||||
		log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
 | 
			
		||||
		tracing.EndWithErrorStatus(registerSpan, err)
 | 
			
		||||
		q.sessionLimiter.Release()
 | 
			
		||||
		q.flowLimiter.Release()
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer q.sessionLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
 | 
			
		||||
		defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
 | 
			
		||||
		q.serveUDPSession(session, closeAfterIdleHint)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,8 +12,8 @@ import (
 | 
			
		|||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"go.uber.org/mock/gomock"
 | 
			
		||||
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/mocks"
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type mockQuicConnection struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -75,7 +75,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
 | 
			
		|||
	log := zerolog.Nop()
 | 
			
		||||
	conn := &mockQuicConnection{}
 | 
			
		||||
	ctrl := gomock.NewController(t)
 | 
			
		||||
	sessionLimiterMock := mocks.NewMockLimiter(ctrl)
 | 
			
		||||
	flowLimiterMock := mocks.NewMockLimiter(ctrl)
 | 
			
		||||
 | 
			
		||||
	datagramConn := NewDatagramV2Connection(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
| 
						 | 
				
			
			@ -84,13 +84,13 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
 | 
			
		|||
		0,
 | 
			
		||||
		0*time.Second,
 | 
			
		||||
		0*time.Second,
 | 
			
		||||
		sessionLimiterMock,
 | 
			
		||||
		flowLimiterMock,
 | 
			
		||||
		&log,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
 | 
			
		||||
	sessionLimiterMock.EXPECT().Release().Times(0)
 | 
			
		||||
	flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
 | 
			
		||||
	flowLimiterMock.EXPECT().Release().Times(0)
 | 
			
		||||
 | 
			
		||||
	_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
 | 
			
		||||
	require.ErrorIs(t, err, cfdsession.ErrTooManyActiveSessions)
 | 
			
		||||
	require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,77 @@
 | 
			
		|||
package flow
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	unlimitedActiveFlows = 0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrTooManyActiveFlows = errors.New("too many active flows")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Limiter interface {
 | 
			
		||||
	// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
 | 
			
		||||
	// the maximum it returns ErrTooManyActiveFlows.
 | 
			
		||||
	Acquire(flowType string) error
 | 
			
		||||
	// Release releases a slot for a flow.
 | 
			
		||||
	Release()
 | 
			
		||||
	// SetLimit allows to hot swap the limit value of the limiter.
 | 
			
		||||
	SetLimit(uint64)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type flowLimiter struct {
 | 
			
		||||
	limiterLock        sync.Mutex
 | 
			
		||||
	activeFlowsCounter uint64
 | 
			
		||||
	maxActiveFlows     uint64
 | 
			
		||||
	unlimited          bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewLimiter(maxActiveFlows uint64) Limiter {
 | 
			
		||||
	flowLimiter := &flowLimiter{
 | 
			
		||||
		maxActiveFlows: maxActiveFlows,
 | 
			
		||||
		unlimited:      isUnlimited(maxActiveFlows),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return flowLimiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *flowLimiter) Acquire(flowType string) error {
 | 
			
		||||
	s.limiterLock.Lock()
 | 
			
		||||
	defer s.limiterLock.Unlock()
 | 
			
		||||
 | 
			
		||||
	if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
 | 
			
		||||
		flowRegistrationsDropped.WithLabelValues(flowType).Inc()
 | 
			
		||||
		return ErrTooManyActiveFlows
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.activeFlowsCounter++
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *flowLimiter) Release() {
 | 
			
		||||
	s.limiterLock.Lock()
 | 
			
		||||
	defer s.limiterLock.Unlock()
 | 
			
		||||
 | 
			
		||||
	if s.activeFlowsCounter <= 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.activeFlowsCounter--
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
 | 
			
		||||
	s.limiterLock.Lock()
 | 
			
		||||
	defer s.limiterLock.Unlock()
 | 
			
		||||
 | 
			
		||||
	s.maxActiveFlows = newMaxActiveFlows
 | 
			
		||||
	s.unlimited = isUnlimited(newMaxActiveFlows)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
 | 
			
		||||
func isUnlimited(value uint64) bool {
 | 
			
		||||
	return value == unlimitedActiveFlows
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,119 @@
 | 
			
		|||
package flow_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestFlowLimiter_Unlimited(t *testing.T) {
 | 
			
		||||
	unlimitedLimiter := flow.NewLimiter(0)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 1000; i++ {
 | 
			
		||||
		err := unlimitedLimiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFlowLimiter_Limited(t *testing.T) {
 | 
			
		||||
	maxFlows := uint64(5)
 | 
			
		||||
	limiter := flow.NewLimiter(maxFlows)
 | 
			
		||||
 | 
			
		||||
	for i := uint64(0); i < maxFlows; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
 | 
			
		||||
	maxFlows := uint64(5)
 | 
			
		||||
	limiter := flow.NewLimiter(maxFlows)
 | 
			
		||||
 | 
			
		||||
	// Acquire the maximum number of flows
 | 
			
		||||
	for i := uint64(0); i < maxFlows; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more flows fails
 | 
			
		||||
	err := limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
 | 
			
		||||
 | 
			
		||||
	// Release the maximum number of flows
 | 
			
		||||
	for i := uint64(0); i < maxFlows; i++ {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more flows works
 | 
			
		||||
	err = limiter.Acquire("shouldn't fail")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Release a 10x the number of max flows
 | 
			
		||||
	for i := uint64(0); i < 10*maxFlows; i++ {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate it still can only acquire a value = number max flows.
 | 
			
		||||
	for i := uint64(0); i < maxFlows; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
	err = limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFlowLimiter_SetLimit(t *testing.T) {
 | 
			
		||||
	maxFlows := uint64(5)
 | 
			
		||||
	limiter := flow.NewLimiter(maxFlows)
 | 
			
		||||
 | 
			
		||||
	// Acquire the maximum number of flows
 | 
			
		||||
	for i := uint64(0); i < maxFlows; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more flows fails
 | 
			
		||||
	err := limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
 | 
			
		||||
 | 
			
		||||
	// Set the flow limiter to support one more request
 | 
			
		||||
	limiter.SetLimit(maxFlows + 1)
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more flows now works
 | 
			
		||||
	err = limiter.Acquire("shouldn't fail")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more flows doesn't work because we already reached the limit
 | 
			
		||||
	err = limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
 | 
			
		||||
 | 
			
		||||
	// Release all flows
 | 
			
		||||
	for i := uint64(0); i < maxFlows+1; i++ {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate 1 flow works again
 | 
			
		||||
	err = limiter.Acquire("shouldn't fail")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Set the flow limit to 1
 | 
			
		||||
	limiter.SetLimit(1)
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more flows doesn't work
 | 
			
		||||
	err = limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
 | 
			
		||||
 | 
			
		||||
	// Set the flow limit to unlimited
 | 
			
		||||
	limiter.SetLimit(0)
 | 
			
		||||
 | 
			
		||||
	// Validate it can acquire a lot of flows because it is now unlimited.
 | 
			
		||||
	for i := uint64(0); i < 10*maxFlows; i++ {
 | 
			
		||||
		err := limiter.Acquire("shouldn't fail")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,4 +1,4 @@
 | 
			
		|||
package session
 | 
			
		||||
package flow
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
| 
						 | 
				
			
			@ -6,17 +6,17 @@ import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	namespace = "session"
 | 
			
		||||
	namespace = "flow"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	labels = []string{"session_type"}
 | 
			
		||||
	labels = []string{"flow_type"}
 | 
			
		||||
 | 
			
		||||
	sessionRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
 | 
			
		||||
	flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
 | 
			
		||||
		Namespace: namespace,
 | 
			
		||||
		Subsystem: "client",
 | 
			
		||||
		Name:      "registrations_rate_limited_total",
 | 
			
		||||
		Help:      "Count registrations dropped due to high number of concurrent sessions being handled",
 | 
			
		||||
		Help:      "Count registrations dropped due to high number of concurrent flows being handled",
 | 
			
		||||
	},
 | 
			
		||||
		labels,
 | 
			
		||||
	)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,9 +1,9 @@
 | 
			
		|||
// Code generated by MockGen. DO NOT EDIT.
 | 
			
		||||
// Source: ../session/limiter.go
 | 
			
		||||
// Source: ../flow/limiter.go
 | 
			
		||||
//
 | 
			
		||||
// Generated by this command:
 | 
			
		||||
//
 | 
			
		||||
//	mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter
 | 
			
		||||
//	mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
// Package mocks is a generated GoMock package.
 | 
			
		||||
| 
						 | 
				
			
			@ -40,17 +40,17 @@ func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// Acquire mocks base method.
 | 
			
		||||
func (m *MockLimiter) Acquire(sessionType string) error {
 | 
			
		||||
func (m *MockLimiter) Acquire(flowType string) error {
 | 
			
		||||
	m.ctrl.T.Helper()
 | 
			
		||||
	ret := m.ctrl.Call(m, "Acquire", sessionType)
 | 
			
		||||
	ret := m.ctrl.Call(m, "Acquire", flowType)
 | 
			
		||||
	ret0, _ := ret[0].(error)
 | 
			
		||||
	return ret0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Acquire indicates an expected call of Acquire.
 | 
			
		||||
func (mr *MockLimiterMockRecorder) Acquire(sessionType any) *MockLimiterAcquireCall {
 | 
			
		||||
func (mr *MockLimiterMockRecorder) Acquire(flowType any) *MockLimiterAcquireCall {
 | 
			
		||||
	mr.mock.ctrl.T.Helper()
 | 
			
		||||
	call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), sessionType)
 | 
			
		||||
	call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), flowType)
 | 
			
		||||
	return &MockLimiterAcquireCall{Call: call}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,4 +2,4 @@
 | 
			
		|||
 | 
			
		||||
package mocks
 | 
			
		||||
 | 
			
		||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../session/limiter.go Limiter"
 | 
			
		||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,7 @@ import (
 | 
			
		|||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/rs/zerolog"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/config"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/connection"
 | 
			
		||||
| 
						 | 
				
			
			@ -35,9 +35,9 @@ type Orchestrator struct {
 | 
			
		|||
	// cloudflared Configuration
 | 
			
		||||
	config *Config
 | 
			
		||||
	tags   []pogs.Tag
 | 
			
		||||
	// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
 | 
			
		||||
	sessionLimiter cfdsession.Limiter
 | 
			
		||||
	log            *zerolog.Logger
 | 
			
		||||
	// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
 | 
			
		||||
	flowLimiter cfdflow.Limiter
 | 
			
		||||
	log         *zerolog.Logger
 | 
			
		||||
 | 
			
		||||
	// orchestrator must not handle any more updates after shutdownC is closed
 | 
			
		||||
	shutdownC <-chan struct{}
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +58,7 @@ func NewOrchestrator(ctx context.Context,
 | 
			
		|||
		internalRules:  internalRules,
 | 
			
		||||
		config:         config,
 | 
			
		||||
		tags:           tags,
 | 
			
		||||
		sessionLimiter: cfdsession.NewLimiter(config.WarpRouting.MaxActiveFlows),
 | 
			
		||||
		flowLimiter:    cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows),
 | 
			
		||||
		log:            log,
 | 
			
		||||
		shutdownC:      ctx.Done(),
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -142,10 +142,10 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
 | 
			
		|||
		return errors.Wrap(err, "failed to start origin")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update the sessions limit since the configuration might have changed
 | 
			
		||||
	o.sessionLimiter.SetLimit(warpRouting.MaxActiveFlows)
 | 
			
		||||
	// Update the flow limit since the configuration might have changed
 | 
			
		||||
	o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows)
 | 
			
		||||
 | 
			
		||||
	proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log)
 | 
			
		||||
	proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log)
 | 
			
		||||
	o.proxy.Store(proxy)
 | 
			
		||||
	o.config.Ingress = &ingressRules
 | 
			
		||||
	o.config.WarpRouting = warpRouting
 | 
			
		||||
| 
						 | 
				
			
			@ -217,10 +217,10 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
 | 
			
		|||
	return proxy, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetSessionLimiter returns the session limiter used across cloudflared, that can be hot reload when
 | 
			
		||||
// GetFlowLimiter returns the flow limiter used across cloudflared, that can be hot reload when
 | 
			
		||||
// the configuration changes.
 | 
			
		||||
func (o *Orchestrator) GetSessionLimiter() cfdsession.Limiter {
 | 
			
		||||
	return o.sessionLimiter
 | 
			
		||||
func (o *Orchestrator) GetFlowLimiter() cfdflow.Limiter {
 | 
			
		||||
	return o.flowLimiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *Orchestrator) waitToCloseLastProxy() {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,8 +14,8 @@ import (
 | 
			
		|||
	"go.opentelemetry.io/otel/attribute"
 | 
			
		||||
	"go.opentelemetry.io/otel/trace"
 | 
			
		||||
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/management"
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/carrier"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/cfio"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,11 +34,11 @@ const (
 | 
			
		|||
 | 
			
		||||
// Proxy represents a means to Proxy between cloudflared and the origin services.
 | 
			
		||||
type Proxy struct {
 | 
			
		||||
	ingressRules   ingress.Ingress
 | 
			
		||||
	warpRouting    *ingress.WarpRoutingService
 | 
			
		||||
	tags           []pogs.Tag
 | 
			
		||||
	sessionLimiter cfdsession.Limiter
 | 
			
		||||
	log            *zerolog.Logger
 | 
			
		||||
	ingressRules ingress.Ingress
 | 
			
		||||
	warpRouting  *ingress.WarpRoutingService
 | 
			
		||||
	tags         []pogs.Tag
 | 
			
		||||
	flowLimiter  cfdflow.Limiter
 | 
			
		||||
	log          *zerolog.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewOriginProxy returns a new instance of the Proxy struct.
 | 
			
		||||
| 
						 | 
				
			
			@ -46,15 +46,15 @@ func NewOriginProxy(
 | 
			
		|||
	ingressRules ingress.Ingress,
 | 
			
		||||
	warpRouting ingress.WarpRoutingConfig,
 | 
			
		||||
	tags []pogs.Tag,
 | 
			
		||||
	sessionLimiter cfdsession.Limiter,
 | 
			
		||||
	flowLimiter cfdflow.Limiter,
 | 
			
		||||
	writeTimeout time.Duration,
 | 
			
		||||
	log *zerolog.Logger,
 | 
			
		||||
) *Proxy {
 | 
			
		||||
	proxy := &Proxy{
 | 
			
		||||
		ingressRules:   ingressRules,
 | 
			
		||||
		tags:           tags,
 | 
			
		||||
		sessionLimiter: sessionLimiter,
 | 
			
		||||
		log:            log,
 | 
			
		||||
		ingressRules: ingressRules,
 | 
			
		||||
		tags:         tags,
 | 
			
		||||
		flowLimiter:  flowLimiter,
 | 
			
		||||
		log:          log,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
 | 
			
		||||
| 
						 | 
				
			
			@ -160,12 +160,12 @@ func (p *Proxy) ProxyTCP(
 | 
			
		|||
 | 
			
		||||
	logger := newTCPLogger(p.log, req)
 | 
			
		||||
 | 
			
		||||
	// Try to start a new session
 | 
			
		||||
	if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil {
 | 
			
		||||
		logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy")
 | 
			
		||||
		return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting")
 | 
			
		||||
	// Try to start a new flow
 | 
			
		||||
	if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil {
 | 
			
		||||
		logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy")
 | 
			
		||||
		return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting")
 | 
			
		||||
	}
 | 
			
		||||
	defer p.sessionLimiter.Release()
 | 
			
		||||
	defer p.flowLimiter.Release()
 | 
			
		||||
 | 
			
		||||
	serveCtx, cancel := context.WithCancel(ctx)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ import (
 | 
			
		|||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/mocks"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/cfio"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/config"
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
 | 
			
		|||
 | 
			
		||||
	require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
 | 
			
		||||
 | 
			
		||||
	proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
 | 
			
		||||
	proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
 | 
			
		||||
	t.Run("testProxyHTTP", testProxyHTTP(proxy))
 | 
			
		||||
	t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
 | 
			
		||||
	t.Run("testProxySSE", testProxySSE(proxy))
 | 
			
		||||
| 
						 | 
				
			
			@ -368,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
 | 
			
		|||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
 | 
			
		||||
 | 
			
		||||
	proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
 | 
			
		||||
	proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
 | 
			
		||||
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		responseWriter := newMockHTTPRespWriter()
 | 
			
		||||
| 
						 | 
				
			
			@ -416,7 +416,7 @@ func TestProxyError(t *testing.T) {
 | 
			
		|||
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
 | 
			
		||||
	proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
 | 
			
		||||
	proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
 | 
			
		||||
 | 
			
		||||
	responseWriter := newMockHTTPRespWriter()
 | 
			
		||||
	req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
 | 
			
		||||
| 
						 | 
				
			
			@ -484,8 +484,8 @@ func TestConnections(t *testing.T) {
 | 
			
		|||
		// requestheaders to be sent in the call to proxy.Proxy
 | 
			
		||||
		requestHeaders http.Header
 | 
			
		||||
 | 
			
		||||
		// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
 | 
			
		||||
		sessionLimiterResponse error
 | 
			
		||||
		// flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
 | 
			
		||||
		flowLimiterResponse error
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	type want struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
 | 
			
		|||
				requestHeaders: map[string][]string{
 | 
			
		||||
					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
 | 
			
		||||
				},
 | 
			
		||||
				sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
 | 
			
		||||
				flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
 | 
			
		||||
			},
 | 
			
		||||
			want: want{
 | 
			
		||||
				message: []byte{},
 | 
			
		||||
| 
						 | 
				
			
			@ -695,14 +695,14 @@ func TestConnections(t *testing.T) {
 | 
			
		|||
			ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
 | 
			
		||||
			_ = ingressRule.StartOrigins(logger, ctx.Done())
 | 
			
		||||
 | 
			
		||||
			// Mock session limiter
 | 
			
		||||
			// Mock flow limiter
 | 
			
		||||
			ctrl := gomock.NewController(t)
 | 
			
		||||
			defer ctrl.Finish()
 | 
			
		||||
			sessionLimiter := mocks.NewMockLimiter(ctrl)
 | 
			
		||||
			sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
 | 
			
		||||
			sessionLimiter.EXPECT().Release().AnyTimes()
 | 
			
		||||
			flowLimiter := mocks.NewMockLimiter(ctrl)
 | 
			
		||||
			flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
 | 
			
		||||
			flowLimiter.EXPECT().Release().AnyTimes()
 | 
			
		||||
 | 
			
		||||
			proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
 | 
			
		||||
			proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
 | 
			
		||||
			proxy.warpRouting = test.args.warpRoutingService
 | 
			
		||||
 | 
			
		||||
			dest := ln.Addr().String()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -284,8 +284,8 @@ const (
 | 
			
		|||
	ResponseDestinationUnreachable SessionRegistrationResp = 0x01
 | 
			
		||||
	// Session registration was unable to bind to a local UDP socket.
 | 
			
		||||
	ResponseUnableToBindSocket SessionRegistrationResp = 0x02
 | 
			
		||||
	// Session registration failed due to the number of session being higher than the limit.
 | 
			
		||||
	ResponseTooManyActiveSessions SessionRegistrationResp = 0x03
 | 
			
		||||
	// Session registration failed due to the number of flows being higher than the limit.
 | 
			
		||||
	ResponseTooManyActiveFlows SessionRegistrationResp = 0x03
 | 
			
		||||
	// Session registration failed with an unexpected error but provided a message.
 | 
			
		||||
	ResponseErrorWithMsg SessionRegistrationResp = 0xff
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,7 +10,7 @@ import (
 | 
			
		|||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/management"
 | 
			
		||||
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
| 
						 | 
				
			
			@ -20,7 +20,7 @@ var (
 | 
			
		|||
	ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
 | 
			
		||||
	// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
 | 
			
		||||
	ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
 | 
			
		||||
	// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active sessions.
 | 
			
		||||
	// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active flows.
 | 
			
		||||
	ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,12 +44,12 @@ type sessionManager struct {
 | 
			
		|||
	sessions     map[RequestID]Session
 | 
			
		||||
	mutex        sync.RWMutex
 | 
			
		||||
	originDialer DialUDP
 | 
			
		||||
	limiter      cfdsession.Limiter
 | 
			
		||||
	limiter      cfdflow.Limiter
 | 
			
		||||
	metrics      Metrics
 | 
			
		||||
	log          *zerolog.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdsession.Limiter) SessionManager {
 | 
			
		||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager {
 | 
			
		||||
	return &sessionManager{
 | 
			
		||||
		sessions:     make(map[RequestID]Session),
 | 
			
		||||
		originDialer: originDialer,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,14 +13,14 @@ import (
 | 
			
		|||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/mocks"
 | 
			
		||||
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/ingress"
 | 
			
		||||
	v3 "github.com/cloudflare/cloudflared/quic/v3"
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestRegisterSession(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0))
 | 
			
		||||
	manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
 | 
			
		||||
 | 
			
		||||
	request := v3.UDPSessionRegistrationDatagram{
 | 
			
		||||
		RequestID:        testRequestID,
 | 
			
		||||
| 
						 | 
				
			
			@ -76,7 +76,7 @@ func TestRegisterSession(t *testing.T) {
 | 
			
		|||
 | 
			
		||||
func TestGetSession_Empty(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0))
 | 
			
		||||
	manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
 | 
			
		||||
 | 
			
		||||
	_, err := manager.GetSession(testRequestID)
 | 
			
		||||
	if !errors.Is(err, v3.ErrSessionNotFound) {
 | 
			
		||||
| 
						 | 
				
			
			@ -88,12 +88,12 @@ func TestRegisterSessionRateLimit(t *testing.T) {
 | 
			
		|||
	log := zerolog.Nop()
 | 
			
		||||
	ctrl := gomock.NewController(t)
 | 
			
		||||
 | 
			
		||||
	sessionLimiterMock := mocks.NewMockLimiter(ctrl)
 | 
			
		||||
	flowLimiterMock := mocks.NewMockLimiter(ctrl)
 | 
			
		||||
 | 
			
		||||
	sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
 | 
			
		||||
	sessionLimiterMock.EXPECT().Release().Times(0)
 | 
			
		||||
	flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
 | 
			
		||||
	flowLimiterMock.EXPECT().Release().Times(0)
 | 
			
		||||
 | 
			
		||||
	manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, sessionLimiterMock)
 | 
			
		||||
	manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock)
 | 
			
		||||
 | 
			
		||||
	request := v3.UDPSessionRegistrationDatagram{
 | 
			
		||||
		RequestID:        testRequestID,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -351,7 +351,7 @@ func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, log
 | 
			
		|||
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
 | 
			
		||||
	c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
 | 
			
		||||
 | 
			
		||||
	rateLimitResponse := ResponseTooManyActiveSessions
 | 
			
		||||
	rateLimitResponse := ResponseTooManyActiveFlows
 | 
			
		||||
	err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,10 +20,10 @@ import (
 | 
			
		|||
	"golang.org/x/net/icmp"
 | 
			
		||||
	"golang.org/x/net/ipv4"
 | 
			
		||||
 | 
			
		||||
	cfdflow "github.com/cloudflare/cloudflared/flow"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/ingress"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/packet"
 | 
			
		||||
	v3 "github.com/cloudflare/cloudflared/quic/v3"
 | 
			
		||||
	cfdsession "github.com/cloudflare/cloudflared/session"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type noopEyeball struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -88,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
 | 
			
		|||
 | 
			
		||||
func TestDatagramConn_New(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		t.Fatal("expected valid connection")
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +97,7 @@ func TestDatagramConn_New(t *testing.T) {
 | 
			
		|||
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	quic := newMockQuicConn()
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
 | 
			
		||||
	payload := []byte{0xef, 0xef}
 | 
			
		||||
	err := conn.SendUDPSessionDatagram(payload)
 | 
			
		||||
| 
						 | 
				
			
			@ -112,7 +112,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
 | 
			
		|||
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	quic := newMockQuicConn()
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
 | 
			
		||||
	err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
| 
						 | 
				
			
			@ -134,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
 | 
			
		|||
func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	quic := newMockQuicConn()
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
| 
						 | 
				
			
			@ -150,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
 | 
			
		|||
	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	quic.ctx = ctx
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
 | 
			
		||||
	err := conn.Serve(context.Background())
 | 
			
		||||
	if !errors.Is(err, context.DeadlineExceeded) {
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
 | 
			
		|||
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
 | 
			
		||||
	log := zerolog.Nop()
 | 
			
		||||
	quic := &mockQuicConnReadError{err: net.ErrClosed}
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdsession.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
	conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
 | 
			
		||||
 | 
			
		||||
	err := conn.Serve(context.Background())
 | 
			
		||||
	if !errors.Is(err, net.ErrClosed) {
 | 
			
		||||
| 
						 | 
				
			
			@ -198,7 +198,7 @@ func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	require.EqualValues(t, testRequestID, resp.RequestID)
 | 
			
		||||
	require.EqualValues(t, v3.ResponseTooManyActiveSessions, resp.ResponseType)
 | 
			
		||||
	require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,77 +0,0 @@
 | 
			
		|||
package session
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	unlimitedActiveSessions = 0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrTooManyActiveSessions = errors.New("too many active sessions")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Limiter interface {
 | 
			
		||||
	// Acquire tries to acquire a free slot for a session, if the value of sessions is already above
 | 
			
		||||
	// the maximum it returns ErrTooManyActiveSessions.
 | 
			
		||||
	Acquire(sessionType string) error
 | 
			
		||||
	// Release releases a slot for a session.
 | 
			
		||||
	Release()
 | 
			
		||||
	// SetLimit allows to hot swap the limit value of the limiter.
 | 
			
		||||
	SetLimit(uint64)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type sessionLimiter struct {
 | 
			
		||||
	limiterLock           sync.Mutex
 | 
			
		||||
	activeSessionsCounter uint64
 | 
			
		||||
	maxActiveSessions     uint64
 | 
			
		||||
	unlimited             bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewLimiter(maxActiveSessions uint64) Limiter {
 | 
			
		||||
	sessionLimiter := &sessionLimiter{
 | 
			
		||||
		maxActiveSessions: maxActiveSessions,
 | 
			
		||||
		unlimited:         isUnlimited(maxActiveSessions),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sessionLimiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *sessionLimiter) Acquire(sessionType string) error {
 | 
			
		||||
	s.limiterLock.Lock()
 | 
			
		||||
	defer s.limiterLock.Unlock()
 | 
			
		||||
 | 
			
		||||
	if !s.unlimited && s.activeSessionsCounter >= s.maxActiveSessions {
 | 
			
		||||
		sessionRegistrationsDropped.WithLabelValues(sessionType).Inc()
 | 
			
		||||
		return ErrTooManyActiveSessions
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.activeSessionsCounter++
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *sessionLimiter) Release() {
 | 
			
		||||
	s.limiterLock.Lock()
 | 
			
		||||
	defer s.limiterLock.Unlock()
 | 
			
		||||
 | 
			
		||||
	if s.activeSessionsCounter <= 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.activeSessionsCounter--
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *sessionLimiter) SetLimit(newMaxActiveSessions uint64) {
 | 
			
		||||
	s.limiterLock.Lock()
 | 
			
		||||
	defer s.limiterLock.Unlock()
 | 
			
		||||
 | 
			
		||||
	s.maxActiveSessions = newMaxActiveSessions
 | 
			
		||||
	s.unlimited = isUnlimited(newMaxActiveSessions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// isUnlimited checks if the value received matches the configuration for the unlimited session limiter.
 | 
			
		||||
func isUnlimited(value uint64) bool {
 | 
			
		||||
	return value == unlimitedActiveSessions
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,119 +0,0 @@
 | 
			
		|||
package session_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/session"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestSessionLimiter_Unlimited(t *testing.T) {
 | 
			
		||||
	unlimitedLimiter := session.NewLimiter(0)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 1000; i++ {
 | 
			
		||||
		err := unlimitedLimiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionLimiter_Limited(t *testing.T) {
 | 
			
		||||
	maxSessions := uint64(5)
 | 
			
		||||
	limiter := session.NewLimiter(maxSessions)
 | 
			
		||||
 | 
			
		||||
	for i := uint64(0); i < maxSessions; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionLimiter_AcquireAndReleaseSession(t *testing.T) {
 | 
			
		||||
	maxSessions := uint64(5)
 | 
			
		||||
	limiter := session.NewLimiter(maxSessions)
 | 
			
		||||
 | 
			
		||||
	// Acquire the maximum number of sessions
 | 
			
		||||
	for i := uint64(0); i < maxSessions; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more sessions fails
 | 
			
		||||
	err := limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
 | 
			
		||||
 | 
			
		||||
	// Release the maximum number of sessions
 | 
			
		||||
	for i := uint64(0); i < maxSessions; i++ {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more sessions works
 | 
			
		||||
	err = limiter.Acquire("shouldn't fail")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Release a 10x the number of max sessions
 | 
			
		||||
	for i := uint64(0); i < 10*maxSessions; i++ {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate it still can only acquire a value = number max sessions.
 | 
			
		||||
	for i := uint64(0); i < maxSessions; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
	err = limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionLimiter_SetLimit(t *testing.T) {
 | 
			
		||||
	maxSessions := uint64(5)
 | 
			
		||||
	limiter := session.NewLimiter(maxSessions)
 | 
			
		||||
 | 
			
		||||
	// Acquire the maximum number of sessions
 | 
			
		||||
	for i := uint64(0); i < maxSessions; i++ {
 | 
			
		||||
		err := limiter.Acquire("test")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more sessions fails
 | 
			
		||||
	err := limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
 | 
			
		||||
 | 
			
		||||
	// Set the session limiter to support one more request
 | 
			
		||||
	limiter.SetLimit(maxSessions + 1)
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more sessions now works
 | 
			
		||||
	err = limiter.Acquire("shouldn't fail")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more sessions doesn't work because we already reached the limit
 | 
			
		||||
	err = limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
 | 
			
		||||
 | 
			
		||||
	// Release all sessions
 | 
			
		||||
	for i := uint64(0); i < maxSessions+1; i++ {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Validate 1 session works again
 | 
			
		||||
	err = limiter.Acquire("shouldn't fail")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	// Set the session limit to 1
 | 
			
		||||
	limiter.SetLimit(1)
 | 
			
		||||
 | 
			
		||||
	// Validate acquire 1 more sessions doesn't work
 | 
			
		||||
	err = limiter.Acquire("should fail")
 | 
			
		||||
	require.ErrorIs(t, err, session.ErrTooManyActiveSessions)
 | 
			
		||||
 | 
			
		||||
	// Set the session limit to unlimited
 | 
			
		||||
	limiter.SetLimit(0)
 | 
			
		||||
 | 
			
		||||
	// Validate it can acquire a lot of sessions because it is now unlimited.
 | 
			
		||||
	for i := uint64(0); i < 10*maxSessions; i++ {
 | 
			
		||||
		err := limiter.Acquire("shouldn't fail")
 | 
			
		||||
		require.NoError(t, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +78,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
 | 
			
		|||
	edgeBindAddr := config.EdgeBindAddr
 | 
			
		||||
 | 
			
		||||
	datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
 | 
			
		||||
	sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetSessionLimiter())
 | 
			
		||||
	sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
 | 
			
		||||
 | 
			
		||||
	edgeTunnelServer := EdgeTunnelServer{
 | 
			
		||||
		config:            config,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -617,7 +617,7 @@ func (e *EdgeTunnelServer) serveQUIC(
 | 
			
		|||
			connIndex,
 | 
			
		||||
			e.config.RPCTimeout,
 | 
			
		||||
			e.config.WriteStreamTimeout,
 | 
			
		||||
			e.orchestrator.GetSessionLimiter(),
 | 
			
		||||
			e.orchestrator.GetFlowLimiter(),
 | 
			
		||||
			connLogger.Logger(),
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue