TUN-8236: Add write timeout to quic and tcp connections
## Summary To prevent bad eyeballs and severs to be able to exhaust the quic control flows we are adding the possibility of having a timeout for a write operation to be acknowledged. This will prevent hanging connections from exhausting the quic control flows, creating a DDoS.
This commit is contained in:
parent
56aeb6be65
commit
76badfa01b
|
@ -81,6 +81,9 @@ const (
|
||||||
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
|
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
|
||||||
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
|
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
|
||||||
|
|
||||||
|
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
|
||||||
|
writeStreamTimeout = "write-stream-timeout"
|
||||||
|
|
||||||
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
|
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
|
||||||
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
||||||
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
|
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
|
||||||
|
@ -696,6 +699,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Value: 5 * time.Second,
|
Value: 5 * time.Second,
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: writeStreamTimeout,
|
||||||
|
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
|
||||||
|
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
|
||||||
|
Value: 0 * time.Second,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
Name: quicDisablePathMTUDiscovery,
|
Name: quicDisablePathMTUDiscovery,
|
||||||
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
|
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
|
||||||
|
|
|
@ -247,6 +247,7 @@ func prepareTunnelConfig(
|
||||||
FeatureSelector: featureSelector,
|
FeatureSelector: featureSelector,
|
||||||
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
|
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
|
||||||
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
|
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
|
||||||
|
WriteStreamTimeout: c.Duration(writeStreamTimeout),
|
||||||
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
|
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
|
||||||
}
|
}
|
||||||
packetConfig, err := newPacketConfig(c, log)
|
packetConfig, err := newPacketConfig(c, log)
|
||||||
|
@ -259,6 +260,7 @@ func prepareTunnelConfig(
|
||||||
Ingress: &ingressRules,
|
Ingress: &ingressRules,
|
||||||
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
|
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
|
||||||
ConfigurationFlags: parseConfigFlags(c),
|
ConfigurationFlags: parseConfigFlags(c),
|
||||||
|
WriteTimeout: c.Duration(writeStreamTimeout),
|
||||||
}
|
}
|
||||||
return tunnelConfig, orchestratorConfig, nil
|
return tunnelConfig, orchestratorConfig, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,6 +66,7 @@ type QUICConnection struct {
|
||||||
connIndex uint8
|
connIndex uint8
|
||||||
|
|
||||||
udpUnregisterTimeout time.Duration
|
udpUnregisterTimeout time.Duration
|
||||||
|
streamWriteTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQUICConnection returns a new instance of QUICConnection.
|
// NewQUICConnection returns a new instance of QUICConnection.
|
||||||
|
@ -82,6 +83,7 @@ func NewQUICConnection(
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
packetRouterConfig *ingress.GlobalRouterConfig,
|
packetRouterConfig *ingress.GlobalRouterConfig,
|
||||||
udpUnregisterTimeout time.Duration,
|
udpUnregisterTimeout time.Duration,
|
||||||
|
streamWriteTimeout time.Duration,
|
||||||
) (*QUICConnection, error) {
|
) (*QUICConnection, error) {
|
||||||
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
|
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -117,6 +119,7 @@ func NewQUICConnection(
|
||||||
connOptions: connOptions,
|
connOptions: connOptions,
|
||||||
connIndex: connIndex,
|
connIndex: connIndex,
|
||||||
udpUnregisterTimeout: udpUnregisterTimeout,
|
udpUnregisterTimeout: udpUnregisterTimeout,
|
||||||
|
streamWriteTimeout: streamWriteTimeout,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||||
|
|
||||||
func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
||||||
ctx := quicStream.Context()
|
ctx := quicStream.Context()
|
||||||
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
|
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
|
||||||
|
@ -373,7 +376,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
|
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -35,6 +35,7 @@ var (
|
||||||
KeepAlivePeriod: 5 * time.Second,
|
KeepAlivePeriod: 5 * time.Second,
|
||||||
EnableDatagrams: true,
|
EnableDatagrams: true,
|
||||||
}
|
}
|
||||||
|
defaultQUICTimeout = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
|
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
|
||||||
|
@ -197,7 +198,7 @@ func quicServer(
|
||||||
|
|
||||||
quicStream, err := session.OpenStreamSync(context.Background())
|
quicStream, err := session.OpenStreamSync(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
stream := quicpogs.NewSafeStreamCloser(quicStream)
|
stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
|
||||||
|
|
||||||
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
|
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
|
||||||
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
|
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
|
||||||
|
@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
|
||||||
&log,
|
&log,
|
||||||
nil,
|
nil,
|
||||||
5*time.Second,
|
5*time.Second,
|
||||||
|
0*time.Second,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return qc
|
return qc
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import "github.com/cloudflare/cloudflared/logger"
|
||||||
|
|
||||||
|
var (
|
||||||
|
TestLogger = logger.Create(nil)
|
||||||
|
)
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
|
||||||
|
|
||||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||||
type tcpConnection struct {
|
type tcpConnection struct {
|
||||||
conn net.Conn
|
net.Conn
|
||||||
|
writeTimeout time.Duration
|
||||||
|
logger *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
|
||||||
stream.Pipe(tunnelConn, tc.conn, log)
|
stream.Pipe(tunnelConn, tc, tc.logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc *tcpConnection) Write(b []byte) (int, error) {
|
||||||
|
if tc.writeTimeout > 0 {
|
||||||
|
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
|
||||||
|
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nBytes, err := tc.Conn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
tc.logger.Err(err).Msg("Error writing to the TCP connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nBytes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *tcpConnection) Close() {
|
func (tc *tcpConnection) Close() {
|
||||||
tc.conn.Close()
|
tc.Conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
|
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
|
||||||
|
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
"github.com/cloudflare/cloudflared/stream"
|
"github.com/cloudflare/cloudflared/stream"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
@ -31,7 +30,6 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testLogger = logger.Create(nil)
|
|
||||||
testMessage = []byte("TestStreamOriginConnection")
|
testMessage = []byte("TestStreamOriginConnection")
|
||||||
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
|
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
|
||||||
)
|
)
|
||||||
|
@ -39,7 +37,8 @@ var (
|
||||||
func TestStreamTCPConnection(t *testing.T) {
|
func TestStreamTCPConnection(t *testing.T) {
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
tcpConn := tcpConnection{
|
tcpConn := tcpConnection{
|
||||||
conn: cfdConn,
|
Conn: cfdConn,
|
||||||
|
writeTimeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
eyeballConn, edgeConn := net.Pipe()
|
eyeballConn, edgeConn := net.Pipe()
|
||||||
|
@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
tcpConn.Stream(ctx, edgeConn, testLogger)
|
tcpConn.Stream(ctx, edgeConn, TestLogger)
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||||
|
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer wsForwarderInConn.Close()
|
defer wsForwarderInConn.Close()
|
||||||
|
|
||||||
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
|
||||||
originConn.Close()
|
originConn.Close()
|
||||||
}()
|
}()
|
||||||
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
|
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
|
||||||
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
|
tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
|
||||||
})
|
})
|
||||||
server := httptest.NewServer(handler)
|
server := httptest.NewServer(handler)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||||
|
@ -14,7 +16,7 @@ type HTTPOriginProxy interface {
|
||||||
|
|
||||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||||
type StreamBasedOriginProxy interface {
|
type StreamBasedOriginProxy interface {
|
||||||
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error)
|
EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
|
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
|
||||||
|
@ -62,19 +64,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
|
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
|
||||||
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
|
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
originConn := &tcpConnection{
|
originConn := &tcpConnection{
|
||||||
conn: conn,
|
Conn: conn,
|
||||||
|
writeTimeout: o.writeTimeout,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
return originConn, nil
|
return originConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
|
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
|
||||||
var err error
|
var err error
|
||||||
if !o.isBastion {
|
if !o.isBastion {
|
||||||
dest = o.dest
|
dest = o.dest
|
||||||
|
@ -92,6 +96,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) {
|
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
|
||||||
return o.conn, nil
|
return o.conn, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Origin not listening for new connection, should return an error
|
// Origin not listening for new connection, should return an error
|
||||||
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String())
|
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
t.Run(test.testCase, func(t *testing.T) {
|
t.Run(test.testCase, func(t *testing.T) {
|
||||||
if test.expectErr {
|
if test.expectErr {
|
||||||
bastionHost, _ := carrier.ResolveBastionDest(test.req)
|
bastionHost, _ := carrier.ResolveBastionDest(test.req)
|
||||||
_, err := test.service.EstablishConnection(context.Background(), bastionHost)
|
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||||
// Origin not listening for new connection, should return an error
|
// Origin not listening for new connection, should return an error
|
||||||
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
|
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
|
||||||
_, err := service.EstablishConnection(context.Background(), bastionHost)
|
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||||
url: originURL,
|
url: originURL,
|
||||||
}
|
}
|
||||||
shutdownC := make(chan struct{})
|
shutdownC := make(chan struct{})
|
||||||
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
|
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
|
||||||
url: originURL,
|
url: originURL,
|
||||||
}
|
}
|
||||||
shutdownC := make(chan struct{})
|
shutdownC := make(chan struct{})
|
||||||
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
|
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
|
||||||
|
|
||||||
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
|
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
|
||||||
protos := []string{"https", "http", "dne"}
|
protos := []string{"https", "http", "dne"}
|
||||||
|
|
|
@ -96,13 +96,15 @@ func (o httpService) MarshalJSON() ([]byte, error) {
|
||||||
type rawTCPService struct {
|
type rawTCPService struct {
|
||||||
name string
|
name string
|
||||||
dialer net.Dialer
|
dialer net.Dialer
|
||||||
|
writeTimeout time.Duration
|
||||||
|
logger *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawTCPService) String() string {
|
func (o *rawTCPService) String() string {
|
||||||
return o.name
|
return o.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
|
func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,13 +287,14 @@ type WarpRoutingService struct {
|
||||||
Proxy StreamBasedOriginProxy
|
Proxy StreamBasedOriginProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService {
|
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
|
||||||
svc := &rawTCPService{
|
svc := &rawTCPService{
|
||||||
name: ServiceWarpRouting,
|
name: ServiceWarpRouting,
|
||||||
dialer: net.Dialer{
|
dialer: net.Dialer{
|
||||||
Timeout: config.ConnectTimeout.Duration,
|
Timeout: config.ConnectTimeout.Duration,
|
||||||
KeepAlive: config.TCPKeepAlive.Duration,
|
KeepAlive: config.TCPKeepAlive.Duration,
|
||||||
},
|
},
|
||||||
|
writeTimeout: writeTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &WarpRoutingService{Proxy: svc}
|
return &WarpRoutingService{Proxy: svc}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package orchestration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
|
@ -21,6 +22,7 @@ type newLocalConfig struct {
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Ingress *ingress.Ingress
|
Ingress *ingress.Ingress
|
||||||
WarpRouting ingress.WarpRoutingConfig
|
WarpRouting ingress.WarpRoutingConfig
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
|
||||||
// Extra settings used to configure this instance but that are not eligible for remotely management
|
// Extra settings used to configure this instance but that are not eligible for remotely management
|
||||||
// ie. (--protocol, --loglevel, ...)
|
// ie. (--protocol, --loglevel, ...)
|
||||||
|
|
|
@ -17,10 +17,10 @@ import (
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Orchestrator manages configurations so they can be updatable during runtime
|
// Orchestrator manages configurations, so they can be updatable during runtime
|
||||||
// properties are static, so it can be read without lock
|
// properties are static, so it can be read without lock
|
||||||
// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex
|
// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex
|
||||||
// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently
|
// access to proxy is synchronized with atomic.Value, because it uses copy-on-write to provide scalable frequently
|
||||||
// read when update is infrequent
|
// read when update is infrequent
|
||||||
type Orchestrator struct {
|
type Orchestrator struct {
|
||||||
currentVersion int32
|
currentVersion int32
|
||||||
|
@ -30,6 +30,7 @@ type Orchestrator struct {
|
||||||
proxy atomic.Value
|
proxy atomic.Value
|
||||||
// Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules)
|
// Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules)
|
||||||
internalRules []ingress.Rule
|
internalRules []ingress.Rule
|
||||||
|
// cloudflared Configuration
|
||||||
config *Config
|
config *Config
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
|
@ -40,7 +41,11 @@ type Orchestrator struct {
|
||||||
proxyShutdownC chan<- struct{}
|
proxyShutdownC chan<- struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, internalRules []ingress.Rule, log *zerolog.Logger) (*Orchestrator, error) {
|
func NewOrchestrator(ctx context.Context,
|
||||||
|
config *Config,
|
||||||
|
tags []tunnelpogs.Tag,
|
||||||
|
internalRules []ingress.Rule,
|
||||||
|
log *zerolog.Logger) (*Orchestrator, error) {
|
||||||
o := &Orchestrator{
|
o := &Orchestrator{
|
||||||
// Lowest possible version, any remote configuration will have version higher than this
|
// Lowest possible version, any remote configuration will have version higher than this
|
||||||
// Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it
|
// Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it
|
||||||
|
@ -131,7 +136,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
|
||||||
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
|
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
|
||||||
return errors.Wrap(err, "failed to start origin")
|
return errors.Wrap(err, "failed to start origin")
|
||||||
}
|
}
|
||||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.log)
|
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
|
||||||
o.proxy.Store(proxy)
|
o.proxy.Store(proxy)
|
||||||
o.config.Ingress = &ingressRules
|
o.config.Ingress = &ingressRules
|
||||||
o.config.WarpRouting = warpRouting
|
o.config.WarpRouting = warpRouting
|
||||||
|
|
|
@ -27,7 +27,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testLogger = zerolog.Logger{}
|
testLogger = zerolog.Nop()
|
||||||
testTags = []tunnelpogs.Tag{
|
testTags = []tunnelpogs.Tag{
|
||||||
{
|
{
|
||||||
Name: "package",
|
Name: "package",
|
||||||
|
|
|
@ -51,6 +51,7 @@ func NewOriginProxy(
|
||||||
ingressRules ingress.Ingress,
|
ingressRules ingress.Ingress,
|
||||||
warpRouting ingress.WarpRoutingConfig,
|
warpRouting ingress.WarpRoutingConfig,
|
||||||
tags []tunnelpogs.Tag,
|
tags []tunnelpogs.Tag,
|
||||||
|
writeTimeout time.Duration,
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
) *Proxy {
|
) *Proxy {
|
||||||
proxy := &Proxy{
|
proxy := &Proxy{
|
||||||
|
@ -59,7 +60,7 @@ func NewOriginProxy(
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting)
|
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
|
||||||
|
|
||||||
return proxy
|
return proxy
|
||||||
}
|
}
|
||||||
|
@ -309,7 +310,7 @@ func (p *Proxy) proxyStream(
|
||||||
_, connectSpan := tr.Tracer().Start(ctx, "stream-connect")
|
_, connectSpan := tr.Tracer().Start(ctx, "stream-connect")
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
originConn, err := connectionProxy.EstablishConnection(ctx, dest)
|
originConn, err := connectionProxy.EstablishConnection(ctx, dest, &logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connectStreamErrors.Inc()
|
connectStreamErrors.Inc()
|
||||||
tracing.EndWithErrorStatus(connectSpan, err)
|
tracing.EndWithErrorStatus(connectSpan, err)
|
||||||
|
|
|
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||||
|
|
||||||
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, &log)
|
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
|
||||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||||
t.Run("testProxySSE", testProxySSE(proxy))
|
t.Run("testProxySSE", testProxySSE(proxy))
|
||||||
|
@ -366,7 +366,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, &log)
|
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
|
@ -414,7 +414,7 @@ func TestProxyError(t *testing.T) {
|
||||||
|
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
|
||||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, &log)
|
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
|
||||||
|
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
|
@ -530,7 +530,7 @@ func TestConnections(t *testing.T) {
|
||||||
originService: runEchoTCPService,
|
originService: runEchoTCPService,
|
||||||
eyeballResponseWriter: newTCPRespWriter(replayer),
|
eyeballResponseWriter: newTCPRespWriter(replayer),
|
||||||
eyeballRequestBody: newTCPRequestBody([]byte("test2")),
|
eyeballRequestBody: newTCPRequestBody([]byte("test2")),
|
||||||
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting),
|
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
|
||||||
connectionType: connection.TypeTCP,
|
connectionType: connection.TypeTCP,
|
||||||
requestHeaders: map[string][]string{
|
requestHeaders: map[string][]string{
|
||||||
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||||
|
@ -548,7 +548,7 @@ func TestConnections(t *testing.T) {
|
||||||
originService: runEchoWSService,
|
originService: runEchoWSService,
|
||||||
// eyeballResponseWriter gets set after roundtrip dial.
|
// eyeballResponseWriter gets set after roundtrip dial.
|
||||||
eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
|
eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
|
||||||
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting),
|
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
|
||||||
requestHeaders: map[string][]string{
|
requestHeaders: map[string][]string{
|
||||||
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||||
},
|
},
|
||||||
|
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
|
||||||
|
|
||||||
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
||||||
ingressRule.StartOrigins(logger, ctx.Done())
|
ingressRule.StartOrigins(logger, ctx.Done())
|
||||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, logger)
|
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
|
||||||
proxy.warpRouting = test.args.warpRoutingService
|
proxy.warpRouting = test.args.warpRoutingService
|
||||||
|
|
||||||
dest := ln.Addr().String()
|
dest := ln.Addr().String()
|
||||||
|
|
|
@ -1,20 +1,28 @@
|
||||||
package quic
|
package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SafeStreamCloser struct {
|
type SafeStreamCloser struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
stream quic.Stream
|
stream quic.Stream
|
||||||
|
writeTimeout time.Duration
|
||||||
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser {
|
func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
|
||||||
return &SafeStreamCloser{
|
return &SafeStreamCloser{
|
||||||
stream: stream,
|
stream: stream,
|
||||||
|
writeTimeout: writeTimeout,
|
||||||
|
log: log,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +33,29 @@ func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
|
||||||
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
|
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
|
||||||
s.lock.Lock()
|
s.lock.Lock()
|
||||||
defer s.lock.Unlock()
|
defer s.lock.Unlock()
|
||||||
return s.stream.Write(p)
|
if s.writeTimeout > 0 {
|
||||||
|
err = s.stream.SetWriteDeadline(time.Now().Add(s.writeTimeout))
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Msg("Error setting write deadline for QUIC stream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nBytes, err := s.stream.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
s.handleTimeout(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nBytes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles the timeout error in case it happened, by canceling the stream write.
|
||||||
|
func (s *SafeStreamCloser) handleTimeout(err error) {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
if netErr.Timeout() {
|
||||||
|
s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing")
|
||||||
|
s.stream.CancelWrite(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SafeStreamCloser) Close() error {
|
func (s *SafeStreamCloser) Close() error {
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
"github.com/quic-go/quic-go"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -70,7 +72,8 @@ func quicClient(t *testing.T, addr net.Addr) {
|
||||||
go func(iter int) {
|
go func(iter int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
stream := NewSafeStreamCloser(quicStream)
|
log := zerolog.Nop()
|
||||||
|
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
// Do a bunch of round trips over this stream that should work.
|
// Do a bunch of round trips over this stream that should work.
|
||||||
|
@ -107,7 +110,8 @@ func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn)
|
||||||
go func(iter int) {
|
go func(iter int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
stream := NewSafeStreamCloser(quicStream)
|
log := zerolog.Nop()
|
||||||
|
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
// Do a bunch of round trips over this stream that should work.
|
// Do a bunch of round trips over this stream that should work.
|
||||||
|
|
|
@ -66,6 +66,7 @@ type TunnelConfig struct {
|
||||||
PacketConfig *ingress.GlobalRouterConfig
|
PacketConfig *ingress.GlobalRouterConfig
|
||||||
|
|
||||||
UDPUnregisterSessionTimeout time.Duration
|
UDPUnregisterSessionTimeout time.Duration
|
||||||
|
WriteStreamTimeout time.Duration
|
||||||
|
|
||||||
DisableQUICPathMTUDiscovery bool
|
DisableQUICPathMTUDiscovery bool
|
||||||
|
|
||||||
|
@ -614,6 +615,7 @@ func (e *EdgeTunnelServer) serveQUIC(
|
||||||
connLogger.Logger(),
|
connLogger.Logger(),
|
||||||
e.config.PacketConfig,
|
e.config.PacketConfig,
|
||||||
e.config.UDPUnregisterSessionTimeout,
|
e.config.UDPUnregisterSessionTimeout,
|
||||||
|
e.config.WriteStreamTimeout,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
|
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
|
||||||
|
|
Loading…
Reference in New Issue