TUN-8236: Add write timeout to quic and tcp connections

## Summary
To prevent bad eyeballs and severs to be able to exhaust the quic
control flows we are adding the possibility of having a timeout
for a write operation to be acknowledged. This will prevent hanging
connections from exhausting the quic control flows, creating a DDoS.
This commit is contained in:
João "Pisco" Fernandes 2024-02-12 18:58:55 +00:00
parent 56aeb6be65
commit 76badfa01b
18 changed files with 146 additions and 54 deletions

View File

@ -81,6 +81,9 @@ const (
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge // udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout" udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size. // quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets. // Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
@ -696,6 +699,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: 5 * time.Second, Value: 5 * time.Second,
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: writeStreamTimeout,
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
Value: 0 * time.Second,
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: quicDisablePathMTUDiscovery, Name: quicDisablePathMTUDiscovery,
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"}, EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},

View File

@ -247,6 +247,7 @@ func prepareTunnelConfig(
FeatureSelector: featureSelector, FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag), UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery), DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
} }
packetConfig, err := newPacketConfig(c, log) packetConfig, err := newPacketConfig(c, log)
@ -259,6 +260,7 @@ func prepareTunnelConfig(
Ingress: &ingressRules, Ingress: &ingressRules,
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
ConfigurationFlags: parseConfigFlags(c), ConfigurationFlags: parseConfigFlags(c),
WriteTimeout: c.Duration(writeStreamTimeout),
} }
return tunnelConfig, orchestratorConfig, nil return tunnelConfig, orchestratorConfig, nil
} }

View File

@ -66,6 +66,7 @@ type QUICConnection struct {
connIndex uint8 connIndex uint8
udpUnregisterTimeout time.Duration udpUnregisterTimeout time.Duration
streamWriteTimeout time.Duration
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -82,6 +83,7 @@ func NewQUICConnection(
logger *zerolog.Logger, logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig, packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration, udpUnregisterTimeout time.Duration,
streamWriteTimeout time.Duration,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger) udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil { if err != nil {
@ -117,6 +119,7 @@ func NewQUICConnection(
connOptions: connOptions, connOptions: connOptions,
connIndex: connIndex, connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout, udpUnregisterTimeout: udpUnregisterTimeout,
streamWriteTimeout: streamWriteTimeout,
}, nil }, nil
} }
@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
func (q *QUICConnection) runStream(quicStream quic.Stream) { func (q *QUICConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context() ctx := quicStream.Context()
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close() defer stream.Close()
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
@ -373,7 +376,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
return return
} }
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close() defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger) rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
if err != nil { if err != nil {

View File

@ -35,6 +35,7 @@ var (
KeepAlivePeriod: 5 * time.Second, KeepAlivePeriod: 5 * time.Second,
EnableDatagrams: true, EnableDatagrams: true,
} }
defaultQUICTimeout = 30 * time.Second
) )
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil) var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
@ -197,7 +198,7 @@ func quicServer(
quicStream, err := session.OpenStreamSync(context.Background()) quicStream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err) require.NoError(t, err)
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream} reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...) err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
&log, &log,
nil, nil,
5*time.Second, 5*time.Second,
0*time.Second,
) )
require.NoError(t, err) require.NoError(t, err)
return qc return qc

View File

@ -0,0 +1,7 @@
package ingress
import "github.com/cloudflare/cloudflared/logger"
var (
TestLogger = logger.Create(nil)
)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
// tcpConnection is an OriginConnection that directly streams to raw TCP. // tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct { type tcpConnection struct {
conn net.Conn net.Conn
writeTimeout time.Duration
logger *zerolog.Logger
} }
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
stream.Pipe(tunnelConn, tc.conn, log) stream.Pipe(tunnelConn, tc, tc.logger)
}
func (tc *tcpConnection) Write(b []byte) (int, error) {
if tc.writeTimeout > 0 {
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
}
}
nBytes, err := tc.Conn.Write(b)
if err != nil {
tc.logger.Err(err).Msg("Error writing to the TCP connection")
}
return nBytes, err
} }
func (tc *tcpConnection) Close() { func (tc *tcpConnection) Close() {
tc.conn.Close() tc.Conn.Close()
} }
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS. // tcpOverWSConnection is an OriginConnection that streams to TCP over WS.

