diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index e5b0ebb5..8c98032a 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -81,6 +81,9 @@ const ( // udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout" + // writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin). + writeStreamTimeout = "write-stream-timeout" + // quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size. // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. // Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets. @@ -696,6 +699,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Value: 5 * time.Second, Hidden: true, }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: writeStreamTimeout, + EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"}, + Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.", + Value: 0 * time.Second, + Hidden: true, + }), altsrc.NewBoolFlag(&cli.BoolFlag{ Name: quicDisablePathMTUDiscovery, EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"}, diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index e6c05d67..2f410a2c 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -247,6 +247,7 @@ func prepareTunnelConfig( FeatureSelector: featureSelector, MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag), + WriteStreamTimeout: c.Duration(writeStreamTimeout), DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery), } packetConfig, err := newPacketConfig(c, log) @@ -259,6 +260,7 @@ func prepareTunnelConfig( Ingress: &ingressRules, WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), ConfigurationFlags: parseConfigFlags(c), + WriteTimeout: c.Duration(writeStreamTimeout), } return tunnelConfig, orchestratorConfig, nil } diff --git a/connection/quic.go b/connection/quic.go index 3d61d93a..28186dc7 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -66,6 +66,7 @@ type QUICConnection struct { connIndex uint8 udpUnregisterTimeout time.Duration + streamWriteTimeout time.Duration } // NewQUICConnection returns a new instance of QUICConnection. @@ -82,6 +83,7 @@ func NewQUICConnection( logger *zerolog.Logger, packetRouterConfig *ingress.GlobalRouterConfig, udpUnregisterTimeout time.Duration, + streamWriteTimeout time.Duration, ) (*QUICConnection, error) { udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger) if err != nil { @@ -117,6 +119,7 @@ func NewQUICConnection( connOptions: connOptions, connIndex: connIndex, udpUnregisterTimeout: udpUnregisterTimeout, + streamWriteTimeout: streamWriteTimeout, }, nil } @@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error { func (q *QUICConnection) runStream(quicStream quic.Stream) { ctx := quicStream.Context() - stream := quicpogs.NewSafeStreamCloser(quicStream) + stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) defer stream.Close() // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that @@ -373,7 +376,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI return } - stream := quicpogs.NewSafeStreamCloser(quicStream) + stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) defer stream.Close() rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger) if err != nil { diff --git a/connection/quic_test.go b/connection/quic_test.go index 44a6034d..5d06e8ee 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -35,6 +35,7 @@ var ( KeepAlivePeriod: 5 * time.Second, EnableDatagrams: true, } + defaultQUICTimeout = 30 * time.Second ) var _ ReadWriteAcker = (*streamReadWriteAcker)(nil) @@ -197,7 +198,7 @@ func quicServer( quicStream, err := session.OpenStreamSync(context.Background()) require.NoError(t, err) - stream := quicpogs.NewSafeStreamCloser(quicStream) + stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log) reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream} err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...) @@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU &log, nil, 5*time.Second, + 0*time.Second, ) require.NoError(t, err) return qc diff --git a/ingress/constants_test.go b/ingress/constants_test.go new file mode 100644 index 00000000..e2f62f0e --- /dev/null +++ b/ingress/constants_test.go @@ -0,0 +1,7 @@ +package ingress + +import "github.com/cloudflare/cloudflared/logger" + +var ( + TestLogger = logger.Create(nil) +) diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index fbc4df39..f7e08004 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "time" "github.com/rs/zerolog" @@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze // tcpConnection is an OriginConnection that directly streams to raw TCP. type tcpConnection struct { - conn net.Conn + net.Conn + writeTimeout time.Duration + logger *zerolog.Logger } -func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - stream.Pipe(tunnelConn, tc.conn, log) +func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) { + stream.Pipe(tunnelConn, tc, tc.logger) +} + +func (tc *tcpConnection) Write(b []byte) (int, error) { + if tc.writeTimeout > 0 { + if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil { + tc.logger.Err(err).Msg("Error setting write deadline for TCP connection") + } + } + + nBytes, err := tc.Conn.Write(b) + if err != nil { + tc.logger.Err(err).Msg("Error writing to the TCP connection") + } + + return nBytes, err } func (tc *tcpConnection) Close() { - tc.conn.Close() + tc.Conn.Close() } // tcpOverWSConnection is an OriginConnection that streams to TCP over WS. diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go index e47d38ce..40611de7 100644 --- a/ingress/origin_connection_test.go +++ b/ingress/origin_connection_test.go @@ -19,7 +19,6 @@ import ( "golang.org/x/net/proxy" "golang.org/x/sync/errgroup" - "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/websocket" @@ -31,7 +30,6 @@ const ( ) var ( - testLogger = logger.Create(nil) testMessage = []byte("TestStreamOriginConnection") testResponse = []byte(fmt.Sprintf("echo-%s", testMessage)) ) @@ -39,7 +37,8 @@ var ( func TestStreamTCPConnection(t *testing.T) { cfdConn, originConn := net.Pipe() tcpConn := tcpConnection{ - conn: cfdConn, + Conn: cfdConn, + writeTimeout: 30 * time.Second, } eyeballConn, edgeConn := net.Pipe() @@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) { return nil }) - tcpConn.Stream(ctx, edgeConn, testLogger) + tcpConn.Stream(ctx, edgeConn, TestLogger) require.NoError(t, errGroup.Wait()) } @@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) { return nil }) - tcpOverWSConn.Stream(ctx, edgeConn, testLogger) + tcpOverWSConn.Stream(ctx, edgeConn, TestLogger) require.NoError(t, errGroup.Wait()) } @@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) { errGroup, ctx := errgroup.WithContext(ctx) errGroup.Go(func() error { - tcpOverWSConn.Stream(ctx, edgeConn, testLogger) + tcpOverWSConn.Stream(ctx, edgeConn, TestLogger) return nil }) @@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) { require.NoError(t, err) defer wsForwarderInConn.Close() - stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger) + stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger) return nil }) @@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) { originConn.Close() }() ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond) - tcpOverWSConn.Stream(ctx, eyeballConn, testLogger) + tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger) }) server := httptest.NewServer(handler) defer server.Close() diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 40a87811..186ddff1 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net/http" + + "github.com/rs/zerolog" ) // HTTPOriginProxy can be implemented by origin services that want to proxy http requests. @@ -14,7 +16,7 @@ type HTTPOriginProxy interface { // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. type StreamBasedOriginProxy interface { - EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) + EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error) } // HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests. @@ -62,19 +64,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { return resp, nil } -func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { +func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) { conn, err := o.dialer.DialContext(ctx, "tcp", dest) if err != nil { return nil, err } originConn := &tcpConnection{ - conn: conn, + Conn: conn, + writeTimeout: o.writeTimeout, + logger: logger, } return originConn, nil } -func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { +func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) { var err error if !o.isBastion { dest = o.dest @@ -92,6 +96,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) } -func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) { +func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) { return o.conn, nil } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index fb84c837..7a6170a2 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { require.NoError(t, err) // Origin not listening for new connection, should return an error - _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String()) + _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger) require.Error(t, err) } @@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { t.Run(test.testCase, func(t *testing.T) { if test.expectErr { bastionHost, _ := carrier.ResolveBastionDest(test.req) - _, err := test.service.EstablishConnection(context.Background(), bastionHost) + _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } }) @@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { // Origin not listening for new connection, should return an error bastionHost, _ := carrier.ResolveBastionDest(bastionReq) - _, err := service.EstablishConnection(context.Background(), bastionHost) + _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } } @@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { url: originURL, } shutdownC := make(chan struct{}) - require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) + require.NoError(t, httpService.start(TestLogger, shutdownC, cfg)) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) require.NoError(t, err) @@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { url: originURL, } shutdownC := make(chan struct{}) - require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) + require.NoError(t, httpService.start(TestLogger, shutdownC, cfg)) // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header protos := []string{"https", "http", "dne"} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index f7bcc297..ec7ccb48 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -94,15 +94,17 @@ func (o httpService) MarshalJSON() ([]byte, error) { // rawTCPService dials TCP to the destination specified by the client // It's used by warp routing type rawTCPService struct { - name string - dialer net.Dialer + name string + dialer net.Dialer + writeTimeout time.Duration + logger *zerolog.Logger } func (o *rawTCPService) String() string { return o.name } -func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { +func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error { return nil } @@ -285,13 +287,14 @@ type WarpRoutingService struct { Proxy StreamBasedOriginProxy } -func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService { +func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService { svc := &rawTCPService{ name: ServiceWarpRouting, dialer: net.Dialer{ Timeout: config.ConnectTimeout.Duration, KeepAlive: config.TCPKeepAlive.Duration, }, + writeTimeout: writeTimeout, } return &WarpRoutingService{Proxy: svc} diff --git a/orchestration/config.go b/orchestration/config.go index 26904b57..04c7a0ab 100644 --- a/orchestration/config.go +++ b/orchestration/config.go @@ -2,6 +2,7 @@ package orchestration import ( "encoding/json" + "time" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/ingress" @@ -19,8 +20,9 @@ type newLocalConfig struct { // Config is the original config as read and parsed by cloudflared. type Config struct { - Ingress *ingress.Ingress - WarpRouting ingress.WarpRoutingConfig + Ingress *ingress.Ingress + WarpRouting ingress.WarpRoutingConfig + WriteTimeout time.Duration // Extra settings used to configure this instance but that are not eligible for remotely management // ie. (--protocol, --loglevel, ...) diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index 3ff763a2..93afef31 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -17,10 +17,10 @@ import ( tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -// Orchestrator manages configurations so they can be updatable during runtime +// Orchestrator manages configurations, so they can be updatable during runtime // properties are static, so it can be read without lock // currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex -// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently +// access to proxy is synchronized with atomic.Value, because it uses copy-on-write to provide scalable frequently // read when update is infrequent type Orchestrator struct { currentVersion int32 @@ -30,9 +30,10 @@ type Orchestrator struct { proxy atomic.Value // Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules) internalRules []ingress.Rule - config *Config - tags []tunnelpogs.Tag - log *zerolog.Logger + // cloudflared Configuration + config *Config + tags []tunnelpogs.Tag + log *zerolog.Logger // orchestrator must not handle any more updates after shutdownC is closed shutdownC <-chan struct{} @@ -40,7 +41,11 @@ type Orchestrator struct { proxyShutdownC chan<- struct{} } -func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, internalRules []ingress.Rule, log *zerolog.Logger) (*Orchestrator, error) { +func NewOrchestrator(ctx context.Context, + config *Config, + tags []tunnelpogs.Tag, + internalRules []ingress.Rule, + log *zerolog.Logger) (*Orchestrator, error) { o := &Orchestrator{ // Lowest possible version, any remote configuration will have version higher than this // Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it @@ -131,7 +136,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { return errors.Wrap(err, "failed to start origin") } - proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.log) + proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log) o.proxy.Store(proxy) o.config.Ingress = &ingressRules o.config.WarpRouting = warpRouting diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index 0cfd3176..233367c9 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -27,7 +27,7 @@ import ( ) var ( - testLogger = zerolog.Logger{} + testLogger = zerolog.Nop() testTags = []tunnelpogs.Tag{ { Name: "package", diff --git a/proxy/proxy.go b/proxy/proxy.go index 73f4376c..fea2e3ee 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -51,6 +51,7 @@ func NewOriginProxy( ingressRules ingress.Ingress, warpRouting ingress.WarpRoutingConfig, tags []tunnelpogs.Tag, + writeTimeout time.Duration, log *zerolog.Logger, ) *Proxy { proxy := &Proxy{ @@ -59,7 +60,7 @@ func NewOriginProxy( log: log, } - proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting) + proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout) return proxy } @@ -309,7 +310,7 @@ func (p *Proxy) proxyStream( _, connectSpan := tr.Tracer().Start(ctx, "stream-connect") start := time.Now() - originConn, err := connectionProxy.EstablishConnection(ctx, dest) + originConn, err := connectionProxy.EstablishConnection(ctx, dest, &logger) if err != nil { connectStreamErrors.Inc() tracing.EndWithErrorStatus(connectSpan, err) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 49a48f16..ed7a1439 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) { require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, &log) + proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) @@ -366,7 +366,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingress, noWarpRouting, testTags, &log) + proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -414,7 +414,7 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, noWarpRouting, testTags, &log) + proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -530,7 +530,7 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballResponseWriter: newTCPRespWriter(replayer), eyeballRequestBody: newTCPRequestBody([]byte("test2")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), + warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), connectionType: connection.TypeTCP, requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, @@ -548,7 +548,7 @@ func TestConnections(t *testing.T) { originService: runEchoWSService, // eyeballResponseWriter gets set after roundtrip dial. eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), + warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, }, @@ -675,7 +675,7 @@ func TestConnections(t *testing.T) { ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) ingressRule.StartOrigins(logger, ctx.Done()) - proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, logger) + proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger) proxy.warpRouting = test.args.warpRoutingService dest := ln.Addr().String() diff --git a/quic/safe_stream.go b/quic/safe_stream.go index 6fa98259..1351a0c3 100644 --- a/quic/safe_stream.go +++ b/quic/safe_stream.go @@ -1,20 +1,28 @@ package quic import ( + "errors" + "net" "sync" "time" "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" ) type SafeStreamCloser struct { - lock sync.Mutex - stream quic.Stream + lock sync.Mutex + stream quic.Stream + writeTimeout time.Duration + log *zerolog.Logger } -func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser { +func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser { return &SafeStreamCloser{ - stream: stream, + stream: stream, + writeTimeout: writeTimeout, + log: log, } } @@ -25,7 +33,29 @@ func (s *SafeStreamCloser) Read(p []byte) (n int, err error) { func (s *SafeStreamCloser) Write(p []byte) (n int, err error) { s.lock.Lock() defer s.lock.Unlock() - return s.stream.Write(p) + if s.writeTimeout > 0 { + err = s.stream.SetWriteDeadline(time.Now().Add(s.writeTimeout)) + if err != nil { + log.Err(err).Msg("Error setting write deadline for QUIC stream") + } + } + nBytes, err := s.stream.Write(p) + if err != nil { + s.handleTimeout(err) + } + + return nBytes, err +} + +// Handles the timeout error in case it happened, by canceling the stream write. +func (s *SafeStreamCloser) handleTimeout(err error) { + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing") + s.stream.CancelWrite(0) + } + } } func (s *SafeStreamCloser) Close() error { diff --git a/quic/safe_stream_test.go b/quic/safe_stream_test.go index 7cfecc84..bae708f3 100644 --- a/quic/safe_stream_test.go +++ b/quic/safe_stream_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/rs/zerolog" + "github.com/quic-go/quic-go" "github.com/stretchr/testify/require" ) @@ -70,7 +72,8 @@ func quicClient(t *testing.T, addr net.Addr) { go func(iter int) { defer wg.Done() - stream := NewSafeStreamCloser(quicStream) + log := zerolog.Nop() + stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log) defer stream.Close() // Do a bunch of round trips over this stream that should work. @@ -107,7 +110,8 @@ func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn) go func(iter int) { defer wg.Done() - stream := NewSafeStreamCloser(quicStream) + log := zerolog.Nop() + stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log) defer stream.Close() // Do a bunch of round trips over this stream that should work. diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 8e9afbd3..a64486cf 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -66,6 +66,7 @@ type TunnelConfig struct { PacketConfig *ingress.GlobalRouterConfig UDPUnregisterSessionTimeout time.Duration + WriteStreamTimeout time.Duration DisableQUICPathMTUDiscovery bool @@ -614,6 +615,7 @@ func (e *EdgeTunnelServer) serveQUIC( connLogger.Logger(), e.config.PacketConfig, e.config.UDPUnregisterSessionTimeout, + e.config.WriteStreamTimeout, ) if err != nil { connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")