From c6b4bac76f0f4f14a2a11b0d92518b6f9abe968f Mon Sep 17 00:00:00 2001 From: Sudarsan Reddy Date: Mon, 19 Jun 2023 17:03:11 +0100 Subject: [PATCH] TUN-7480: Added a timeout for unregisterUDP. I deliberately kept this as an unregistertimeout because that was the intent. In the future we could change this to a UDPConnConfig if we want to pass multiple values here. The idea of this PR is simply to add a configurable unregister UDP timeout. --- cmd/cloudflared/tunnel/cmd.go | 8 ++++ cmd/cloudflared/tunnel/configuration.go | 17 +++---- connection/quic.go | 6 ++- connection/quic_test.go | 1 + quic/quic_protocol.go | 12 +++-- quic/quic_protocol_test.go | 61 ++++++++++++++++++++++++- supervisor/tunnel.go | 6 ++- 7 files changed, 96 insertions(+), 15 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 263283ef..2a32dbac 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -79,6 +79,9 @@ const ( // hostKeyPath is the path of the dir to save SSH host keys too hostKeyPath = "host-key-path" + // udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge + udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout" + // uiFlag is to enable launching cloudflared in interactive UI mode uiFlag = "ui" @@ -683,6 +686,11 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Value: 4, Hidden: true, }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: udpUnregisterSessionTimeoutFlag, + Value: 5 * time.Second, + Hidden: true, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: connectorLabelFlag, Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.", diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 89b76392..56e5c12c 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -231,14 +231,15 @@ func prepareTunnelConfig( Observer: observer, ReportedVersion: info.Version(), // Note TUN-3758 , we use Int because UInt is not supported with altsrc - Retries: uint(c.Int("retries")), - RunFromTerminal: isRunningFromTerminal(), - NamedTunnel: namedTunnel, - ProtocolSelector: protocolSelector, - EdgeTLSConfigs: edgeTLSConfigs, - NeedPQ: needPQ, - PQKexIdx: pqKexIdx, - MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), + Retries: uint(c.Int("retries")), + RunFromTerminal: isRunningFromTerminal(), + NamedTunnel: namedTunnel, + ProtocolSelector: protocolSelector, + EdgeTLSConfigs: edgeTLSConfigs, + NeedPQ: needPQ, + PQKexIdx: pqKexIdx, + MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), + UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag), } packetConfig, err := newPacketConfig(c, log) if err != nil { diff --git a/connection/quic.go b/connection/quic.go index 3861fced..c1ccf2f0 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -63,6 +63,8 @@ type QUICConnection struct { controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions connIndex uint8 + + udpUnregisterTimeout time.Duration } // NewQUICConnection returns a new instance of QUICConnection. @@ -78,6 +80,7 @@ func NewQUICConnection( controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, packetRouterConfig *ingress.GlobalRouterConfig, + udpUnregisterTimeout time.Duration, ) (*QUICConnection, error) { udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger) if err != nil { @@ -112,6 +115,7 @@ func NewQUICConnection( controlStreamHandler: controlStreamHandler, connOptions: connOptions, connIndex: connIndex, + udpUnregisterTimeout: udpUnregisterTimeout, }, nil } @@ -370,7 +374,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI stream := quicpogs.NewSafeStreamCloser(quicStream) defer stream.Close() - rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger) + rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger) if err != nil { // Log this at debug because this is not an error if session was closed due to lost connection // with edge diff --git a/connection/quic_test.go b/connection/quic_test.go index ce8bd371..79f3cd73 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -725,6 +725,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU fakeControlStream{}, &log, nil, + 5*time.Second, ) require.NoError(t, err) return qc diff --git a/quic/quic_protocol.go b/quic/quic_protocol.go index 7b23fcbb..e9cf47cd 100644 --- a/quic/quic_protocol.go +++ b/quic/quic_protocol.go @@ -229,9 +229,12 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error { type RPCClientStream struct { client tunnelpogs.CloudflaredServer_PogsClient transport rpc.Transport + + // Time we wait for the server to respond to a request before we close the connection. + rpcUnregisterUDPSessionDeadline time.Duration } -func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) { +func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, rpcUnregisterUDPSessionDeadline time.Duration, logger *zerolog.Logger) (*RPCClientStream, error) { n, err := stream.Write(RPCStreamProtocolSignature[:]) if err != nil { return nil, err @@ -245,8 +248,9 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger * tunnelrpc.ConnLog(logger), ) return &RPCClientStream{ - client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn), - transport: transport, + client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn), + transport: transport, + rpcUnregisterUDPSessionDeadline: rpcUnregisterUDPSessionDeadline, }, nil } @@ -255,6 +259,8 @@ func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uu } func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { + ctx, cancel := context.WithTimeout(ctx, rcs.rpcUnregisterUDPSessionDeadline) + defer cancel() return rcs.client.UnregisterUdpSession(ctx, sessionID, message) } diff --git a/quic/quic_protocol_test.go b/quic/quic_protocol_test.go index 64943298..ab6d7c1e 100644 --- a/quic/quic_protocol_test.go +++ b/quic/quic_protocol_test.go @@ -109,6 +109,63 @@ func TestConnectResponseMeta(t *testing.T) { } } +func TestUnregisterUdpSession(t *testing.T) { + unregisterMessage := "closed by eyeball" + + var tests = []struct { + name string + sessionRPCServer mockSessionRPCServer + timeout time.Duration + }{ + + { + name: "UnregisterUdpSessionTimesout if the RPC server does not respond", + sessionRPCServer: mockSessionRPCServer{ + sessionID: uuid.New(), + dstIP: net.IP{172, 16, 0, 1}, + dstPort: 8000, + closeIdleAfter: testCloseIdleAfterHint, + unregisterMessage: unregisterMessage, + traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1", + }, + // very very low value so we trigger the timeout every time. + timeout: time.Nanosecond * 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + logger := zerolog.Nop() + clientStream, serverStream := newMockRPCStreams() + sessionRegisteredChan := make(chan struct{}) + go func() { + protocol, err := DetermineProtocol(serverStream) + assert.NoError(t, err) + rpcServerStream, err := NewRPCServerStream(serverStream, protocol) + assert.NoError(t, err) + err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger) + assert.NoError(t, err) + + serverStream.Close() + close(sessionRegisteredChan) + }() + + rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger) + assert.NoError(t, err) + + reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext) + assert.NoError(t, err) + assert.NoError(t, reg.Err) + + assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage)) + + rpcClientStream.Close() + <-sessionRegisteredChan + }) + } + +} + func TestRegisterUdpSession(t *testing.T) { unregisterMessage := "closed by eyeball" @@ -157,7 +214,7 @@ func TestRegisterUdpSession(t *testing.T) { close(sessionRegisteredChan) }() - rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger) + rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, 5*time.Second, &logger) assert.NoError(t, err) reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext) @@ -208,7 +265,7 @@ func TestManageConfiguration(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger) + rpcClientStream, err := NewRPCClientStream(ctx, clientStream, 5*time.Second, &logger) assert.NoError(t, err) result, err := rpcClientStream.UpdateConfiguration(ctx, version, config) diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 24f909ef..9052398d 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -68,6 +68,8 @@ type TunnelConfig struct { ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config PacketConfig *ingress.GlobalRouterConfig + + UDPUnregisterSessionTimeout time.Duration } func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { @@ -615,7 +617,9 @@ func (e *EdgeTunnelServer) serveQUIC( connOptions, controlStreamHandler, connLogger.Logger(), - e.config.PacketConfig) + e.config.PacketConfig, + e.config.UDPUnregisterSessionTimeout, + ) if err != nil { if e.config.NeedPQ { handlePQTunnelError(err, e.config)