TUN-8861: Add session limiter to TCP session manager
## Summary In order to make cloudflared behavior more predictable and prevent an exhaustion of resources, we have decided to add session limits that can be configured by the user. This commit adds the session limiter to the HTTP/TCP handling path. For now the limiter is set to run only in unlimited mode.
This commit is contained in:
parent
bf4954e96a
commit
8bfe111cab
|
@ -2,14 +2,18 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
pkgerrors "github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/stream"
|
"github.com/cloudflare/cloudflared/stream"
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
|
||||||
return wsFlakyEndpoint(w, req)
|
return wsFlakyEndpoint(w, req)
|
||||||
default:
|
default:
|
||||||
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
|
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
|
||||||
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
|
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch req.URL.Path {
|
switch req.URL.Path {
|
||||||
|
@ -95,7 +99,6 @@ func (moc *mockOriginProxy) ProxyHTTP(
|
||||||
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
|
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moc *mockOriginProxy) ProxyTCP(
|
func (moc *mockOriginProxy) ProxyTCP(
|
||||||
|
@ -103,6 +106,10 @@ func (moc *mockOriginProxy) ProxyTCP(
|
||||||
rwa ReadWriteAcker,
|
rwa ReadWriteAcker,
|
||||||
r *TCPRequest,
|
r *TCPRequest,
|
||||||
) error {
|
) error {
|
||||||
|
if r.CfTraceID == "flow-rate-limited" {
|
||||||
|
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
|
||||||
|
|
||||||
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
|
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
|
||||||
|
|
||||||
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
|
rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
|
||||||
|
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
|
||||||
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
|
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
|
||||||
stream.Pipe(wsConn, originConn, &log)
|
stream.Pipe(wsConn, originConn, &log)
|
||||||
cancel()
|
cancel()
|
||||||
|
|
|
@ -22,8 +22,9 @@ var (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// pre-generate possible values for res
|
// pre-generate possible values for res
|
||||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
|
||||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
|
||||||
|
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPHeader is a custom header struct that expects only ever one value for the header.
|
// HTTPHeader is a custom header struct that expects only ever one value for the header.
|
||||||
|
@ -35,10 +36,11 @@ type HTTPHeader struct {
|
||||||
|
|
||||||
type responseMetaHeader struct {
|
type responseMetaHeader struct {
|
||||||
Source string `json:"src"`
|
Source string `json:"src"`
|
||||||
|
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustInitRespMetaHeader(src string) string {
|
func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
|
||||||
header, err := json.Marshal(responseMetaHeader{Source: src})
|
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
|
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
|
||||||
}
|
}
|
||||||
|
@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string {
|
||||||
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
|
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
|
||||||
const unableToDeserializeErr = "Unable to deserialize headers"
|
const unableToDeserializeErr = "Unable to deserialize headers"
|
||||||
|
|
||||||
var deserialized []HTTPHeader
|
deserialized := make([]HTTPHeader, 0)
|
||||||
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
||||||
if len(serializedPair) == 0 {
|
if len(serializedPair) == 0 {
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -16,6 +16,8 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
|
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
|
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
|
||||||
|
|
||||||
// WriteErrorResponse will return false if status was already written. we need to abort handler.
|
// WriteErrorResponse will return false if status was already written. we need to abort handler.
|
||||||
if !respWriter.WriteErrorResponse() {
|
if !respWriter.WriteErrorResponse(requestErr) {
|
||||||
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
|
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
|
||||||
panic(http.ErrAbortHandler)
|
panic(http.ErrAbortHandler)
|
||||||
}
|
}
|
||||||
|
@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
|
||||||
w: w,
|
w: w,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
respWriter.WriteErrorResponse()
|
err := fmt.Errorf("%T doesn't implement http.Flusher", w)
|
||||||
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
|
respWriter.WriteErrorResponse(err)
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http2RespWriter{
|
return &http2RespWriter{
|
||||||
|
@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
|
||||||
rp.log.Warn().Msg("WriteHeader after hijack")
|
rp.log.Warn().Msg("WriteHeader after hijack")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rp.WriteRespHeaders(status, rp.respHeaders)
|
_ = rp.WriteRespHeaders(status, rp.respHeaders)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) hijacked() bool {
|
func (rp *http2RespWriter) hijacked() bool {
|
||||||
|
@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
return conn, readWriter, nil
|
return conn, readWriter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) WriteErrorResponse() bool {
|
func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
|
||||||
if rp.statusWritten {
|
if rp.statusWritten {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
|
||||||
|
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
|
||||||
|
} else {
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
||||||
|
}
|
||||||
rp.w.WriteHeader(http.StatusBadGateway)
|
rp.w.WriteHeader(http.StatusBadGateway)
|
||||||
rp.statusWritten = true
|
rp.statusWritten = true
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
@ -65,19 +67,18 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
http2Conn.Serve(ctx)
|
_ = http2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://localhost:8080/ok")
|
|
||||||
reqBody := []byte(`{
|
reqBody := []byte(`{
|
||||||
"version": 2,
|
"version": 2,
|
||||||
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
|
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
|
||||||
`)
|
`)
|
||||||
reader := bytes.NewReader(reqBody)
|
reader := bytes.NewReader(reqBody)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
|
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
|
||||||
|
|
||||||
|
@ -85,11 +86,11 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
bdy, err := io.ReadAll(resp.Body)
|
bdy, err := io.ReadAll(resp.Body)
|
||||||
|
defer resp.Body.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
|
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
|
||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServeHTTP(t *testing.T) {
|
func TestServeHTTP(t *testing.T) {
|
||||||
|
@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
http2Conn.Serve(ctx)
|
_ = http2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
|
@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, test.expectedBody, respBody)
|
require.Equal(t, test.expectedBody, respBody)
|
||||||
}
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
if test.isProxyError {
|
if test.isProxyError {
|
||||||
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
|
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
|
||||||
} else {
|
} else {
|
||||||
|
@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) {
|
||||||
|
|
||||||
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
resp := respWriter.Result()
|
resp := respWriter.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
// http2RespWriter should rewrite status 101 to 200
|
// http2RespWriter should rewrite status 101 to 200
|
||||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||||
|
@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
||||||
serverDone := make(chan struct{})
|
serverDone := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer close(serverDone)
|
defer close(serverDone)
|
||||||
cfdHTTP2Conn.Serve(ctx)
|
_ = cfdHTTP2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
edgeTransport := http2.Transport{}
|
edgeTransport := http2.Transport{}
|
||||||
|
@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
||||||
readPipe, writePipe := io.Pipe()
|
readPipe, writePipe := io.Pipe()
|
||||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||||
|
|
||||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
// http2RespWriter should rewrite status 101 to 200
|
// http2RespWriter should rewrite status 101 to 200
|
||||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
http2Conn.Serve(ctx)
|
_ = http2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
|
@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
edgeHTTP2Conn.RoundTrip(req)
|
// nolint: bodyclose
|
||||||
|
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
<-rpcClientFactory.registered
|
<-rpcClientFactory.registered
|
||||||
|
@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
http2Conn.Serve(ctx)
|
_ = http2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
|
@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
|
||||||
assert.NotNil(t, http2Conn.controlStreamErr)
|
require.Error(t, http2Conn.controlStreamErr)
|
||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
http2Conn.Serve(ctx)
|
_ = http2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
|
@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
// nolint: bodyclose
|
||||||
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServeTCP_RateLimited(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = http2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
|
||||||
|
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
|
||||||
|
|
||||||
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
http2Conn.Serve(ctx)
|
_ = http2Conn.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
||||||
|
|
|
@ -17,6 +17,8 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||||
|
|
||||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -108,7 +110,6 @@ func (q *quicConnection) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
return err
|
return err
|
||||||
|
|
||||||
})
|
})
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -129,7 +130,7 @@ func (q *quicConnection) serveControlStream(ctx context.Context, controlStream q
|
||||||
|
|
||||||
// Close the connection with no errors specified.
|
// Close the connection with no errors specified.
|
||||||
func (q *quicConnection) Close() {
|
func (q *quicConnection) Close() {
|
||||||
q.conn.CloseWithError(0, "")
|
_ = q.conn.CloseWithError(0, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *quicConnection) acceptStream(ctx context.Context) error {
|
func (q *quicConnection) acceptStream(ctx context.Context) error {
|
||||||
|
@ -182,7 +183,13 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
|
var metadata []pogs.Metadata
|
||||||
|
// Check the type of error that was throw and add metadata that will help identify it on OTD.
|
||||||
|
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
|
||||||
|
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil {
|
||||||
return writeRespErr
|
return writeRespErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -278,7 +285,7 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
|
||||||
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
|
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
|
||||||
// Make sure to send WriteHeader response if not called yet
|
// Make sure to send WriteHeader response if not called yet
|
||||||
if !hrw.connectResponseSent {
|
if !hrw.connectResponseSent {
|
||||||
hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
|
_ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
|
||||||
}
|
}
|
||||||
return hrw.RequestServerStream.Write(p)
|
return hrw.RequestServerStream.Write(p)
|
||||||
}
|
}
|
||||||
|
@ -291,7 +298,7 @@ func (hrw *httpResponseAdapter) Header() http.Header {
|
||||||
func (hrw *httpResponseAdapter) Flush() {}
|
func (hrw *httpResponseAdapter) Flush() {}
|
||||||
|
|
||||||
func (hrw *httpResponseAdapter) WriteHeader(status int) {
|
func (hrw *httpResponseAdapter) WriteHeader(status int) {
|
||||||
hrw.WriteRespHeaders(status, hrw.headers)
|
_ = hrw.WriteRespHeaders(status, hrw.headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
@ -304,7 +311,7 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
|
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
|
||||||
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
_ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
|
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
@ -21,7 +22,7 @@ import (
|
||||||
|
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pkg/errors"
|
pkgerrors "github.com/pkg/errors"
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -506,6 +507,10 @@ func TestBuildHTTPRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
||||||
|
if tcpRequest.Dest == "rate-limit-me" {
|
||||||
|
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "failed tcp stream")
|
||||||
|
}
|
||||||
|
|
||||||
_ = rwa.AckConnection("")
|
_ = rwa.AckConnection("")
|
||||||
_, _ = io.Copy(rwa, rwa)
|
_, _ = io.Copy(rwa, rwa)
|
||||||
return nil
|
return nil
|
||||||
|
@ -597,6 +602,59 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPProxy_FlowRateLimited tests if the pogs.ConnectResponse returns the expected error and metadata, when a
|
||||||
|
// new flow is rate limited.
|
||||||
|
func TestTCPProxy_FlowRateLimited(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Start a UDP Listener for QUIC.
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer udpListener.Close()
|
||||||
|
|
||||||
|
quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16}
|
||||||
|
quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
|
|
||||||
|
session, err := quicListener.Accept(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
quicStream, err := session.OpenStreamSync(context.Background())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
stream := cfdquic.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
|
||||||
|
|
||||||
|
reqClientStream := rpcquic.RequestClientStream{ReadWriteCloser: stream}
|
||||||
|
err = reqClientStream.WriteConnectRequestData("rate-limit-me", pogs.ConnectionTypeTCP)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
response, err := reqClientStream.ReadConnectResponseData()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Got Rate Limited
|
||||||
|
assert.NotEmpty(t, response.Error)
|
||||||
|
assert.Contains(t, response.Metadata, pogs.ErrorFlowConnectRateLimitedKey)
|
||||||
|
}()
|
||||||
|
|
||||||
|
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(0))
|
||||||
|
|
||||||
|
connDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(connDone)
|
||||||
|
_ = tunnelConn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-serverDone
|
||||||
|
cancel()
|
||||||
|
<-connDone
|
||||||
|
}
|
||||||
|
|
||||||
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
|
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
|
||||||
logger := zerolog.Nop()
|
logger := zerolog.Nop()
|
||||||
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)
|
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)
|
||||||
|
|
|
@ -141,7 +141,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
|
||||||
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
|
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
|
||||||
return errors.Wrap(err, "failed to start origin")
|
return errors.Wrap(err, "failed to start origin")
|
||||||
}
|
}
|
||||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
|
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.sessionLimiter, o.config.WriteTimeout, o.log)
|
||||||
o.proxy.Store(proxy)
|
o.proxy.Store(proxy)
|
||||||
o.config.Ingress = &ingressRules
|
o.config.Ingress = &ingressRules
|
||||||
o.config.WarpRouting = warpRouting
|
o.config.WarpRouting = warpRouting
|
||||||
|
|
|
@ -9,10 +9,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
pkgerrors "github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/management"
|
||||||
|
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/cfio"
|
"github.com/cloudflare/cloudflared/cfio"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
@ -32,8 +36,8 @@ const (
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
ingressRules ingress.Ingress
|
ingressRules ingress.Ingress
|
||||||
warpRouting *ingress.WarpRoutingService
|
warpRouting *ingress.WarpRoutingService
|
||||||
management *ingress.ManagementService
|
|
||||||
tags []pogs.Tag
|
tags []pogs.Tag
|
||||||
|
sessionLimiter cfdsession.Limiter
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,12 +46,14 @@ func NewOriginProxy(
|
||||||
ingressRules ingress.Ingress,
|
ingressRules ingress.Ingress,
|
||||||
warpRouting ingress.WarpRoutingConfig,
|
warpRouting ingress.WarpRoutingConfig,
|
||||||
tags []pogs.Tag,
|
tags []pogs.Tag,
|
||||||
|
sessionLimiter cfdsession.Limiter,
|
||||||
writeTimeout time.Duration,
|
writeTimeout time.Duration,
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
) *Proxy {
|
) *Proxy {
|
||||||
proxy := &Proxy{
|
proxy := &Proxy{
|
||||||
ingressRules: ingressRules,
|
ingressRules: ingressRules,
|
||||||
tags: tags,
|
tags: tags,
|
||||||
|
sessionLimiter: sessionLimiter,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +70,7 @@ func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w co
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.ShouldFilterRequest {
|
if result.ShouldFilterRequest {
|
||||||
w.WriteRespHeaders(result.StatusCode, nil)
|
_ = w.WriteRespHeaders(result.StatusCode, nil)
|
||||||
return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true
|
return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,10 +158,18 @@ func (p *Proxy) ProxyTCP(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger := newTCPLogger(p.log, req)
|
||||||
|
|
||||||
|
// Try to start a new session
|
||||||
|
if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil {
|
||||||
|
logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy")
|
||||||
|
return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting")
|
||||||
|
}
|
||||||
|
defer p.sessionLimiter.Release()
|
||||||
|
|
||||||
serveCtx, cancel := context.WithCancel(ctx)
|
serveCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
logger := newTCPLogger(p.log, req)
|
|
||||||
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
|
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
|
||||||
logger.Debug().Msg("tcp proxy stream started")
|
logger.Debug().Msg("tcp proxy stream started")
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,13 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
"go.uber.org/mock/gomock"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/mocks"
|
||||||
|
|
||||||
|
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cfio"
|
"github.com/cloudflare/cloudflared/cfio"
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||||
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||||
}
|
}
|
||||||
|
|
||||||
// respHeaders is a test function to read respHeaders
|
|
||||||
func (w *mockHTTPRespWriter) headers() http.Header {
|
|
||||||
return w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
panic("Hijack not implemented")
|
panic("Hijack not implemented")
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||||
return w.reader.Read(data)
|
return w.reader.Read(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
panic("Hijack not implemented")
|
panic("Hijack not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||||
|
|
||||||
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
|
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||||
t.Run("testProxySSE", testProxySSE(proxy))
|
t.Run("testProxySSE", testProxySSE(proxy))
|
||||||
|
@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
_ = responseWriter.Close()
|
_ = responseWriter.Close()
|
||||||
|
|
||||||
close(finished)
|
close(finished)
|
||||||
errGroup.Wait()
|
_ = errGroup.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
|
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
|
||||||
require.Equal(t, err.Error(), "context canceled")
|
require.Equal(t, "context canceled", err.Error())
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, responseWriter.Code)
|
require.Equal(t, http.StatusOK, responseWriter.Code)
|
||||||
}()
|
}()
|
||||||
|
@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
for i := 0; i < pushCount; i++ {
|
for i := 0; i < pushCount; i++ {
|
||||||
line := responseWriter.ReadBytes()
|
line := responseWriter.ReadBytes()
|
||||||
expect := fmt.Sprintf("%d\n\n", i)
|
expect := fmt.Sprintf("%d\n\n", i)
|
||||||
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
|
||||||
}
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
|
@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) {
|
||||||
responseWriter := newMockSSERespWriter()
|
responseWriter := newMockSSERespWriter()
|
||||||
|
|
||||||
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
||||||
go cfio.Copy(responseWriter, eyeballReader)
|
go func() {
|
||||||
|
_, _ = cfio.Copy(responseWriter, eyeballReader)
|
||||||
|
}()
|
||||||
|
|
||||||
result := string(<-responseWriter.writeNotification)
|
result := string(<-responseWriter.writeNotification)
|
||||||
require.Equal(t, "data\r\r", result)
|
require.Equal(t, "data\r\r", result)
|
||||||
|
@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
|
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
|
@ -414,25 +416,20 @@ func TestProxyError(t *testing.T) {
|
||||||
|
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
|
||||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
|
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||||
|
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
|
require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
|
||||||
}
|
}
|
||||||
|
|
||||||
type replayer struct {
|
type replayer struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
writeDone chan struct{}
|
|
||||||
rw *bytes.Buffer
|
rw *bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func newReplayer(buffer *bytes.Buffer) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *replayer) Read(p []byte) (int, error) {
|
func (r *replayer) Read(p []byte) (int, error) {
|
||||||
r.RLock()
|
r.RLock()
|
||||||
defer r.RUnlock()
|
defer r.RUnlock()
|
||||||
|
@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte {
|
||||||
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
|
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
|
||||||
func TestConnections(t *testing.T) {
|
func TestConnections(t *testing.T) {
|
||||||
logger := logger.Create(nil)
|
logger := logger.Create(nil)
|
||||||
replayer := &replayer{rw: &bytes.Buffer{}}
|
replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
|
||||||
type args struct {
|
type args struct {
|
||||||
ingressServiceScheme string
|
ingressServiceScheme string
|
||||||
originService func(*testing.T, net.Listener)
|
originService func(*testing.T, net.Listener)
|
||||||
|
@ -486,6 +483,9 @@ func TestConnections(t *testing.T) {
|
||||||
|
|
||||||
// requestheaders to be sent in the call to proxy.Proxy
|
// requestheaders to be sent in the call to proxy.Proxy
|
||||||
requestHeaders http.Header
|
requestHeaders http.Header
|
||||||
|
|
||||||
|
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
|
||||||
|
sessionLimiterResponse error
|
||||||
}
|
}
|
||||||
|
|
||||||
type want struct {
|
type want struct {
|
||||||
|
@ -663,6 +663,25 @@ func TestConnections(t *testing.T) {
|
||||||
err: true,
|
err: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "tcp-* proxy rate limited flow",
|
||||||
|
args: args{
|
||||||
|
ingressServiceScheme: "tcp://",
|
||||||
|
originService: runEchoTCPService,
|
||||||
|
eyeballResponseWriter: newTCPRespWriter(replayer),
|
||||||
|
eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
|
||||||
|
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
|
||||||
|
connectionType: connection.TypeTCP,
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||||
|
},
|
||||||
|
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
message: []byte{},
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
@ -674,8 +693,16 @@ func TestConnections(t *testing.T) {
|
||||||
test.args.originService(t, ln)
|
test.args.originService(t, ln)
|
||||||
|
|
||||||
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
||||||
ingressRule.StartOrigins(logger, ctx.Done())
|
_ = ingressRule.StartOrigins(logger, ctx.Done())
|
||||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
|
|
||||||
|
// Mock session limiter
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
sessionLimiter := mocks.NewMockLimiter(ctrl)
|
||||||
|
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
|
||||||
|
sessionLimiter.EXPECT().Release().AnyTimes()
|
||||||
|
|
||||||
|
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
|
||||||
proxy.warpRouting = test.args.warpRoutingService
|
proxy.warpRouting = test.args.warpRoutingService
|
||||||
|
|
||||||
dest := ln.Addr().String()
|
dest := ln.Addr().String()
|
||||||
|
@ -693,7 +720,7 @@ func TestConnections(t *testing.T) {
|
||||||
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
|
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
|
||||||
go func() {
|
go func() {
|
||||||
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
|
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
|
||||||
replayer.Write(resp)
|
_, _ = replayer.Write(resp)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if test.args.connectionType == connection.TypeTCP {
|
if test.args.connectionType == connection.TypeTCP {
|
||||||
|
@ -705,9 +732,9 @@ func TestConnections(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
assert.Equal(t, test.want.err, err != nil)
|
require.Equal(t, test.want.err, err != nil)
|
||||||
assert.Equal(t, test.want.message, replayer.Bytes())
|
require.Equal(t, test.want.message, replayer.Bytes())
|
||||||
assert.Equal(t, test.want.headers, respWriter.Header())
|
require.Equal(t, test.want.headers, respWriter.Header())
|
||||||
replayer.rw.Reset()
|
replayer.rw.Reset()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -720,7 +747,9 @@ type requestBody struct {
|
||||||
|
|
||||||
func newWSRequestBody(data []byte) *requestBody {
|
func newWSRequestBody(data []byte) *requestBody {
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
go wsutil.WriteClientBinary(pw, data)
|
go func() {
|
||||||
|
_ = wsutil.WriteClientBinary(pw, data)
|
||||||
|
}()
|
||||||
return &requestBody{
|
return &requestBody{
|
||||||
pr: pr,
|
pr: pr,
|
||||||
pw: pw,
|
pw: pw,
|
||||||
|
@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody {
|
||||||
}
|
}
|
||||||
func newTCPRequestBody(data []byte) *requestBody {
|
func newTCPRequestBody(data []byte) *requestBody {
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
go pw.Write(data)
|
go func() {
|
||||||
|
_, _ = pw.Write(data)
|
||||||
|
}()
|
||||||
return &requestBody{
|
return &requestBody{
|
||||||
pr: pr,
|
pr: pr,
|
||||||
pw: pw,
|
pw: pw,
|
||||||
|
@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *requestBody) Close() error {
|
func (r *requestBody) Close() error {
|
||||||
r.pw.Close()
|
_ = r.pw.Close()
|
||||||
r.pr.Close()
|
_ = r.pr.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
|
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
|
||||||
|
@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint: gosec
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
Handler: http.HandlerFunc(ws),
|
Handler: http.HandlerFunc(ws),
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := server.Serve(l)
|
err := server.Serve(l)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,11 @@ const (
|
||||||
ConnectionTypeTCP
|
ConnectionTypeTCP
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrorFlowConnectRateLimitedKey is the Metadata entry that allows to know if a request was rate limited on connect.
|
||||||
|
ErrorFlowConnectRateLimitedKey = Metadata{Key: "FlowConnectRateLimited", Val: "true"}
|
||||||
|
)
|
||||||
|
|
||||||
func (c ConnectionType) String() string {
|
func (c ConnectionType) String() string {
|
||||||
switch c {
|
switch c {
|
||||||
case ConnectionTypeHTTP:
|
case ConnectionTypeHTTP:
|
||||||
|
|
|
@ -38,6 +38,7 @@ func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
connectResponse = &pogs.ConnectResponse{
|
connectResponse = &pogs.ConnectResponse{
|
||||||
Error: respErr.Error(),
|
Error: respErr.Error(),
|
||||||
|
Metadata: metadata,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
connectResponse = &pogs.ConnectResponse{
|
connectResponse = &pogs.ConnectResponse{
|
||||||
|
|
|
@ -98,12 +98,7 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||||
respMeta, err := reqClientStream.ReadConnectResponseData()
|
respMeta, err := reqClientStream.ReadConnectResponseData()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, test.metadata, respMeta.Metadata)
|
||||||
if respMeta.Error == "" {
|
|
||||||
assert.Equal(t, test.metadata, respMeta.Metadata)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, 0, len(respMeta.Metadata))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,21 +148,21 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second)
|
rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NoError(t, reg.Err)
|
require.NoError(t, reg.Err)
|
||||||
|
|
||||||
// Different sessionID, the RPC server should reject the registraion
|
// Different sessionID, the RPC server should reject the registration
|
||||||
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Error(t, reg.Err)
|
require.Error(t, reg.Err)
|
||||||
|
|
||||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
|
require.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
|
||||||
|
|
||||||
// Different sessionID, the RPC server should reject the unregistraion
|
// Different sessionID, the RPC server should reject the unregistration
|
||||||
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
require.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
||||||
|
|
||||||
rpcClientStream.Close()
|
rpcClientStream.Close()
|
||||||
<-sessionRegisteredChan
|
<-sessionRegisteredChan
|
||||||
|
@ -200,10 +195,10 @@ func TestManageConfiguration(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second)
|
rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
|
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, version, result.LastAppliedVersion)
|
require.Equal(t, version, result.LastAppliedVersion)
|
||||||
require.NoError(t, result.Err)
|
require.NoError(t, result.Err)
|
||||||
|
|
Loading…
Reference in New Issue