TUN-8861: Add session limiter to TCP session manager

## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This commit
adds the session limiter to the HTTP/TCP handling path.
For now the limiter is set to run only in unlimited mode.
This commit is contained in:
João "Pisco" Fernandes 2025-01-14 14:05:18 +00:00
parent bf4954e96a
commit 8bfe111cab
12 changed files with 275 additions and 102 deletions

View File

@ -2,14 +2,18 @@ package connection
import (
"context"
"crypto/rand"
"fmt"
"io"
"math/rand"
"math/big"
"net/http"
"time"
pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog"
cfdsession "github.com/cloudflare/cloudflared/session"
"github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
return wsFlakyEndpoint(w, req)
default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
}
}
switch req.URL.Path {
@ -95,7 +99,6 @@ func (moc *mockOriginProxy) ProxyHTTP(
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
}
return nil
}
func (moc *mockOriginProxy) ProxyTCP(
@ -103,6 +106,10 @@ func (moc *mockOriginProxy) ProxyTCP(
rwa ReadWriteAcker,
r *TCPRequest,
) error {
if r.CfTraceID == "flow-rate-limited" {
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
}
return nil
}
@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
stream.Pipe(wsConn, originConn, &log)
cancel()

View File

@ -22,8 +22,9 @@ var (
var (
// pre-generate possible values for res
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
)
// HTTPHeader is a custom header struct that expects only ever one value for the header.
@ -35,10 +36,11 @@ type HTTPHeader struct {
type responseMetaHeader struct {
Source string `json:"src"`
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
}
func mustInitRespMetaHeader(src string) string {
header, err := json.Marshal(responseMetaHeader{Source: src})
func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
if err != nil {
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
}
@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string {
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
const unableToDeserializeErr = "Unable to deserialize headers"
var deserialized []HTTPHeader
deserialized := make([]HTTPHeader, 0)
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
if len(serializedPair) == 0 {
continue

View File

@ -16,6 +16,8 @@ import (
"github.com/rs/zerolog"
"golang.org/x/net/http2"
cfdsession "github.com/cloudflare/cloudflared/session"
"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
// WriteErrorResponse will return false if status was already written. we need to abort handler.
if !respWriter.WriteErrorResponse() {
if !respWriter.WriteErrorResponse(requestErr) {
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
panic(http.ErrAbortHandler)
}
@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
w: w,
log: log,
}
respWriter.WriteErrorResponse()
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
err := fmt.Errorf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse(err)
return nil, err
}
return &http2RespWriter{
@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
rp.log.Warn().Msg("WriteHeader after hijack")
return
}
rp.WriteRespHeaders(status, rp.respHeaders)
_ = rp.WriteRespHeaders(status, rp.respHeaders)
}
func (rp *http2RespWriter) hijacked() bool {
@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return conn, readWriter, nil
}
func (rp *http2RespWriter) WriteErrorResponse() bool {
func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
if rp.statusWritten {
return false
}
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
} else {
rp.setResponseMetaHeader(responseMetaHeaderCfd)
}
rp.w.WriteHeader(http.StatusBadGateway)
rp.statusWritten = true

View File

@ -20,6 +20,8 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@ -65,19 +67,18 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
endpoint := fmt.Sprintf("http://localhost:8080/ok")
reqBody := []byte(`{
"version": 2,
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`)
reader := bytes.NewReader(reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
@ -85,11 +86,11 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
bdy, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
cancel()
wg.Wait()
}
func TestServeHTTP(t *testing.T) {
@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) {
require.NoError(t, err)
require.Equal(t, test.expectedBody, respBody)
}
_ = resp.Body.Close()
if test.isProxyError {
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
} else {
@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) {
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
cancel()
resp := respWriter.Result()
defer resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
cfdHTTP2Conn.Serve(ctx)
_ = cfdHTTP2Conn.Serve(ctx)
}()
edgeTransport := http2.Transport{}
@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
require.NoError(t, err)
assert.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
assert.NoError(t, err)
_ = resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusOK, resp.StatusCode)
wg.Add(1)
go func() {
@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
edgeHTTP2Conn.RoundTrip(req)
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
<-rpcClientFactory.registered
@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) {
require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
assert.NotNil(t, http2Conn.controlStreamErr)
require.Error(t, http2Conn.controlStreamErr)
cancel()
wg.Wait()
}
@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
})
}
func TestServeTCP_RateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
http2Conn, edgeConn := newTestHTTP2Connection()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
cancel()
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection()
@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)

