TUN-3015: Add a new cap'n'proto RPC interface for connection registration as well as matching client and server implementations. The old interface extends the new one for backward compatibility.
This commit is contained in:
parent
dc3a228d51
commit
448a7798f7
1
go.mod
1
go.mod
|
@ -54,7 +54,6 @@ require (
|
|||
github.com/prometheus/common v0.7.0 // indirect
|
||||
github.com/prometheus/procfs v0.0.5 // indirect
|
||||
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 // indirect
|
||||
github.com/sirupsen/logrus v1.4.2 // indirect
|
||||
github.com/stretchr/testify v1.3.0
|
||||
github.com/tinylib/msgp v1.1.0 // indirect
|
||||
github.com/xo/dburl v0.0.0-20191005012637-293c3298d6c0
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
)
|
||||
|
||||
type RegistrationServer interface {
|
||||
RegisterConnection(ctx context.Context, auth []byte, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
|
||||
UnregisterConnection(ctx context.Context)
|
||||
}
|
||||
|
||||
type ClientInfo struct {
|
||||
ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
|
||||
Features []string
|
||||
Version string
|
||||
Arch string
|
||||
}
|
||||
|
||||
type ConnectionOptions struct {
|
||||
Client ClientInfo
|
||||
OriginLocalIP net.IP `capnp:"originLocalIp"`
|
||||
ReplaceExisting bool
|
||||
CompressionQuality uint8
|
||||
}
|
||||
|
||||
func (p *ConnectionOptions) MarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
|
||||
return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func (p *ConnectionOptions) UnmarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
|
||||
return pogs.Extract(p, tunnelrpc.ConnectionOptions_TypeID, s.Struct)
|
||||
}
|
||||
|
||||
type ConnectionDetails struct {
|
||||
UUID uuid.UUID
|
||||
Location string
|
||||
}
|
||||
|
||||
func (details *ConnectionDetails) MarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
|
||||
if err := s.SetUuid(details.UUID[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.SetLocationName(details.Location); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (details *ConnectionDetails) UnmarshalCapnproto(s tunnelrpc.ConnectionDetails) error {
|
||||
uuidBytes, err := s.Uuid()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
details.UUID, err = uuid.FromBytes(uuidBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
details.Location, err = s.LocationName()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func MarshalError(s tunnelrpc.ConnectionError, err error) error {
|
||||
if err := s.SetCause(err.Error()); err != nil {
|
||||
return err
|
||||
}
|
||||
if retryableErr, ok := err.(*RetryableError); ok {
|
||||
s.SetShouldRetry(true)
|
||||
s.SetRetryAfter(int64(retryableErr.Delay))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
|
||||
server.Ack(p.Options)
|
||||
|
||||
auth, err := p.Params.Auth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uuidBytes, err := p.Params.TunnelId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnelID, err := uuid.FromBytes(uuidBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
connIndex := p.Params.ConnIndex()
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pogsOptions ConnectionOptions
|
||||
err = pogsOptions.UnmarshalCapnproto(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connDetails, callError := i.impl.RegisterConnection(p.Ctx, auth, tunnelID, connIndex, &pogsOptions)
|
||||
|
||||
resp, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if callError != nil {
|
||||
if connError, err := resp.Result().NewError(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return MarshalError(connError, callError)
|
||||
}
|
||||
}
|
||||
|
||||
if details, err := resp.Result().NewConnectionDetails(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return connDetails.MarshalCapnproto(details)
|
||||
}
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
|
||||
server.Ack(p.Options)
|
||||
|
||||
i.impl.UnregisterConnection(p.Ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth []byte, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
||||
err := p.SetAuth(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetTunnelId(tunnelID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetConnIndex(connIndex)
|
||||
connectionOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = options.MarshalCapnproto(connectionOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
response, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
result := response.Result()
|
||||
switch result.Which() {
|
||||
case tunnelrpc.ConnectionResponse_result_Which_error:
|
||||
resultError, err := result.Error()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
cause, err := resultError.Cause()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
err = errors.New(cause)
|
||||
if resultError.ShouldRetry() {
|
||||
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
|
||||
}
|
||||
return nil, err
|
||||
|
||||
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
|
||||
connDetails, err := result.ConnectionDetails()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
details := new(ConnectionDetails)
|
||||
if err = details.UnmarshalCapnproto(connDetails); err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
return details, nil
|
||||
}
|
||||
|
||||
return nil, newRPCError("unknown result which %d", result.Which())
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) Unregister(ctx context.Context) error {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
|
||||
return nil
|
||||
})
|
||||
_, err := promise.Struct()
|
||||
return wrapRPCError(err)
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
)
|
||||
|
||||
func TestMarshalConnectionOptions(t *testing.T) {
|
||||
clientID := uuid.New()
|
||||
orig := ConnectionOptions{
|
||||
Client: ClientInfo{
|
||||
ClientID: clientID[:],
|
||||
Features: []string{"a", "b"},
|
||||
Version: "1.2.3",
|
||||
Arch: "macos",
|
||||
},
|
||||
OriginLocalIP: []byte{10, 2, 3, 4},
|
||||
ReplaceExisting: false,
|
||||
CompressionQuality: 1,
|
||||
}
|
||||
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
require.NoError(t, err)
|
||||
capnpOpts, err := tunnelrpc.NewConnectionOptions(seg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = orig.MarshalCapnproto(capnpOpts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var pogsOpts ConnectionOptions
|
||||
err = pogsOpts.UnmarshalCapnproto(capnpOpts)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, orig, pogsOpts)
|
||||
}
|
||||
|
||||
func TestConnectionRegistrationRPC(t *testing.T) {
|
||||
p1, p2 := net.Pipe()
|
||||
t1, t2 := rpc.StreamTransport(p1), rpc.StreamTransport(p2)
|
||||
|
||||
// Server-side
|
||||
testImpl := testConnectionRegistrationServer{}
|
||||
srv := TunnelServer_ServerToClient(&testImpl)
|
||||
serverConn := rpc.NewConn(t1, rpc.MainInterface(srv.Client))
|
||||
defer serverConn.Wait()
|
||||
|
||||
ctx := context.Background()
|
||||
clientConn := rpc.NewConn(t2)
|
||||
defer clientConn.Close()
|
||||
client := TunnelServer_PogsClient{
|
||||
Client: clientConn.Bootstrap(ctx),
|
||||
Conn: clientConn,
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
clientID := uuid.New()
|
||||
options := &ConnectionOptions{
|
||||
Client: ClientInfo{
|
||||
ClientID: clientID[:],
|
||||
Features: []string{"foo"},
|
||||
Version: "1.2.3",
|
||||
Arch: "macos",
|
||||
},
|
||||
OriginLocalIP: net.IP{10, 20, 30, 40},
|
||||
ReplaceExisting: true,
|
||||
CompressionQuality: 0,
|
||||
}
|
||||
|
||||
expectedDetails := ConnectionDetails{
|
||||
UUID: uuid.New(),
|
||||
Location: "TEST",
|
||||
}
|
||||
testImpl.details = &expectedDetails
|
||||
testImpl.err = nil
|
||||
|
||||
// success
|
||||
tunnelID := uuid.New()
|
||||
details, err := client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedDetails, *details)
|
||||
|
||||
// regular error
|
||||
testImpl.details = nil
|
||||
testImpl.err = errors.New("internal")
|
||||
|
||||
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
||||
assert.EqualError(t, err, "internal")
|
||||
|
||||
// retriable error
|
||||
testImpl.details = nil
|
||||
const delay = 27*time.Second
|
||||
testImpl.err = RetryErrorAfter(errors.New("retryable"), delay)
|
||||
|
||||
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
||||
assert.EqualError(t, err, "retryable")
|
||||
|
||||
re, ok := err.(*RetryableError)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, delay, re.Delay)
|
||||
}
|
||||
|
||||
type testConnectionRegistrationServer struct {
|
||||
mockTunnelServerBase
|
||||
|
||||
details *ConnectionDetails
|
||||
err error
|
||||
}
|
||||
|
||||
func (t testConnectionRegistrationServer) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
if t.err != nil {
|
||||
return nil, t.err
|
||||
}
|
||||
if t.details != nil {
|
||||
return t.details, nil
|
||||
}
|
||||
|
||||
panic("either details or err mush be set")
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RetryableError struct {
|
||||
err error
|
||||
Delay time.Duration
|
||||
}
|
||||
|
||||
func (re *RetryableError) Error() string {
|
||||
return re.err.Error()
|
||||
}
|
||||
|
||||
// RetryErrorAfter wraps err to indicate that client should retry after delay
|
||||
func RetryErrorAfter(err error, delay time.Duration) *RetryableError {
|
||||
return &RetryableError{
|
||||
err: err,
|
||||
Delay: delay,
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RetryableError) Unwrap() error {
|
||||
return re.err
|
||||
}
|
||||
|
||||
// RPCError is used to indicate errors returned by the RPC subsystem rather
|
||||
// than failure of a remote operation
|
||||
type RPCError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (re *RPCError) Error() string {
|
||||
return re.err.Error()
|
||||
}
|
||||
|
||||
func wrapRPCError(err error) *RPCError {
|
||||
return &RPCError{
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func newRPCError(format string, args ...interface{}) *RPCError {
|
||||
return &RPCError{
|
||||
fmt.Errorf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RPCError) Unwrap() error {
|
||||
return re.err
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// mockTunnelServerBase provides a placeholder implementation
|
||||
// for TunnelServer interface that can be used to build
|
||||
// mocks for specific unit tests without having to implement every method
|
||||
type mockTunnelServerBase struct{}
|
||||
|
||||
func (mockTunnelServerBase) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
panic("unexpected call to Register")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) Unregister(ctx context.Context) {
|
||||
panic("unexpected call to Unregister")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
|
||||
panic("unexpected call to RegisterTunnel")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
|
||||
panic("unexpected call to GetServerInfo")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error {
|
||||
panic("unexpected call to UnregisterTunnel")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
|
||||
panic("unexpected call to Authenticate")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) ReconnectTunnel(ctx context.Context, jwt, eventDigest, connDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
|
||||
panic("unexpected call to ReconnectTunnel")
|
||||
}
|
||||
|
|
@ -201,6 +201,7 @@ func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
|
|||
}
|
||||
|
||||
type TunnelServer interface {
|
||||
RegistrationServer
|
||||
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
|
||||
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
|
||||
|
|
|
@ -78,7 +78,55 @@ struct AuthenticateResponse {
|
|||
hoursUntilRefresh @3 :UInt8;
|
||||
}
|
||||
|
||||
interface TunnelServer {
|
||||
struct ClientInfo {
|
||||
# The tunnel client's unique identifier, used to verify a reconnection.
|
||||
clientId @0 :Data;
|
||||
# Set of features this cloudflared knows it supports
|
||||
features @1 :List(Text);
|
||||
# Information about the running binary.
|
||||
version @2 :Text;
|
||||
# Client OS and CPU info
|
||||
arch @3 :Text;
|
||||
}
|
||||
|
||||
struct ConnectionOptions {
|
||||
# client details
|
||||
client @0 :ClientInfo;
|
||||
# origin LAN IP
|
||||
originLocalIp @1 :Data;
|
||||
# What to do if connection already exists
|
||||
replaceExisting @2 :Bool;
|
||||
# cross stream compression setting, 0 - off, 3 - high
|
||||
compressionQuality @3 :UInt8;
|
||||
}
|
||||
|
||||
struct ConnectionResponse {
|
||||
result :union {
|
||||
error @0 :ConnectionError;
|
||||
connectionDetails @1 :ConnectionDetails;
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionError {
|
||||
cause @0 :Text;
|
||||
# How long should this connection wait to retry in ns
|
||||
retryAfter @1 :Int64;
|
||||
shouldRetry @2 :Bool;
|
||||
}
|
||||
|
||||
struct ConnectionDetails {
|
||||
# identifier of this connection
|
||||
uuid @0 :Data;
|
||||
# airport code of the colo where this connection landed
|
||||
locationName @1 :Text;
|
||||
}
|
||||
|
||||
interface RegistrationServer {
|
||||
registerConnection @0 (auth :Data, tunnelId :Data, connIndex :UInt8, options :ConnectionOptions) -> (result :ConnectionResponse);
|
||||
unregisterConnection @1 () -> ();
|
||||
}
|
||||
|
||||
interface TunnelServer extends (RegistrationServer) {
|
||||
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||
getServerInfo @1 () -> (result :ServerInfo);
|
||||
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue