TUN-2607: add RPC stream helpers

This commit is contained in:
Nick Vollmar 2019-11-21 11:03:13 -06:00
parent 8f4fd70783
commit bbf31377c2
6 changed files with 83 additions and 97 deletions

View File

@ -4,15 +4,12 @@ import (
"context" "context"
"time" "time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc" "github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
const ( const (
@ -41,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
} }
// Connect is used to establish connections with cloudflare's edge network // Connect is used to establish connections with cloudflare's edge network
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) { func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout) tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
defer cancel()
rpcConn, err := c.newRPConn(openStreamCtx, logger)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot create new RPC connection") return nil, errors.Wrap(err, "cannot create new RPC connection")
} }
defer rpcConn.Close() defer tsClient.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
return tsClient.Connect(ctx, parameters) return tsClient.Connect(ctx, parameters)
} }
func (c *Connection) Shutdown() { func (c *Connection) Shutdown() {
c.muxer.Shutdown() c.muxer.Shutdown()
} }
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
stream, err := c.muxer.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
return rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
), nil
}

View File

@ -7,15 +7,15 @@ import (
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
) )
const ( const (
@ -58,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics {
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager // EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type EdgeManagerConfigurable struct { type EdgeManagerConfigurable struct {
TunnelHostnames []h2mux.TunnelHostname TunnelHostnames []h2mux.TunnelHostname
*pogs.EdgeConnectionConfig *tunnelpogs.EdgeConnectionConfig
} }
type CloudflaredConfig struct { type CloudflaredConfig struct {
CloudflaredID uuid.UUID CloudflaredID uuid.UUID
Tags []pogs.Tag Tags []tunnelpogs.Tag
BuildInfo *buildinfo.BuildInfo BuildInfo *buildinfo.BuildInfo
IntentLabel string IntentLabel string
} }
@ -126,7 +126,7 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
em.state.updateConfigurable(newConfigurable) em.state.updateConfigurable(newConfigurable)
} }
func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError { func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
edgeTCPAddr := em.serviceDiscoverer.Addr() edgeTCPAddr := em.serviceDiscoverer.Addr()
configurable := em.state.getConfigurable() configurable := em.state.getConfigurable()
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr) edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
@ -154,7 +154,7 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
go em.serveConn(ctx, h2muxConn) go em.serveConn(ctx, h2muxConn)
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{ connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
CloudflaredID: em.cloudflaredConfig.CloudflaredID, CloudflaredID: em.cloudflaredConfig.CloudflaredID,
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion, CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
NumPreviousAttempts: 0, NumPreviousAttempts: 0,
@ -285,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte {
return ems.userCredential return ems.userCredential
} }
func retryConnection(cause string) *pogs.ConnectError { func retryConnection(cause string) *tunnelpogs.ConnectError {
return &pogs.ConnectError{ return &tunnelpogs.ConnectError{
Cause: cause, Cause: cause,
RetryAfter: defaultRetryAfter, RetryAfter: defaultRetryAfter,
ShouldRetry: true, ShouldRetry: true,

View File

@ -4,15 +4,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
) )
var ( var (

49
connection/rpc.go Normal file
View File

@ -0,0 +1,49 @@
package connection
import (
"context"
"fmt"
"time"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// NewRPCClient creates and returns a new RPC client, which will communicate
// using a stream on the given muxer
func NewRPCClient(
ctx context.Context,
muxer *h2mux.Muxer,
logger *logrus.Entry,
openStreamTimeout time.Duration,
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return
}
if !isRPCStreamResponse(stream.Headers) {
stream.Close()
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
return
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger),
)
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
return client, nil
}
func isRPCStreamResponse(headers []h2mux.Header) bool {
return len(headers) == 1 &&
headers[0].Name == ":status" &&
headers[0].Value == "200"
}

View File

@ -6,12 +6,11 @@ import (
"net" "net"
"time" "time"
"github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/google/uuid"
) )
const ( const (

View File

@ -14,23 +14,21 @@ import (
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
rpc "zombiezen.com/go/capnproto2/rpc"
) )
const ( const (
@ -288,16 +286,6 @@ func ServeTunnel(
return nil, true return nil, true
} }
func IsRPCStreamResponse(headers []h2mux.Header) bool {
if len(headers) != 1 {
return false
}
if headers[0].Name != ":status" || headers[0].Value != "200" {
return false
}
return true
}
func RegisterTunnel( func RegisterTunnel(
ctx context.Context, ctx context.Context,
muxer *h2mux.Muxer, muxer *h2mux.Muxer,
@ -308,28 +296,18 @@ func RegisterTunnel(
uuid uuid.UUID, uuid uuid.UUID,
) error { ) error {
config.TransportLogger.Debug("initiating RPC stream to register") config.TransportLogger.Debug("initiating RPC stream to register")
stream, err := openStream(ctx, muxer) tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail) return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
} }
if !IsRPCStreamResponse(stream.Headers) { defer tunnelServer.Close()
// stream response error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
// Request server info without blocking tunnel registration; must use capnp library directly. // Request server info without blocking tunnel registration; must use capnp library directly.
tsClient := tunnelrpc.TunnelServer{Client: ts.Client} serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil return nil
}) })
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := ts.RegisterTunnel( registration := tunnelServer.RegisterTunnel(
ctx, ctx,
config.OriginCert, config.OriginCert,
config.Hostname, config.Hostname,
@ -369,7 +347,7 @@ func RegisterTunnel(
return nil return nil
} }
func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error { func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
if err.Error() == DuplicateConnectionError { if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn").Inc() metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
return dupConnRegisterTunnelError{} return dupConnRegisterTunnelError{}
@ -384,35 +362,15 @@ func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *Tunne
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error { func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
logger.Debug("initiating RPC stream to unregister") logger.Debug("initiating RPC stream to unregister")
ctx := context.Background() ctx := context.Background()
stream, err := openStream(ctx, muxer) ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
return err return err
} }
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return err
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
// gracePeriod is encoded in int64 using capnproto // gracePeriod is encoded in int64 using capnproto
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
} }
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
}
func LogServerInfo( func LogServerInfo(
promise tunnelrpc.ServerInfo_Promise, promise tunnelrpc.ServerInfo_Promise,
connectionID uint8, connectionID uint8,