TUN-8424: Refactor capnp registration server

Move RegistrationServer and RegistrationClient into tunnelrpc module
to properly abstract out the capnp aspects internal to the module only.
This commit is contained in:
Devin Carr 2024-05-24 11:40:10 -07:00
parent 43446bc692
commit 654a326098
13 changed files with 197 additions and 167 deletions

View File

@ -287,7 +287,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
func StartServer( func StartServer(
c *cli.Context, c *cli.Context,
info *cliutil.BuildInfo, info *cliutil.BuildInfo,
namedTunnel *connection.NamedTunnelProperties, namedTunnel *connection.TunnelProperties,
log *zerolog.Logger, log *zerolog.Logger,
) error { ) error {
err := sentry.Init(sentry.ClientOptions{ err := sentry.Init(sentry.ClientOptions{

View File

@ -108,7 +108,7 @@ func isSecretEnvVar(key string) bool {
return false return false
} }
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool { func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool {
return c.IsSet("proxy-dns") && return c.IsSet("proxy-dns") &&
!(c.IsSet("name") || // adhoc-named tunnel !(c.IsSet("name") || // adhoc-named tunnel
c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel
@ -121,7 +121,7 @@ func prepareTunnelConfig(
info *cliutil.BuildInfo, info *cliutil.BuildInfo,
log, logTransport *zerolog.Logger, log, logTransport *zerolog.Logger,
observer *connection.Observer, observer *connection.Observer,
namedTunnel *connection.NamedTunnelProperties, namedTunnel *connection.TunnelProperties,
) (*supervisor.TunnelConfig, *orchestration.Config, error) { ) (*supervisor.TunnelConfig, *orchestration.Config, error) {
clientID, err := uuid.NewRandom() clientID, err := uuid.NewRandom()
if err != nil { if err != nil {

View File

@ -79,7 +79,7 @@ func RunQuickTunnel(sc *subcommandContext) error {
return StartServer( return StartServer(
sc.c, sc.c,
buildInfo, buildInfo,
&connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, &connection.TunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
sc.log, sc.log,
) )
} }

View File

@ -261,7 +261,7 @@ func (sc *subcommandContext) runWithCredentials(credentials connection.Credentia
return StartServer( return StartServer(
sc.c, sc.c,
buildInfo, buildInfo,
&connection.NamedTunnelProperties{Credentials: credentials}, &connection.TunnelProperties{Credentials: credentials},
sc.log, sc.log,
) )
} }

View File

@ -42,7 +42,7 @@ type Orchestrator interface {
GetOriginProxy() (OriginProxy, error) GetOriginProxy() (OriginProxy, error)
} }
type NamedTunnelProperties struct { type TunnelProperties struct {
Credentials Credentials Credentials Credentials
Client pogs.ClientInfo Client pogs.ClientInfo
QuickTunnelUrl string QuickTunnelUrl string

View File

@ -6,25 +6,25 @@ import (
"net" "net"
"time" "time"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections. // registerClient derives a named tunnel rpc client that can then be used to register and unregister connections.
type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient type registerClientFunc func(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient
type controlStream struct { type controlStream struct {
observer *Observer observer *Observer
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
namedTunnelProperties *NamedTunnelProperties tunnelProperties *TunnelProperties
connIndex uint8 connIndex uint8
edgeAddress net.IP edgeAddress net.IP
protocol Protocol protocol Protocol
newRPCClientFunc RPCClientFunc registerClientFunc registerClientFunc
registerTimeout time.Duration
gracefulShutdownC <-chan struct{} gracefulShutdownC <-chan struct{}
gracePeriod time.Duration gracePeriod time.Duration
@ -47,22 +47,24 @@ type TunnelConfigJSONGetter interface {
func NewControlStream( func NewControlStream(
observer *Observer, observer *Observer,
connectedFuse ConnectedFuse, connectedFuse ConnectedFuse,
namedTunnelConfig *NamedTunnelProperties, tunnelProperties *TunnelProperties,
connIndex uint8, connIndex uint8,
edgeAddress net.IP, edgeAddress net.IP,
newRPCClientFunc RPCClientFunc, registerClientFunc registerClientFunc,
registerTimeout time.Duration,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
gracePeriod time.Duration, gracePeriod time.Duration,
protocol Protocol, protocol Protocol,
) ControlStreamHandler { ) ControlStreamHandler {
if newRPCClientFunc == nil { if registerClientFunc == nil {
newRPCClientFunc = newRegistrationRPCClient registerClientFunc = tunnelrpc.NewRegistrationClient
} }
return &controlStream{ return &controlStream{
observer: observer, observer: observer,
connectedFuse: connectedFuse, connectedFuse: connectedFuse,
namedTunnelProperties: namedTunnelConfig, tunnelProperties: tunnelProperties,
newRPCClientFunc: newRPCClientFunc, registerClientFunc: registerClientFunc,
registerTimeout: registerTimeout,
connIndex: connIndex, connIndex: connIndex,
edgeAddress: edgeAddress, edgeAddress: edgeAddress,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
@ -77,13 +79,25 @@ func (c *controlStream) ServeControlStream(
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
tunnelConfigGetter TunnelConfigJSONGetter, tunnelConfigGetter TunnelConfigJSONGetter,
) error { ) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)
registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.edgeAddress, c.observer) registrationDetails, err := registrationClient.RegisterConnection(
ctx,
c.tunnelProperties.Credentials.Auth(),
c.tunnelProperties.Credentials.TunnelID,
connOptions,
c.connIndex,
c.edgeAddress)
if err != nil { if err != nil {
rpcClient.Close() defer registrationClient.Close()
return err if err.Error() == DuplicateConnectionError {
c.observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
return errDuplicationConnection
} }
c.observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
return serverRegistrationErrorFromRPC(err)
}
c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol) c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol)
c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location) c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location)
@ -92,21 +106,23 @@ func (c *controlStream) ServeControlStream(
// if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration
if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged { if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged {
if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil { if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil {
if err := rpcClient.SendLocalConfiguration(ctx, tunnelConfig, c.observer); err != nil { if err := registrationClient.SendLocalConfiguration(ctx, tunnelConfig); err != nil {
c.observer.metrics.localConfigMetrics.pushesErrors.Inc()
c.observer.log.Err(err).Msg("unable to send local configuration") c.observer.log.Err(err).Msg("unable to send local configuration")
} }
c.observer.metrics.localConfigMetrics.pushes.Inc()
} else { } else {
c.observer.log.Err(err).Msg("failed to obtain current configuration") c.observer.log.Err(err).Msg("failed to obtain current configuration")
} }
} }
c.waitForUnregister(ctx, rpcClient) c.waitForUnregister(ctx, registrationClient)
return nil return nil
} }
func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTunnelRPCClient) { func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) {
// wait for connection termination or start of graceful shutdown // wait for connection termination or start of graceful shutdown
defer rpcClient.Close() defer registrationClient.Close()
select { select {
case <-ctx.Done(): case <-ctx.Done():
break break
@ -115,7 +131,7 @@ func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTu
} }
c.observer.sendUnregisteringEvent(c.connIndex) c.observer.sendUnregisteringEvent(c.connIndex)
rpcClient.GracefulShutdown(ctx, c.gracePeriod) registrationClient.GracefulShutdown(ctx, c.gracePeriod)
c.observer.log.Info(). c.observer.log.Info().
Int(management.EventTypeKey, int(management.Cloudflared)). Int(management.EventTypeKey, int(management.Cloudflared)).
Uint8(LogFieldConnIndex, c.connIndex). Uint8(LogFieldConnIndex, c.connIndex).

View File

@ -40,8 +40,6 @@ type HTTP2Connection struct {
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
observer *Observer observer *Observer
connIndex uint8 connIndex uint8
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
log *zerolog.Logger log *zerolog.Logger
activeRequestsWG sync.WaitGroup activeRequestsWG sync.WaitGroup
@ -69,7 +67,6 @@ func NewHTTP2Connection(
connOptions: connOptions, connOptions: connOptions,
observer: observer, observer: observer,
connIndex: connIndex, connIndex: connIndex,
newRPCClientFunc: newRegistrationRPCClient,
controlStreamHandler: controlStreamHandler, controlStreamHandler: controlStreamHandler,
log: log, log: log,
} }

View File

@ -20,8 +20,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
var ( var (
@ -36,10 +36,11 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &TunnelProperties{},
connIndex, connIndex,
nil, nil,
nil, nil,
1*time.Second,
nil, nil,
1*time.Second, 1*time.Second,
HTTP2, HTTP2,
@ -168,23 +169,23 @@ type mockNamedTunnelRPCClient struct {
unregistered chan struct{} unregistered chan struct{}
} }
func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte, observer *Observer) error { func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error {
return nil return nil
} }
func (mc mockNamedTunnelRPCClient) RegisterConnection( func (mc mockNamedTunnelRPCClient) RegisterConnection(
c context.Context, ctx context.Context,
properties *NamedTunnelProperties, auth pogs.TunnelAuth,
options *tunnelpogs.ConnectionOptions, tunnelID uuid.UUID,
options *pogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
edgeAddress net.IP, edgeAddress net.IP,
observer *Observer, ) (*pogs.ConnectionDetails, error) {
) (*tunnelpogs.ConnectionDetails, error) {
if mc.shouldFail != nil { if mc.shouldFail != nil {
return nil, mc.shouldFail return nil, mc.shouldFail
} }
close(mc.registered) close(mc.registered)
return &tunnelpogs.ConnectionDetails{ return &pogs.ConnectionDetails{
Location: "LIS", Location: "LIS",
UUID: uuid.New(), UUID: uuid.New(),
TunnelIsRemotelyManaged: false, TunnelIsRemotelyManaged: false,
@ -203,8 +204,8 @@ type mockRPCClientFactory struct {
unregistered chan struct{} unregistered chan struct{}
} }
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient { func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient {
return mockNamedTunnelRPCClient{ return &mockNamedTunnelRPCClient{
shouldFail: mf.shouldFail, shouldFail: mf.shouldFail,
registered: mf.registered, registered: mf.registered,
unregistered: mf.unregistered, unregistered: mf.unregistered,
@ -360,10 +361,11 @@ func TestServeControlStream(t *testing.T) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &TunnelProperties{},
1, 1,
nil, nil,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
1*time.Second,
nil, nil,
1*time.Second, 1*time.Second,
HTTP2, HTTP2,
@ -412,10 +414,11 @@ func TestFailRegistration(t *testing.T) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &TunnelProperties{},
http2Conn.connIndex, http2Conn.connIndex,
nil, nil,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
1*time.Second,
nil, nil,
1*time.Second, 1*time.Second,
HTTP2, HTTP2,
@ -460,10 +463,11 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelProperties{}, &TunnelProperties{},
http2Conn.connIndex, http2Conn.connIndex,
nil, nil,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
1*time.Second,
shutdownC, shutdownC,
1*time.Second, 1*time.Second,
HTTP2, HTTP2,

View File

@ -1,111 +0,0 @@
package connection
import (
"context"
"io"
"net"
"time"
"github.com/rs/zerolog"
"zombiezen.com/go/capnproto2/rpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type NamedTunnelRPCClient interface {
RegisterConnection(
c context.Context,
config *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
edgeAddress net.IP,
observer *Observer,
) (*tunnelpogs.ConnectionDetails, error)
SendLocalConfiguration(
c context.Context,
config []byte,
observer *Observer,
) error
GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
Close()
}
type registrationServerClient struct {
client tunnelpogs.RegistrationServer_PogsClient
transport rpc.Transport
}
func newRegistrationRPCClient(
ctx context.Context,
stream io.ReadWriteCloser,
log *zerolog.Logger,
) NamedTunnelRPCClient {
transport := rpc.StreamTransport(stream)
conn := rpc.NewConn(transport)
return &registrationServerClient{
client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
transport: transport,
}
}
func (rsc *registrationServerClient) RegisterConnection(
ctx context.Context,
properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
edgeAddress net.IP,
observer *Observer,
) (*tunnelpogs.ConnectionDetails, error) {
conn, err := rsc.client.RegisterConnection(
ctx,
properties.Credentials.Auth(),
properties.Credentials.TunnelID,
connIndex,
options,
)
if err != nil {
if err.Error() == DuplicateConnectionError {
observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
return nil, errDuplicationConnection
}
observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
return nil, serverRegistrationErrorFromRPC(err)
}
observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
return conn, nil
}
func (rsc *registrationServerClient) SendLocalConfiguration(ctx context.Context, config []byte, observer *Observer) (err error) {
observer.metrics.localConfigMetrics.pushes.Inc()
defer func() {
if err != nil {
observer.metrics.localConfigMetrics.pushesErrors.Inc()
}
}()
return rsc.client.SendLocalConfiguration(ctx, config)
}
func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
ctx, cancel := context.WithTimeout(ctx, gracePeriod)
defer cancel()
_ = rsc.client.UnregisterConnection(ctx)
}
func (rsc *registrationServerClient) Close() {
// Closing the client will also close the connection
_ = rsc.client.Close()
// Closing the transport also closes the stream
_ = rsc.transport.Close()
}
type rpcName string
const (
register rpcName = "register"
reconnect rpcName = "reconnect"
unregister rpcName = "unregister"
authenticate rpcName = " authenticate"
)

View File

@ -58,7 +58,7 @@ type TunnelConfig struct {
NeedPQ bool NeedPQ bool
NamedTunnel *connection.NamedTunnelProperties NamedTunnel *connection.TunnelProperties
ProtocolSelector connection.ProtocolSelector ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config EdgeTLSConfigs map[connection.Protocol]*tls.Config
PacketConfig *ingress.GlobalRouterConfig PacketConfig *ingress.GlobalRouterConfig
@ -454,6 +454,7 @@ func (e *EdgeTunnelServer) serveConnection(
connIndex, connIndex,
addr.UDP.IP, addr.UDP.IP,
nil, nil,
e.config.RPCTimeout,
e.gracefulShutdownC, e.gracefulShutdownC,
e.config.GracePeriod, e.config.GracePeriod,
protocol, protocol,

View File

@ -105,6 +105,13 @@ type RegistrationServer_PogsClient struct {
Conn *rpc.Conn Conn *rpc.Conn
} }
func NewRegistrationServer_PogsClient(client capnp.Client, conn *rpc.Conn) RegistrationServer_PogsClient {
return RegistrationServer_PogsClient{
Client: client,
Conn: conn,
}
}
func (c RegistrationServer_PogsClient) Close() error { func (c RegistrationServer_PogsClient) Close() error {
c.Client.Close() c.Client.Close()
return c.Conn.Close() return c.Conn.Close()

View File

@ -0,0 +1,76 @@
package tunnelrpc
import (
"context"
"io"
"net"
"time"
"github.com/google/uuid"
"zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type RegistrationClient interface {
RegisterConnection(
ctx context.Context,
auth pogs.TunnelAuth,
tunnelID uuid.UUID,
options *pogs.ConnectionOptions,
connIndex uint8,
edgeAddress net.IP,
) (*pogs.ConnectionDetails, error)
SendLocalConfiguration(ctx context.Context, config []byte) error
GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
Close()
}
type registrationClient struct {
client pogs.RegistrationServer_PogsClient
transport rpc.Transport
requestTimeout time.Duration
}
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser, requestTimeout time.Duration) RegistrationClient {
transport := SafeTransport(stream)
conn := rpc.NewConn(transport)
client := pogs.NewRegistrationServer_PogsClient(conn.Bootstrap(ctx), conn)
return &registrationClient{
client: client,
transport: transport,
requestTimeout: requestTimeout,
}
}
func (r *registrationClient) RegisterConnection(
ctx context.Context,
auth pogs.TunnelAuth,
tunnelID uuid.UUID,
options *pogs.ConnectionOptions,
connIndex uint8,
edgeAddress net.IP,
) (*pogs.ConnectionDetails, error) {
ctx, cancel := context.WithTimeout(ctx, r.requestTimeout)
defer cancel()
return r.client.RegisterConnection(ctx, auth, tunnelID, connIndex, options)
}
func (r *registrationClient) SendLocalConfiguration(ctx context.Context, config []byte) error {
ctx, cancel := context.WithTimeout(ctx, r.requestTimeout)
defer cancel()
return r.client.SendLocalConfiguration(ctx, config)
}
func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
ctx, cancel := context.WithTimeout(ctx, gracePeriod)
defer cancel()
_ = r.client.UnregisterConnection(ctx)
}
func (r *registrationClient) Close() {
// Closing the client will also close the connection
_ = r.client.Close()
// Closing the transport also closes the stream
_ = r.transport.Close()
}

View File

@ -0,0 +1,40 @@
package tunnelrpc
import (
"context"
"io"
"zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// RegistrationServer provides a handler interface for a client to provide methods to handle the different types of
// requests that can be communicated by the stream.
type RegistrationServer struct {
registrationServer pogs.RegistrationServer
}
func NewRegistrationServer(registrationServer pogs.RegistrationServer) *RegistrationServer {
return &RegistrationServer{
registrationServer: registrationServer,
}
}
// Serve listens for all RegistrationServer RPCs, including UnregisterConnection until the underlying connection
// is terminated.
func (s *RegistrationServer) Serve(ctx context.Context, stream io.ReadWriteCloser) error {
transport := SafeTransport(stream)
defer transport.Close()
main := pogs.RegistrationServer_ServerToClient(s.registrationServer)
rpcConn := rpc.NewConn(transport, rpc.MainInterface(main.Client))
defer rpcConn.Close()
select {
case <-rpcConn.Done():
return rpcConn.Wait()
case <-ctx.Done():
return ctx.Err()
}
}