TUN-7480: Added a timeout for unregisterUDP.

I deliberately kept this as an unregistertimeout because that was the
intent. In the future we could change this to a UDPConnConfig if we want
to pass multiple values here.

The idea of this PR is simply to add a configurable unregister UDP
timeout.
This commit is contained in:
Sudarsan Reddy 2023-06-19 17:03:11 +01:00 committed by Jean Khawand
parent 136f232c00
commit c6b4bac76f
7 changed files with 96 additions and 15 deletions

View File

@ -79,6 +79,9 @@ const (
// hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path"
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
// uiFlag is to enable launching cloudflared in interactive UI mode
uiFlag = "ui"
@ -683,6 +686,11 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: 4,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: udpUnregisterSessionTimeoutFlag,
Value: 5 * time.Second,
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: connectorLabelFlag,
Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.",

View File

@ -231,14 +231,15 @@ func prepareTunnelConfig(
Observer: observer,
ReportedVersion: info.Version(),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
Retries: uint(c.Int("retries")),
RunFromTerminal: isRunningFromTerminal(),
NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
NeedPQ: needPQ,
PQKexIdx: pqKexIdx,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
Retries: uint(c.Int("retries")),
RunFromTerminal: isRunningFromTerminal(),
NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
NeedPQ: needPQ,
PQKexIdx: pqKexIdx,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
}
packetConfig, err := newPacketConfig(c, log)
if err != nil {

View File

@ -63,6 +63,8 @@ type QUICConnection struct {
controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions
connIndex uint8
udpUnregisterTimeout time.Duration
}
// NewQUICConnection returns a new instance of QUICConnection.
@ -78,6 +80,7 @@ func NewQUICConnection(
controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration,
) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil {
@ -112,6 +115,7 @@ func NewQUICConnection(
controlStreamHandler: controlStreamHandler,
connOptions: connOptions,
connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout,
}, nil
}
@ -370,7 +374,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
stream := quicpogs.NewSafeStreamCloser(quicStream)
defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger)
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
// with edge

View File

@ -725,6 +725,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
fakeControlStream{},
&log,
nil,
5*time.Second,
)
require.NoError(t, err)
return qc

View File

@ -229,9 +229,12 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error {
type RPCClientStream struct {
client tunnelpogs.CloudflaredServer_PogsClient
transport rpc.Transport
// Time we wait for the server to respond to a request before we close the connection.
rpcUnregisterUDPSessionDeadline time.Duration
}
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) {
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, rpcUnregisterUDPSessionDeadline time.Duration, logger *zerolog.Logger) (*RPCClientStream, error) {
n, err := stream.Write(RPCStreamProtocolSignature[:])
if err != nil {
return nil, err
@ -245,8 +248,9 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
tunnelrpc.ConnLog(logger),
)
return &RPCClientStream{
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
transport: transport,
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
transport: transport,
rpcUnregisterUDPSessionDeadline: rpcUnregisterUDPSessionDeadline,
}, nil
}
@ -255,6 +259,8 @@ func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uu
}
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
ctx, cancel := context.WithTimeout(ctx, rcs.rpcUnregisterUDPSessionDeadline)
defer cancel()
return rcs.client.UnregisterUdpSession(ctx, sessionID, message)
}

View File

@ -109,6 +109,63 @@ func TestConnectResponseMeta(t *testing.T) {
}
}
func TestUnregisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
var tests = []struct {
name string
sessionRPCServer mockSessionRPCServer
timeout time.Duration
}{
{
name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
},
// very very low value so we trigger the timeout every time.
timeout: time.Nanosecond * 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := zerolog.Nop()
clientStream, serverStream := newMockRPCStreams()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
})
}
}
func TestRegisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
@ -157,7 +214,7 @@ func TestRegisterUdpSession(t *testing.T) {
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, 5*time.Second, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
@ -208,7 +265,7 @@ func TestManageConfiguration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger)
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, 5*time.Second, &logger)
assert.NoError(t, err)
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)

View File

@ -68,6 +68,8 @@ type TunnelConfig struct {
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
PacketConfig *ingress.GlobalRouterConfig
UDPUnregisterSessionTimeout time.Duration
}
func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
@ -615,7 +617,9 @@ func (e *EdgeTunnelServer) serveQUIC(
connOptions,
controlStreamHandler,
connLogger.Logger(),
e.config.PacketConfig)
e.config.PacketConfig,
e.config.UDPUnregisterSessionTimeout,
)
if err != nil {
if e.config.NeedPQ {
handlePQTunnelError(err, e.config)