View File

@ -19,7 +19,6 @@ import (
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
@ -31,7 +30,6 @@ const (
) )
var ( var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection") testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage)) testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
) )
@ -39,7 +37,8 @@ var (
func TestStreamTCPConnection(t *testing.T) { func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{ tcpConn := tcpConnection{
conn: cfdConn, Conn: cfdConn,
writeTimeout: 30 * time.Second,
} }
eyeballConn, edgeConn := net.Pipe() eyeballConn, edgeConn := net.Pipe()
@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
return nil return nil
}) })
tcpConn.Stream(ctx, edgeConn, testLogger) tcpConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
} }
@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
return nil return nil
}) })
tcpOverWSConn.Stream(ctx, edgeConn, testLogger) tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
} }
@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
tcpOverWSConn.Stream(ctx, edgeConn, testLogger) tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
return nil return nil
}) })
@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer wsForwarderInConn.Close() defer wsForwarderInConn.Close()
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger) stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
return nil return nil
}) })
@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
originConn.Close() originConn.Close()
}() }()
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond) ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger) tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
}) })
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"github.com/rs/zerolog"
) )
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests. // HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@ -14,7 +16,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface { type StreamBasedOriginProxy interface {
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
} }
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests. // HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
@ -62,19 +64,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return resp, nil return resp, nil
} }
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
conn, err := o.dialer.DialContext(ctx, "tcp", dest) conn, err := o.dialer.DialContext(ctx, "tcp", dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
originConn := &tcpConnection{ originConn := &tcpConnection{
conn: conn, Conn: conn,
writeTimeout: o.writeTimeout,
logger: logger,
} }
return originConn, nil return originConn, nil
} }
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
var err error var err error
if !o.isBastion { if !o.isBastion {
dest = o.dest dest = o.dest
@ -92,6 +96,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)
} }
func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) { func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
return o.conn, nil return o.conn, nil
} }

View File

@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Origin not listening for new connection, should return an error // Origin not listening for new connection, should return an error
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String()) _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
require.Error(t, err) require.Error(t, err)
} }
@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
t.Run(test.testCase, func(t *testing.T) { t.Run(test.testCase, func(t *testing.T) {
if test.expectErr { if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req) bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(context.Background(), bastionHost) _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err) assert.Error(t, err)
} }
}) })
@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error // Origin not listening for new connection, should return an error
bastionHost, _ := carrier.ResolveBastionDest(bastionReq) bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(context.Background(), bastionHost) _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err) assert.Error(t, err)
} }
} }
@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
url: originURL, url: originURL,
} }
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err) require.NoError(t, err)
@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
url: originURL, url: originURL,
} }
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
protos := []string{"https", "http", "dne"} protos := []string{"https", "http", "dne"}

View File

