TUN-3462: Refactor cloudflared to separate origin from connection

This commit is contained in:
cthuang 2020-10-08 11:12:26 +01:00
parent a5a5b93b64
commit 9ac40dcf04
32 changed files with 2006 additions and 1339 deletions

View File

@ -2,7 +2,6 @@ package access
import (
"net/http"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/carrier"
@ -17,16 +16,11 @@ import (
// StartForwarder starts a client side websocket forward
func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, logger logger.Service) error {
validURLString, err := validation.ValidateUrl(forwarder.Listener)
validURL, err := validation.ValidateUrl(forwarder.Listener)
if err != nil {
return errors.Wrap(err, "error validating origin URL")
}
validURL, err := url.Parse(validURLString)
if err != nil {
return errors.Wrap(err, "error parsing origin URL")
}
// get the headers from the config file and add to the request
headers := make(http.Header)
if forwarder.TokenClientID != "" {
@ -106,12 +100,7 @@ func ssh(c *cli.Context) error {
wsConn := carrier.NewWSConnection(logger, false)
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
localForwarder, err := config.ValidateUrl(c, true)
if err != nil {
logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL")
}
forwarder, err := url.Parse(localForwarder)
forwarder, err := config.ValidateUrl(c, true)
if err != nil {
logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL")

View File

@ -2,6 +2,7 @@ package config
import (
"fmt"
"net/url"
"os"
"path/filepath"
"runtime"
@ -189,11 +190,11 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) {
func ValidateUrl(c *cli.Context, allowFromArgs bool) (*url.URL, error) {
var url = c.String("url")
if allowFromArgs && c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}

View File

@ -18,6 +18,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/dbconnect"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
@ -247,7 +248,7 @@ func StartServer(
version string,
shutdownC,
graceShutdownC chan struct{},
namedTunnel *origin.NamedTunnelConfig,
namedTunnel *connection.NamedTunnelConfig,
log logger.Service,
isUIEnabled bool,
) error {
@ -366,7 +367,7 @@ func StartServer(
return errors.Wrap(err, "error setting up transport logger")
}
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel)
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled)
if err != nil {
return err
}
@ -386,10 +387,6 @@ func StartServer(
}()
if isUIEnabled {
const tunnelEventChanBufferSize = 16
tunnelEventChan := make(chan ui.TunnelEvent, tunnelEventChanBufferSize)
tunnelConfig.TunnelEventChan = tunnelEventChan
tunnelInfo := ui.NewUIModel(
version,
hostname,
@ -402,7 +399,7 @@ func StartServer(
if err != nil {
return err
}
tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelEventChan)
tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelConfig.TunnelEventChan)
}
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log)
@ -986,7 +983,7 @@ func configureLoggingFlags(shouldHide bool) []cli.Flag {
altsrc.NewStringFlag(&cli.StringFlag{
Name: "transport-loglevel",
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
Value: "fatal",
Value: "info",
Usage: "Transport logging level(previously called protocol logging level) {fatal, error, info, debug}",
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL", "TUNNEL_TRANSPORT_LOGLEVEL"},
Hidden: shouldHide,

View File

@ -9,6 +9,9 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/origin"
@ -154,10 +157,10 @@ func prepareTunnelConfig(
version string,
logger logger.Service,
transportLogger logger.Service,
namedTunnel *origin.NamedTunnelConfig,
namedTunnel *connection.NamedTunnelConfig,
uiIsEnabled bool,
) (*origin.TunnelConfig, error) {
isNamedTunnel := namedTunnel != nil
compatibilityMode := !isNamedTunnel
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
@ -189,10 +192,11 @@ func prepareTunnelConfig(
}
}
tunnelMetrics := origin.NewTunnelMetrics()
var ingressRules ingress.Ingress
if namedTunnel != nil {
var (
ingressRules ingress.Ingress
classicTunnel *connection.ClassicTunnelConfig
)
if isNamedTunnel {
clientUUID, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "can't generate clientUUID")
@ -210,6 +214,13 @@ func prepareTunnelConfig(
if !ingressRules.IsEmpty() && c.IsSet("url") {
return nil, ingress.ErrURLIncompatibleWithIngress
}
} else {
classicTunnel = &connection.ClassicTunnelConfig{
Hostname: hostname,
OriginCert: originCert,
// turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: !isNamedTunnel && c.Bool("use-reconnect-token"),
}
}
// Convert single-origin configuration into multi-origin configuration.
@ -220,43 +231,71 @@ func prepareTunnelConfig(
}
}
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel)
protocol := determineProtocol(namedTunnel)
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName())
if err != nil {
logger.Errorf("unable to create TLS config to connect with edge: %s", err)
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
}
return &origin.TunnelConfig{
BuildInfo: buildInfo,
ClientID: clientID,
CompressionQuality: c.Uint64("compression-quality"),
EdgeAddrs: c.StringSlice("edge"),
GracePeriod: c.Duration("grace-period"),
HAConnections: c.Int("ha-connections"),
proxyConfig := &origin.ProxyConfig{
Client: httpTransport,
URL: originURL,
TLSConfig: httpTransport.TLSClientConfig,
HostHeader: c.String("http-host-header"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
Tags: tags,
}
originClient := origin.NewClient(proxyConfig, logger)
transportConfig := &connection.Config{
OriginClient: originClient,
GracePeriod: c.Duration("grace-period"),
ReplaceExisting: c.Bool("force"),
}
muxerConfig := &connection.MuxerConfig{
HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname,
IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"),
Logger: logger,
TransportLogger: transportLogger,
MaxHeartbeats: c.Uint64("heartbeat-count"),
Metrics: tunnelMetrics,
CompressionSetting: h2mux.CompressionSetting(c.Uint64("compression-quality")),
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
OriginCert: originCert,
ReportedVersion: version,
Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(),
Tags: tags,
TlsConfig: toEdgeTLSConfig,
NamedTunnel: namedTunnel,
ReplaceExisting: c.Bool("force"),
IngressRules: ingressRules,
// turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"),
}
var tunnelEventChan chan ui.TunnelEvent
if uiIsEnabled {
tunnelEventChan = make(chan ui.TunnelEvent, 16)
}
return &origin.TunnelConfig{
ConnectionConfig: transportConfig,
ProxyConfig: proxyConfig,
BuildInfo: buildInfo,
ClientID: clientID,
EdgeAddrs: c.StringSlice("edge"),
HAConnections: c.Int("ha-connections"),
IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"),
Logger: logger,
Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol),
ReportedVersion: version,
Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(),
TLSConfig: toEdgeTLSConfig,
NamedTunnel: namedTunnel,
ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig,
TunnelEventChan: tunnelEventChan,
IngressRules: ingressRules,
}, nil
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}
func determineProtocol(namedTunnel *connection.NamedTunnelConfig) connection.Protocol {
if namedTunnel != nil {
return namedTunnel.Protocol
}
return connection.H2mux
}

View File

@ -14,8 +14,8 @@ import (
"github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstore"
)
@ -260,7 +260,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return err
}
protocol, ok := origin.ParseProtocol(sc.c.String("protocol"))
protocol, ok := connection.ParseProtocol(sc.c.String("protocol"))
if !ok {
return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol)
}
@ -269,7 +269,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
version,
shutdownC,
graceShutdownC,
&origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol},
&connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol},
sc.logger,
sc.isUIEnabled,
)