View File

@ -17,6 +17,8 @@ import (
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
cfdsession "github.com/cloudflare/cloudflared/session"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -108,7 +110,6 @@ func (q *quicConnection) Serve(ctx context.Context) error {
}
cancel()
return err
})
errGroup.Go(func() error {
defer cancel()
@ -129,7 +130,7 @@ func (q *quicConnection) serveControlStream(ctx context.Context, controlStream q
// Close the connection with no errors specified.
func (q *quicConnection) Close() {
q.conn.CloseWithError(0, "")
_ = q.conn.CloseWithError(0, "")
}
func (q *quicConnection) acceptStream(ctx context.Context) error {
@ -182,7 +183,13 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
return err
}
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
var metadata []pogs.Metadata
// Check the type of error that was throw and add metadata that will help identify it on OTD.
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
}
if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil {
return writeRespErr
}
}
@ -278,7 +285,7 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
// Make sure to send WriteHeader response if not called yet
if !hrw.connectResponseSent {
hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
_ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
}
return hrw.RequestServerStream.Write(p)
}
@ -291,7 +298,7 @@ func (hrw *httpResponseAdapter) Header() http.Header {
func (hrw *httpResponseAdapter) Flush() {}
func (hrw *httpResponseAdapter) WriteHeader(status int) {
hrw.WriteRespHeaders(status, hrw.headers)
_ = hrw.WriteRespHeaders(status, hrw.headers)
}
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
@ -304,7 +311,7 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
_ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
}
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {

View File

@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
@ -21,7 +22,7 @@ import (
"github.com/gobwas/ws/wsutil"
"github.com/google/uuid"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
@ -506,6 +507,10 @@ func TestBuildHTTPRequest(t *testing.T) {
}
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
if tcpRequest.Dest == "rate-limit-me" {
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream")
}
_ = rwa.AckConnection("")
_, _ = io.Copy(rwa, rwa)
return nil
@ -597,6 +602,59 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
}
}
// TestTCPProxy_FlowRateLimited tests if the pogs.ConnectResponse returns the expected error and metadata, when a
// new flow is rate limited.
func TestTCPProxy_FlowRateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Start a UDP Listener for QUIC.
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
require.NoError(t, err)
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
require.NoError(t, err)
defer udpListener.Close()
quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16}
quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
session, err := quicListener.Accept(ctx)
assert.NoError(t, err)
quicStream, err := session.OpenStreamSync(context.Background())
assert.NoError(t, err)
stream := cfdquic.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
reqClientStream := rpcquic.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData("rate-limit-me", pogs.ConnectionTypeTCP)
assert.NoError(t, err)
response, err := reqClientStream.ReadConnectResponseData()
assert.NoError(t, err)
// Got Rate Limited
assert.NotEmpty(t, response.Error)
assert.Contains(t, response.Metadata, pogs.ErrorFlowConnectRateLimitedKey)
}()
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(0))
connDone := make(chan struct{})
go func() {
defer close(connDone)
_ = tunnelConn.Serve(ctx)
}()
<-serverDone
cancel()
<-connDone
}
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
logger := zerolog.Nop()
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)

View File

@ -141,7 +141,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
return errors.Wrap(err, "failed to start origin")
}
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log)
o.proxy.Store(proxy)
o.config.Ingress = &ingressRules
o.config.WarpRouting = warpRouting

View File

