TUN-3427: Define a struct that only implements RegistrationServer in tunnelpogs

This commit is contained in:
cthuang 2020-09-28 10:10:30 +01:00
parent 8e8513e325
commit 2c9b7361b7
9 changed files with 242 additions and 201 deletions

View File

@ -2,48 +2,40 @@ package connection
import ( import (
"context" "context"
"fmt" "io"
"time"
rpc "zombiezen.com/go/capnproto2/rpc" rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// NewRPCClient creates and returns a new RPC client, which will communicate // NewTunnelRPCClient creates and returns a new RPC client, which will communicate
// using a stream on the given muxer // using a stream on the given muxer
func NewRPCClient( func NewTunnelRPCClient(
ctx context.Context, ctx context.Context,
muxer *h2mux.Muxer, stream io.ReadWriteCloser,
logger logger.Service, logger logger.Service,
openStreamTimeout time.Duration,
) (client tunnelpogs.TunnelServer_PogsClient, err error) { ) (client tunnelpogs.TunnelServer_PogsClient, err error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return
}
if !isRPCStreamResponse(stream.Headers) {
stream.Close()
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
return
}
conn := rpc.NewConn( conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger), tunnelrpc.ConnLog(logger),
) )
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn} registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
client = tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn}
return client, nil return client, nil
} }
func isRPCStreamResponse(headers []h2mux.Header) bool { func NewRegistrationRPCClient(
return len(headers) == 1 && ctx context.Context,
headers[0].Name == ":status" && stream io.ReadWriteCloser,
headers[0].Value == "200" logger logger.Service,
) (client tunnelpogs.RegistrationServer_PogsClient, err error) {
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger),
)
client = tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
return client, nil
} }

View File

@ -19,6 +19,7 @@ var (
ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol} ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol}
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol} ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol} ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
ErrNotRPCStream = MuxerProtocolError{"2004 not RPC stream", http2.ErrCodeProtocol}
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"} ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
ErrStreamRequestConnectionClosed = MuxerApplicationError{"3001 connection closed while opening stream"} ErrStreamRequestConnectionClosed = MuxerApplicationError{"3001 connection closed while opening stream"}

View File

