TUN-8424: Refactor capnp registration server
Move RegistrationServer and RegistrationClient into tunnelrpc module to properly abstract out the capnp aspects internal to the module only.
This commit is contained in:
		
							parent
							
								
									43446bc692
								
							
						
					
					
						commit
						654a326098
					
				| 
						 | 
				
			
			@ -287,7 +287,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
 | 
			
		|||
func StartServer(
 | 
			
		||||
	c *cli.Context,
 | 
			
		||||
	info *cliutil.BuildInfo,
 | 
			
		||||
	namedTunnel *connection.NamedTunnelProperties,
 | 
			
		||||
	namedTunnel *connection.TunnelProperties,
 | 
			
		||||
	log *zerolog.Logger,
 | 
			
		||||
) error {
 | 
			
		||||
	err := sentry.Init(sentry.ClientOptions{
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -108,7 +108,7 @@ func isSecretEnvVar(key string) bool {
 | 
			
		|||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool {
 | 
			
		||||
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool {
 | 
			
		||||
	return c.IsSet("proxy-dns") &&
 | 
			
		||||
		!(c.IsSet("name") || // adhoc-named tunnel
 | 
			
		||||
			c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +121,7 @@ func prepareTunnelConfig(
 | 
			
		|||
	info *cliutil.BuildInfo,
 | 
			
		||||
	log, logTransport *zerolog.Logger,
 | 
			
		||||
	observer *connection.Observer,
 | 
			
		||||
	namedTunnel *connection.NamedTunnelProperties,
 | 
			
		||||
	namedTunnel *connection.TunnelProperties,
 | 
			
		||||
) (*supervisor.TunnelConfig, *orchestration.Config, error) {
 | 
			
		||||
	clientID, err := uuid.NewRandom()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -79,7 +79,7 @@ func RunQuickTunnel(sc *subcommandContext) error {
 | 
			
		|||
	return StartServer(
 | 
			
		||||
		sc.c,
 | 
			
		||||
		buildInfo,
 | 
			
		||||
		&connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
 | 
			
		||||
		&connection.TunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
 | 
			
		||||
		sc.log,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -261,7 +261,7 @@ func (sc *subcommandContext) runWithCredentials(credentials connection.Credentia
 | 
			
		|||
	return StartServer(
 | 
			
		||||
		sc.c,
 | 
			
		||||
		buildInfo,
 | 
			
		||||
		&connection.NamedTunnelProperties{Credentials: credentials},
 | 
			
		||||
		&connection.TunnelProperties{Credentials: credentials},
 | 
			
		||||
		sc.log,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,7 +42,7 @@ type Orchestrator interface {
 | 
			
		|||
	GetOriginProxy() (OriginProxy, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NamedTunnelProperties struct {
 | 
			
		||||
type TunnelProperties struct {
 | 
			
		||||
	Credentials    Credentials
 | 
			
		||||
	Client         pogs.ClientInfo
 | 
			
		||||
	QuickTunnelUrl string
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,25 +6,25 @@ import (
 | 
			
		|||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/rs/zerolog"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/management"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections.
 | 
			
		||||
type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
 | 
			
		||||
// registerClient derives a named tunnel rpc client that can then be used to register and unregister connections.
 | 
			
		||||
type registerClientFunc func(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient
 | 
			
		||||
 | 
			
		||||
type controlStream struct {
 | 
			
		||||
	observer *Observer
 | 
			
		||||
 | 
			
		||||
	connectedFuse         ConnectedFuse
 | 
			
		||||
	namedTunnelProperties *NamedTunnelProperties
 | 
			
		||||
	connIndex             uint8
 | 
			
		||||
	edgeAddress           net.IP
 | 
			
		||||
	protocol              Protocol
 | 
			
		||||
	connectedFuse    ConnectedFuse
 | 
			
		||||
	tunnelProperties *TunnelProperties
 | 
			
		||||
	connIndex        uint8
 | 
			
		||||
	edgeAddress      net.IP
 | 
			
		||||
	protocol         Protocol
 | 
			
		||||
 | 
			
		||||
	newRPCClientFunc RPCClientFunc
 | 
			
		||||
	registerClientFunc registerClientFunc
 | 
			
		||||
	registerTimeout    time.Duration
 | 
			
		||||
 | 
			
		||||
	gracefulShutdownC <-chan struct{}
 | 
			
		||||
	gracePeriod       time.Duration
 | 
			
		||||
| 
						 | 
				
			
			@ -47,27 +47,29 @@ type TunnelConfigJSONGetter interface {
 | 
			
		|||
func NewControlStream(
 | 
			
		||||
	observer *Observer,
 | 
			
		||||
	connectedFuse ConnectedFuse,
 | 
			
		||||
	namedTunnelConfig *NamedTunnelProperties,
 | 
			
		||||
	tunnelProperties *TunnelProperties,
 | 
			
		||||
	connIndex uint8,
 | 
			
		||||
	edgeAddress net.IP,
 | 
			
		||||
	newRPCClientFunc RPCClientFunc,
 | 
			
		||||
	registerClientFunc registerClientFunc,
 | 
			
		||||
	registerTimeout time.Duration,
 | 
			
		||||
	gracefulShutdownC <-chan struct{},
 | 
			
		||||
	gracePeriod time.Duration,
 | 
			
		||||
	protocol Protocol,
 | 
			
		||||
) ControlStreamHandler {
 | 
			
		||||
	if newRPCClientFunc == nil {
 | 
			
		||||
		newRPCClientFunc = newRegistrationRPCClient
 | 
			
		||||
	if registerClientFunc == nil {
 | 
			
		||||
		registerClientFunc = tunnelrpc.NewRegistrationClient
 | 
			
		||||
	}
 | 
			
		||||
	return &controlStream{
 | 
			
		||||
		observer:              observer,
 | 
			
		||||
		connectedFuse:         connectedFuse,
 | 
			
		||||
		namedTunnelProperties: namedTunnelConfig,
 | 
			
		||||
		newRPCClientFunc:      newRPCClientFunc,
 | 
			
		||||
		connIndex:             connIndex,
 | 
			
		||||
		edgeAddress:           edgeAddress,
 | 
			
		||||
		gracefulShutdownC:     gracefulShutdownC,
 | 
			
		||||
		gracePeriod:           gracePeriod,
 | 
			
		||||
		protocol:              protocol,
 | 
			
		||||
		observer:           observer,
 | 
			
		||||
		connectedFuse:      connectedFuse,
 | 
			
		||||
		tunnelProperties:   tunnelProperties,
 | 
			
		||||
		registerClientFunc: registerClientFunc,
 | 
			
		||||
		registerTimeout:    registerTimeout,
 | 
			
		||||
		connIndex:          connIndex,
 | 
			
		||||
		edgeAddress:        edgeAddress,
 | 
			
		||||
		gracefulShutdownC:  gracefulShutdownC,
 | 
			
		||||
		gracePeriod:        gracePeriod,
 | 
			
		||||
		protocol:           protocol,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -77,13 +79,25 @@ func (c *controlStream) ServeControlStream(
 | 
			
		|||
	connOptions *tunnelpogs.ConnectionOptions,
 | 
			
		||||
	tunnelConfigGetter TunnelConfigJSONGetter,
 | 
			
		||||
) error {
 | 
			
		||||
	rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
 | 
			
		||||
	registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)
 | 
			
		||||
 | 
			
		||||
	registrationDetails, err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.edgeAddress, c.observer)
 | 
			
		||||
	registrationDetails, err := registrationClient.RegisterConnection(
 | 
			
		||||
		ctx,
 | 
			
		||||
		c.tunnelProperties.Credentials.Auth(),
 | 
			
		||||
		c.tunnelProperties.Credentials.TunnelID,
 | 
			
		||||
		connOptions,
 | 
			
		||||
		c.connIndex,
 | 
			
		||||
		c.edgeAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		rpcClient.Close()
 | 
			
		||||
		return err
 | 
			
		||||
		defer registrationClient.Close()
 | 
			
		||||
		if err.Error() == DuplicateConnectionError {
 | 
			
		||||
			c.observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
 | 
			
		||||
			return errDuplicationConnection
 | 
			
		||||
		}
 | 
			
		||||
		c.observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
 | 
			
		||||
		return serverRegistrationErrorFromRPC(err)
 | 
			
		||||
	}
 | 
			
		||||
	c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
 | 
			
		||||
 | 
			
		||||
	c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol)
 | 
			
		||||
	c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location)
 | 
			
		||||
| 
						 | 
				
			
			@ -92,21 +106,23 @@ func (c *controlStream) ServeControlStream(
 | 
			
		|||
	// if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration
 | 
			
		||||
	if c.connIndex == 0 && !registrationDetails.TunnelIsRemotelyManaged {
 | 
			
		||||
		if tunnelConfig, err := tunnelConfigGetter.GetConfigJSON(); err == nil {
 | 
			
		||||
			if err := rpcClient.SendLocalConfiguration(ctx, tunnelConfig, c.observer); err != nil {
 | 
			
		||||
			if err := registrationClient.SendLocalConfiguration(ctx, tunnelConfig); err != nil {
 | 
			
		||||
				c.observer.metrics.localConfigMetrics.pushesErrors.Inc()
 | 
			
		||||
				c.observer.log.Err(err).Msg("unable to send local configuration")
 | 
			
		||||
			}
 | 
			
		||||
			c.observer.metrics.localConfigMetrics.pushes.Inc()
 | 
			
		||||
		} else {
 | 
			
		||||
			c.observer.log.Err(err).Msg("failed to obtain current configuration")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.waitForUnregister(ctx, rpcClient)
 | 
			
		||||
	c.waitForUnregister(ctx, registrationClient)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTunnelRPCClient) {
 | 
			
		||||
func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) {
 | 
			
		||||
	// wait for connection termination or start of graceful shutdown
 | 
			
		||||
	defer rpcClient.Close()
 | 
			
		||||
	defer registrationClient.Close()
 | 
			
		||||
	select {
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		break
 | 
			
		||||
| 
						 | 
				
			
			@ -115,7 +131,7 @@ func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTu
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	c.observer.sendUnregisteringEvent(c.connIndex)
 | 
			
		||||
	rpcClient.GracefulShutdown(ctx, c.gracePeriod)
 | 
			
		||||
	registrationClient.GracefulShutdown(ctx, c.gracePeriod)
 | 
			
		||||
	c.observer.log.Info().
 | 
			
		||||
		Int(management.EventTypeKey, int(management.Cloudflared)).
 | 
			
		||||
		Uint8(LogFieldConnIndex, c.connIndex).
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,8 +40,6 @@ type HTTP2Connection struct {
 | 
			
		|||
	connOptions  *tunnelpogs.ConnectionOptions
 | 
			
		||||
	observer     *Observer
 | 
			
		||||
	connIndex    uint8
 | 
			
		||||
	// newRPCClientFunc allows us to mock RPCs during testing
 | 
			
		||||
	newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
 | 
			
		||||
 | 
			
		||||
	log                  *zerolog.Logger
 | 
			
		||||
	activeRequestsWG     sync.WaitGroup
 | 
			
		||||
| 
						 | 
				
			
			@ -69,7 +67,6 @@ func NewHTTP2Connection(
 | 
			
		|||
		connOptions:          connOptions,
 | 
			
		||||
		observer:             observer,
 | 
			
		||||
		connIndex:            connIndex,
 | 
			
		||||
		newRPCClientFunc:     newRegistrationRPCClient,
 | 
			
		||||
		controlStreamHandler: controlStreamHandler,
 | 
			
		||||
		log:                  log,
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,8 +20,8 @@ import (
 | 
			
		|||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"golang.org/x/net/http2"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
| 
						 | 
				
			
			@ -36,10 +36,11 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
 | 
			
		|||
	controlStream := NewControlStream(
 | 
			
		||||
		obs,
 | 
			
		||||
		mockConnectedFuse{},
 | 
			
		||||
		&NamedTunnelProperties{},
 | 
			
		||||
		&TunnelProperties{},
 | 
			
		||||
		connIndex,
 | 
			
		||||
		nil,
 | 
			
		||||
		nil,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		nil,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		HTTP2,
 | 
			
		||||
| 
						 | 
				
			
			@ -168,23 +169,23 @@ type mockNamedTunnelRPCClient struct {
 | 
			
		|||
	unregistered chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte, observer *Observer) error {
 | 
			
		||||
func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mc mockNamedTunnelRPCClient) RegisterConnection(
 | 
			
		||||
	c context.Context,
 | 
			
		||||
	properties *NamedTunnelProperties,
 | 
			
		||||
	options *tunnelpogs.ConnectionOptions,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	auth pogs.TunnelAuth,
 | 
			
		||||
	tunnelID uuid.UUID,
 | 
			
		||||
	options *pogs.ConnectionOptions,
 | 
			
		||||
	connIndex uint8,
 | 
			
		||||
	edgeAddress net.IP,
 | 
			
		||||
	observer *Observer,
 | 
			
		||||
) (*tunnelpogs.ConnectionDetails, error) {
 | 
			
		||||
) (*pogs.ConnectionDetails, error) {
 | 
			
		||||
	if mc.shouldFail != nil {
 | 
			
		||||
		return nil, mc.shouldFail
 | 
			
		||||
	}
 | 
			
		||||
	close(mc.registered)
 | 
			
		||||
	return &tunnelpogs.ConnectionDetails{
 | 
			
		||||
	return &pogs.ConnectionDetails{
 | 
			
		||||
		Location:                "LIS",
 | 
			
		||||
		UUID:                    uuid.New(),
 | 
			
		||||
		TunnelIsRemotelyManaged: false,
 | 
			
		||||
| 
						 | 
				
			
			@ -203,8 +204,8 @@ type mockRPCClientFactory struct {
 | 
			
		|||
	unregistered chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient {
 | 
			
		||||
	return mockNamedTunnelRPCClient{
 | 
			
		||||
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient {
 | 
			
		||||
	return &mockNamedTunnelRPCClient{
 | 
			
		||||
		shouldFail:   mf.shouldFail,
 | 
			
		||||
		registered:   mf.registered,
 | 
			
		||||
		unregistered: mf.unregistered,
 | 
			
		||||
| 
						 | 
				
			
			@ -360,10 +361,11 @@ func TestServeControlStream(t *testing.T) {
 | 
			
		|||
	controlStream := NewControlStream(
 | 
			
		||||
		obs,
 | 
			
		||||
		mockConnectedFuse{},
 | 
			
		||||
		&NamedTunnelProperties{},
 | 
			
		||||
		&TunnelProperties{},
 | 
			
		||||
		1,
 | 
			
		||||
		nil,
 | 
			
		||||
		rpcClientFactory.newMockRPCClient,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		nil,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		HTTP2,
 | 
			
		||||
| 
						 | 
				
			
			@ -412,10 +414,11 @@ func TestFailRegistration(t *testing.T) {
 | 
			
		|||
	controlStream := NewControlStream(
 | 
			
		||||
		obs,
 | 
			
		||||
		mockConnectedFuse{},
 | 
			
		||||
		&NamedTunnelProperties{},
 | 
			
		||||
		&TunnelProperties{},
 | 
			
		||||
		http2Conn.connIndex,
 | 
			
		||||
		nil,
 | 
			
		||||
		rpcClientFactory.newMockRPCClient,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		nil,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		HTTP2,
 | 
			
		||||
| 
						 | 
				
			
			@ -460,10 +463,11 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
 | 
			
		|||
	controlStream := NewControlStream(
 | 
			
		||||
		obs,
 | 
			
		||||
		mockConnectedFuse{},
 | 
			
		||||
		&NamedTunnelProperties{},
 | 
			
		||||
		&TunnelProperties{},
 | 
			
		||||
		http2Conn.connIndex,
 | 
			
		||||
		nil,
 | 
			
		||||
		rpcClientFactory.newMockRPCClient,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		shutdownC,
 | 
			
		||||
		1*time.Second,
 | 
			
		||||
		HTTP2,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,111 +0,0 @@
 | 
			
		|||
package connection
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/rs/zerolog"
 | 
			
		||||
	"zombiezen.com/go/capnproto2/rpc"
 | 
			
		||||
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NamedTunnelRPCClient interface {
 | 
			
		||||
	RegisterConnection(
 | 
			
		||||
		c context.Context,
 | 
			
		||||
		config *NamedTunnelProperties,
 | 
			
		||||
		options *tunnelpogs.ConnectionOptions,
 | 
			
		||||
		connIndex uint8,
 | 
			
		||||
		edgeAddress net.IP,
 | 
			
		||||
		observer *Observer,
 | 
			
		||||
	) (*tunnelpogs.ConnectionDetails, error)
 | 
			
		||||
	SendLocalConfiguration(
 | 
			
		||||
		c context.Context,
 | 
			
		||||
		config []byte,
 | 
			
		||||
		observer *Observer,
 | 
			
		||||
	) error
 | 
			
		||||
	GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
 | 
			
		||||
	Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type registrationServerClient struct {
 | 
			
		||||
	client    tunnelpogs.RegistrationServer_PogsClient
 | 
			
		||||
	transport rpc.Transport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newRegistrationRPCClient(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	stream io.ReadWriteCloser,
 | 
			
		||||
	log *zerolog.Logger,
 | 
			
		||||
) NamedTunnelRPCClient {
 | 
			
		||||
	transport := rpc.StreamTransport(stream)
 | 
			
		||||
	conn := rpc.NewConn(transport)
 | 
			
		||||
	return ®istrationServerClient{
 | 
			
		||||
		client:    tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
 | 
			
		||||
		transport: transport,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsc *registrationServerClient) RegisterConnection(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	properties *NamedTunnelProperties,
 | 
			
		||||
	options *tunnelpogs.ConnectionOptions,
 | 
			
		||||
	connIndex uint8,
 | 
			
		||||
	edgeAddress net.IP,
 | 
			
		||||
	observer *Observer,
 | 
			
		||||
) (*tunnelpogs.ConnectionDetails, error) {
 | 
			
		||||
	conn, err := rsc.client.RegisterConnection(
 | 
			
		||||
		ctx,
 | 
			
		||||
		properties.Credentials.Auth(),
 | 
			
		||||
		properties.Credentials.TunnelID,
 | 
			
		||||
		connIndex,
 | 
			
		||||
		options,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err.Error() == DuplicateConnectionError {
 | 
			
		||||
			observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
 | 
			
		||||
			return nil, errDuplicationConnection
 | 
			
		||||
		}
 | 
			
		||||
		observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
 | 
			
		||||
		return nil, serverRegistrationErrorFromRPC(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
 | 
			
		||||
 | 
			
		||||
	return conn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsc *registrationServerClient) SendLocalConfiguration(ctx context.Context, config []byte, observer *Observer) (err error) {
 | 
			
		||||
	observer.metrics.localConfigMetrics.pushes.Inc()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			observer.metrics.localConfigMetrics.pushesErrors.Inc()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return rsc.client.SendLocalConfiguration(ctx, config)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
 | 
			
		||||
	ctx, cancel := context.WithTimeout(ctx, gracePeriod)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	_ = rsc.client.UnregisterConnection(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsc *registrationServerClient) Close() {
 | 
			
		||||
	// Closing the client will also close the connection
 | 
			
		||||
	_ = rsc.client.Close()
 | 
			
		||||
	// Closing the transport also closes the stream
 | 
			
		||||
	_ = rsc.transport.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rpcName string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	register     rpcName = "register"
 | 
			
		||||
	reconnect    rpcName = "reconnect"
 | 
			
		||||
	unregister   rpcName = "unregister"
 | 
			
		||||
	authenticate rpcName = " authenticate"
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +58,7 @@ type TunnelConfig struct {
 | 
			
		|||
 | 
			
		||||
	NeedPQ bool
 | 
			
		||||
 | 
			
		||||
	NamedTunnel      *connection.NamedTunnelProperties
 | 
			
		||||
	NamedTunnel      *connection.TunnelProperties
 | 
			
		||||
	ProtocolSelector connection.ProtocolSelector
 | 
			
		||||
	EdgeTLSConfigs   map[connection.Protocol]*tls.Config
 | 
			
		||||
	PacketConfig     *ingress.GlobalRouterConfig
 | 
			
		||||
| 
						 | 
				
			
			@ -454,6 +454,7 @@ func (e *EdgeTunnelServer) serveConnection(
 | 
			
		|||
		connIndex,
 | 
			
		||||
		addr.UDP.IP,
 | 
			
		||||
		nil,
 | 
			
		||||
		e.config.RPCTimeout,
 | 
			
		||||
		e.gracefulShutdownC,
 | 
			
		||||
		e.config.GracePeriod,
 | 
			
		||||
		protocol,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,6 +105,13 @@ type RegistrationServer_PogsClient struct {
 | 
			
		|||
	Conn   *rpc.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRegistrationServer_PogsClient(client capnp.Client, conn *rpc.Conn) RegistrationServer_PogsClient {
 | 
			
		||||
	return RegistrationServer_PogsClient{
 | 
			
		||||
		Client: client,
 | 
			
		||||
		Conn:   conn,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c RegistrationServer_PogsClient) Close() error {
 | 
			
		||||
	c.Client.Close()
 | 
			
		||||
	return c.Conn.Close()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,76 @@
 | 
			
		|||
package tunnelrpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"zombiezen.com/go/capnproto2/rpc"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RegistrationClient interface {
 | 
			
		||||
	RegisterConnection(
 | 
			
		||||
		ctx context.Context,
 | 
			
		||||
		auth pogs.TunnelAuth,
 | 
			
		||||
		tunnelID uuid.UUID,
 | 
			
		||||
		options *pogs.ConnectionOptions,
 | 
			
		||||
		connIndex uint8,
 | 
			
		||||
		edgeAddress net.IP,
 | 
			
		||||
	) (*pogs.ConnectionDetails, error)
 | 
			
		||||
	SendLocalConfiguration(ctx context.Context, config []byte) error
 | 
			
		||||
	GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
 | 
			
		||||
	Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type registrationClient struct {
 | 
			
		||||
	client         pogs.RegistrationServer_PogsClient
 | 
			
		||||
	transport      rpc.Transport
 | 
			
		||||
	requestTimeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser, requestTimeout time.Duration) RegistrationClient {
 | 
			
		||||
	transport := SafeTransport(stream)
 | 
			
		||||
	conn := rpc.NewConn(transport)
 | 
			
		||||
	client := pogs.NewRegistrationServer_PogsClient(conn.Bootstrap(ctx), conn)
 | 
			
		||||
	return ®istrationClient{
 | 
			
		||||
		client:         client,
 | 
			
		||||
		transport:      transport,
 | 
			
		||||
		requestTimeout: requestTimeout,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *registrationClient) RegisterConnection(
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	auth pogs.TunnelAuth,
 | 
			
		||||
	tunnelID uuid.UUID,
 | 
			
		||||
	options *pogs.ConnectionOptions,
 | 
			
		||||
	connIndex uint8,
 | 
			
		||||
	edgeAddress net.IP,
 | 
			
		||||
) (*pogs.ConnectionDetails, error) {
 | 
			
		||||
	ctx, cancel := context.WithTimeout(ctx, r.requestTimeout)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	return r.client.RegisterConnection(ctx, auth, tunnelID, connIndex, options)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *registrationClient) SendLocalConfiguration(ctx context.Context, config []byte) error {
 | 
			
		||||
	ctx, cancel := context.WithTimeout(ctx, r.requestTimeout)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	return r.client.SendLocalConfiguration(ctx, config)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *registrationClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
 | 
			
		||||
	ctx, cancel := context.WithTimeout(ctx, gracePeriod)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
	_ = r.client.UnregisterConnection(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *registrationClient) Close() {
 | 
			
		||||
	// Closing the client will also close the connection
 | 
			
		||||
	_ = r.client.Close()
 | 
			
		||||
	// Closing the transport also closes the stream
 | 
			
		||||
	_ = r.transport.Close()
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,40 @@
 | 
			
		|||
package tunnelrpc
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"io"
 | 
			
		||||
 | 
			
		||||
	"zombiezen.com/go/capnproto2/rpc"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// RegistrationServer provides a handler interface for a client to provide methods to handle the different types of
 | 
			
		||||
// requests that can be communicated by the stream.
 | 
			
		||||
type RegistrationServer struct {
 | 
			
		||||
	registrationServer pogs.RegistrationServer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRegistrationServer(registrationServer pogs.RegistrationServer) *RegistrationServer {
 | 
			
		||||
	return &RegistrationServer{
 | 
			
		||||
		registrationServer: registrationServer,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Serve listens for all RegistrationServer RPCs, including UnregisterConnection until the underlying connection
 | 
			
		||||
// is terminated.
 | 
			
		||||
func (s *RegistrationServer) Serve(ctx context.Context, stream io.ReadWriteCloser) error {
 | 
			
		||||
	transport := SafeTransport(stream)
 | 
			
		||||
	defer transport.Close()
 | 
			
		||||
 | 
			
		||||
	main := pogs.RegistrationServer_ServerToClient(s.registrationServer)
 | 
			
		||||
	rpcConn := rpc.NewConn(transport, rpc.MainInterface(main.Client))
 | 
			
		||||
	defer rpcConn.Close()
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case <-rpcConn.Done():
 | 
			
		||||
		return rpcConn.Wait()
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		return ctx.Err()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue