TUN-3015: Add a new cap'n'proto RPC interface for connection registration as well as matching client and server implementations. The old interface extends the new one for backward compatibility.

This commit is contained in:
Igor Postelnik 2020-06-02 13:19:19 -05:00 committed by Adam Chalmers
parent dc3a228d51
commit 448a7798f7
8 changed files with 1743 additions and 128 deletions

1
go.mod
View File

@ -54,7 +54,6 @@ require (
github.com/prometheus/common v0.7.0 // indirect
github.com/prometheus/procfs v0.0.5 // indirect
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 // indirect
github.com/sirupsen/logrus v1.4.2 // indirect
github.com/stretchr/testify v1.3.0
github.com/tinylib/msgp v1.1.0 // indirect
github.com/xo/dburl v0.0.0-20191005012637-293c3298d6c0

View File

@ -0,0 +1,208 @@
package pogs
import (
"context"
"errors"
"net"
"time"
"github.com/google/uuid"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc"
)
type RegistrationServer interface {
RegisterConnection(ctx context.Context, auth []byte, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
UnregisterConnection(ctx context.Context)
}
type ClientInfo struct {
ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
Features []string
Version string
Arch string
}
type ConnectionOptions struct {
Client ClientInfo
OriginLocalIP net.IP `capnp:"originLocalIp"`
ReplaceExisting bool
CompressionQuality uint8
}
func (p *ConnectionOptions) MarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, s.Struct, p)
}
func (p *ConnectionOptions) UnmarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
return pogs.Extract(p, tunnelrpc.ConnectionOptions_TypeID, s.Struct)
}
type ConnectionDetails struct {
UUID uuid.UUID
Location string
}
func (details *ConnectionDetails) MarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
if err := s.SetUuid(details.UUID[:]); err != nil {
return err
}
if err := s.SetLocationName(details.Location); err != nil {
return err
}
return nil
}
func (details *ConnectionDetails) UnmarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
uuidBytes, err := s.Uuid()
if err != nil {
return err
}
details.UUID, err = uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
details.Location, err = s.LocationName()
if err != nil {
return err
}
return err
}
func MarshalError(s tunnelrpc.ConnectionError, err error) error {
if err := s.SetCause(err.Error()); err != nil {
return err
}
if retryableErr, ok := err.(*RetryableError); ok {
s.SetShouldRetry(true)
s.SetRetryAfter(int64(retryableErr.Delay))
}
return nil
}
func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
server.Ack(p.Options)
auth, err := p.Params.Auth()
if err != nil {
return err
}
uuidBytes, err := p.Params.TunnelId()
if err != nil {
return err
}
tunnelID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
connIndex := p.Params.ConnIndex()
options, err := p.Params.Options()
if err != nil {
return err
}
var pogsOptions ConnectionOptions
err = pogsOptions.UnmarshalCapnproto(options)
if err != nil {
return err
}
connDetails, callError := i.impl.RegisterConnection(p.Ctx, auth, tunnelID, connIndex, &pogsOptions)
resp, err := p.Results.NewResult()
if err != nil {
return err
}
if callError != nil {
if connError, err := resp.Result().NewError(); err != nil {
return err
} else {
return MarshalError(connError, callError)
}
}
if details, err := resp.Result().NewConnectionDetails(); err != nil {
return err
} else {
return connDetails.MarshalCapnproto(details)
}
}
func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
server.Ack(p.Options)
i.impl.UnregisterConnection(p.Ctx)
return nil
}
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth []byte, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
err := p.SetAuth(auth)
if err != nil {
return err
}
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
p.SetConnIndex(connIndex)
connectionOptions, err := p.NewOptions()
if err != nil {
return err
}
err = options.MarshalCapnproto(connectionOptions)
if err != nil {
return err
}
return nil
})
response, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
result := response.Result()
switch result.Which() {
case tunnelrpc.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, wrapRPCError(err)
}
cause, err := resultError.Cause()
if err != nil {
return nil, wrapRPCError(err)
}
err = errors.New(cause)
if resultError.ShouldRetry() {
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
}
return nil, err
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, wrapRPCError(err)
}
details := new(ConnectionDetails)
if err = details.UnmarshalCapnproto(connDetails); err != nil {
return nil, wrapRPCError(err)
}
return details, nil
}
return nil, newRPCError("unknown result which %d", result.Which())
}
func (c TunnelServer_PogsClient) Unregister(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil
})
_, err := promise.Struct()
return wrapRPCError(err)
}

View File