@ -22,7 +22,7 @@ const (
defaultTimeout time.Duration = 5 * time.Second defaultTimeout time.Duration = 5 * time.Second
defaultRetries uint64 = 5 defaultRetries uint64 = 5
defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb
writeBufferInitialSize int = 16 * 1024 // 16KB writeBufferInitialSize int = 16 * 1024 // 16KB
SettingMuxerMagic http2.SettingID = 0x42db SettingMuxerMagic http2.SettingID = 0x42db
MuxerMagicOrigin uint32 = 0xa2e43c8b MuxerMagicOrigin uint32 = 0xa2e43c8b
@ -441,11 +441,17 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) { func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
stream := m.NewStream(RPCHeaders()) stream := m.NewStream(RPCHeaders())
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil { if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
stream.Close()
return nil, err return nil, err
} }
if err := m.AwaitResponseHeaders(ctx, stream); err != nil { if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
stream.Close()
return nil, err return nil, err
} }
if !IsRPCStreamResponse(stream) {
stream.Close()
return nil, ErrNotRPCStream
}
return stream, nil return stream, nil
} }
@ -499,3 +505,10 @@ func (m *Muxer) abort() {
func (m *Muxer) TimerRetries() uint64 { func (m *Muxer) TimerRetries() uint64 {
return m.muxWriter.idleTimer.RetryCount() return m.muxWriter.idleTimer.RetryCount()
} }
func IsRPCStreamResponse(stream *MuxedStream) bool {
headers := stream.Headers
return len(headers) == 1 &&
headers[0].Name == ":status" &&
headers[0].Value == "200"
}

View File

@ -7,7 +7,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
@ -164,18 +163,17 @@ func ReconnectTunnel(
} }
config.TransportLogger.Debug("initiating RPC stream to reconnect") config.TransportLogger.Debug("initiating RPC stream to reconnect")
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
if err != nil { if err != nil {
// RPC stream open error return err
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
} }
defer tunnelServer.Close() defer rpcClient.Close()
// Request server info without blocking tunnel registration; must use capnp library directly. // Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil return nil
}) })
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan) LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
registration := tunnelServer.ReconnectTunnel( registration := rpcClient.ReconnectTunnel(
ctx, ctx,
token, token,
eventDigest, eventDigest,

View File

@ -323,16 +323,16 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
<-muxer.Shutdown() <-muxer.Shutdown()
}() }()
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger, openStreamTimeout) rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tunnelServer.Close() defer rpcClient.Close()
const arbitraryConnectionID = uint8(0) const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
authResponse, err := tunnelServer.Authenticate( authResponse, err := rpcClient.Authenticate(
ctx, ctx,
s.config.OriginCert, s.config.OriginCert,
s.config.Hostname, s.config.Hostname,

View File

@ -44,11 +44,13 @@ const (
FeatureQuickReconnects = "quick_reconnects" FeatureQuickReconnects = "quick_reconnects"
) )
type registerRPCName string type rpcName string
const ( const (
register registerRPCName = "register" register rpcName = "register"
reconnect registerRPCName = "reconnect" reconnect rpcName = "reconnect"
unregister rpcName = "unregister"
authenticate rpcName = " authenticate"
) )
type TunnelConfig struct { type TunnelConfig struct {
@ -121,7 +123,7 @@ type clientRegisterTunnelError struct {
cause error cause error
} }
func newClientRegisterTunnelError(cause error, counter *prometheus.CounterVec, name registerRPCName) clientRegisterTunnelError { func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
counter.WithLabelValues(cause.Error(), string(name)).Inc() counter.WithLabelValues(cause.Error(), string(name)).Inc()
return clientRegisterTunnelError{cause: cause} return clientRegisterTunnelError{cause: cause}
} }
@ -337,7 +339,7 @@ func ServeTunnel(
if config.NamedTunnel != nil { if config.NamedTunnel != nil {
_ = UnregisterConnection(ctx, handler.muxer, config) _ = UnregisterConnection(ctx, handler.muxer, config)
} else { } else {
_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger) _ = UnregisterTunnel(handler.muxer, config)
} }
} }
handler.muxer.Shutdown() handler.muxer.Shutdown()
@ -417,14 +419,13 @@ func RegisterConnection(
const registerConnection = "registerConnection" const registerConnection = "registerConnection"
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection") config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection)
if err != nil { if err != nil {
// RPC stream open error return err
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
} }
defer rpc.Close() defer rpcClient.Close()
conn, err := rpc.RegisterConnection( conn, err := rpcClient.RegisterConnection(
ctx, ctx,
config.NamedTunnel.Auth, config.NamedTunnel.Auth,
config.NamedTunnel.ID, config.NamedTunnel.ID,
@ -470,14 +471,14 @@ func UnregisterConnection(
config *TunnelConfig, config *TunnelConfig,
) error { ) error {
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection") config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register) return err
} }
defer rpc.Close() defer rpcClient.Close()
return rpc.UnregisterConnection(ctx) return rpcClient.UnregisterConnection(ctx)
} }
func RegisterTunnel( func RegisterTunnel(
@ -494,18 +495,18 @@ func RegisterTunnel(
if config.TunnelEventChan != nil { if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel} config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
} }
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
if err != nil { if err != nil {
// RPC stream open error return err
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
} }
defer tunnelServer.Close() defer rpcClient.Close()
// Request server info without blocking tunnel registration; must use capnp library directly. // Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil return nil
}) })
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan) LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
registration := tunnelServer.RegisterTunnel( registration := rpcClient.RegisterTunnel(
ctx, ctx,
config.OriginCert, config.OriginCert,
config.Hostname, config.Hostname,
@ -529,7 +530,7 @@ func processRegistrationSuccess(
logger logger.Service, logger logger.Service,
connectionID uint8, connectionID uint8,
registration *tunnelpogs.TunnelRegistration, registration *tunnelpogs.TunnelRegistration,
name registerRPCName, name rpcName,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
) error { ) error {
for _, logLine := range registration.LogLines { for _, logLine := range registration.LogLines {
@ -563,7 +564,7 @@ func processRegistrationSuccess(
return nil return nil
} }
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error { func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error {
if err.Error() == DuplicateConnectionError { if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc() metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
return errDuplicationConnection return errDuplicationConnection
@ -575,18 +576,18 @@ func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics
} }
} }
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger logger.Service) error { func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error {
logger.Debug("initiating RPC stream to unregister") config.TransportLogger.Debug("initiating RPC stream to unregister")
ctx := context.Background() ctx := context.Background()
tunnelServer, err := connection.NewRPCClient(ctx, muxer, logger, openStreamTimeout) rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
return err return err
} }
defer tunnelServer.Close() defer rpcClient.Close()
// gracePeriod is encoded in int64 using capnproto // gracePeriod is encoded in int64 using capnproto
return tunnelServer.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds())
} }
func LogServerInfo( func LogServerInfo(
@ -909,3 +910,18 @@ func findCfRayHeader(h1 *http.Request) string {
func isLBProbeRequest(req *http.Request) bool { func isLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
} }
func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return tunnelpogs.TunnelServer_PogsClient{}, err
}
rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger)
if err != nil {
// RPC stream open error
return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName)
}
return rpcClient, nil
}

