TUN-2607: add RPC stream helpers

This commit is contained in:
Nick Vollmar 2019-11-21 11:03:13 -06:00
parent 8f4fd70783
commit bbf31377c2
6 changed files with 83 additions and 97 deletions

View File

@ -4,15 +4,12 @@ import (
"context"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -41,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error {
}
// Connect is used to establish connections with cloudflare's edge network
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
rpcConn, err := c.newRPConn(openStreamCtx, logger)
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) {
tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout)
if err != nil {
return nil, errors.Wrap(err, "cannot create new RPC connection")
}
defer rpcConn.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
defer tsClient.Close()
return tsClient.Connect(ctx, parameters)
}
func (c *Connection) Shutdown() {
c.muxer.Shutdown()
}
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
stream, err := c.muxer.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
return rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
), nil
}

View File

@ -7,15 +7,15 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -58,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics {
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type EdgeManagerConfigurable struct {
TunnelHostnames []h2mux.TunnelHostname
*pogs.EdgeConnectionConfig
*tunnelpogs.EdgeConnectionConfig
}
type CloudflaredConfig struct {
CloudflaredID uuid.UUID
Tags []pogs.Tag
Tags []tunnelpogs.Tag
BuildInfo *buildinfo.BuildInfo
IntentLabel string
}
@ -126,7 +126,7 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab
em.state.updateConfigurable(newConfigurable)
}
func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError {
edgeTCPAddr := em.serviceDiscoverer.Addr()
configurable := em.state.getConfigurable()
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
@ -154,7 +154,7 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError {
go em.serveConn(ctx, h2muxConn)
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
NumPreviousAttempts: 0,
@ -285,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte {
return ems.userCredential
}
func retryConnection(cause string) *pogs.ConnectError {
return &pogs.ConnectError{
func retryConnection(cause string) *tunnelpogs.ConnectError {
return &tunnelpogs.ConnectError{
Cause: cause,
RetryAfter: defaultRetryAfter,
ShouldRetry: true,

View File

@ -4,15 +4,15 @@ import (
"testing"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
var (

49
connection/rpc.go Normal file
View File

@ -0,0 +1,49 @@
package connection
import (
"context"
"fmt"
"time"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// NewRPCClient creates and returns a new RPC client, which will communicate
// using a stream on the given muxer
func NewRPCClient(
ctx context.Context,
muxer *h2mux.Muxer,
logger *logrus.Entry,
openStreamTimeout time.Duration,
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return
}
if !isRPCStreamResponse(stream.Headers) {
stream.Close()
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
return
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger),
)
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
return client, nil
}
func isRPCStreamResponse(headers []h2mux.Header) bool {
return len(headers) == 1 &&
headers[0].Name == ":status" &&
headers[0].Value == "200"
}

View File

@ -6,12 +6,11 @@ import (
"net"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/signal"
"github.com/google/uuid"
)
const (

View File

@ -14,23 +14,21 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
rpc "zombiezen.com/go/capnproto2/rpc"
)
const (
@ -288,16 +286,6 @@ func ServeTunnel(
return nil, true
}
func IsRPCStreamResponse(headers []h2mux.Header) bool {
if len(headers) != 1 {
return false
}
if headers[0].Name != ":status" || headers[0].Value != "200" {
return false
}
return true
}
func RegisterTunnel(
ctx context.Context,
muxer *h2mux.Muxer,
@ -308,28 +296,18 @@ func RegisterTunnel(
uuid uuid.UUID,
) error {
config.TransportLogger.Debug("initiating RPC stream to register")
stream, err := openStream(ctx, muxer)
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
defer tunnelServer.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := ts.RegisterTunnel(
registration := tunnelServer.RegisterTunnel(
ctx,
config.OriginCert,
config.Hostname,
@ -369,7 +347,7 @@ func RegisterTunnel(
return nil
}
func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error {
if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn").Inc()
return dupConnRegisterTunnelError{}
@ -384,35 +362,15 @@ func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *Tunne
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
logger.Debug("initiating RPC stream to unregister")
ctx := context.Background()
stream, err := openStream(ctx, muxer)
ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout)
if err != nil {
// RPC stream open error
return err
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
return err
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
// gracePeriod is encoded in int64 using capnproto
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
}
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
}
func LogServerInfo(
promise tunnelrpc.ServerInfo_Promise,
connectionID uint8,