View File

@ -78,28 +78,31 @@ var (
Name: "credentials-file",
Aliases: []string{credFileFlagAlias},
Usage: "File path of tunnel credentials",
EnvVars: []string{"TUNNEL_CRED_FILE"},
})
forceDeleteFlag = &cli.BoolFlag{
Name: "force",
Aliases: []string{"f"},
Usage: "Allows you to delete a tunnel, even if it has active connections.",
EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"},
}
selectProtocolFlag = &cli.StringFlag{
Name: "protocol",
Value: "h2mux",
Aliases: []string{"p"},
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol),
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
Hidden: true,
}
)
func buildCreateCommand() *cli.Command {
return &cli.Command{
Name: "create",
Action: cliutil.ErrorHandler(createCommand),
Usage: "Create a new tunnel with given name",
UsageText: "cloudflared tunnel [tunnel command options] create [subcommand options] NAME",
Description: `Creates a tunnel, registers it with Cloudflare edge and generates credential file used to run this tunnel.
Name: "create",
Action: cliutil.ErrorHandler(createCommand),
Usage: "Create a new tunnel with given name",
UsageText: "cloudflared tunnel [tunnel command options] create [subcommand options] NAME",
Description: `Creates a tunnel, registers it with Cloudflare edge and generates credential file used to run this tunnel.
Use "cloudflared tunnel route" subcommand to map a DNS name to this tunnel and "cloudflared tunnel run" to start the connection.
For example, to create a tunnel named 'my-tunnel' run:

91
connection/connection.go Normal file
View File

@ -0,0 +1,91 @@
package connection
import (
"io"
"net/http"
"strconv"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
)
const (
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
)
type Config struct {
OriginClient OriginClient
GracePeriod time.Duration
ReplaceExisting bool
}
type NamedTunnelConfig struct {
Auth pogs.TunnelAuth
ID uuid.UUID
Client pogs.ClientInfo
Protocol Protocol
}
type ClassicTunnelConfig struct {
Hostname string
OriginCert []byte
// feature-flag to use new edge reconnect tokens
UseReconnectToken bool
}
func (c *ClassicTunnelConfig) IsTrialZone() bool {
return c.Hostname == ""
}
type Protocol int64
const (
H2mux Protocol = iota
HTTP2
)
func ParseProtocol(s string) (Protocol, bool) {
switch s {
case "h2mux":
return H2mux, true
case "http2":
return HTTP2, true
default:
return 0, false
}
}
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
type OriginClient interface {
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
}
type ResponseWriter interface {
WriteRespHeaders(*http.Response) error
WriteErrorResponse(error)
io.ReadWriter
}
type ConnectedFuse interface {
Connected()
IsConnected() bool
}
func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10)
}