View File

@ -7,7 +7,9 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs" "zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server" "zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
@ -18,6 +20,156 @@ type RegistrationServer interface {
UnregisterConnection(ctx context.Context) UnregisterConnection(ctx context.Context)
} }
type RegistrationServer_PogsImpl struct {
impl RegistrationServer
}
func RegistrationServer_ServerToClient(s RegistrationServer) tunnelrpc.RegistrationServer {
return tunnelrpc.RegistrationServer_ServerToClient(RegistrationServer_PogsImpl{s})
}
func (i RegistrationServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
server.Ack(p.Options)
auth, err := p.Params.Auth()
if err != nil {
return err
}
var pogsAuth TunnelAuth
err = pogsAuth.UnmarshalCapnproto(auth)
if err != nil {
return err
}
uuidBytes, err := p.Params.TunnelId()
if err != nil {
return err
}
tunnelID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
connIndex := p.Params.ConnIndex()
options, err := p.Params.Options()
if err != nil {
return err
}
var pogsOptions ConnectionOptions
err = pogsOptions.UnmarshalCapnproto(options)
if err != nil {
return err
}
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
resp, err := p.Results.NewResult()
if err != nil {
return err
}
if callError != nil {
if connError, err := resp.Result().NewError(); err != nil {
return err
} else {
return MarshalError(connError, callError)
}
}
if details, err := resp.Result().NewConnectionDetails(); err != nil {
return err
} else {
return connDetails.MarshalCapnproto(details)
}
}
func (i RegistrationServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
server.Ack(p.Options)
i.impl.UnregisterConnection(p.Ctx)
return nil
}
type RegistrationServer_PogsClient struct {
Client capnp.Client
Conn *rpc.Conn
}
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
tunnelAuth, err := p.NewAuth()
if err != nil {
return err
}
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
return err
}
err = p.SetAuth(tunnelAuth)
if err != nil {
return err
}
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
p.SetConnIndex(connIndex)
connectionOptions, err := p.NewOptions()
if err != nil {
return err
}
err = options.MarshalCapnproto(connectionOptions)
if err != nil {
return err
}
return nil
})
response, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
result := response.Result()
switch result.Which() {
case tunnelrpc.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, wrapRPCError(err)
}
cause, err := resultError.Cause()
if err != nil {
return nil, wrapRPCError(err)
}
err = errors.New(cause)
if resultError.ShouldRetry() {
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
}
return nil, err
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, wrapRPCError(err)
}
details := new(ConnectionDetails)
if err = details.UnmarshalCapnproto(connDetails); err != nil {
return nil, wrapRPCError(err)
}
return details, nil
}
return nil, newRPCError("unknown result which %d", result.Which())
}
func (c RegistrationServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil
})
_, err := promise.Struct()
if err != nil {
return wrapRPCError(err)
}
return nil
}
type ClientInfo struct { type ClientInfo struct {
ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
Features []string Features []string
@ -98,140 +250,3 @@ func MarshalError(s tunnelrpc.ConnectionError, err error) error {
return nil return nil
} }
func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
server.Ack(p.Options)
auth, err := p.Params.Auth()
if err != nil {
return err
}
var pogsAuth TunnelAuth
err = pogsAuth.UnmarshalCapnproto(auth)
if err != nil {
return err
}
uuidBytes, err := p.Params.TunnelId()
if err != nil {
return err
}
tunnelID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
connIndex := p.Params.ConnIndex()
options, err := p.Params.Options()
if err != nil {
return err
}
var pogsOptions ConnectionOptions
err = pogsOptions.UnmarshalCapnproto(options)
if err != nil {
return err
}
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
resp, err := p.Results.NewResult()
if err != nil {
return err
}
if callError != nil {
if connError, err := resp.Result().NewError(); err != nil {
return err
} else {
return MarshalError(connError, callError)
}
}
if details, err := resp.Result().NewConnectionDetails(); err != nil {
return err
} else {
return connDetails.MarshalCapnproto(details)
}
}
func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
server.Ack(p.Options)
i.impl.UnregisterConnection(p.Ctx)
return nil
}
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
tunnelAuth, err := p.NewAuth()
if err != nil {
return err
}
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
return err
}
err = p.SetAuth(tunnelAuth)
if err != nil {
return err
}
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
p.SetConnIndex(connIndex)
connectionOptions, err := p.NewOptions()
if err != nil {
return err
}
err = options.MarshalCapnproto(connectionOptions)
if err != nil {
return err
}
return nil
})
response, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
result := response.Result()
switch result.Which() {
case tunnelrpc.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, wrapRPCError(err)
}
cause, err := resultError.Cause()
if err != nil {
return nil, wrapRPCError(err)
}
err = errors.New(cause)
if resultError.ShouldRetry() {
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
}
return nil, err
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, wrapRPCError(err)
}
details := new(ConnectionDetails)
if err = details.UnmarshalCapnproto(connDetails); err != nil {
return nil, wrapRPCError(err)
}
return details, nil
}
return nil, newRPCError("unknown result which %d", result.Which())
}
func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil
})
_, err := promise.Struct()
if err != nil {
return wrapRPCError(err)
}
return nil
}

