From 654a326098290a707ff9bb1ca37270b45a242a38 Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Fri, 24 May 2024 11:40:10 -0700 Subject: [PATCH] TUN-8424: Refactor capnp registration server Move RegistrationServer and RegistrationClient into tunnelrpc module to properly abstract out the capnp aspects internal to the module only. --- cmd/cloudflared/tunnel/cmd.go | 2 +- cmd/cloudflared/tunnel/configuration.go | 4 +- cmd/cloudflared/tunnel/quick_tunnel.go | 2 +- cmd/cloudflared/tunnel/subcommand_context.go | 2 +- connection/connection.go | 2 +- connection/control.go | 80 +++++++------ connection/http2.go | 3 - connection/http2_test.go | 32 +++--- connection/rpc.go | 111 ------------------- supervisor/tunnel.go | 3 +- tunnelrpc/pogs/registration_server.go | 7 ++ tunnelrpc/registration_client.go | 76 +++++++++++++ tunnelrpc/registration_server.go | 40 +++++++ 13 files changed, 197 insertions(+), 167 deletions(-) delete mode 100644 connection/rpc.go create mode 100644 tunnelrpc/registration_client.go create mode 100644 tunnelrpc/registration_server.go diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 9b2cd834..07c171b6 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -287,7 +287,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) { func StartServer( c *cli.Context, info *cliutil.BuildInfo, - namedTunnel *connection.NamedTunnelProperties, + namedTunnel *connection.TunnelProperties, log *zerolog.Logger, ) error { err := sentry.Init(sentry.ClientOptions{ diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index db9558e3..01833e01 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -108,7 +108,7 @@ func isSecretEnvVar(key string) bool { return false } -func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool { +func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool { return c.IsSet("proxy-dns") && !(c.IsSet("name") || // adhoc-named tunnel c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel @@ -121,7 +121,7 @@ func prepareTunnelConfig( info *cliutil.BuildInfo, log, logTransport *zerolog.Logger, observer *connection.Observer, - namedTunnel *connection.NamedTunnelProperties, + namedTunnel *connection.TunnelProperties, ) (*supervisor.TunnelConfig, *orchestration.Config, error) { clientID, err := uuid.NewRandom() if err != nil { diff --git a/cmd/cloudflared/tunnel/quick_tunnel.go b/cmd/cloudflared/tunnel/quick_tunnel.go index da7d0a63..dc8e8707 100644 --- a/cmd/cloudflared/tunnel/quick_tunnel.go +++ b/cmd/cloudflared/tunnel/quick_tunnel.go @@ -79,7 +79,7 @@ func RunQuickTunnel(sc *subcommandContext) error { return StartServer( sc.c, buildInfo, - &connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, + &connection.TunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, sc.log, ) } diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index c0bddd9d..83332b51 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -261,7 +261,7 @@ func (sc *subcommandContext) runWithCredentials(credentials connection.Credentia return StartServer( sc.c, buildInfo, - &connection.NamedTunnelProperties{Credentials: credentials}, + &connection.TunnelProperties{Credentials: credentials}, sc.log, ) } diff --git a/connection/connection.go b/connection/connection.go index b24ef4ea..50464e4a 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -42,7 +42,7 @@ type Orchestrator interface { GetOriginProxy() (OriginProxy, error) } -type NamedTunnelProperties struct { +type TunnelProperties struct { Credentials Credentials Client pogs.ClientInfo QuickTunnelUrl string diff --git a/connection/control.go b/connection/control.go index 5cde0204..e0bfeae9 100644 --- a/connection/control.go +++ b/connection/control.go @@ -6,25 +6,25 @@ import ( "net" "time" - "github.com/rs/zerolog" - "github.com/cloudflare/cloudflared/management" + "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections. -type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient +// registerClient derives a named tunnel rpc client that can then be used to register and unregister connections. +type registerClientFunc func(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient type controlStream struct { observer *Observer - connectedFuse ConnectedFuse - namedTunnelProperties *NamedTunnelProperties - connIndex uint8 - edgeAddress net.IP - protocol Protocol + connectedFuse ConnectedFuse + tunnelProperties *TunnelProperties + connIndex uint8 + edgeAddress net.IP + protocol Protocol - newRPCClientFunc RPCClientFunc + registerClientFunc registerClientFunc + registerTimeout time.Duration gracefulShutdownC <-chan struct{} gracePeriod time.Duration @@ -47,27 +47,29 @@ type TunnelConfigJSONGetter interface { func NewControlStream( observer *Observer, connectedFuse ConnectedFuse, - namedTunnelConfig *NamedTunnelProperties, + tunnelProperties *TunnelProperties, connIndex uint8, edgeAddress net.IP, - newRPCClientFunc RPCClientFunc, + registerClientFunc registerClientFunc, + registerTimeout time.Duration, gracefulShutdownC <-chan struct{}, gracePeriod time.Duration, protocol Protocol, ) ControlStreamHandler { - if newRPCClientFunc == nil { - newRPCClientFunc = newRegistrationRPCClient + if registerClientFunc == nil { + registerClientFunc = tunnelrpc.NewRegistrationClient } return &controlStream{ - observer: observer, - connectedFuse: connectedFuse, - namedTunnelProperties: namedTunnelConfig, - newRPCClientFunc: newRPCClientFunc, - connIndex: connIndex, - edgeAddress: edgeAddress, - gracefulShutdownC: gracefulShutdownC, - gracePeriod: gracePeriod, - protocol: protocol, + observer: observer, + connectedFuse: connectedFuse, + tunnelProperties: tunnelProperties, + registerClientFunc: registerClientFunc, + registerTimeout: registerTimeout, + connIndex: connIndex, + edgeAddress: edgeAddress, + gracefulShutdownC: gracefulShutdownC, + gracePeriod: gracePeriod, + protocol: protocol, } } @@ -77,13 +79,25 @@ func (c *controlStream) ServeControlStream( connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter, ) error { - rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) + registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout) - registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.edgeAddress, c.observer) + registrationDetails, err := registrationClient.RegisterConnection( + ctx, + c.tunnelProperties.Credentials.Auth(), + c.tunnelProperties.Credentials.TunnelID, + connOptions, + c.connIndex, + c.edgeAddress) if err != nil { - rpcClient.Close() - return err + defer registrationClient.Close() + if err.Error() == DuplicateConnectionError { + c.observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc() + return errDuplicationConnection + } + c.observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc() + return serverRegistrationErrorFromRPC(err) } + c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol) c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location) @@ -92,21 +106,23 @@ func (c *controlStream) ServeControlStream( // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged { if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil { - if err := rpcClient.SendLocalConfiguration(ctx, tunnelConfig, c.observer); err != nil { + if err := registrationClient.SendLocalConfiguration(ctx, tunnelConfig); err != nil { + c.observer.metrics.localConfigMetrics.pushesErrors.Inc() c.observer.log.Err(err).Msg("unable to send local configuration") } + c.observer.metrics.localConfigMetrics.pushes.Inc() } else { c.observer.log.Err(err).Msg("failed to obtain current configuration") } } - c.waitForUnregister(ctx, rpcClient) + c.waitForUnregister(ctx, registrationClient) return nil } -func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTunnelRPCClient) { +func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) { // wait for connection termination or start of graceful shutdown - defer rpcClient.Close() + defer registrationClient.Close() select { case <-ctx.Done(): break @@ -115,7 +131,7 @@ func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTu } c.observer.sendUnregisteringEvent(c.connIndex) - rpcClient.GracefulShutdown(ctx, c.gracePeriod) + registrationClient.GracefulShutdown(ctx, c.gracePeriod) c.observer.log.Info(). Int(management.EventTypeKey, int(management.Cloudflared)). Uint8(LogFieldConnIndex, c.connIndex). diff --git a/connection/http2.go b/connection/http2.go index 124746cb..f5e4d873 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -40,8 +40,6 @@ type HTTP2Connection struct { connOptions *tunnelpogs.ConnectionOptions observer *Observer connIndex uint8 - // newRPCClientFunc allows us to mock RPCs during testing - newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient log *zerolog.Logger activeRequestsWG sync.WaitGroup @@ -69,7 +67,6 @@ func NewHTTP2Connection( connOptions: connOptions, observer: observer, connIndex: connIndex, - newRPCClientFunc: newRegistrationRPCClient, controlStreamHandler: controlStreamHandler, log: log, } diff --git a/connection/http2_test.go b/connection/http2_test.go index 1cb39646..a0ec8b45 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -20,8 +20,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/http2" + "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) var ( @@ -36,10 +36,11 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelProperties{}, + &TunnelProperties{}, connIndex, nil, nil, + 1*time.Second, nil, 1*time.Second, HTTP2, @@ -168,23 +169,23 @@ type mockNamedTunnelRPCClient struct { unregistered chan struct{} } -func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte, observer *Observer) error { +func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error { return nil } func (mc mockNamedTunnelRPCClient) RegisterConnection( - c context.Context, - properties *NamedTunnelProperties, - options *tunnelpogs.ConnectionOptions, + ctx context.Context, + auth pogs.TunnelAuth, + tunnelID uuid.UUID, + options *pogs.ConnectionOptions, connIndex uint8, edgeAddress net.IP, - observer *Observer, -) (*tunnelpogs.ConnectionDetails, error) { +) (*pogs.ConnectionDetails, error) { if mc.shouldFail != nil { return nil, mc.shouldFail } close(mc.registered) - return &tunnelpogs.ConnectionDetails{ + return &pogs.ConnectionDetails{ Location: "LIS", UUID: uuid.New(), TunnelIsRemotelyManaged: false, @@ -203,8 +204,8 @@ type mockRPCClientFactory struct { unregistered chan struct{} } -func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient { - return mockNamedTunnelRPCClient{ +func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient { + return &mockNamedTunnelRPCClient{ shouldFail: mf.shouldFail, registered: mf.registered, unregistered: mf.unregistered, @@ -360,10 +361,11 @@ func TestServeControlStream(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelProperties{}, + &TunnelProperties{}, 1, nil, rpcClientFactory.newMockRPCClient, + 1*time.Second, nil, 1*time.Second, HTTP2, @@ -412,10 +414,11 @@ func TestFailRegistration(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelProperties{}, + &TunnelProperties{}, http2Conn.connIndex, nil, rpcClientFactory.newMockRPCClient, + 1*time.Second, nil, 1*time.Second, HTTP2, @@ -460,10 +463,11 @@ func TestGracefulShutdownHTTP2(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelProperties{}, + &TunnelProperties{}, http2Conn.connIndex, nil, rpcClientFactory.newMockRPCClient, + 1*time.Second, shutdownC, 1*time.Second, HTTP2, diff --git a/connection/rpc.go b/connection/rpc.go deleted file mode 100644 index f30b7b93..00000000 --- a/connection/rpc.go +++ /dev/null @@ -1,111 +0,0 @@ -package connection - -import ( - "context" - "io" - "net" - "time" - - "github.com/rs/zerolog" - "zombiezen.com/go/capnproto2/rpc" - - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" -) - -type NamedTunnelRPCClient interface { - RegisterConnection( - c context.Context, - config *NamedTunnelProperties, - options *tunnelpogs.ConnectionOptions, - connIndex uint8, - edgeAddress net.IP, - observer *Observer, - ) (*tunnelpogs.ConnectionDetails, error) - SendLocalConfiguration( - c context.Context, - config []byte, - observer *Observer, - ) error - GracefulShutdown(ctx context.Context, gracePeriod time.Duration) - Close() -} - -type registrationServerClient struct { - client tunnelpogs.RegistrationServer_PogsClient - transport rpc.Transport -} - -func newRegistrationRPCClient( - ctx context.Context, - stream io.ReadWriteCloser, - log *zerolog.Logger, -) NamedTunnelRPCClient { - transport := rpc.StreamTransport(stream) - conn := rpc.NewConn(transport) - return ®istrationServerClient{ - client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}, - transport: transport, - } -} - -func (rsc *registrationServerClient) RegisterConnection( - ctx context.Context, - properties *NamedTunnelProperties, - options *tunnelpogs.ConnectionOptions, - connIndex uint8, - edgeAddress net.IP, - observer *Observer, -) (*tunnelpogs.ConnectionDetails, error) { - conn, err := rsc.client.RegisterConnection( - ctx, - properties.Credentials.Auth(), - properties.Credentials.TunnelID, - connIndex, - options, - ) - if err != nil { - if err.Error() == DuplicateConnectionError { - observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc() - return nil, errDuplicationConnection - } - observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc() - return nil, serverRegistrationErrorFromRPC(err) - } - - observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() - - return conn, nil -} - -func (rsc *registrationServerClient) SendLocalConfiguration(ctx context.Context, config []byte, observer *Observer) (err error) { - observer.metrics.localConfigMetrics.pushes.Inc() - defer func() { - if err != nil { - observer.metrics.localConfigMetrics.pushesErrors.Inc() - } - }() - - return rsc.client.SendLocalConfiguration(ctx, config) -} - -func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { - ctx, cancel := context.WithTimeout(ctx, gracePeriod) - defer cancel() - _ = rsc.client.UnregisterConnection(ctx) -} - -func (rsc *registrationServerClient) Close() { - // Closing the client will also close the connection - _ = rsc.client.Close() - // Closing the transport also closes the stream - _ = rsc.transport.Close() -} - -type rpcName string - -const ( - register rpcName = "register" - reconnect rpcName = "reconnect" - unregister rpcName = "unregister" - authenticate rpcName = " authenticate" -) diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index e7c25baf..687acce6 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -58,7 +58,7 @@ type TunnelConfig struct { NeedPQ bool - NamedTunnel *connection.NamedTunnelProperties + NamedTunnel *connection.TunnelProperties ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config PacketConfig *ingress.GlobalRouterConfig @@ -454,6 +454,7 @@ func (e *EdgeTunnelServer) serveConnection( connIndex, addr.UDP.IP, nil, + e.config.RPCTimeout, e.gracefulShutdownC, e.config.GracePeriod, protocol, diff --git a/tunnelrpc/pogs/registration_server.go b/tunnelrpc/pogs/registration_server.go index 995ad2e3..e0f6b1e8 100644 --- a/tunnelrpc/pogs/registration_server.go +++ b/tunnelrpc/pogs/registration_server.go @@ -105,6 +105,13 @@ type RegistrationServer_PogsClient struct { Conn *rpc.Conn } +func NewRegistrationServer_PogsClient(client capnp.Client, conn *rpc.Conn) RegistrationServer_PogsClient { + return RegistrationServer_PogsClient{ + Client: client, + Conn: conn, + } +} + func (c RegistrationServer_PogsClient) Close() error { c.Client.Close() return c.Conn.Close() diff --git a/tunnelrpc/registration_client.go b/tunnelrpc/registration_client.go new file mode 100644 index 00000000..96aef963 --- /dev/null +++ b/tunnelrpc/registration_client.go @@ -0,0 +1,76 @@ +package tunnelrpc + +import ( + "context" + "io" + "net" + "time" + + "github.com/google/uuid" + "zombiezen.com/go/capnproto2/rpc" + + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +type RegistrationClient interface { + RegisterConnection( + ctx context.Context, + auth pogs.TunnelAuth, + tunnelID uuid.UUID, + options *pogs.ConnectionOptions, + connIndex uint8, + edgeAddress net.IP, + ) (*pogs.ConnectionDetails, error) + SendLocalConfiguration(ctx context.Context, config []byte) error + GracefulShutdown(ctx context.Context, gracePeriod time.Duration) + Close() +} + +type registrationClient struct { + client pogs.RegistrationServer_PogsClient + transport rpc.Transport + requestTimeout time.Duration +} + +func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser, requestTimeout time.Duration) RegistrationClient { + transport := SafeTransport(stream) + conn := rpc.NewConn(transport) + client := pogs.NewRegistrationServer_PogsClient(conn.Bootstrap(ctx), conn) + return ®istrationClient{ + client: client, + transport: transport, + requestTimeout: requestTimeout, + } +} + +func (r *registrationClient) RegisterConnection( + ctx context.Context, + auth pogs.TunnelAuth, + tunnelID uuid.UUID, + options *pogs.ConnectionOptions, + connIndex uint8, + edgeAddress net.IP, +) (*pogs.ConnectionDetails, error) { + ctx, cancel := context.WithTimeout(ctx, r.requestTimeout) + defer cancel() + return r.client.RegisterConnection(ctx, auth, tunnelID, connIndex, options) +} + +func (r *registrationClient) SendLocalConfiguration(ctx context.Context, config []byte) error { + ctx, cancel := context.WithTimeout(ctx, r.requestTimeout) + defer cancel() + return r.client.SendLocalConfiguration(ctx, config) +} + +func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { + ctx, cancel := context.WithTimeout(ctx, gracePeriod) + defer cancel() + _ = r.client.UnregisterConnection(ctx) +} + +func (r *registrationClient) Close() { + // Closing the client will also close the connection + _ = r.client.Close() + // Closing the transport also closes the stream + _ = r.transport.Close() +} diff --git a/tunnelrpc/registration_server.go b/tunnelrpc/registration_server.go new file mode 100644 index 00000000..84044e84 --- /dev/null +++ b/tunnelrpc/registration_server.go @@ -0,0 +1,40 @@ +package tunnelrpc + +import ( + "context" + "io" + + "zombiezen.com/go/capnproto2/rpc" + + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +// RegistrationServer provides a handler interface for a client to provide methods to handle the different types of +// requests that can be communicated by the stream. +type RegistrationServer struct { + registrationServer pogs.RegistrationServer +} + +func NewRegistrationServer(registrationServer pogs.RegistrationServer) *RegistrationServer { + return &RegistrationServer{ + registrationServer: registrationServer, + } +} + +// Serve listens for all RegistrationServer RPCs, including UnregisterConnection until the underlying connection +// is terminated. +func (s *RegistrationServer) Serve(ctx context.Context, stream io.ReadWriteCloser) error { + transport := SafeTransport(stream) + defer transport.Close() + + main := pogs.RegistrationServer_ServerToClient(s.registrationServer) + rpcConn := rpc.NewConn(transport, rpc.MainInterface(main.Client)) + defer rpcConn.Close() + + select { + case <-rpcConn.Done(): + return rpcConn.Wait() + case <-ctx.Done(): + return ctx.Err() + } +}