diff --git a/connection/connection.go b/connection/connection.go index 629d37e7..7e032745 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -4,15 +4,12 @@ import ( "context" "time" - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/tunnelrpc" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/google/uuid" "github.com/pkg/errors" "github.com/sirupsen/logrus" - rpc "zombiezen.com/go/capnproto2/rpc" + "github.com/cloudflare/cloudflared/h2mux" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) const ( @@ -41,32 +38,15 @@ func (c *Connection) Serve(ctx context.Context) error { } // Connect is used to establish connections with cloudflare's edge network -func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (pogs.ConnectResult, error) { - openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout) - defer cancel() - - rpcConn, err := c.newRPConn(openStreamCtx, logger) +func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (tunnelpogs.ConnectResult, error) { + tsClient, err := NewRPCClient(ctx, c.muxer, logger.WithField("rpc", "connect"), openStreamTimeout) if err != nil { return nil, errors.Wrap(err, "cannot create new RPC connection") } - defer rpcConn.Close() - - tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)} - + defer tsClient.Close() return tsClient.Connect(ctx, parameters) } func (c *Connection) Shutdown() { c.muxer.Shutdown() } - -func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) { - stream, err := c.muxer.OpenRPCStream(ctx) - if err != nil { - return nil, err - } - return rpc.NewConn( - tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)), - tunnelrpc.ConnLog(logger.WithField("rpc", "connect")), - ), nil -} diff --git a/connection/manager.go b/connection/manager.go index ee4b4870..9ffc50ec 100644 --- a/connection/manager.go +++ b/connection/manager.go @@ -7,15 +7,15 @@ import ( "sync" "time" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/streamhandler" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/prometheus/client_golang/prometheus" - - "github.com/google/uuid" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) const ( @@ -58,12 +58,12 @@ func newMetrics(namespace, subsystem string) *metrics { // EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager type EdgeManagerConfigurable struct { TunnelHostnames []h2mux.TunnelHostname - *pogs.EdgeConnectionConfig + *tunnelpogs.EdgeConnectionConfig } type CloudflaredConfig struct { CloudflaredID uuid.UUID - Tags []pogs.Tag + Tags []tunnelpogs.Tag BuildInfo *buildinfo.BuildInfo IntentLabel string } @@ -126,7 +126,7 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab em.state.updateConfigurable(newConfigurable) } -func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError { +func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError { edgeTCPAddr := em.serviceDiscoverer.Addr() configurable := em.state.getConfigurable() edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr) @@ -154,7 +154,7 @@ func (em *EdgeManager) newConnection(ctx context.Context) *pogs.ConnectError { go em.serveConn(ctx, h2muxConn) - connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{ + connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{ CloudflaredID: em.cloudflaredConfig.CloudflaredID, CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion, NumPreviousAttempts: 0, @@ -285,8 +285,8 @@ func (ems *edgeManagerState) getUserCredential() []byte { return ems.userCredential } -func retryConnection(cause string) *pogs.ConnectError { - return &pogs.ConnectError{ +func retryConnection(cause string) *tunnelpogs.ConnectError { + return &tunnelpogs.ConnectError{ Cause: cause, RetryAfter: defaultRetryAfter, ShouldRetry: true, diff --git a/connection/manager_test.go b/connection/manager_test.go index 6732ce24..a8dd0e58 100644 --- a/connection/manager_test.go +++ b/connection/manager_test.go @@ -4,15 +4,15 @@ import ( "testing" "time" - "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" + "github.com/google/uuid" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - - "github.com/google/uuid" - "github.com/sirupsen/logrus" ) var ( diff --git a/connection/rpc.go b/connection/rpc.go new file mode 100644 index 00000000..9c10c334 --- /dev/null +++ b/connection/rpc.go @@ -0,0 +1,49 @@ +package connection + +import ( + "context" + "fmt" + "time" + + "github.com/sirupsen/logrus" + rpc "zombiezen.com/go/capnproto2/rpc" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/tunnelrpc" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +// NewRPCClient creates and returns a new RPC client, which will communicate +// using a stream on the given muxer +func NewRPCClient( + ctx context.Context, + muxer *h2mux.Muxer, + logger *logrus.Entry, + openStreamTimeout time.Duration, +) (client tunnelpogs.TunnelServer_PogsClient, err error) { + openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout) + defer openStreamCancel() + stream, err := muxer.OpenRPCStream(openStreamCtx) + if err != nil { + return + } + + if !isRPCStreamResponse(stream.Headers) { + stream.Close() + err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers) + return + } + + conn := rpc.NewConn( + tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), + tunnelrpc.ConnLog(logger), + ) + client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn} + return client, nil +} + +func isRPCStreamResponse(headers []h2mux.Header) bool { + return len(headers) == 1 && + headers[0].Name == ":status" && + headers[0].Value == "200" +} diff --git a/origin/supervisor.go b/origin/supervisor.go index 6c58674f..cea9f9b3 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -6,12 +6,11 @@ import ( "net" "time" + "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/signal" - - "github.com/google/uuid" ) const ( diff --git a/origin/tunnel.go b/origin/tunnel.go index b6c2ea68..4c931c10 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -14,23 +14,21 @@ import ( "sync" "time" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/websocket" - - "github.com/google/uuid" - "github.com/pkg/errors" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" - rpc "zombiezen.com/go/capnproto2/rpc" ) const ( @@ -288,16 +286,6 @@ func ServeTunnel( return nil, true } -func IsRPCStreamResponse(headers []h2mux.Header) bool { - if len(headers) != 1 { - return false - } - if headers[0].Name != ":status" || headers[0].Value != "200" { - return false - } - return true -} - func RegisterTunnel( ctx context.Context, muxer *h2mux.Muxer, @@ -308,28 +296,18 @@ func RegisterTunnel( uuid uuid.UUID, ) error { config.TransportLogger.Debug("initiating RPC stream to register") - stream, err := openStream(ctx, muxer) + tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-register"), openStreamTimeout) if err != nil { // RPC stream open error return newClientRegisterTunnelError(err, config.Metrics.rpcFail) } - if !IsRPCStreamResponse(stream.Headers) { - // stream response error - return newClientRegisterTunnelError(err, config.Metrics.rpcFail) - } - conn := rpc.NewConn( - tunnelrpc.NewTransportLogger(config.TransportLogger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)), - tunnelrpc.ConnLog(config.TransportLogger.WithField("subsystem", "rpc-transport")), - ) - defer conn.Close() - ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)} + defer tunnelServer.Close() // Request server info without blocking tunnel registration; must use capnp library directly. - tsClient := tunnelrpc.TunnelServer{Client: ts.Client} - serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { + serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { return nil }) LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) - registration := ts.RegisterTunnel( + registration := tunnelServer.RegisterTunnel( ctx, config.OriginCert, config.Hostname, @@ -369,7 +347,7 @@ func RegisterTunnel( return nil } -func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *TunnelMetrics) error { +func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics) error { if err.Error() == DuplicateConnectionError { metrics.regFail.WithLabelValues("dup_edge_conn").Inc() return dupConnRegisterTunnelError{} @@ -384,35 +362,15 @@ func processRegisterTunnelError(err pogs.TunnelRegistrationError, metrics *Tunne func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error { logger.Debug("initiating RPC stream to unregister") ctx := context.Background() - stream, err := openStream(ctx, muxer) + ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout) if err != nil { // RPC stream open error return err } - if !IsRPCStreamResponse(stream.Headers) { - // stream response error - return err - } - conn := rpc.NewConn( - tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)), - tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), - ) - defer conn.Close() - ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)} // gracePeriod is encoded in int64 using capnproto return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) } -func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) { - openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout) - defer cancel() - return muxer.OpenStream(openStreamCtx, []h2mux.Header{ - {Name: ":method", Value: "RPC"}, - {Name: ":scheme", Value: "capnp"}, - {Name: ":path", Value: "*"}, - }, nil) -} - func LogServerInfo( promise tunnelrpc.ServerInfo_Promise, connectionID uint8,