76
connection/errors.go Normal file
View File

@ -0,0 +1,76 @@
package connection
import (
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
)
const (
DuplicateConnectionError = "EDUPCONN"
)
// RegisterTunnel error from client
type clientRegisterTunnelError struct {
cause error
}
func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
counter.WithLabelValues(cause.Error(), string(name)).Inc()
return clientRegisterTunnelError{cause: cause}
}
func (e clientRegisterTunnelError) Error() string {
return e.cause.Error()
}
type DupConnRegisterTunnelError struct{}
var errDuplicationConnection = &DupConnRegisterTunnelError{}
func (e DupConnRegisterTunnelError) Error() string {
return "already connected to this server, trying another address"
}
// RegisterTunnel error from server
type serverRegisterTunnelError struct {
cause error
permanent bool
}
func (e serverRegisterTunnelError) Error() string {
return e.cause.Error()
}
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
return &serverRegisterTunnelError{
cause: retryable.Unwrap(),
permanent: false,
}
}
return &serverRegisterTunnelError{
cause: err,
permanent: true,
}
}
type muxerShutdownError struct{}
func (e muxerShutdownError) Error() string {
return "muxer shutdown"
}
func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) bool {
switch err.(type) {
case edgediscovery.DialError:
observer.Errorf("Connection %d unable to dial edge: %s", connIndex, err)
case h2mux.MuxerHandshakeError:
observer.Errorf("Connection %d handshake with edge server failed: %s", connIndex, err)
default:
observer.Errorf("Connection %d failed: %s", connIndex, err)
return false
}
return true
}

216
connection/h2mux.go Normal file
View File

@ -0,0 +1,216 @@
package connection
import (
"context"
"net"
"net/http"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
)
const (
muxerTimeout = 5 * time.Second
openStreamTimeout = 30 * time.Second
)
type h2muxConnection struct {
config *Config
muxerConfig *MuxerConfig
originURL string
muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string
connIndexStr string
connIndex uint8
observer *Observer
}
type MuxerConfig struct {
HeartbeatInterval time.Duration
MaxHeartbeats uint64
CompressionSetting h2mux.CompressionSetting
MetricsUpdateFreq time.Duration
}
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.Service) *h2mux.MuxerConfig {
return &h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: h,
IsClient: true,
HeartbeatInterval: mc.HeartbeatInterval,
MaxHeartbeats: mc.MaxHeartbeats,
Logger: logger,
CompressionQuality: mc.CompressionSetting,
}
}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection(ctx context.Context,
config *Config,
muxerConfig *MuxerConfig,
originURL string,
edgeConn net.Conn,
connIndex uint8,
observer *Observer,
) (*h2muxConnection, error, bool) {
h := &h2muxConnection{
config: config,
muxerConfig: muxerConfig,
originURL: originURL,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
observer: observer,
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer), h2mux.ActiveStreams)
if err != nil {
recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
return nil, err, recoverable
}
h.muxer = muxer
return h, nil, false
}
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() error {
stream, err := h.newRPCStream(serveCtx, register)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
defer rpcClient.close()
if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err
}
connectedFuse.Connected()
return nil
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, true)
return nil
})
return errGroup.Wait()
}
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Connected()
}
}()
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
if err == nil {
return nil
}
// log errors and proceed to RegisterTunnel
h.observer.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", h.connIndex, err)
}
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, false)
return nil
})
return errGroup.Wait()
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
// here to notify other routines to stop
err := h.muxer.Serve(ctx)
if err == nil {
return muxerShutdownError{}
}
return err
}
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
for {
select {
case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns
if connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.muxer.Shutdown()
return
case <-updateMetricsTickC:
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
}
}
}
func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := h.muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return nil, err
}
return stream, nil
}
func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
respWriter := &h2muxRespWriter{stream}
req, reqErr := h.newRequest(stream)
if reqErr != nil {
respWriter.WriteErrorResponse(reqErr)
return reqErr
}
return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
}
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req, err := http.NewRequest("GET", h.originURL, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
return req, nil
}
type h2muxRespWriter struct {
*h2mux.MuxedStream
}
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
return rp.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
}
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
rp.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"},
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
})
rp.Write([]byte("502 Bad Gateway"))
}

