TUN-2607: add RPC stream helpers
This commit is contained in:
parent
8f4fd70783
commit
bbf31377c2
|
@ -4,15 +4,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -41,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect is used to establish connections with cloudflare's edge network
|
// Connect is used to establish connections with cloudflare's edge network
|
||||||
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) {
|
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
|
||||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
rpcConn, err := c.newRPConn(openStreamCtx, logger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
||||||
}
|
}
|
||||||
defer rpcConn.Close()
|
defer tsClient.Close()
|
||||||
|
|
||||||
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
|
|
||||||
|
|
||||||
return tsClient.Connect(ctx, parameters)
|
return tsClient.Connect(ctx, parameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) Shutdown() {
|
func (c *Connection) Shutdown() {
|
||||||
c.muxer.Shutdown()
|
c.muxer.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
|
|
||||||
stream, err := c.muxer.OpenRPCStream(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rpc.NewConn(
|
|
||||||
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
|
|
||||||
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
|
|
||||||
), nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,15 +7,15 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -58,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics {
|
||||||
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
||||||
type EdgeManagerConfigurable struct {
|
type EdgeManagerConfigurable struct {
|
||||||
TunnelHostnames []h2mux.TunnelHostname
|
TunnelHostnames []h2mux.TunnelHostname
|
||||||
*pogs.EdgeConnectionConfig
|
*tunnelpogs.EdgeConnectionConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloudflaredConfig struct {
|
type CloudflaredConfig struct {
|
||||||
CloudflaredID uuid.UUID
|
CloudflaredID uuid.UUID
|
||||||
Tags []pogs.Tag
|
Tags []tunnelpogs.Tag
|
||||||
BuildInfo *buildinfo.BuildInfo
|
BuildInfo *buildinfo.BuildInfo
|
||||||
IntentLabel string
|
IntentLabel string
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
|
||||||
em.state.updateConfigurable(newConfigurable)
|
em.state.updateConfigurable(newConfigurable)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
|
func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
|
||||||
edgeTCPAddr := em.serviceDiscoverer.Addr()
|
edgeTCPAddr := em.serviceDiscoverer.Addr()
|
||||||
configurable := em.state.getConfigurable()
|
configurable := em.state.getConfigurable()
|
||||||
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
|
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
|
||||||
|
@ -154,7 +154,7 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
|
||||||
|
|
||||||
go em.serveConn(ctx, h2muxConn)
|
go em.serveConn(ctx, h2muxConn)
|
||||||
|
|
||||||
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
|
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
|
||||||
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
||||||
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
||||||
NumPreviousAttempts: 0,
|
NumPreviousAttempts: 0,
|
||||||
|
@ -285,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte {
|
||||||
return ems.userCredential
|
return ems.userCredential
|
||||||
}
|
}
|
||||||
|
|
||||||
func retryConnection(cause string) *pogs.ConnectError {
|
func retryConnection(cause string) *tunnelpogs.ConnectError {
|
||||||
return &pogs.ConnectError{
|
return &tunnelpogs.ConnectError{
|
||||||
Cause: cause,
|
Cause: cause,
|
||||||
RetryAfter: defaultRetryAfter,
|
RetryAfter: defaultRetryAfter,
|
||||||
ShouldRetry: true,
|
ShouldRetry: true,
|
||||||
|
|
|
@ -4,15 +4,15 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/google/uuid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRPCClient creates and returns a new RPC client, which will communicate
|
||||||
|
// using a stream on the given muxer
|
||||||
|
func NewRPCClient(
|
||||||
|
ctx context.Context,
|
||||||
|
muxer *h2mux.Muxer,
|
||||||
|
logger *logrus.Entry,
|
||||||
|
openStreamTimeout time.Duration,
|
||||||
|
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
|
||||||
|
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||||
|
defer openStreamCancel()
|
||||||
|
stream, err := muxer.OpenRPCStream(openStreamCtx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isRPCStreamResponse(stream.Headers) {
|
||||||
|
stream.Close()
|
||||||
|
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := rpc.NewConn(
|
||||||
|
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
||||||
|
tunnelrpc.ConnLog(logger),
|
||||||
|
)
|
||||||
|
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRPCStreamResponse(headers []h2mux.Header) bool {
|
||||||
|
return len(headers) == 1 &&
|
||||||
|
headers[0].Name == ":status" &&
|
||||||
|
headers[0].Value == "200"
|
||||||
|
}
|
|
@ -6,12 +6,11 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -14,23 +14,21 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
"github.com/cloudflare/cloudflared/streamhandler"
|
"github.com/cloudflare/cloudflared/streamhandler"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/validation"
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -288,16 +286,6 @@ func ServeTunnel(
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
|
||||||
if len(headers) != 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if headers[0].Name != ":status" || headers[0].Value != "200" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterTunnel(
|
func RegisterTunnel(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
muxer *h2mux.Muxer,
|
muxer *h2mux.Muxer,
|
||||||
|
@ -308,28 +296,18 @@ func RegisterTunnel(
|
||||||
uuid uuid.UUID,
|
uuid uuid.UUID,
|
||||||
) error {
|
) error {
|
||||||
config.TransportLogger.Debug("initiating RPC stream to register")
|
config.TransportLogger.Debug("initiating RPC stream to register")
|
||||||
stream, err := openStream(ctx, muxer)
|
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RPC stream open error
|
// RPC stream open error
|
||||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
||||||
}
|
}
|
||||||
if !IsRPCStreamResponse(stream.Headers) {
|
defer tunnelServer.Close()
|
||||||
// stream response error
|
|
||||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
|
|
||||||
}
|
|
||||||
conn := rpc.NewConn(
|
|
||||||
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
|
|
||||||
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
|
|
||||||
)
|
|
||||||
defer conn.Close()
|
|
||||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
|
||||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||||
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
|
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||||
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
||||||
registration := ts.RegisterTunnel(
|
registration := tunnelServer.RegisterTunnel(
|
||||||
ctx,
|
ctx,
|
||||||
config.OriginCert,
|
config.OriginCert,
|
||||||
config.Hostname,
|
config.Hostname,
|
||||||
|
@ -369,7 +347,7 @@ func RegisterTunnel(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
|
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
|
||||||
if err.Error() == DuplicateConnectionError {
|
if err.Error() == DuplicateConnectionError {
|
||||||
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
|
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
|
||||||
return dupConnRegisterTunnelError{}
|
return dupConnRegisterTunnelError{}
|
||||||
|
@ -384,35 +362,15 @@ func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *Tunne
|
||||||
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
|
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
|
||||||
logger.Debug("initiating RPC stream to unregister")
|
logger.Debug("initiating RPC stream to unregister")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
stream, err := openStream(ctx, muxer)
|
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// RPC stream open error
|
// RPC stream open error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !IsRPCStreamResponse(stream.Headers) {
|
|
||||||
// stream response error
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn := rpc.NewConn(
|
|
||||||
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
|
|
||||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
|
||||||
)
|
|
||||||
defer conn.Close()
|
|
||||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
|
||||||
// gracePeriod is encoded in int64 using capnproto
|
// gracePeriod is encoded in int64 using capnproto
|
||||||
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
||||||
}
|
}
|
||||||
|
|
||||||
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
|
|
||||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
|
||||||
defer cancel()
|
|
||||||
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
|
|
||||||
{Name: ":method", Value: "RPC"},
|
|
||||||
{Name: ":scheme", Value: "capnp"},
|
|
||||||
{Name: ":path", Value: "*"},
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LogServerInfo(
|
func LogServerInfo(
|
||||||
promise tunnelrpc.ServerInfo_Promise,
|
promise tunnelrpc.ServerInfo_Promise,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
|
|
Loading…
Reference in New Issue