TUN-7480: Added a timeout for unregisterUDP.

I deliberately kept this as an unregistertimeout because that was the
intent. In the future we could change this to a UDPConnConfig if we want
to pass multiple values here.

The idea of this PR is simply to add a configurable unregister UDP
timeout.
This commit is contained in:
Sudarsan Reddy 2023-06-19 17:03:11 +01:00
parent a3bcf25fae
commit 1abd22ef0a
7 changed files with 96 additions and 15 deletions

View File

@ -79,6 +79,9 @@ const (
// hostKeyPath is the path of the dir to save SSH host keys too // hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path" hostKeyPath = "host-key-path"
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
// uiFlag is to enable launching cloudflared in interactive UI mode // uiFlag is to enable launching cloudflared in interactive UI mode
uiFlag = "ui" uiFlag = "ui"
@ -683,6 +686,11 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: 4, Value: 4,
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: udpUnregisterSessionTimeoutFlag,
Value: 5 * time.Second,
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: connectorLabelFlag, Name: connectorLabelFlag,
Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.", Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.",

View File

@ -231,14 +231,15 @@ func prepareTunnelConfig(
Observer: observer, Observer: observer,
ReportedVersion: info.Version(), ReportedVersion: info.Version(),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc // Note TUN-3758 , we use Int because UInt is not supported with altsrc
Retries: uint(c.Int("retries")), Retries: uint(c.Int("retries")),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
NamedTunnel: namedTunnel, NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector, ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs, EdgeTLSConfigs: edgeTLSConfigs,
NeedPQ: needPQ, NeedPQ: needPQ,
PQKexIdx: pqKexIdx, PQKexIdx: pqKexIdx,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
} }
packetConfig, err := newPacketConfig(c, log) packetConfig, err := newPacketConfig(c, log)
if err != nil { if err != nil {

View File

@ -63,6 +63,8 @@ type QUICConnection struct {
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
connIndex uint8 connIndex uint8
udpUnregisterTimeout time.Duration
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -78,6 +80,7 @@ func NewQUICConnection(
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger, logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig, packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger) udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil { if err != nil {
@ -112,6 +115,7 @@ func NewQUICConnection(
controlStreamHandler: controlStreamHandler, controlStreamHandler: controlStreamHandler,
connOptions: connOptions, connOptions: connOptions,
connIndex: connIndex, connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout,
}, nil }, nil
} }
@ -370,7 +374,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream)
defer stream.Close() defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger) rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
if err != nil { if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection // Log this at debug because this is not an error if session was closed due to lost connection
// with edge // with edge

View File

@ -725,6 +725,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
fakeControlStream{}, fakeControlStream{},
&log, &log,
nil, nil,
5*time.Second,
) )
require.NoError(t, err) require.NoError(t, err)
return qc return qc

View File

@ -229,9 +229,12 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error {
type RPCClientStream struct { type RPCClientStream struct {
client tunnelpogs.CloudflaredServer_PogsClient client tunnelpogs.CloudflaredServer_PogsClient
transport rpc.Transport transport rpc.Transport
// Time we wait for the server to respond to a request before we close the connection.
rpcUnregisterUDPSessionDeadline time.Duration
} }
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) { func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, rpcUnregisterUDPSessionDeadline time.Duration, logger *zerolog.Logger) (*RPCClientStream, error) {
n, err := stream.Write(RPCStreamProtocolSignature[:]) n, err := stream.Write(RPCStreamProtocolSignature[:])
if err != nil { if err != nil {
return nil, err return nil, err
@ -245,8 +248,9 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
tunnelrpc.ConnLog(logger), tunnelrpc.ConnLog(logger),
) )
return &RPCClientStream{ return &RPCClientStream{
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn), client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
transport: transport, transport: transport,
rpcUnregisterUDPSessionDeadline: rpcUnregisterUDPSessionDeadline,
}, nil }, nil
} }
@ -255,6 +259,8 @@ func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uu
} }
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
ctx, cancel := context.WithTimeout(ctx, rcs.rpcUnregisterUDPSessionDeadline)
defer cancel()
return rcs.client.UnregisterUdpSession(ctx, sessionID, message) return rcs.client.UnregisterUdpSession(ctx, sessionID, message)
} }

View File

@ -109,6 +109,63 @@ func TestConnectResponseMeta(t *testing.T) {
} }
} }
func TestUnregisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
var tests = []struct {
name string
sessionRPCServer mockSessionRPCServer
timeout time.Duration
}{
{
name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
},
// very very low value so we trigger the timeout every time.
timeout: time.Nanosecond * 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := zerolog.Nop()
clientStream, serverStream := newMockRPCStreams()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
})
}
}
func TestRegisterUdpSession(t *testing.T) { func TestRegisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball" unregisterMessage := "closed by eyeball"
@ -157,7 +214,7 @@ func TestRegisterUdpSession(t *testing.T) {
close(sessionRegisteredChan) close(sessionRegisteredChan)
}() }()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger) rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, 5*time.Second, &logger)
assert.NoError(t, err) assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext) reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
@ -208,7 +265,7 @@ func TestManageConfiguration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger) rpcClientStream, err := NewRPCClientStream(ctx, clientStream, 5*time.Second, &logger)
assert.NoError(t, err) assert.NoError(t, err)
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config) result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)

View File

@ -68,6 +68,8 @@ type TunnelConfig struct {
ProtocolSelector connection.ProtocolSelector ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config EdgeTLSConfigs map[connection.Protocol]*tls.Config
PacketConfig *ingress.GlobalRouterConfig PacketConfig *ingress.GlobalRouterConfig
UDPUnregisterSessionTimeout time.Duration
} }
func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
@ -615,7 +617,9 @@ func (e *EdgeTunnelServer) serveQUIC(
connOptions, connOptions,
controlStreamHandler, controlStreamHandler,
connLogger.Logger(), connLogger.Logger(),
e.config.PacketConfig) e.config.PacketConfig,
e.config.UDPUnregisterSessionTimeout,
)
if err != nil { if err != nil {
if e.config.NeedPQ { if e.config.NeedPQ {
handlePQTunnelError(err, e.config) handlePQTunnelError(err, e.config)