253
connection/http2.go Normal file
View File

@ -0,0 +1,253 @@
package connection
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"golang.org/x/net/http2"
)
const (
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
websocketUpgrade = "websocket"
controlStreamUpgrade = "control-stream"
)
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
config *Config
originURL *url.URL
namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndexStr string
connIndex uint8
shutdownChan chan struct{}
connectedFuse ConnectedFuse
}
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) {
return &HTTP2Connection{
conn: conn,
server: &http2.Server{},
config: config,
originURL: originURL,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
shutdownChan: make(chan struct{}),
connectedFuse: connectedFuse,
}, nil
}
func (c *HTTP2Connection) Serve(ctx context.Context) {
go func() {
<-ctx.Done()
c.close()
}()
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
Context: ctx,
Handler: c,
})
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = c.originURL.Scheme
r.URL.Host = c.originURL.Host
respWriter := &http2RespWriter{
r: r.Body,
w: w,
}
if isControlStreamUpgrade(r) {
err := c.serveControlStream(r.Context(), respWriter)
if err != nil {
respWriter.WriteErrorResponse(err)
}
} else if isWebsocketUpgrade(r) {
wsRespWriter, err := newWSRespWriter(respWriter)
if err != nil {
respWriter.WriteErrorResponse(err)
return
}
stripWebsocketUpgradeHeader(r)
c.config.OriginClient.Proxy(wsRespWriter, r, true)
} else {
c.config.OriginClient.Proxy(respWriter, r, false)
}
}
func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error {
stream, err := newWSRespWriter(h2RespWriter)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, c.observer)
defer rpcClient.close()
if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
<-c.shutdownChan
c.gracefulShutdown(ctx, rpcClient)
close(c.shutdownChan)
return nil
}
func (c *HTTP2Connection) registerConnection(
ctx context.Context,
rpcClient tunnelpogs.RegistrationServer_PogsClient,
) error {
connDetail, err := rpcClient.RegisterConnection(
ctx,
c.namedTunnel.Auth,
c.namedTunnel.ID,
c.connIndex,
c.connOptions,
)
if err != nil {
c.observer.Errorf("Cannot register connection, err: %v", err)
return err
}
c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
return nil
}
func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) {
ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod)
defer cancel()
rpcClient.client.UnregisterConnection(ctx)
}
func (c *HTTP2Connection) close() {
// Send signal to control loop to start graceful shutdown
c.shutdownChan <- struct{}{}
// Wait for control loop to close channel
<-c.shutdownChan
c.conn.Close()
}
type http2RespWriter struct {
r io.Reader
w http.ResponseWriter
}
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
dest := rp.w.Header()
userHeaders := make(http.Header, len(resp.Header))
for header, values := range resp.Header {
// Since these are http2 headers, they're required to be lowercase
h2name := strings.ToLower(header)
for _, v := range values {
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,
// so it should be sent as an HTTP/2 response header.
dest.Add(h2name, v)
// Since these are http2 headers, they're required to be lowercase
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
// User headers, on the other hand, must all be serialized so that
// HTTP/2 header validation won't be applied to HTTP/1 header values
userHeaders.Add(h2name, v)
}
}
}
// Perform user header serialization and set them in the single header
dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
status := resp.StatusCode
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
if status == http.StatusSwitchingProtocols {
status = http.StatusOK
}
rp.w.WriteHeader(status)
return nil
}
func (rp *http2RespWriter) WriteErrorResponse(err error) {
jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
if err == nil {
rp.w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
}
rp.w.WriteHeader(http.StatusBadGateway)
}
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
return rp.r.Read(p)
}
func (wr *http2RespWriter) Write(p []byte) (n int, err error) {
return wr.w.Write(p)
}
type wsRespWriter struct {
h2 *http2RespWriter
flusher http.Flusher
}
func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) {
flusher, ok := h2.w.(http.Flusher)
if !ok {
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
}
return &wsRespWriter{
h2: h2,
flusher: flusher,
}, nil
}
func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) error {
err := rw.h2.WriteRespHeaders(resp)
if err != nil {
return err
}
rw.flusher.Flush()
return nil
}
func (rw *wsRespWriter) WriteErrorResponse(err error) {
rw.h2.WriteErrorResponse(err)
}
func (rw *wsRespWriter) Read(p []byte) (n int, err error) {
return rw.h2.Read(p)
}
func (rw *wsRespWriter) Write(p []byte) (n int, err error) {
n, err = rw.h2.Write(p)
if err != nil {
return
}
rw.flusher.Flush()
return
}
func (rw *wsRespWriter) Close() error {
return nil
}
func isControlStreamUpgrade(r *http.Request) bool {
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade
}
func isWebsocketUpgrade(r *http.Request) bool {
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
}
func stripWebsocketUpgradeHeader(r *http.Request) {
r.Header.Del(internalUpgradeHeader)
}