View File

@ -62,6 +62,10 @@ func TestConnectionRegistrationRPC(t *testing.T) {
clientConn := rpc.NewConn(t2) clientConn := rpc.NewConn(t2)
defer clientConn.Close() defer clientConn.Close()
client := TunnelServer_PogsClient{ client := TunnelServer_PogsClient{
RegistrationServer_PogsClient: RegistrationServer_PogsClient{
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
},
Client: clientConn.Bootstrap(ctx), Client: clientConn.Bootstrap(ctx),
Conn: clientConn, Conn: clientConn,
} }

View File

@ -210,10 +210,11 @@ type TunnelServer interface {
} }
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer { func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s}) return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{RegistrationServer_PogsImpl{s}, s})
} }
type TunnelServer_PogsImpl struct { type TunnelServer_PogsImpl struct {
RegistrationServer_PogsImpl
impl TunnelServer impl TunnelServer
} }
@ -268,6 +269,7 @@ func (i TunnelServer_PogsImpl) ObsoleteDeclarativeTunnelConnect(p tunnelrpc.Tunn
} }
type TunnelServer_PogsClient struct { type TunnelServer_PogsClient struct {
RegistrationServer_PogsClient
Client capnp.Client Client capnp.Client
Conn *rpc.Conn Conn *rpc.Conn
} }