@ -9,10 +9,14 @@ import (
"time"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"github.com/cloudflare/cloudflared/management"
cfdsession "github.com/cloudflare/cloudflared/session"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/connection"
@ -32,8 +36,8 @@ const (
type Proxy struct {
ingressRules ingress.Ingress
warpRouting *ingress.WarpRoutingService
management *ingress.ManagementService
tags []pogs.Tag
sessionLimiter cfdsession.Limiter
log *zerolog.Logger
}
@ -42,12 +46,14 @@ func NewOriginProxy(
ingressRules ingress.Ingress,
warpRouting ingress.WarpRoutingConfig,
tags []pogs.Tag,
sessionLimiter cfdsession.Limiter,
writeTimeout time.Duration,
log *zerolog.Logger,
) *Proxy {
proxy := &Proxy{
ingressRules: ingressRules,
tags: tags,
sessionLimiter: sessionLimiter,
log: log,
}
@ -64,7 +70,7 @@ func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w co
}
if result.ShouldFilterRequest {
w.WriteRespHeaders(result.StatusCode, nil)
_ = w.WriteRespHeaders(result.StatusCode, nil)
return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true
}
}
@ -152,10 +158,18 @@ func (p *Proxy) ProxyTCP(
return err
}
logger := newTCPLogger(p.log, req)
// Try to start a new session
if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil {
logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy")
return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting")
}
defer p.sessionLimiter.Release()
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
logger := newTCPLogger(p.log, req)
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
logger.Debug().Msg("tcp proxy stream started")

View File

@ -21,8 +21,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"go.uber.org/mock/gomock"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/mocks"
cfdsession "github.com/cloudflare/cloudflared/session"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
}
// respHeaders is a test function to read respHeaders
func (w *mockHTTPRespWriter) headers() http.Header {
return w.Header()
}
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data)
}
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
_ = responseWriter.Close()
close(finished)
errGroup.Wait()
_ = errGroup.Wait()
}
}
@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
defer wg.Done()
log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
require.Equal(t, err.Error(), "context canceled")
require.Equal(t, "context canceled", err.Error())
require.Equal(t, http.StatusOK, responseWriter.Code)
}()
@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
for i := 0; i < pushCount; i++ {
line := responseWriter.ReadBytes()
expect := fmt.Sprintf("%d\n\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
}
cancel()
@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) {
responseWriter := newMockSSERespWriter()
// responseWriter uses an unbuffered channel, so we call in a different go-routine
go cfio.Copy(responseWriter, eyeballReader)
go func() {
_, _ = cfio.Copy(responseWriter, eyeballReader)
}()
result := string(<-responseWriter.writeNotification)
require.Equal(t, "data\r\r", result)
@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
for _, test := range tests {
responseWriter := newMockHTTPRespWriter()
@ -414,25 +416,20 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
require.NoError(t, err)
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
}
type replayer struct {
sync.RWMutex
writeDone chan struct{}
rw *bytes.Buffer
}
func newReplayer(buffer *bytes.Buffer) {
}
func (r *replayer) Read(p []byte) (int, error) {
r.RLock()
defer r.RUnlock()
@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte {
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
func TestConnections(t *testing.T) {
logger := logger.Create(nil)
replayer := &replayer{rw: &bytes.Buffer{}}
replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
type args struct {
ingressServiceScheme string
originService func(*testing.T, net.Listener)
@ -486,6 +483,9 @@ func TestConnections(t *testing.T) {
// requestheaders to be sent in the call to proxy.Proxy
requestHeaders http.Header
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
sessionLimiterResponse error
}
type want struct {
@ -663,6 +663,25 @@ func TestConnections(t *testing.T) {
err: true,
},
},
{
name: "tcp-* proxy rate limited flow",
args: args{
ingressServiceScheme: "tcp://",
originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
},
want: want{
message: []byte{},
err: true,
},
},
}
for _, test := range tests {
@ -674,8 +693,16 @@ func TestConnections(t *testing.T) {
test.args.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
ingressRule.StartOrigins(logger, ctx.Done())
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
_ = ingressRule.StartOrigins(logger, ctx.Done())
// Mock session limiter
ctrl := gomock.NewController(t)
defer ctrl.Finish()
sessionLimiter := mocks.NewMockLimiter(ctrl)
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
sessionLimiter.EXPECT().Release().AnyTimes()
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String()
@ -693,7 +720,7 @@ func TestConnections(t *testing.T) {
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
go func() {
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
replayer.Write(resp)
_, _ = replayer.Write(resp)
}()
}
if test.args.connectionType == connection.TypeTCP {
@ -705,9 +732,9 @@ func TestConnections(t *testing.T) {
}
cancel()
assert.Equal(t, test.want.err, err != nil)
assert.Equal(t, test.want.message, replayer.Bytes())
assert.Equal(t, test.want.headers, respWriter.Header())
require.Equal(t, test.want.err, err != nil)
require.Equal(t, test.want.message, replayer.Bytes())
require.Equal(t, test.want.headers, respWriter.Header())
replayer.rw.Reset()
})
}
@ -720,7 +747,9 @@ type requestBody struct {
func newWSRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go wsutil.WriteClientBinary(pw, data)
go func() {
_ = wsutil.WriteClientBinary(pw, data)
}()
return &requestBody{
pr: pr,
pw: pw,
@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody {
}
func newTCPRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go pw.Write(data)
go func() {
_, _ = pw.Write(data)
}()
return &requestBody{
pr: pr,
pw: pw,
@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) {
}
func (r *requestBody) Close() error {
r.pw.Close()
r.pr.Close()
_ = r.pw.Close()
_ = r.pr.Close()
return nil
}
@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte {
panic(err)
}
defer conn.Close()
defer resp.Body.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
go func() {
for {
conn, err := l.Accept()
require.NoError(t, err)
if err != nil {
panic(err)
}
defer conn.Close()
for {
@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) {
}
}
// nolint: gosec
server := http.Server{
Handler: http.HandlerFunc(ws),
}
go func() {
err := server.Serve(l)
require.NoError(t, err)
if err != nil {
panic(err)
}
}()
}

View File

@ -18,6 +18,11 @@ const (
ConnectionTypeTCP
)
var (
// ErrorFlowConnectRateLimitedKey is the Metadata entry that allows to know if a request was rate limited on connect.
ErrorFlowConnectRateLimitedKey = Metadata{Key: "FlowConnectRateLimited", Val: "true"}
)
func (c ConnectionType) String() string {
switch c {
case ConnectionTypeHTTP:

View File

@ -38,6 +38,7 @@ func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata
if respErr != nil {
connectResponse = &pogs.ConnectResponse{
Error: respErr.Error(),
Metadata: metadata,
}
} else {
connectResponse = &pogs.ConnectResponse{

View File

@ -98,12 +98,7 @@ func TestConnectResponseMeta(t *testing.T) {
reqClientStream := RequestClientStream{noopCloser{b}}
respMeta, err := reqClientStream.ReadConnectResponseData()
require.NoError(t, err)
if respMeta.Error == "" {
assert.Equal(t, test.metadata, respMeta.Metadata)
} else {
assert.Equal(t, 0, len(respMeta.Metadata))
}
require.Equal(t, test.metadata, respMeta.Metadata)
})
}
}
@ -153,21 +148,21 @@ func TestRegisterUdpSession(t *testing.T) {
}()
rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second)
assert.NoError(t, err)
require.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
require.NoError(t, err)
require.NoError(t, reg.Err)
// Different sessionID, the RPC server should reject the registraion
// Different sessionID, the RPC server should reject the registration
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.Error(t, reg.Err)
require.NoError(t, err)
require.Error(t, reg.Err)
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
require.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
// Different sessionID, the RPC server should reject the unregistraion
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
// Different sessionID, the RPC server should reject the unregistration
require.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
@ -200,10 +195,10 @@ func TestManageConfiguration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second)
assert.NoError(t, err)
require.NoError(t, err)
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
assert.NoError(t, err)
require.NoError(t, err)
require.Equal(t, version, result.LastAppliedVersion)
require.NoError(t, result.Err)