409
connection/metrics.go Normal file
View File

@ -0,0 +1,409 @@
package connection
import (
"sync"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/prometheus/client_golang/prometheus"
)
const (
MetricsNamespace = "cloudflared"
TunnelSubsystem = "tunnel"
muxerSubsystem = "muxer"
)
type muxerMetrics struct {
rtt *prometheus.GaugeVec
rttMin *prometheus.GaugeVec
rttMax *prometheus.GaugeVec
receiveWindowAve *prometheus.GaugeVec
sendWindowAve *prometheus.GaugeVec
receiveWindowMin *prometheus.GaugeVec
receiveWindowMax *prometheus.GaugeVec
sendWindowMin *prometheus.GaugeVec
sendWindowMax *prometheus.GaugeVec
inBoundRateCurr *prometheus.GaugeVec
inBoundRateMin *prometheus.GaugeVec
inBoundRateMax *prometheus.GaugeVec
outBoundRateCurr *prometheus.GaugeVec
outBoundRateMin *prometheus.GaugeVec
outBoundRateMax *prometheus.GaugeVec
compBytesBefore *prometheus.GaugeVec
compBytesAfter *prometheus.GaugeVec
compRateAve *prometheus.GaugeVec
}
type tunnelMetrics struct {
timerRetries prometheus.Gauge
serverLocations *prometheus.GaugeVec
// locationLock is a mutex for oldServerLocations
locationLock sync.Mutex
// oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string
regSuccess *prometheus.CounterVec
regFail *prometheus.CounterVec
rpcFail *prometheus.CounterVec
muxerMetrics *muxerMetrics
tunnelsHA tunnelsForHA
userHostnamesCounts *prometheus.CounterVec
}
func newMuxerMetrics() *muxerMetrics {
rtt := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt",
Help: "Round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rtt)
rttMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_min",
Help: "Shortest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMin)
rttMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_max",
Help: "Longest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMax)
receiveWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_ave",
Help: "Average receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowAve)
sendWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_ave",
Help: "Average send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowAve)
receiveWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_min",
Help: "Smallest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMin)
receiveWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_max",
Help: "Largest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMax)
sendWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_min",
Help: "Smallest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMin)
sendWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_max",
Help: "Largest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMax)
inBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_curr",
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateCurr)
inBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_min",
Help: "Minimum non-zero inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMin)
inBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_max",
Help: "Maximum inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMax)
outBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_curr",
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateCurr)
outBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_min",
Help: "Minimum non-zero outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMin)
outBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_max",
Help: "Maximum outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMax)
compBytesBefore := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_before",
Help: "Bytes sent via cross-stream compression, pre compression",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compBytesBefore)
compBytesAfter := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_after",
Help: "Bytes sent via cross-stream compression, post compression",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compBytesAfter)
compRateAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_rate_ave",
Help: "Average outbound cross-stream compression ratio",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compRateAve)
return &muxerMetrics{
rtt: rtt,
rttMin: rttMin,
rttMax: rttMax,
receiveWindowAve: receiveWindowAve,
sendWindowAve: sendWindowAve,
receiveWindowMin: receiveWindowMin,
receiveWindowMax: receiveWindowMax,
sendWindowMin: sendWindowMin,
sendWindowMax: sendWindowMax,
inBoundRateCurr: inBoundRateCurr,
inBoundRateMin: inBoundRateMin,
inBoundRateMax: inBoundRateMax,
outBoundRateCurr: outBoundRateCurr,
outBoundRateMin: outBoundRateMin,
outBoundRateMax: outBoundRateMax,
compBytesBefore: compBytesBefore,
compBytesAfter: compBytesAfter,
compRateAve: compRateAve,
}
}
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
}
func convertRTTMilliSec(t time.Duration) float64 {
return float64(t / time.Millisecond)
}
// Metrics that can be collected without asking the edge
func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "max_concurrent_requests_per_tunnel",
Help: "Largest number of concurrent requests proxied through each tunnel so far",
},
[]string{"connection_id"},
)
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "timer_retries",
Help: "Unacknowledged heart beats count",
})
prometheus.MustRegister(timerRetries)
serverLocations := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "server_locations",
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
},
[]string{"connection_id", "location"},
)
prometheus.MustRegister(serverLocations)
rpcFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "tunnel_rpc_fail",
Help: "Count of RPC connection errors by type",
},
[]string{"error", "rpcName"},
)
prometheus.MustRegister(rpcFail)
registerFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "tunnel_register_fail",
Help: "Count of tunnel registration errors by type",
},
[]string{"error", "rpcName"},
)
prometheus.MustRegister(registerFail)
userHostnamesCounts := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "user_hostnames_counts",
Help: "Which user hostnames cloudflared is serving",
},
[]string{"userHostname"},
)
prometheus.MustRegister(userHostnamesCounts)
registerSuccess := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "tunnel_register_success",
Help: "Count of successful tunnel registrations",
},
[]string{"rpcName"},
)
prometheus.MustRegister(registerSuccess)
var muxerMetrics *muxerMetrics
if protocol == H2mux {
muxerMetrics = newMuxerMetrics()
}
return &tunnelMetrics{
timerRetries: timerRetries,
serverLocations: serverLocations,
oldServerLocations: make(map[string]string),
muxerMetrics: muxerMetrics,
tunnelsHA: NewTunnelsForHA(),
regSuccess: registerSuccess,
regFail: registerFail,
rpcFail: rpcFail,
userHostnamesCounts: userHostnamesCounts,
}
}
func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
t.muxerMetrics.update(connectionID, metrics)
}
func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
t.locationLock.Lock()
defer t.locationLock.Unlock()
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
return
} else if ok {
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
}
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
t.oldServerLocations[connectionID] = loc
}