@ -96,13 +96,15 @@ func (o httpService) MarshalJSON() ([]byte, error) {
type rawTCPService struct { type rawTCPService struct {
name string name string
dialer net.Dialer dialer net.Dialer
writeTimeout time.Duration
logger *zerolog.Logger
} }
func (o *rawTCPService) String() string { func (o *rawTCPService) String() string {
return o.name return o.name
} }
func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
return nil return nil
} }
@ -285,13 +287,14 @@ type WarpRoutingService struct {
Proxy StreamBasedOriginProxy Proxy StreamBasedOriginProxy
} }
func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService { func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
svc := &rawTCPService{ svc := &rawTCPService{
name: ServiceWarpRouting, name: ServiceWarpRouting,
dialer: net.Dialer{ dialer: net.Dialer{
Timeout: config.ConnectTimeout.Duration, Timeout: config.ConnectTimeout.Duration,
KeepAlive: config.TCPKeepAlive.Duration, KeepAlive: config.TCPKeepAlive.Duration,
}, },
writeTimeout: writeTimeout,
} }
return &WarpRoutingService{Proxy: svc} return &WarpRoutingService{Proxy: svc}

View File

@ -2,6 +2,7 @@ package orchestration
import ( import (
"encoding/json" "encoding/json"
"time"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
@ -21,6 +22,7 @@ type newLocalConfig struct {
type Config struct { type Config struct {
Ingress *ingress.Ingress Ingress *ingress.Ingress
WarpRouting ingress.WarpRoutingConfig WarpRouting ingress.WarpRoutingConfig
WriteTimeout time.Duration
// Extra settings used to configure this instance but that are not eligible for remotely management // Extra settings used to configure this instance but that are not eligible for remotely management
// ie. (--protocol, --loglevel, ...) // ie. (--protocol, --loglevel, ...)

View File

@ -17,10 +17,10 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// Orchestrator manages configurations so they can be updatable during runtime // Orchestrator manages configurations, so they can be updatable during runtime
// properties are static, so it can be read without lock // properties are static, so it can be read without lock
// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex // currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex
// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently // access to proxy is synchronized with atomic.Value, because it uses copy-on-write to provide scalable frequently
// read when update is infrequent // read when update is infrequent
type Orchestrator struct { type Orchestrator struct {
currentVersion int32 currentVersion int32
@ -30,6 +30,7 @@ type Orchestrator struct {
proxy atomic.Value proxy atomic.Value
// Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules) // Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules)
internalRules []ingress.Rule internalRules []ingress.Rule
// cloudflared Configuration
config *Config config *Config
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
log *zerolog.Logger log *zerolog.Logger
@ -40,7 +41,11 @@ type Orchestrator struct {
proxyShutdownC chan<- struct{} proxyShutdownC chan<- struct{}
} }
func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, internalRules []ingress.Rule, log *zerolog.Logger) (*Orchestrator, error) { func NewOrchestrator(ctx context.Context,
config *Config,
tags []tunnelpogs.Tag,
internalRules []ingress.Rule,
log *zerolog.Logger) (*Orchestrator, error) {
o := &Orchestrator{ o := &Orchestrator{
// Lowest possible version, any remote configuration will have version higher than this // Lowest possible version, any remote configuration will have version higher than this
// Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it // Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it
@ -131,7 +136,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
return errors.Wrap(err, "failed to start origin") return errors.Wrap(err, "failed to start origin")
} }
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.log) proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
o.proxy.Store(proxy) o.proxy.Store(proxy)
o.config.Ingress = &ingressRules o.config.Ingress = &ingressRules
o.config.WarpRouting = warpRouting o.config.WarpRouting = warpRouting

View File

@ -27,7 +27,7 @@ import (
) )
var ( var (
testLogger = zerolog.Logger{} testLogger = zerolog.Nop()
testTags = []tunnelpogs.Tag{ testTags = []tunnelpogs.Tag{
{ {
Name: "package", Name: "package",

View File

@ -51,6 +51,7 @@ func NewOriginProxy(
ingressRules ingress.Ingress, ingressRules ingress.Ingress,
warpRouting ingress.WarpRoutingConfig, warpRouting ingress.WarpRoutingConfig,
tags []tunnelpogs.Tag, tags []tunnelpogs.Tag,
writeTimeout time.Duration,
log *zerolog.Logger, log *zerolog.Logger,
) *Proxy { ) *Proxy {
proxy := &Proxy{ proxy := &Proxy{
@ -59,7 +60,7 @@ func NewOriginProxy(
log: log, log: log,
} }
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting) proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
return proxy return proxy
} }
@ -309,7 +310,7 @@ func (p *Proxy) proxyStream(
_, connectSpan := tr.Tracer().Start(ctx, "stream-connect") _, connectSpan := tr.Tracer().Start(ctx, "stream-connect")
start := time.Now() start := time.Now()
originConn, err := connectionProxy.EstablishConnection(ctx, dest) originConn, err := connectionProxy.EstablishConnection(ctx, dest, &logger)
if err != nil { if err != nil {
connectStreamErrors.Inc() connectStreamErrors.Inc()
tracing.EndWithErrorStatus(connectSpan, err) tracing.EndWithErrorStatus(connectSpan, err)

View File

@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, &log) proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSE", testProxySSE(proxy))
@ -366,7 +366,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, &log) proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
for _, test := range tests { for _, test := range tests {
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
@ -414,7 +414,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, &log) proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -530,7 +530,7 @@ func TestConnections(t *testing.T) {
originService: runEchoTCPService, originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer), eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("test2")), eyeballRequestBody: newTCPRequestBody([]byte("test2")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
connectionType: connection.TypeTCP, connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{ requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
@ -548,7 +548,7 @@ func TestConnections(t *testing.T) {
originService: runEchoWSService, originService: runEchoWSService,
// eyeballResponseWriter gets set after roundtrip dial. // eyeballResponseWriter gets set after roundtrip dial.
eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
requestHeaders: map[string][]string{ requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
}, },
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
ingressRule.StartOrigins(logger, ctx.Done()) ingressRule.StartOrigins(logger, ctx.Done())
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, logger) proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String() dest := ln.Addr().String()

View File

@ -1,20 +1,28 @@
package quic package quic
import ( import (
"errors"
"net"
"sync" "sync"
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
) )
type SafeStreamCloser struct { type SafeStreamCloser struct {
lock sync.Mutex lock sync.Mutex
stream quic.Stream stream quic.Stream
writeTimeout time.Duration
log *zerolog.Logger
} }
func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser { func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
return &SafeStreamCloser{ return &SafeStreamCloser{
stream: stream, stream: stream,
writeTimeout: writeTimeout,
log: log,
} }
} }
@ -25,7 +33,29 @@ func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) { func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
return s.stream.Write(p) if s.writeTimeout > 0 {
err = s.stream.SetWriteDeadline(time.Now().Add(s.writeTimeout))
if err != nil {
log.Err(err).Msg("Error setting write deadline for QUIC stream")
}
}
nBytes, err := s.stream.Write(p)
if err != nil {
s.handleTimeout(err)
}
return nBytes, err
}
// Handles the timeout error in case it happened, by canceling the stream write.
func (s *SafeStreamCloser) handleTimeout(err error) {
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing")
s.stream.CancelWrite(0)
}
}
} }
func (s *SafeStreamCloser) Close() error { func (s *SafeStreamCloser) Close() error {

View File

@ -9,6 +9,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/zerolog"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -70,7 +72,8 @@ func quicClient(t *testing.T, addr net.Addr) {
go func(iter int) { go func(iter int) {
defer wg.Done() defer wg.Done()
stream := NewSafeStreamCloser(quicStream) log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close() defer stream.Close()
// Do a bunch of round trips over this stream that should work. // Do a bunch of round trips over this stream that should work.
@ -107,7 +110,8 @@ func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn)
go func(iter int) { go func(iter int) {
defer wg.Done() defer wg.Done()
stream := NewSafeStreamCloser(quicStream) log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close() defer stream.Close()
// Do a bunch of round trips over this stream that should work. // Do a bunch of round trips over this stream that should work.

View File

@ -66,6 +66,7 @@ type TunnelConfig struct {
PacketConfig *ingress.GlobalRouterConfig PacketConfig *ingress.GlobalRouterConfig
UDPUnregisterSessionTimeout time.Duration UDPUnregisterSessionTimeout time.Duration
WriteStreamTimeout time.Duration
DisableQUICPathMTUDiscovery bool DisableQUICPathMTUDiscovery bool
@ -614,6 +615,7 @@ func (e *EdgeTunnelServer) serveQUIC(
connLogger.Logger(), connLogger.Logger(),
e.config.PacketConfig, e.config.PacketConfig,
e.config.UDPUnregisterSessionTimeout, e.config.UDPUnregisterSessionTimeout,
e.config.WriteStreamTimeout,
) )
if err != nil { if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")