@ -0,0 +1,129 @@
package pogs
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/tunnelrpc"
)
func TestMarshalConnectionOptions(t *testing.T) {
clientID := uuid.New()
orig := ConnectionOptions{
Client: ClientInfo{
ClientID: clientID[:],
Features: []string{"a", "b"},
Version: "1.2.3",
Arch: "macos",
},
OriginLocalIP: []byte{10, 2, 3, 4},
ReplaceExisting: false,
CompressionQuality: 1,
}
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
require.NoError(t, err)
capnpOpts, err := tunnelrpc.NewConnectionOptions(seg)
require.NoError(t, err)
err = orig.MarshalCapnproto(capnpOpts)
assert.NoError(t, err)
var pogsOpts ConnectionOptions
err = pogsOpts.UnmarshalCapnproto(capnpOpts)
assert.NoError(t, err)
assert.Equal(t, orig, pogsOpts)
}
func TestConnectionRegistrationRPC(t *testing.T) {
p1, p2 := net.Pipe()
t1, t2 := rpc.StreamTransport(p1), rpc.StreamTransport(p2)
// Server-side
testImpl := testConnectionRegistrationServer{}
srv := TunnelServer_ServerToClient(&testImpl)
serverConn := rpc.NewConn(t1, rpc.MainInterface(srv.Client))
defer serverConn.Wait()
ctx := context.Background()
clientConn := rpc.NewConn(t2)
defer clientConn.Close()
client := TunnelServer_PogsClient{
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
}
defer client.Close()
clientID := uuid.New()
options := &ConnectionOptions{
Client: ClientInfo{
ClientID: clientID[:],
Features: []string{"foo"},
Version: "1.2.3",
Arch: "macos",
},
OriginLocalIP: net.IP{10, 20, 30, 40},
ReplaceExisting: true,
CompressionQuality: 0,
}
expectedDetails := ConnectionDetails{
UUID: uuid.New(),
Location: "TEST",
}
testImpl.details = &expectedDetails
testImpl.err = nil
// success
tunnelID := uuid.New()
details, err := client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
assert.NoError(t, err)
assert.Equal(t, expectedDetails, *details)
// regular error
testImpl.details = nil
testImpl.err = errors.New("internal")
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
assert.EqualError(t, err, "internal")
// retriable error
testImpl.details = nil
const delay = 27*time.Second
testImpl.err = RetryErrorAfter(errors.New("retryable"), delay)
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
assert.EqualError(t, err, "retryable")
re, ok := err.(*RetryableError)
assert.True(t, ok)
assert.Equal(t, delay, re.Delay)
}
type testConnectionRegistrationServer struct {
mockTunnelServerBase
details *ConnectionDetails
err error
}
func (t testConnectionRegistrationServer) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
if t.err != nil {
return nil, t.err
}
if t.details != nil {
return t.details, nil
}
panic("either details or err mush be set")
}

53
tunnelrpc/pogs/errors.go Normal file
View File

@ -0,0 +1,53 @@
package pogs
import (
"fmt"
"time"
)
type RetryableError struct {
err error
Delay time.Duration
}
func (re *RetryableError) Error() string {
return re.err.Error()
}
// RetryErrorAfter wraps err to indicate that client should retry after delay
func RetryErrorAfter(err error, delay time.Duration) *RetryableError {
return &RetryableError{
err: err,
Delay: delay,
}
}
func (re *RetryableError) Unwrap() error {
return re.err
}
// RPCError is used to indicate errors returned by the RPC subsystem rather
// than failure of a remote operation
type RPCError struct {
err error
}
func (re *RPCError) Error() string {
return re.err.Error()
}
func wrapRPCError(err error) *RPCError {
return &RPCError{
err: err,
}
}
func newRPCError(format string, args ...interface{}) *RPCError {
return &RPCError{
fmt.Errorf(format, args...),
}
}
func (re *RPCError) Unwrap() error {
return re.err
}

View File

@ -0,0 +1,41 @@
package pogs
import (
"context"
"github.com/google/uuid"
)
// mockTunnelServerBase provides a placeholder implementation
// for TunnelServer interface that can be used to build
// mocks for specific unit tests without having to implement every method
type mockTunnelServerBase struct{}
func (mockTunnelServerBase) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
panic("unexpected call to Register")
}
func (mockTunnelServerBase) Unregister(ctx context.Context) {
panic("unexpected call to Unregister")
}
func (mockTunnelServerBase) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
panic("unexpected call to RegisterTunnel")
}
func (mockTunnelServerBase) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
panic("unexpected call to GetServerInfo")
}
func (mockTunnelServerBase) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error {
panic("unexpected call to UnregisterTunnel")
}
func (mockTunnelServerBase) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
panic("unexpected call to Authenticate")
}
func (mockTunnelServerBase) ReconnectTunnel(ctx context.Context, jwt, eventDigest, connDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
panic("unexpected call to ReconnectTunnel")
}

View File

@ -201,6 +201,7 @@ func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
}
type TunnelServer interface {
RegistrationServer
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
GetServerInfo(ctx context.Context) (*ServerInfo, error)
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error

View File

@ -78,7 +78,55 @@ struct AuthenticateResponse {
hoursUntilRefresh @3 :UInt8;
}
interface TunnelServer {
struct ClientInfo {
# The tunnel client's unique identifier, used to verify a reconnection.
clientId @0 :Data;
# Set of features this cloudflared knows it supports
features @1 :List(Text);
# Information about the running binary.
version @2 :Text;
# Client OS and CPU info
arch @3 :Text;
}
struct ConnectionOptions {
# client details
client @0 :ClientInfo;
# origin LAN IP
originLocalIp @1 :Data;
# What to do if connection already exists
replaceExisting @2 :Bool;
# cross stream compression setting, 0 - off, 3 - high
compressionQuality @3 :UInt8;
}
struct ConnectionResponse {
result :union {
error @0 :ConnectionError;
connectionDetails @1 :ConnectionDetails;
}
}
struct ConnectionError {
cause @0 :Text;
# How long should this connection wait to retry in ns
retryAfter @1 :Int64;
shouldRetry @2 :Bool;
}
struct ConnectionDetails {
# identifier of this connection
uuid @0 :Data;
# airport code of the colo where this connection landed
locationName @1 :Text;
}
interface RegistrationServer {
registerConnection @0 (auth :Data, tunnelId :Data, connIndex :UInt8, options :ConnectionOptions) -> (result :ConnectionResponse);
unregisterConnection @1 () -> ();
}
interface TunnelServer extends (RegistrationServer) {
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
getServerInfo @1 () -> (result :ServerInfo);
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();

File diff suppressed because it is too large Load Diff