99
connection/observer.go Normal file
View File

@ -0,0 +1,99 @@
package connection
import (
"fmt"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type Observer struct {
logger.Service
metrics *tunnelMetrics
tunnelEventChan chan<- ui.TunnelEvent
}
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer {
return &Observer{
logger,
newTunnelMetrics(protocol),
tunnelEventChan,
}
}
func (o *Observer) logServerInfo(connectionID uint8, location, msg string) {
// If launch-ui flag is set, send connect msg
if o.tunnelEventChan != nil {
o.tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: location}
}
o.Infof(msg)
o.metrics.registerServerLocation(uint8ToString(connectionID), location)
}
func (o *Observer) logTrialHostname(registration *tunnelpogs.TunnelRegistration) error {
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
if o.tunnelEventChan == nil {
if registrationURL, err := url.Parse(registration.Url); err == nil {
for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) {
o.Info(line)
}
} else {
o.Error("Failed to connect tunnel, please try again.")
return fmt.Errorf("empty URL in response from Cloudflare edge")
}
}
return nil
}
// Print out the given lines in a nice ASCII box.
func asciiBox(lines []string, padding int) (box []string) {
maxLen := maxLen(lines)
spacer := strings.Repeat(" ", padding)
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
box = append(box, border)
for _, line := range lines {
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
}
box = append(box, border)
return
}
func maxLen(lines []string) int {
max := 0
for _, line := range lines {
if len(line) > max {
max = len(line)