From 9ac40dcf04a5b05572051ec2f4c276d848bc8457 Mon Sep 17 00:00:00 2001 From: cthuang Date: Thu, 8 Oct 2020 11:12:26 +0100 Subject: [PATCH] TUN-3462: Refactor cloudflared to separate origin from connection --- cmd/cloudflared/access/carrier.go | 15 +- cmd/cloudflared/config/configuration.go | 5 +- cmd/cloudflared/tunnel/cmd.go | 13 +- cmd/cloudflared/tunnel/configuration.go | 105 +++- cmd/cloudflared/tunnel/subcommand_context.go | 6 +- cmd/cloudflared/tunnel/subcommands.go | 13 +- connection/connection.go | 91 +++ connection/errors.go | 76 +++ connection/h2mux.go | 216 +++++++ connection/http2.go | 253 ++++++++ connection/metrics.go | 409 +++++++++++++ connection/observer.go | 99 +++ connection/observer_test.go | 45 ++ connection/rpc.go | 264 +++++++- connection/tunnelsforha.go | 50 ++ {connection => edgediscovery}/dial.go | 2 +- .../mocks_for_test.go | 2 +- h2mux/activestreammap.go | 13 + h2mux/activestreammap_test.go | 6 +- h2mux/h2mux_test.go | 4 +- h2mux/muxmetrics.go | 12 - origin/connection.go | 14 - origin/metrics.go | 537 +--------------- origin/metrics_test.go | 121 ---- origin/proxy.go | 208 +++++++ origin/reconnect.go | 53 -- origin/supervisor.go | 51 +- origin/tunnel.go | 579 +++--------------- tlsconfig/certreloader.go | 11 +- tunnelrpc/pogs/connectionrpc.go | 5 + validation/validation.go | 38 +- validation/validation_test.go | 29 +- 32 files changed, 2006 insertions(+), 1339 deletions(-) create mode 100644 connection/connection.go create mode 100644 connection/errors.go create mode 100644 connection/h2mux.go create mode 100644 connection/http2.go create mode 100644 connection/metrics.go create mode 100644 connection/observer.go create mode 100644 connection/observer_test.go create mode 100644 connection/tunnelsforha.go rename {connection => edgediscovery}/dial.go (98%) rename {connection => edgediscovery}/mocks_for_test.go (99%) delete mode 100644 origin/connection.go delete mode 100644 origin/metrics_test.go create mode 100644 origin/proxy.go diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index 66a92c0a..435ef303 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -2,7 +2,6 @@ package access import ( "net/http" - "net/url" "strings" "github.com/cloudflare/cloudflared/carrier" @@ -17,16 +16,11 @@ import ( // StartForwarder starts a client side websocket forward func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, logger logger.Service) error { - validURLString, err := validation.ValidateUrl(forwarder.Listener) + validURL, err := validation.ValidateUrl(forwarder.Listener) if err != nil { return errors.Wrap(err, "error validating origin URL") } - validURL, err := url.Parse(validURLString) - if err != nil { - return errors.Wrap(err, "error parsing origin URL") - } - // get the headers from the config file and add to the request headers := make(http.Header) if forwarder.TokenClientID != "" { @@ -106,12 +100,7 @@ func ssh(c *cli.Context) error { wsConn := carrier.NewWSConnection(logger, false) if c.NArg() > 0 || c.IsSet(sshURLFlag) { - localForwarder, err := config.ValidateUrl(c, true) - if err != nil { - logger.Errorf("Error validating origin URL: %s", err) - return errors.Wrap(err, "error validating origin URL") - } - forwarder, err := url.Parse(localForwarder) + forwarder, err := config.ValidateUrl(c, true) if err != nil { logger.Errorf("Error validating origin URL: %s", err) return errors.Wrap(err, "error validating origin URL") diff --git a/cmd/cloudflared/config/configuration.go b/cmd/cloudflared/config/configuration.go index 3715ec6c..e8917857 100644 --- a/cmd/cloudflared/config/configuration.go +++ b/cmd/cloudflared/config/configuration.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net/url" "os" "path/filepath" "runtime" @@ -189,11 +190,11 @@ func ValidateUnixSocket(c *cli.Context) (string, error) { // ValidateUrl will validate url flag correctness. It can be either from --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument -func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) { +func ValidateUrl(c *cli.Context, allowFromArgs bool) (*url.URL, error) { var url = c.String("url") if allowFromArgs && c.NArg() > 0 { if c.IsSet("url") { - return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") + return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") } url = c.Args().Get(0) } diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 5b74a236..2850451a 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -18,6 +18,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/dbconnect" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" @@ -247,7 +248,7 @@ func StartServer( version string, shutdownC, graceShutdownC chan struct{}, - namedTunnel *origin.NamedTunnelConfig, + namedTunnel *connection.NamedTunnelConfig, log logger.Service, isUIEnabled bool, ) error { @@ -366,7 +367,7 @@ func StartServer( return errors.Wrap(err, "error setting up transport logger") } - tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel) + tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled) if err != nil { return err } @@ -386,10 +387,6 @@ func StartServer( }() if isUIEnabled { - const tunnelEventChanBufferSize = 16 - tunnelEventChan := make(chan ui.TunnelEvent, tunnelEventChanBufferSize) - tunnelConfig.TunnelEventChan = tunnelEventChan - tunnelInfo := ui.NewUIModel( version, hostname, @@ -402,7 +399,7 @@ func StartServer( if err != nil { return err } - tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelEventChan) + tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelConfig.TunnelEventChan) } return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log) @@ -986,7 +983,7 @@ func configureLoggingFlags(shouldHide bool) []cli.Flag { altsrc.NewStringFlag(&cli.StringFlag{ Name: "transport-loglevel", Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel - Value: "fatal", + Value: "info", Usage: "Transport logging level(previously called protocol logging level) {fatal, error, info, debug}", EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL", "TUNNEL_TRANSPORT_LOGLEVEL"}, Hidden: shouldHide, diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index c5e4978c..14318713 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -9,6 +9,9 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/origin" @@ -154,10 +157,10 @@ func prepareTunnelConfig( version string, logger logger.Service, transportLogger logger.Service, - namedTunnel *origin.NamedTunnelConfig, + namedTunnel *connection.NamedTunnelConfig, + uiIsEnabled bool, ) (*origin.TunnelConfig, error) { isNamedTunnel := namedTunnel != nil - compatibilityMode := !isNamedTunnel hostname, err := validation.ValidateHostname(c.String("hostname")) if err != nil { @@ -189,10 +192,11 @@ func prepareTunnelConfig( } } - tunnelMetrics := origin.NewTunnelMetrics() - - var ingressRules ingress.Ingress - if namedTunnel != nil { + var ( + ingressRules ingress.Ingress + classicTunnel *connection.ClassicTunnelConfig + ) + if isNamedTunnel { clientUUID, err := uuid.NewRandom() if err != nil { return nil, errors.Wrap(err, "can't generate clientUUID") @@ -210,6 +214,13 @@ func prepareTunnelConfig( if !ingressRules.IsEmpty() && c.IsSet("url") { return nil, ingress.ErrURLIncompatibleWithIngress } + } else { + classicTunnel = &connection.ClassicTunnelConfig{ + Hostname: hostname, + OriginCert: originCert, + // turn off use of reconnect token and auth refresh when using named tunnels + UseReconnectToken: !isNamedTunnel && c.Bool("use-reconnect-token"), + } } // Convert single-origin configuration into multi-origin configuration. @@ -220,43 +231,71 @@ func prepareTunnelConfig( } } - toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel) + protocol := determineProtocol(namedTunnel) + toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName()) if err != nil { logger.Errorf("unable to create TLS config to connect with edge: %s", err) return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") } - return &origin.TunnelConfig{ - BuildInfo: buildInfo, - ClientID: clientID, - CompressionQuality: c.Uint64("compression-quality"), - EdgeAddrs: c.StringSlice("edge"), - GracePeriod: c.Duration("grace-period"), - HAConnections: c.Int("ha-connections"), + + proxyConfig := &origin.ProxyConfig{ + Client: httpTransport, + URL: originURL, + TLSConfig: httpTransport.TLSClientConfig, + HostHeader: c.String("http-host-header"), + NoChunkedEncoding: c.Bool("no-chunked-encoding"), + Tags: tags, + } + originClient := origin.NewClient(proxyConfig, logger) + transportConfig := &connection.Config{ + OriginClient: originClient, + GracePeriod: c.Duration("grace-period"), + ReplaceExisting: c.Bool("force"), + } + muxerConfig := &connection.MuxerConfig{ HeartbeatInterval: c.Duration("heartbeat-interval"), - Hostname: hostname, - IncidentLookup: origin.NewIncidentLookup(), - IsAutoupdated: c.Bool("is-autoupdated"), - IsFreeTunnel: isFreeTunnel, - LBPool: c.String("lb-pool"), - Logger: logger, - TransportLogger: transportLogger, MaxHeartbeats: c.Uint64("heartbeat-count"), - Metrics: tunnelMetrics, + CompressionSetting: h2mux.CompressionSetting(c.Uint64("compression-quality")), MetricsUpdateFreq: c.Duration("metrics-update-freq"), - OriginCert: originCert, - ReportedVersion: version, - Retries: c.Uint("retries"), - RunFromTerminal: isRunningFromTerminal(), - Tags: tags, - TlsConfig: toEdgeTLSConfig, - NamedTunnel: namedTunnel, - ReplaceExisting: c.Bool("force"), - IngressRules: ingressRules, - // turn off use of reconnect token and auth refresh when using named tunnels - UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"), + } + + var tunnelEventChan chan ui.TunnelEvent + if uiIsEnabled { + tunnelEventChan = make(chan ui.TunnelEvent, 16) + } + + return &origin.TunnelConfig{ + ConnectionConfig: transportConfig, + ProxyConfig: proxyConfig, + BuildInfo: buildInfo, + ClientID: clientID, + EdgeAddrs: c.StringSlice("edge"), + HAConnections: c.Int("ha-connections"), + IncidentLookup: origin.NewIncidentLookup(), + IsAutoupdated: c.Bool("is-autoupdated"), + IsFreeTunnel: isFreeTunnel, + LBPool: c.String("lb-pool"), + Logger: logger, + Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol), + ReportedVersion: version, + Retries: c.Uint("retries"), + RunFromTerminal: isRunningFromTerminal(), + TLSConfig: toEdgeTLSConfig, + NamedTunnel: namedTunnel, + ClassicTunnel: classicTunnel, + MuxerConfig: muxerConfig, + TunnelEventChan: tunnelEventChan, + IngressRules: ingressRules, }, nil } func isRunningFromTerminal() bool { return terminal.IsTerminal(int(os.Stdout.Fd())) } + +func determineProtocol(namedTunnel *connection.NamedTunnelConfig) connection.Protocol { + if namedTunnel != nil { + return namedTunnel.Protocol + } + return connection.H2mux +} diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 3183786b..ce04ab5a 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -14,8 +14,8 @@ import ( "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelstore" ) @@ -260,7 +260,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { return err } - protocol, ok := origin.ParseProtocol(sc.c.String("protocol")) + protocol, ok := connection.ParseProtocol(sc.c.String("protocol")) if !ok { return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol) } @@ -269,7 +269,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { version, shutdownC, graceShutdownC, - &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol}, + &connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol}, sc.logger, sc.isUIEnabled, ) diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index a5b84e96..959a022a 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -78,28 +78,31 @@ var ( Name: "credentials-file", Aliases: []string{credFileFlagAlias}, Usage: "File path of tunnel credentials", + EnvVars: []string{"TUNNEL_CRED_FILE"}, }) forceDeleteFlag = &cli.BoolFlag{ Name: "force", Aliases: []string{"f"}, Usage: "Allows you to delete a tunnel, even if it has active connections.", + EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"}, } selectProtocolFlag = &cli.StringFlag{ Name: "protocol", Value: "h2mux", Aliases: []string{"p"}, Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol), + EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"}, Hidden: true, } ) func buildCreateCommand() *cli.Command { return &cli.Command{ - Name: "create", - Action: cliutil.ErrorHandler(createCommand), - Usage: "Create a new tunnel with given name", - UsageText: "cloudflared tunnel [tunnel command options] create [subcommand options] NAME", - Description: `Creates a tunnel, registers it with Cloudflare edge and generates credential file used to run this tunnel. + Name: "create", + Action: cliutil.ErrorHandler(createCommand), + Usage: "Create a new tunnel with given name", + UsageText: "cloudflared tunnel [tunnel command options] create [subcommand options] NAME", + Description: `Creates a tunnel, registers it with Cloudflare edge and generates credential file used to run this tunnel. Use "cloudflared tunnel route" subcommand to map a DNS name to this tunnel and "cloudflared tunnel run" to start the connection. For example, to create a tunnel named 'my-tunnel' run: diff --git a/connection/connection.go b/connection/connection.go new file mode 100644 index 00000000..5f7103b9 --- /dev/null +++ b/connection/connection.go @@ -0,0 +1,91 @@ +package connection + +import ( + "io" + "net/http" + "strconv" + "time" + + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/google/uuid" +) + +const ( + // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge + edgeH2muxTLSServerName = "cftunnel.com" + // edgeH2TLSServerName is the server name to establish http2 connection with edge + edgeH2TLSServerName = "h2.cftunnel.com" + lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" +) + +type Config struct { + OriginClient OriginClient + GracePeriod time.Duration + ReplaceExisting bool +} + +type NamedTunnelConfig struct { + Auth pogs.TunnelAuth + ID uuid.UUID + Client pogs.ClientInfo + Protocol Protocol +} + +type ClassicTunnelConfig struct { + Hostname string + OriginCert []byte + // feature-flag to use new edge reconnect tokens + UseReconnectToken bool +} + +func (c *ClassicTunnelConfig) IsTrialZone() bool { + return c.Hostname == "" +} + +type Protocol int64 + +const ( + H2mux Protocol = iota + HTTP2 +) + +func ParseProtocol(s string) (Protocol, bool) { + switch s { + case "h2mux": + return H2mux, true + case "http2": + return HTTP2, true + default: + return 0, false + } +} + +func (p Protocol) ServerName() string { + switch p { + case H2mux: + return edgeH2muxTLSServerName + case HTTP2: + return edgeH2TLSServerName + default: + return "" + } +} + +type OriginClient interface { + Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error +} + +type ResponseWriter interface { + WriteRespHeaders(*http.Response) error + WriteErrorResponse(error) + io.ReadWriter +} + +type ConnectedFuse interface { + Connected() + IsConnected() bool +} + +func uint8ToString(input uint8) string { + return strconv.FormatUint(uint64(input), 10) +} diff --git a/connection/errors.go b/connection/errors.go new file mode 100644 index 00000000..521a0349 --- /dev/null +++ b/connection/errors.go @@ -0,0 +1,76 @@ +package connection + +import ( + "github.com/cloudflare/cloudflared/edgediscovery" + "github.com/cloudflare/cloudflared/h2mux" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/prometheus/client_golang/prometheus" +) + +const ( + DuplicateConnectionError = "EDUPCONN" +) + +// RegisterTunnel error from client +type clientRegisterTunnelError struct { + cause error +} + +func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError { + counter.WithLabelValues(cause.Error(), string(name)).Inc() + return clientRegisterTunnelError{cause: cause} +} + +func (e clientRegisterTunnelError) Error() string { + return e.cause.Error() +} + +type DupConnRegisterTunnelError struct{} + +var errDuplicationConnection = &DupConnRegisterTunnelError{} + +func (e DupConnRegisterTunnelError) Error() string { + return "already connected to this server, trying another address" +} + +// RegisterTunnel error from server +type serverRegisterTunnelError struct { + cause error + permanent bool +} + +func (e serverRegisterTunnelError) Error() string { + return e.cause.Error() +} + +func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { + if retryable, ok := err.(*tunnelpogs.RetryableError); ok { + return &serverRegisterTunnelError{ + cause: retryable.Unwrap(), + permanent: false, + } + } + return &serverRegisterTunnelError{ + cause: err, + permanent: true, + } +} + +type muxerShutdownError struct{} + +func (e muxerShutdownError) Error() string { + return "muxer shutdown" +} + +func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) bool { + switch err.(type) { + case edgediscovery.DialError: + observer.Errorf("Connection %d unable to dial edge: %s", connIndex, err) + case h2mux.MuxerHandshakeError: + observer.Errorf("Connection %d handshake with edge server failed: %s", connIndex, err) + default: + observer.Errorf("Connection %d failed: %s", connIndex, err) + return false + } + return true +} diff --git a/connection/h2mux.go b/connection/h2mux.go new file mode 100644 index 00000000..6a51f698 --- /dev/null +++ b/connection/h2mux.go @@ -0,0 +1,216 @@ +package connection + +import ( + "context" + "net" + "net/http" + "time" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/logger" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" +) + +const ( + muxerTimeout = 5 * time.Second + openStreamTimeout = 30 * time.Second +) + +type h2muxConnection struct { + config *Config + muxerConfig *MuxerConfig + originURL string + muxer *h2mux.Muxer + // connectionID is only used by metrics, and prometheus requires labels to be string + connIndexStr string + connIndex uint8 + + observer *Observer +} + +type MuxerConfig struct { + HeartbeatInterval time.Duration + MaxHeartbeats uint64 + CompressionSetting h2mux.CompressionSetting + MetricsUpdateFreq time.Duration +} + +func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.Service) *h2mux.MuxerConfig { + return &h2mux.MuxerConfig{ + Timeout: muxerTimeout, + Handler: h, + IsClient: true, + HeartbeatInterval: mc.HeartbeatInterval, + MaxHeartbeats: mc.MaxHeartbeats, + Logger: logger, + CompressionQuality: mc.CompressionSetting, + } +} + +// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error +func NewH2muxConnection(ctx context.Context, + config *Config, + muxerConfig *MuxerConfig, + originURL string, + edgeConn net.Conn, + connIndex uint8, + observer *Observer, +) (*h2muxConnection, error, bool) { + h := &h2muxConnection{ + config: config, + muxerConfig: muxerConfig, + originURL: originURL, + connIndexStr: uint8ToString(connIndex), + connIndex: connIndex, + observer: observer, + } + + // Establish a muxed connection with the edge + // Client mux handshake with agent server + muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer), h2mux.ActiveStreams) + if err != nil { + recoverable := isHandshakeErrRecoverable(err, connIndex, observer) + return nil, err, recoverable + } + h.muxer = muxer + return h, nil, false +} + +func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { + errGroup, serveCtx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + return h.serveMuxer(serveCtx) + }) + + errGroup.Go(func() error { + stream, err := h.newRPCStream(serveCtx, register) + if err != nil { + return err + } + rpcClient := newRegistrationRPCClient(ctx, stream, h.observer) + defer rpcClient.close() + + if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { + return err + } + connectedFuse.Connected() + return nil + }) + + errGroup.Go(func() error { + h.controlLoop(serveCtx, connectedFuse, true) + return nil + }) + return errGroup.Wait() +} + +func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { + errGroup, serveCtx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + return h.serveMuxer(serveCtx) + }) + + errGroup.Go(func() (err error) { + defer func() { + if err == nil { + connectedFuse.Connected() + } + }() + if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() { + err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions) + if err == nil { + return nil + } + // log errors and proceed to RegisterTunnel + h.observer.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", h.connIndex, err) + } + return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions) + }) + + errGroup.Go(func() error { + h.controlLoop(serveCtx, connectedFuse, false) + return nil + }) + return errGroup.Wait() +} + +func (h *h2muxConnection) serveMuxer(ctx context.Context) error { + // All routines should stop when muxer finish serving. When muxer is shutdown + // gracefully, it doesn't return an error, so we need to return errMuxerShutdown + // here to notify other routines to stop + err := h.muxer.Serve(ctx) + if err == nil { + return muxerShutdownError{} + } + return err +} + +func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) { + updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq) + for { + select { + case <-ctx.Done(): + // UnregisterTunnel blocks until the RPC call returns + if connectedFuse.IsConnected() { + h.unregister(isNamedTunnel) + } + h.muxer.Shutdown() + return + case <-updateMetricsTickC: + h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics()) + } + } +} + +func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) { + openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout) + defer openStreamCancel() + stream, err := h.muxer.OpenRPCStream(openStreamCtx) + if err != nil { + return nil, err + } + return stream, nil +} + +func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { + respWriter := &h2muxRespWriter{stream} + + req, reqErr := h.newRequest(stream) + if reqErr != nil { + respWriter.WriteErrorResponse(reqErr) + return reqErr + } + + return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) +} + +func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) { + req, err := http.NewRequest("GET", h.originURL, h2mux.MuxedStreamReader{MuxedStream: stream}) + if err != nil { + return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") + } + err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) + if err != nil { + return nil, errors.Wrap(err, "invalid request received") + } + return req, nil +} + +type h2muxRespWriter struct { + *h2mux.MuxedStream +} + +func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error { + return rp.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp)) +} + +func (rp *h2muxRespWriter) WriteErrorResponse(err error) { + rp.WriteHeaders([]h2mux.Header{ + {Name: ":status", Value: "502"}, + h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared), + }) + rp.Write([]byte("502 Bad Gateway")) +} diff --git a/connection/http2.go b/connection/http2.go new file mode 100644 index 00000000..62726dc8 --- /dev/null +++ b/connection/http2.go @@ -0,0 +1,253 @@ +package connection + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/cloudflare/cloudflared/h2mux" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + + "golang.org/x/net/http2" +) + +const ( + internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade" + websocketUpgrade = "websocket" + controlStreamUpgrade = "control-stream" +) + +type HTTP2Connection struct { + conn net.Conn + server *http2.Server + config *Config + originURL *url.URL + namedTunnel *NamedTunnelConfig + connOptions *tunnelpogs.ConnectionOptions + observer *Observer + connIndexStr string + connIndex uint8 + shutdownChan chan struct{} + connectedFuse ConnectedFuse +} + +func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) { + return &HTTP2Connection{ + conn: conn, + server: &http2.Server{}, + config: config, + originURL: originURL, + namedTunnel: namedTunnelConfig, + connOptions: connOptions, + observer: observer, + connIndexStr: uint8ToString(connIndex), + connIndex: connIndex, + shutdownChan: make(chan struct{}), + connectedFuse: connectedFuse, + }, nil +} + +func (c *HTTP2Connection) Serve(ctx context.Context) { + go func() { + <-ctx.Done() + c.close() + }() + c.server.ServeConn(c.conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: c, + }) +} + +func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + r.URL.Scheme = c.originURL.Scheme + r.URL.Host = c.originURL.Host + + respWriter := &http2RespWriter{ + r: r.Body, + w: w, + } + if isControlStreamUpgrade(r) { + err := c.serveControlStream(r.Context(), respWriter) + if err != nil { + respWriter.WriteErrorResponse(err) + } + } else if isWebsocketUpgrade(r) { + wsRespWriter, err := newWSRespWriter(respWriter) + if err != nil { + respWriter.WriteErrorResponse(err) + return + } + stripWebsocketUpgradeHeader(r) + c.config.OriginClient.Proxy(wsRespWriter, r, true) + } else { + c.config.OriginClient.Proxy(respWriter, r, false) + } +} + +func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error { + stream, err := newWSRespWriter(h2RespWriter) + if err != nil { + return err + } + + rpcClient := newRegistrationRPCClient(ctx, stream, c.observer) + defer rpcClient.close() + + if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { + return err + } + c.connectedFuse.Connected() + + <-c.shutdownChan + c.gracefulShutdown(ctx, rpcClient) + close(c.shutdownChan) + return nil +} + +func (c *HTTP2Connection) registerConnection( + ctx context.Context, + rpcClient tunnelpogs.RegistrationServer_PogsClient, +) error { + connDetail, err := rpcClient.RegisterConnection( + ctx, + c.namedTunnel.Auth, + c.namedTunnel.ID, + c.connIndex, + c.connOptions, + ) + if err != nil { + c.observer.Errorf("Cannot register connection, err: %v", err) + return err + } + c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID) + return nil +} + +func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) { + ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod) + defer cancel() + rpcClient.client.UnregisterConnection(ctx) +} + +func (c *HTTP2Connection) close() { + // Send signal to control loop to start graceful shutdown + c.shutdownChan <- struct{}{} + // Wait for control loop to close channel + <-c.shutdownChan + c.conn.Close() +} + +type http2RespWriter struct { + r io.Reader + w http.ResponseWriter +} + +func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { + dest := rp.w.Header() + userHeaders := make(http.Header, len(resp.Header)) + for header, values := range resp.Header { + // Since these are http2 headers, they're required to be lowercase + h2name := strings.ToLower(header) + for _, v := range values { + if h2name == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + dest.Add(h2name, v) + // Since these are http2 headers, they're required to be lowercase + } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + userHeaders.Add(h2name, v) + } + } + } + + // Perform user header serialization and set them in the single header + dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) + status := resp.StatusCode + // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 + if status == http.StatusSwitchingProtocols { + status = http.StatusOK + } + rp.w.WriteHeader(status) + return nil +} + +func (rp *http2RespWriter) WriteErrorResponse(err error) { + jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared}) + if err == nil { + rp.w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader)) + } + rp.w.WriteHeader(http.StatusBadGateway) +} + +func (rp *http2RespWriter) Read(p []byte) (n int, err error) { + return rp.r.Read(p) +} + +func (wr *http2RespWriter) Write(p []byte) (n int, err error) { + return wr.w.Write(p) +} + +type wsRespWriter struct { + h2 *http2RespWriter + flusher http.Flusher +} + +func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) { + flusher, ok := h2.w.(http.Flusher) + if !ok { + return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher") + } + return &wsRespWriter{ + h2: h2, + flusher: flusher, + }, nil +} + +func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) error { + err := rw.h2.WriteRespHeaders(resp) + if err != nil { + return err + } + rw.flusher.Flush() + return nil +} + +func (rw *wsRespWriter) WriteErrorResponse(err error) { + rw.h2.WriteErrorResponse(err) +} + +func (rw *wsRespWriter) Read(p []byte) (n int, err error) { + return rw.h2.Read(p) +} + +func (rw *wsRespWriter) Write(p []byte) (n int, err error) { + n, err = rw.h2.Write(p) + if err != nil { + return + } + rw.flusher.Flush() + return +} + +func (rw *wsRespWriter) Close() error { + return nil +} + +func isControlStreamUpgrade(r *http.Request) bool { + return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade +} + +func isWebsocketUpgrade(r *http.Request) bool { + return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade +} + +func stripWebsocketUpgradeHeader(r *http.Request) { + r.Header.Del(internalUpgradeHeader) +} diff --git a/connection/metrics.go b/connection/metrics.go new file mode 100644 index 00000000..405611bc --- /dev/null +++ b/connection/metrics.go @@ -0,0 +1,409 @@ +package connection + +import ( + "sync" + "time" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/prometheus/client_golang/prometheus" +) + +const ( + MetricsNamespace = "cloudflared" + TunnelSubsystem = "tunnel" + muxerSubsystem = "muxer" +) + +type muxerMetrics struct { + rtt *prometheus.GaugeVec + rttMin *prometheus.GaugeVec + rttMax *prometheus.GaugeVec + receiveWindowAve *prometheus.GaugeVec + sendWindowAve *prometheus.GaugeVec + receiveWindowMin *prometheus.GaugeVec + receiveWindowMax *prometheus.GaugeVec + sendWindowMin *prometheus.GaugeVec + sendWindowMax *prometheus.GaugeVec + inBoundRateCurr *prometheus.GaugeVec + inBoundRateMin *prometheus.GaugeVec + inBoundRateMax *prometheus.GaugeVec + outBoundRateCurr *prometheus.GaugeVec + outBoundRateMin *prometheus.GaugeVec + outBoundRateMax *prometheus.GaugeVec + compBytesBefore *prometheus.GaugeVec + compBytesAfter *prometheus.GaugeVec + compRateAve *prometheus.GaugeVec +} + +type tunnelMetrics struct { + timerRetries prometheus.Gauge + serverLocations *prometheus.GaugeVec + // locationLock is a mutex for oldServerLocations + locationLock sync.Mutex + // oldServerLocations stores the last server the tunnel was connected to + oldServerLocations map[string]string + + regSuccess *prometheus.CounterVec + regFail *prometheus.CounterVec + rpcFail *prometheus.CounterVec + + muxerMetrics *muxerMetrics + tunnelsHA tunnelsForHA + userHostnamesCounts *prometheus.CounterVec +} + +func newMuxerMetrics() *muxerMetrics { + rtt := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "rtt", + Help: "Round-trip time in millisecond", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(rtt) + + rttMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "rtt_min", + Help: "Shortest round-trip time in millisecond", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(rttMin) + + rttMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "rtt_max", + Help: "Longest round-trip time in millisecond", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(rttMax) + + receiveWindowAve := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "receive_window_ave", + Help: "Average receive window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(receiveWindowAve) + + sendWindowAve := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "send_window_ave", + Help: "Average send window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(sendWindowAve) + + receiveWindowMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "receive_window_min", + Help: "Smallest receive window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(receiveWindowMin) + + receiveWindowMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "receive_window_max", + Help: "Largest receive window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(receiveWindowMax) + + sendWindowMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "send_window_min", + Help: "Smallest send window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(sendWindowMin) + + sendWindowMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "send_window_max", + Help: "Largest send window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(sendWindowMax) + + inBoundRateCurr := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "inbound_bytes_per_sec_curr", + Help: "Current inbounding bytes per second, 0 if there is no incoming connection", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(inBoundRateCurr) + + inBoundRateMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "inbound_bytes_per_sec_min", + Help: "Minimum non-zero inbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(inBoundRateMin) + + inBoundRateMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "inbound_bytes_per_sec_max", + Help: "Maximum inbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(inBoundRateMax) + + outBoundRateCurr := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "outbound_bytes_per_sec_curr", + Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(outBoundRateCurr) + + outBoundRateMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "outbound_bytes_per_sec_min", + Help: "Minimum non-zero outbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(outBoundRateMin) + + outBoundRateMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "outbound_bytes_per_sec_max", + Help: "Maximum outbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(outBoundRateMax) + + compBytesBefore := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "comp_bytes_before", + Help: "Bytes sent via cross-stream compression, pre compression", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(compBytesBefore) + + compBytesAfter := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "comp_bytes_after", + Help: "Bytes sent via cross-stream compression, post compression", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(compBytesAfter) + + compRateAve := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: muxerSubsystem, + Name: "comp_rate_ave", + Help: "Average outbound cross-stream compression ratio", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(compRateAve) + + return &muxerMetrics{ + rtt: rtt, + rttMin: rttMin, + rttMax: rttMax, + receiveWindowAve: receiveWindowAve, + sendWindowAve: sendWindowAve, + receiveWindowMin: receiveWindowMin, + receiveWindowMax: receiveWindowMax, + sendWindowMin: sendWindowMin, + sendWindowMax: sendWindowMax, + inBoundRateCurr: inBoundRateCurr, + inBoundRateMin: inBoundRateMin, + inBoundRateMax: inBoundRateMax, + outBoundRateCurr: outBoundRateCurr, + outBoundRateMin: outBoundRateMin, + outBoundRateMax: outBoundRateMax, + compBytesBefore: compBytesBefore, + compBytesAfter: compBytesAfter, + compRateAve: compRateAve, + } +} + +func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) { + m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT)) + m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin)) + m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax)) + m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve) + m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve) + m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin)) + m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax)) + m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin)) + m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax)) + m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr)) + m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin)) + m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax)) + m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr)) + m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin)) + m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax)) + m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value())) + m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value())) + m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve())) +} + +func convertRTTMilliSec(t time.Duration) float64 { + return float64(t / time.Millisecond) +} + +// Metrics that can be collected without asking the edge +func newTunnelMetrics(protocol Protocol) *tunnelMetrics { + maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "max_concurrent_requests_per_tunnel", + Help: "Largest number of concurrent requests proxied through each tunnel so far", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(maxConcurrentRequestsPerTunnel) + + timerRetries := prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "timer_retries", + Help: "Unacknowledged heart beats count", + }) + prometheus.MustRegister(timerRetries) + + serverLocations := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "server_locations", + Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.", + }, + []string{"connection_id", "location"}, + ) + prometheus.MustRegister(serverLocations) + + rpcFail := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "tunnel_rpc_fail", + Help: "Count of RPC connection errors by type", + }, + []string{"error", "rpcName"}, + ) + prometheus.MustRegister(rpcFail) + + registerFail := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "tunnel_register_fail", + Help: "Count of tunnel registration errors by type", + }, + []string{"error", "rpcName"}, + ) + prometheus.MustRegister(registerFail) + + userHostnamesCounts := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "user_hostnames_counts", + Help: "Which user hostnames cloudflared is serving", + }, + []string{"userHostname"}, + ) + prometheus.MustRegister(userHostnamesCounts) + + registerSuccess := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: MetricsNamespace, + Subsystem: TunnelSubsystem, + Name: "tunnel_register_success", + Help: "Count of successful tunnel registrations", + }, + []string{"rpcName"}, + ) + prometheus.MustRegister(registerSuccess) + var muxerMetrics *muxerMetrics + if protocol == H2mux { + muxerMetrics = newMuxerMetrics() + } + + return &tunnelMetrics{ + timerRetries: timerRetries, + serverLocations: serverLocations, + oldServerLocations: make(map[string]string), + muxerMetrics: muxerMetrics, + tunnelsHA: NewTunnelsForHA(), + regSuccess: registerSuccess, + regFail: registerFail, + rpcFail: rpcFail, + userHostnamesCounts: userHostnamesCounts, + } +} + +func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) { + t.muxerMetrics.update(connectionID, metrics) +} + +func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) { + t.locationLock.Lock() + defer t.locationLock.Unlock() + if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { + return + } else if ok { + t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec() + } + t.serverLocations.WithLabelValues(connectionID, loc).Inc() + t.oldServerLocations[connectionID] = loc +} diff --git a/connection/observer.go b/connection/observer.go new file mode 100644 index 00000000..9dfb15aa --- /dev/null +++ b/connection/observer.go @@ -0,0 +1,99 @@ +package connection + +import ( + "fmt" + "net/url" + "strings" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" + "github.com/cloudflare/cloudflared/logger" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +type Observer struct { + logger.Service + metrics *tunnelMetrics + tunnelEventChan chan<- ui.TunnelEvent +} + +func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer { + return &Observer{ + logger, + newTunnelMetrics(protocol), + tunnelEventChan, + } +} + +func (o *Observer) logServerInfo(connectionID uint8, location, msg string) { + // If launch-ui flag is set, send connect msg + if o.tunnelEventChan != nil { + o.tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: location} + } + o.Infof(msg) + o.metrics.registerServerLocation(uint8ToString(connectionID), location) +} + +func (o *Observer) logTrialHostname(registration *tunnelpogs.TunnelRegistration) error { + // Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set) + if o.tunnelEventChan == nil { + if registrationURL, err := url.Parse(registration.Url); err == nil { + for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) { + o.Info(line) + } + } else { + o.Error("Failed to connect tunnel, please try again.") + return fmt.Errorf("empty URL in response from Cloudflare edge") + } + } + return nil +} + +// Print out the given lines in a nice ASCII box. +func asciiBox(lines []string, padding int) (box []string) { + maxLen := maxLen(lines) + spacer := strings.Repeat(" ", padding) + + border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+" + + box = append(box, border) + for _, line := range lines { + box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|") + } + box = append(box, border) + return +} + +func maxLen(lines []string) int { + max := 0 + for _, line := range lines { + if len(line) > max { + max = len(line) + } + } + return max +} + +func trialZoneMsg(url string) []string { + return []string{ + "Your free tunnel has started! Visit it:", + " " + url, + } +} + +func (o *Observer) sendRegisteringEvent() { + if o.tunnelEventChan != nil { + o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel} + } +} + +func (o *Observer) sendConnectedEvent(connIndex uint8, location string) { + if o.tunnelEventChan != nil { + o.tunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Connected, Location: location} + } +} + +func (o *Observer) sendURL(url string) { + if o.tunnelEventChan != nil { + o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: url} + } +} diff --git a/connection/observer_test.go b/connection/observer_test.go new file mode 100644 index 00000000..aa47430e --- /dev/null +++ b/connection/observer_test.go @@ -0,0 +1,45 @@ +package connection + +import ( + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +// can only be called once +var m = newTunnelMetrics(H2mux) + +func TestRegisterServerLocation(t *testing.T) { + tunnels := 20 + var wg sync.WaitGroup + wg.Add(tunnels) + for i := 0; i < tunnels; i++ { + go func(i int) { + id := strconv.Itoa(i) + m.registerServerLocation(id, "LHR") + wg.Done() + }(i) + } + wg.Wait() + for i := 0; i < tunnels; i++ { + id := strconv.Itoa(i) + assert.Equal(t, "LHR", m.oldServerLocations[id]) + } + + wg.Add(tunnels) + for i := 0; i < tunnels; i++ { + go func(i int) { + id := strconv.Itoa(i) + m.registerServerLocation(id, "AUS") + wg.Done() + }(i) + } + wg.Wait() + for i := 0; i < tunnels; i++ { + id := strconv.Itoa(i) + assert.Equal(t, "AUS", m.oldServerLocations[id]) + } + +} diff --git a/connection/rpc.go b/connection/rpc.go index ee24e250..9da7630d 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -2,40 +2,276 @@ package connection import ( "context" + "fmt" "io" - rpc "zombiezen.com/go/capnproto2/rpc" - "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "zombiezen.com/go/capnproto2/rpc" ) -// NewTunnelRPCClient creates and returns a new RPC client, which will communicate -// using a stream on the given muxer -func NewTunnelRPCClient( +type tunnelServerClient struct { + client tunnelpogs.TunnelServer_PogsClient + transport rpc.Transport +} + +// NewTunnelRPCClient creates and returns a new RPC client, which will communicate using a stream on the given muxer. +// This method is exported for supervisor to call Authenticate RPC +func NewTunnelServerClient( ctx context.Context, stream io.ReadWriteCloser, logger logger.Service, -) (client tunnelpogs.TunnelServer_PogsClient, err error) { +) *tunnelServerClient { + transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)) conn := rpc.NewConn( - tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), + transport, tunnelrpc.ConnLog(logger), ) registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn} - client = tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn} - return client, nil + return &tunnelServerClient{ + client: tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn}, + transport: transport, + } } -func NewRegistrationRPCClient( +func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) { + authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions) + if err != nil { + return nil, err + } + return authResp.Outcome(), nil +} + +func (tsc *tunnelServerClient) Close() { + // Closing the client will also close the connection + tsc.client.Close() + tsc.transport.Close() +} + +type registrationServerClient struct { + client tunnelpogs.RegistrationServer_PogsClient + transport rpc.Transport +} + +func newRegistrationRPCClient( ctx context.Context, stream io.ReadWriteCloser, logger logger.Service, -) (client tunnelpogs.RegistrationServer_PogsClient, err error) { +) *registrationServerClient { + transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)) conn := rpc.NewConn( - tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), + transport, tunnelrpc.ConnLog(logger), ) - client = tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn} - return client, nil + return ®istrationServerClient{ + client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}, + transport: transport, + } +} + +func (rsc *registrationServerClient) close() { + // Closing the client will also close the connection + rsc.client.Close() + // Closing the transport also closes the stream + rsc.transport.Close() +} + +type rpcName string + +const ( + register rpcName = "register" + reconnect rpcName = "reconnect" + unregister rpcName = "unregister" + authenticate rpcName = " authenticate" +) + +func registerConnection( + ctx context.Context, + rpcClient *registrationServerClient, + config *NamedTunnelConfig, + options *tunnelpogs.ConnectionOptions, + connIndex uint8, + observer *Observer, +) error { + conn, err := rpcClient.client.RegisterConnection( + ctx, + config.Auth, + config.ID, + connIndex, + options, + ) + if err != nil { + if err.Error() == DuplicateConnectionError { + observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc() + return errDuplicationConnection + } + observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc() + return serverRegistrationErrorFromRPC(err) + } + + observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() + + observer.logServerInfo(connIndex, conn.Location, fmt.Sprintf("Connection %d registered with %s using ID %s", connIndex, conn.Location, conn.UUID)) + observer.sendConnectedEvent(connIndex, conn.Location) + + return nil +} + +func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { + h.observer.sendRegisteringEvent() + + stream, err := h.newRPCStream(ctx, register) + if err != nil { + return err + } + rpcClient := NewTunnelServerClient(ctx, stream, h.observer) + defer rpcClient.Close() + + h.logServerInfo(ctx, rpcClient) + registration := rpcClient.client.RegisterTunnel( + ctx, + classicTunnel.OriginCert, + classicTunnel.Hostname, + registrationOptions, + ) + if registrationErr := registration.DeserializeError(); registrationErr != nil { + // RegisterTunnel RPC failure + return h.processRegisterTunnelError(registrationErr, register) + } + + // Send free tunnel URL to UI + h.observer.sendURL(registration.Url) + credentialSetter.SetEventDigest(h.connIndex, registration.EventDigest) + return h.processRegistrationSuccess(registration, register, credentialSetter, classicTunnel) +} + +type CredentialManager interface { + ReconnectToken() ([]byte, error) + EventDigest(connID uint8) ([]byte, error) + SetEventDigest(connID uint8, digest []byte) + ConnDigest(connID uint8) ([]byte, error) + SetConnDigest(connID uint8, digest []byte) +} + +func (h *h2muxConnection) processRegistrationSuccess( + registration *tunnelpogs.TunnelRegistration, + name rpcName, + credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, +) error { + for _, logLine := range registration.LogLines { + h.observer.Info(logLine) + } + + if registration.TunnelID != "" { + h.observer.metrics.tunnelsHA.AddTunnelID(h.connIndex, registration.TunnelID) + h.observer.Infof("Each HA connection's tunnel IDs: %v", h.observer.metrics.tunnelsHA.String()) + } + + // Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set) + if classicTunnel.IsTrialZone() { + err := h.observer.logTrialHostname(registration) + if err != nil { + return err + } + } + + credentialManager.SetConnDigest(h.connIndex, registration.ConnDigest) + h.observer.metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc() + + h.observer.Infof("Route propagating, it may take up to 1 minute for your new route to become functional") + h.observer.metrics.regSuccess.WithLabelValues(string(name)).Inc() + return nil +} + +func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, name rpcName) error { + if err.Error() == DuplicateConnectionError { + h.observer.metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc() + return errDuplicationConnection + } + h.observer.metrics.regFail.WithLabelValues("server_error", string(name)).Inc() + return serverRegisterTunnelError{ + cause: err, + permanent: err.IsPermanent(), + } +} + +func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { + token, err := credentialManager.ReconnectToken() + if err != nil { + return err + } + eventDigest, err := credentialManager.EventDigest(h.connIndex) + if err != nil { + return err + } + connDigest, err := credentialManager.ConnDigest(h.connIndex) + if err != nil { + return err + } + + h.observer.Debug("initiating RPC stream to reconnect") + stream, err := h.newRPCStream(ctx, register) + if err != nil { + return err + } + rpcClient := NewTunnelServerClient(ctx, stream, h.observer) + defer rpcClient.Close() + + h.logServerInfo(ctx, rpcClient) + registration := rpcClient.client.ReconnectTunnel( + ctx, + token, + eventDigest, + connDigest, + classicTunnel.Hostname, + registrationOptions, + ) + if registrationErr := registration.DeserializeError(); registrationErr != nil { + // ReconnectTunnel RPC failure + return h.processRegisterTunnelError(registrationErr, reconnect) + } + return h.processRegistrationSuccess(registration, reconnect, credentialManager, classicTunnel) +} + +func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelServerClient) error { + // Request server info without blocking tunnel registration; must use capnp library directly. + serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.client.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { + return nil + }) + serverInfoMessage, err := serverInfoPromise.Result().Struct() + if err != nil { + h.observer.Errorf("Failed to retrieve server information: %s", err) + return err + } + serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage) + if err != nil { + h.observer.Errorf("Failed to retrieve server information: %s", err) + return err + } + h.observer.logServerInfo(h.connIndex, serverInfo.LocationName, fmt.Sprintf("Connnection %d connected to %s", h.connIndex, serverInfo.LocationName)) + return nil +} + +func (h *h2muxConnection) unregister(isNamedTunnel bool) { + unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod) + defer cancel() + + stream, err := h.newRPCStream(unregisterCtx, register) + if err != nil { + return + } + + if isNamedTunnel { + rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer) + defer rpcClient.close() + + rpcClient.client.UnregisterConnection(unregisterCtx) + } else { + rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer) + defer rpcClient.Close() + + // gracePeriod is encoded in int64 using capnproto + rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds()) + } } diff --git a/connection/tunnelsforha.go b/connection/tunnelsforha.go new file mode 100644 index 00000000..9f7ab309 --- /dev/null +++ b/connection/tunnelsforha.go @@ -0,0 +1,50 @@ +package connection + +import ( + "fmt" + "sync" + + "github.com/prometheus/client_golang/prometheus" +) + +// tunnelsForHA maps this cloudflared instance's HA connections to the tunnel IDs they serve. +type tunnelsForHA struct { + sync.Mutex + metrics *prometheus.GaugeVec + entries map[uint8]string +} + +// NewTunnelsForHA initializes the Prometheus metrics etc for a tunnelsForHA. +func NewTunnelsForHA() tunnelsForHA { + metrics := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "tunnel_ids", + Help: "The ID of all tunnels (and their corresponding HA connection ID) running in this instance of cloudflared.", + }, + []string{"tunnel_id", "ha_conn_id"}, + ) + prometheus.MustRegister(metrics) + + return tunnelsForHA{ + metrics: metrics, + entries: make(map[uint8]string), + } +} + +// Track a new tunnel ID, removing the disconnected tunnel (if any) and update metrics. +func (t *tunnelsForHA) AddTunnelID(haConn uint8, tunnelID string) { + t.Lock() + defer t.Unlock() + haStr := fmt.Sprintf("%v", haConn) + if oldTunnelID, ok := t.entries[haConn]; ok { + t.metrics.WithLabelValues(oldTunnelID, haStr).Dec() + } + t.entries[haConn] = tunnelID + t.metrics.WithLabelValues(tunnelID, haStr).Inc() +} + +func (t *tunnelsForHA) String() string { + t.Lock() + defer t.Unlock() + return fmt.Sprintf("%v", t.entries) +} diff --git a/connection/dial.go b/edgediscovery/dial.go similarity index 98% rename from connection/dial.go rename to edgediscovery/dial.go index 4651afd9..c8ae632e 100644 --- a/connection/dial.go +++ b/edgediscovery/dial.go @@ -1,4 +1,4 @@ -package connection +package edgediscovery import ( "context" diff --git a/connection/mocks_for_test.go b/edgediscovery/mocks_for_test.go similarity index 99% rename from connection/mocks_for_test.go rename to edgediscovery/mocks_for_test.go index 03c08d8b..2db110ae 100644 --- a/connection/mocks_for_test.go +++ b/edgediscovery/mocks_for_test.go @@ -1,4 +1,4 @@ -package connection +package edgediscovery import ( "fmt" diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go index 15203423..d8079d07 100644 --- a/h2mux/activestreammap.go +++ b/h2mux/activestreammap.go @@ -7,6 +7,19 @@ import ( "golang.org/x/net/http2" ) +var ( + ActiveStreams = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "cloudflared", + Subsystem: "tunnel", + Name: "active_streams", + Help: "Number of active streams created by all muxers.", + }) +) + +func init() { + prometheus.MustRegister(ActiveStreams) +} + // activeStreamMap is used to moderate access to active streams between the read and write // threads, and deny access to new peer streams while shutting down. type activeStreamMap struct { diff --git a/h2mux/activestreammap_test.go b/h2mux/activestreammap_test.go index f961bcaf..0395b79b 100644 --- a/h2mux/activestreammap_test.go +++ b/h2mux/activestreammap_test.go @@ -9,7 +9,7 @@ import ( func TestShutdown(t *testing.T) { const numStreams = 1000 - m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) + m := newActiveStreamMap(true, ActiveStreams) // Add all the streams { @@ -62,7 +62,7 @@ func TestShutdown(t *testing.T) { func TestEmptyBeforeShutdown(t *testing.T) { const numStreams = 1000 - m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) + m := newActiveStreamMap(true, ActiveStreams) // Add all the streams { @@ -138,7 +138,7 @@ func (_ *noopReadyList) Signal(streamID uint32) {} func TestAbort(t *testing.T) { const numStreams = 1000 - m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) + m := newActiveStreamMap(true, ActiveStreams) var openedStreams sync.Map diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 97312047..3bef2d6c 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -113,11 +113,11 @@ func (p *DefaultMuxerPair) Handshake(testName string) error { defer cancel() errGroup, _ := errgroup.WithContext(ctx) errGroup.Go(func() (err error) { - p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, NewActiveStreamsMetrics(testName, "edge")) + p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, ActiveStreams) return errors.Wrap(err, "edge handshake failure") }) errGroup.Go(func() (err error) { - p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, NewActiveStreamsMetrics(testName, "origin")) + p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, ActiveStreams) return errors.Wrap(err, "origin handshake failure") }) diff --git a/h2mux/muxmetrics.go b/h2mux/muxmetrics.go index 2dcb6179..e125da87 100644 --- a/h2mux/muxmetrics.go +++ b/h2mux/muxmetrics.go @@ -6,7 +6,6 @@ import ( "github.com/cloudflare/cloudflared/logger" "github.com/golang-collections/collections/queue" - "github.com/prometheus/client_golang/prometheus" ) // data points used to compute average receive window and send window size @@ -295,14 +294,3 @@ func (r *rate) get() (curr, min, max uint64) { defer r.lock.RUnlock() return r.curr, r.min, r.max } - -func NewActiveStreamsMetrics(namespace, subsystem string) prometheus.Gauge { - activeStreams := prometheus.NewGauge(prometheus.GaugeOpts{ - Namespace: namespace, - Subsystem: subsystem, - Name: "active_streams", - Help: "Number of active streams created by all muxers.", - }) - prometheus.MustRegister(activeStreams) - return activeStreams -} diff --git a/origin/connection.go b/origin/connection.go deleted file mode 100644 index 11a930d7..00000000 --- a/origin/connection.go +++ /dev/null @@ -1,14 +0,0 @@ -package origin - -import ( - "net" -) - -// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called -type persistentConn struct { - net.Conn -} - -func (pc *persistentConn) Close() error { - return nil -} diff --git a/origin/metrics.go b/origin/metrics.go index 4021041d..edf2cab6 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -1,540 +1,63 @@ package origin import ( - "sync" - "time" - - "github.com/cloudflare/cloudflared/h2mux" - + "github.com/cloudflare/cloudflared/connection" "github.com/prometheus/client_golang/prometheus" ) -const ( - metricsNamespace = "cloudflared" - tunnelSubsystem = "tunnel" - muxerSubsystem = "muxer" -) +// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem +// (tunnel) as subsystem to keep them consistent with the previous qualifier. -type muxerMetrics struct { - rtt *prometheus.GaugeVec - rttMin *prometheus.GaugeVec - rttMax *prometheus.GaugeVec - receiveWindowAve *prometheus.GaugeVec - sendWindowAve *prometheus.GaugeVec - receiveWindowMin *prometheus.GaugeVec - receiveWindowMax *prometheus.GaugeVec - sendWindowMin *prometheus.GaugeVec - sendWindowMax *prometheus.GaugeVec - inBoundRateCurr *prometheus.GaugeVec - inBoundRateMin *prometheus.GaugeVec - inBoundRateMax *prometheus.GaugeVec - outBoundRateCurr *prometheus.GaugeVec - outBoundRateMin *prometheus.GaugeVec - outBoundRateMax *prometheus.GaugeVec - compBytesBefore *prometheus.GaugeVec - compBytesAfter *prometheus.GaugeVec - compRateAve *prometheus.GaugeVec -} - -type TunnelMetrics struct { - haConnections prometheus.Gauge - activeStreams prometheus.Gauge - totalRequests prometheus.Counter - requestsPerTunnel *prometheus.CounterVec - // concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests - concurrentRequestsLock sync.Mutex - concurrentRequestsPerTunnel *prometheus.GaugeVec - // concurrentRequests records count of concurrent requests for each tunnel - concurrentRequests map[string]uint64 - maxConcurrentRequestsPerTunnel *prometheus.GaugeVec - // concurrentRequests records max count of concurrent requests for each tunnel - maxConcurrentRequests map[string]uint64 - timerRetries prometheus.Gauge - responseByCode *prometheus.CounterVec - responseCodePerTunnel *prometheus.CounterVec - serverLocations *prometheus.GaugeVec - // locationLock is a mutex for oldServerLocations - locationLock sync.Mutex - // oldServerLocations stores the last server the tunnel was connected to - oldServerLocations map[string]string - - regSuccess *prometheus.CounterVec - regFail *prometheus.CounterVec - rpcFail *prometheus.CounterVec - - muxerMetrics *muxerMetrics - tunnelsHA tunnelsForHA - userHostnamesCounts *prometheus.CounterVec -} - -func newMuxerMetrics() *muxerMetrics { - rtt := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "rtt", - Help: "Round-trip time in millisecond", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(rtt) - - rttMin := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "rtt_min", - Help: "Shortest round-trip time in millisecond", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(rttMin) - - rttMax := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "rtt_max", - Help: "Longest round-trip time in millisecond", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(rttMax) - - receiveWindowAve := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "receive_window_ave", - Help: "Average receive window size in bytes", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(receiveWindowAve) - - sendWindowAve := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "send_window_ave", - Help: "Average send window size in bytes", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(sendWindowAve) - - receiveWindowMin := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "receive_window_min", - Help: "Smallest receive window size in bytes", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(receiveWindowMin) - - receiveWindowMax := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "receive_window_max", - Help: "Largest receive window size in bytes", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(receiveWindowMax) - - sendWindowMin := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "send_window_min", - Help: "Smallest send window size in bytes", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(sendWindowMin) - - sendWindowMax := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "send_window_max", - Help: "Largest send window size in bytes", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(sendWindowMax) - - inBoundRateCurr := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "inbound_bytes_per_sec_curr", - Help: "Current inbounding bytes per second, 0 if there is no incoming connection", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(inBoundRateCurr) - - inBoundRateMin := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "inbound_bytes_per_sec_min", - Help: "Minimum non-zero inbounding bytes per second", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(inBoundRateMin) - - inBoundRateMax := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "inbound_bytes_per_sec_max", - Help: "Maximum inbounding bytes per second", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(inBoundRateMax) - - outBoundRateCurr := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "outbound_bytes_per_sec_curr", - Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(outBoundRateCurr) - - outBoundRateMin := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "outbound_bytes_per_sec_min", - Help: "Minimum non-zero outbounding bytes per second", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(outBoundRateMin) - - outBoundRateMax := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "outbound_bytes_per_sec_max", - Help: "Maximum outbounding bytes per second", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(outBoundRateMax) - - compBytesBefore := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "comp_bytes_before", - Help: "Bytes sent via cross-stream compression, pre compression", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(compBytesBefore) - - compBytesAfter := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "comp_bytes_after", - Help: "Bytes sent via cross-stream compression, post compression", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(compBytesAfter) - - compRateAve := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: muxerSubsystem, - Name: "comp_rate_ave", - Help: "Average outbound cross-stream compression ratio", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(compRateAve) - - return &muxerMetrics{ - rtt: rtt, - rttMin: rttMin, - rttMax: rttMax, - receiveWindowAve: receiveWindowAve, - sendWindowAve: sendWindowAve, - receiveWindowMin: receiveWindowMin, - receiveWindowMax: receiveWindowMax, - sendWindowMin: sendWindowMin, - sendWindowMax: sendWindowMax, - inBoundRateCurr: inBoundRateCurr, - inBoundRateMin: inBoundRateMin, - inBoundRateMax: inBoundRateMax, - outBoundRateCurr: outBoundRateCurr, - outBoundRateMin: outBoundRateMin, - outBoundRateMax: outBoundRateMax, - compBytesBefore: compBytesBefore, - compBytesAfter: compBytesAfter, - compRateAve: compRateAve, - } -} - -func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) { - m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT)) - m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin)) - m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax)) - m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve) - m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve) - m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin)) - m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax)) - m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin)) - m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax)) - m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr)) - m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin)) - m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax)) - m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr)) - m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin)) - m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax)) - m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value())) - m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value())) - m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve())) -} - -func convertRTTMilliSec(t time.Duration) float64 { - return float64(t / time.Millisecond) -} - -// Metrics that can be collected without asking the edge -func NewTunnelMetrics() *TunnelMetrics { - haConnections := prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "ha_connections", - Help: "Number of active ha connections", - }) - prometheus.MustRegister(haConnections) - - activeStreams := h2mux.NewActiveStreamsMetrics(metricsNamespace, tunnelSubsystem) - - totalRequests := prometheus.NewCounter( +var ( + totalRequests = prometheus.NewCounter( prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, Name: "total_requests", Help: "Amount of requests proxied through all the tunnels", - }) - prometheus.MustRegister(totalRequests) - - requestsPerTunnel := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "requests_per_tunnel", - Help: "Amount of requests proxied through each tunnel", }, - []string{"connection_id"}, ) - prometheus.MustRegister(requestsPerTunnel) - - concurrentRequestsPerTunnel := prometheus.NewGaugeVec( + concurrentRequests = prometheus.NewGauge( prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, Name: "concurrent_requests_per_tunnel", Help: "Concurrent requests proxied through each tunnel", }, - []string{"connection_id"}, ) - prometheus.MustRegister(concurrentRequestsPerTunnel) - - maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "max_concurrent_requests_per_tunnel", - Help: "Largest number of concurrent requests proxied through each tunnel so far", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(maxConcurrentRequestsPerTunnel) - - timerRetries := prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "timer_retries", - Help: "Unacknowledged heart beats count", - }) - prometheus.MustRegister(timerRetries) - - responseByCode := prometheus.NewCounterVec( + responseByCode = prometheus.NewCounterVec( prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, Name: "response_by_code", Help: "Count of responses by HTTP status code", }, []string{"status_code"}, ) - prometheus.MustRegister(responseByCode) - - responseCodePerTunnel := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "response_code_per_tunnel", - Help: "Count of responses by HTTP status code fore each tunnel", - }, - []string{"connection_id", "status_code"}, - ) - prometheus.MustRegister(responseCodePerTunnel) - - serverLocations := prometheus.NewGaugeVec( + haConnections = prometheus.NewGauge( prometheus.GaugeOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "server_locations", - Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.", + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, + Name: "ha_connections", + Help: "Number of active ha connections", }, - []string{"connection_id", "location"}, ) - prometheus.MustRegister(serverLocations) +) - rpcFail := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_rpc_fail", - Help: "Count of RPC connection errors by type", - }, - []string{"error", "rpcName"}, +func init() { + prometheus.MustRegister( + totalRequests, + concurrentRequests, + responseByCode, + haConnections, ) - prometheus.MustRegister(rpcFail) - - registerFail := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_register_fail", - Help: "Count of tunnel registration errors by type", - }, - []string{"error", "rpcName"}, - ) - prometheus.MustRegister(registerFail) - - userHostnamesCounts := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "user_hostnames_counts", - Help: "Which user hostnames cloudflared is serving", - }, - []string{"userHostname"}, - ) - prometheus.MustRegister(userHostnamesCounts) - - registerSuccess := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_register_success", - Help: "Count of successful tunnel registrations", - }, - []string{"rpcName"}, - ) - prometheus.MustRegister(registerSuccess) - - return &TunnelMetrics{ - haConnections: haConnections, - activeStreams: activeStreams, - totalRequests: totalRequests, - requestsPerTunnel: requestsPerTunnel, - concurrentRequestsPerTunnel: concurrentRequestsPerTunnel, - concurrentRequests: make(map[string]uint64), - maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel, - maxConcurrentRequests: make(map[string]uint64), - timerRetries: timerRetries, - responseByCode: responseByCode, - responseCodePerTunnel: responseCodePerTunnel, - serverLocations: serverLocations, - oldServerLocations: make(map[string]string), - muxerMetrics: newMuxerMetrics(), - tunnelsHA: NewTunnelsForHA(), - regSuccess: registerSuccess, - regFail: registerFail, - rpcFail: rpcFail, - userHostnamesCounts: userHostnamesCounts, - } } -func (t *TunnelMetrics) incrementHaConnections() { - t.haConnections.Inc() +func incrementRequests() { + totalRequests.Inc() + concurrentRequests.Inc() } -func (t *TunnelMetrics) decrementHaConnections() { - t.haConnections.Dec() -} - -func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) { - t.muxerMetrics.update(connectionID, metrics) -} - -func (t *TunnelMetrics) incrementRequests(connectionID string) { - t.concurrentRequestsLock.Lock() - var concurrentRequests uint64 - var ok bool - if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok { - t.concurrentRequests[connectionID]++ - concurrentRequests++ - } else { - t.concurrentRequests[connectionID] = 1 - concurrentRequests = 1 - } - if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok { - t.maxConcurrentRequests[connectionID] = concurrentRequests - t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests)) - } - t.concurrentRequestsLock.Unlock() - - t.totalRequests.Inc() - t.requestsPerTunnel.WithLabelValues(connectionID).Inc() - t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc() -} - -func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { - t.concurrentRequestsLock.Lock() - if _, ok := t.concurrentRequests[connectionID]; ok { - t.concurrentRequests[connectionID]-- - } - t.concurrentRequestsLock.Unlock() - - t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec() -} - -func (t *TunnelMetrics) incrementResponses(connectionID, code string) { - t.responseByCode.WithLabelValues(code).Inc() - t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc() - -} - -func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) { - t.locationLock.Lock() - defer t.locationLock.Unlock() - if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { - return - } else if ok { - t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec() - } - t.serverLocations.WithLabelValues(connectionID, loc).Inc() - t.oldServerLocations[connectionID] = loc +func decrementConcurrentRequests() { + concurrentRequests.Dec() } diff --git a/origin/metrics_test.go b/origin/metrics_test.go deleted file mode 100644 index b6cc8206..00000000 --- a/origin/metrics_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package origin - -import ( - "strconv" - "sync" - "testing" - - "github.com/stretchr/testify/assert" -) - -// can only be called once -var m = NewTunnelMetrics() - -func TestConcurrentRequestsSingleTunnel(t *testing.T) { - routines := 20 - var wg sync.WaitGroup - wg.Add(routines) - for i := 0; i < routines; i++ { - go func() { - m.incrementRequests("0") - wg.Done() - }() - } - wg.Wait() - assert.Len(t, m.concurrentRequests, 1) - assert.Equal(t, uint64(routines), m.concurrentRequests["0"]) - assert.Len(t, m.maxConcurrentRequests, 1) - assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) - - wg.Add(routines / 2) - for i := 0; i < routines/2; i++ { - go func() { - m.decrementConcurrentRequests("0") - wg.Done() - }() - } - wg.Wait() - assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"]) - assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) -} - -func TestConcurrentRequestsMultiTunnel(t *testing.T) { - m.concurrentRequests = make(map[string]uint64) - m.maxConcurrentRequests = make(map[string]uint64) - tunnels := 20 - var wg sync.WaitGroup - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - // if we have j < i, then tunnel 0 won't have a chance to call incrementRequests - for j := 0; j < i+1; j++ { - id := strconv.Itoa(i) - m.incrementRequests(id) - } - wg.Done() - }(i) - } - wg.Wait() - - assert.Len(t, m.concurrentRequests, tunnels) - assert.Len(t, m.maxConcurrentRequests, tunnels) - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, uint64(i+1), m.concurrentRequests[id]) - assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) - } - - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - for j := 0; j < i+1; j++ { - id := strconv.Itoa(i) - m.decrementConcurrentRequests(id) - } - wg.Done() - }(i) - } - wg.Wait() - - assert.Len(t, m.concurrentRequests, tunnels) - assert.Len(t, m.maxConcurrentRequests, tunnels) - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, uint64(0), m.concurrentRequests[id]) - assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) - } - -} - -func TestRegisterServerLocation(t *testing.T) { - tunnels := 20 - var wg sync.WaitGroup - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - id := strconv.Itoa(i) - m.registerServerLocation(id, "LHR") - wg.Done() - }(i) - } - wg.Wait() - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, "LHR", m.oldServerLocations[id]) - } - - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - id := strconv.Itoa(i) - m.registerServerLocation(id, "AUS") - wg.Done() - }(i) - } - wg.Wait() - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, "AUS", m.oldServerLocations[id]) - } - -} diff --git a/origin/proxy.go b/origin/proxy.go new file mode 100644 index 00000000..638aee92 --- /dev/null +++ b/origin/proxy.go @@ -0,0 +1,208 @@ +package origin + +import ( + "bufio" + "crypto/tls" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/cloudflare/cloudflared/buffer" + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/logger" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" + "github.com/pkg/errors" +) + +const ( + TagHeaderNamePrefix = "Cf-Warp-Tag-" +) + +type client struct { + config *ProxyConfig + logger logger.Service + bufferPool *buffer.Pool +} + +func NewClient(config *ProxyConfig, logger logger.Service) connection.OriginClient { + return &client{ + config: config, + logger: logger, + bufferPool: buffer.NewPool(512 * 1024), + } +} + +type ProxyConfig struct { + Client http.RoundTripper + URL *url.URL + TLSConfig *tls.Config + HostHeader string + NoChunkedEncoding bool + Tags []tunnelpogs.Tag +} + +func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error { + incrementRequests() + defer decrementConcurrentRequests() + + cfRay := findCfRayHeader(req) + lbProbe := isLBProbeRequest(req) + + c.appendTagHeaders(req) + c.logRequest(req, cfRay, lbProbe) + var ( + resp *http.Response + err error + ) + if isWebsocket { + resp, err = c.proxyWebsocket(w, req) + } else { + resp, err = c.proxyHTTP(w, req) + } + if err != nil { + c.logger.Errorf("HTTP request error: %s", err) + responseByCode.WithLabelValues("502").Inc() + w.WriteErrorResponse(err) + return err + } + c.logResponseOk(resp, cfRay, lbProbe) + return nil +} + +func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { + // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate + if c.config.NoChunkedEncoding { + req.TransferEncoding = []string{"gzip", "deflate"} + cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) + if err == nil { + req.ContentLength = int64(cLength) + } + } + + // Request origin to keep connection alive to improve performance + req.Header.Set("Connection", "keep-alive") + + c.setHostHeader(req) + + resp, err := c.config.Client.RoundTrip(req) + if err != nil { + return nil, errors.Wrap(err, "Error proxying request to origin") + } + defer resp.Body.Close() + + err = w.WriteRespHeaders(resp) + if err != nil { + return nil, errors.Wrap(err, "Error writing response header") + } + if isEventStream(resp) { + //h.observer.Debug("Detected Server-Side Events from Origin") + c.writeEventStream(w, resp.Body) + } else { + // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream + // compression generates dictionary on first write + buf := c.bufferPool.Get() + defer c.bufferPool.Put(buf) + io.CopyBuffer(w, resp.Body, buf) + } + return resp, nil +} + +func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { + c.setHostHeader(req) + + conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig) + if err != nil { + return nil, err + } + defer conn.Close() + err = w.WriteRespHeaders(resp) + if err != nil { + return nil, errors.Wrap(err, "Error writing response header") + } + // Copy to/from stream to the undelying connection. Use the underlying + // connection because cloudflared doesn't operate on the message themselves + websocket.Stream(conn.UnderlyingConn(), w) + + return resp, nil +} + +func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { + reader := bufio.NewReader(respBody) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + break + } + w.Write(line) + } +} + +func (c *client) setHostHeader(req *http.Request) { + if c.config.HostHeader != "" { + req.Header.Set("Host", c.config.HostHeader) + req.Host = c.config.HostHeader + } +} + +func (c *client) appendTagHeaders(r *http.Request) { + for _, tag := range c.config.Tags { + r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) + } +} + +func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) { + if cfRay != "" { + c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) + } else if lbProbe { + c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) + } else { + c.logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) + } + c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) + } else { + c.logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) + } +} + +func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { + responseByCode.WithLabelValues("200").Inc() + if cfRay != "" { + c.logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) + } else if lbProbe { + c.logger.Debugf("Response to Load Balancer health check %s", r.Status) + } else { + c.logger.Infof("%s", r.Status) + } + c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + c.logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) + } else { + c.logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) + } +} + +func findCfRayHeader(req *http.Request) string { + return req.Header.Get("Cf-Ray") +} + +func isLBProbeRequest(req *http.Request) bool { + return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) +} + +func uint8ToString(input uint8) string { + return strconv.FormatUint(uint64(input), 10) +} + +func isEventStream(response *http.Response) bool { + if response.Header.Get("content-type") == "text/event-stream" { + return true + } + return false +} diff --git a/origin/reconnect.go b/origin/reconnect.go index eaae44bc..5b64f38c 100644 --- a/origin/reconnect.go +++ b/origin/reconnect.go @@ -7,11 +7,7 @@ import ( "sync" "time" - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" ) @@ -138,52 +134,3 @@ func (cm *reconnectCredentialManager) RefreshAuth( return nil, err } } - -func ReconnectTunnel( - ctx context.Context, - muxer *h2mux.Muxer, - config *TunnelConfig, - logger logger.Service, - connectionID uint8, - originLocalAddr string, - uuid uuid.UUID, - credentialManager *reconnectCredentialManager, -) error { - token, err := credentialManager.ReconnectToken() - if err != nil { - return err - } - eventDigest, err := credentialManager.EventDigest(connectionID) - if err != nil { - return err - } - connDigest, err := credentialManager.ConnDigest(connectionID) - if err != nil { - return err - } - - config.TransportLogger.Debug("initiating RPC stream to reconnect") - rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect) - if err != nil { - return err - } - defer rpcClient.Close() - // Request server info without blocking tunnel registration; must use capnp library directly. - serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { - return nil - }) - LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan) - registration := rpcClient.ReconnectTunnel( - ctx, - token, - eventDigest, - connDigest, - config.Hostname, - config.RegistrationOptions(connectionID, originLocalAddr, uuid), - ) - if registrationErr := registration.DeserializeError(); registrationErr != nil { - // ReconnectTunnel RPC failure - return processRegisterTunnelError(registrationErr, config.Metrics, reconnect) - } - return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager) -} diff --git a/origin/supervisor.go b/origin/supervisor.go index 1569aecf..af81378a 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -8,7 +8,6 @@ import ( "github.com/google/uuid" - "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" @@ -56,8 +55,7 @@ type Supervisor struct { logger logger.Service reconnectCredentialManager *reconnectCredentialManager - - bufferPool *buffer.Pool + useReconnectToken bool } type resolveResult struct { @@ -76,28 +74,33 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor err error ) if len(config.EdgeAddrs) > 0 { - edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs) + edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs) } else { - edgeIPs, err = edgediscovery.ResolveEdge(config.Logger) + edgeIPs, err = edgediscovery.ResolveEdge(config.Observer) } if err != nil { return nil, err } + useReconnectToken := false + if config.ClassicTunnel != nil { + useReconnectToken = config.ClassicTunnel.UseReconnectToken + } + return &Supervisor{ cloudflaredUUID: cloudflaredUUID, config: config, edgeIPs: edgeIPs, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, - logger: config.Logger, - reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections), - bufferPool: buffer.NewPool(512 * 1024), + logger: config.Observer, + reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections), + useReconnectToken: useReconnectToken, }, nil } func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { - logger := s.config.Logger + logger := s.config.Observer if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { return err } @@ -110,7 +113,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} var refreshAuthBackoffTimer <-chan time.Time - if s.config.UseReconnectToken { + if s.useReconnectToken { if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil { refreshAuthBackoffTimer = timer } else { @@ -227,7 +230,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } - err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) + err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh) // If the first tunnel disconnects, keep restarting it. edgeErrors := 0 for s.unusedIPs() { @@ -239,7 +242,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return // try the next address if it was a dialError(network problem) or // dupConnRegisterTunnelError - case connection.DialError, dupConnRegisterTunnelError: + case edgediscovery.DialError, connection.DupConnRegisterTunnelError: edgeErrors++ default: return @@ -250,7 +253,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } } - err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) + err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh) } } @@ -269,7 +272,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal if err != nil { return } - err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) + err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, reconnectCh) } func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { @@ -301,7 +304,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) return nil, err } - edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP) if err != nil { return nil, err } @@ -311,8 +314,8 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) // This callback is invoked by h2mux when the edge initiates a stream. return nil // noop }) - muxerConfig := s.config.muxerConfig(handler) - muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams) + muxerConfig := s.config.MuxerConfig.H2MuxerConfig(handler, s.logger) + muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig, h2mux.ActiveStreams) if err != nil { return nil, err } @@ -323,23 +326,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) <-muxer.Shutdown() }() - rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate) + stream, err := muxer.OpenRPCStream(ctx) if err != nil { return nil, err } + rpcClient := connection.NewTunnelServerClient(ctx, stream, s.logger) defer rpcClient.Close() const arbitraryConnectionID = uint8(0) registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) - authResponse, err := rpcClient.Authenticate( - ctx, - s.config.OriginCert, - s.config.Hostname, - registrationOptions, - ) - if err != nil { - return nil, err - } - return authResponse.Outcome(), nil + return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions) } diff --git a/origin/tunnel.go b/origin/tunnel.go index 7bc36263..b98d2a47 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -5,9 +5,7 @@ import ( "crypto/tls" "fmt" "net" - "net/http" - "net/url" - "strconv" + "runtime/debug" "strings" "sync" "time" @@ -17,26 +15,22 @@ import ( "github.com/prometheus/client_golang/prometheus" "golang.org/x/sync/errgroup" - "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/websocket" ) const ( dialTimeout = 15 * time.Second - openStreamTimeout = 30 * time.Second muxerTimeout = 5 * time.Second lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" - TagHeaderNamePrefix = "Cf-Warp-Tag-" DuplicateConnectionError = "EDUPCONN" FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" @@ -52,49 +46,31 @@ const ( ) type TunnelConfig struct { - BuildInfo *buildinfo.BuildInfo - ClientID string - CloseConnOnce *sync.Once // Used to close connectedSignal no more than once - CompressionQuality uint64 - EdgeAddrs []string - GracePeriod time.Duration - HAConnections int - HeartbeatInterval time.Duration - Hostname string - IncidentLookup IncidentLookup - IsAutoupdated bool - IsFreeTunnel bool - LBPool string - Logger logger.Service - TransportLogger logger.Service - MaxHeartbeats uint64 - Metrics *TunnelMetrics - MetricsUpdateFreq time.Duration - OriginCert []byte - ReportedVersion string - Retries uint - RunFromTerminal bool - Tags []tunnelpogs.Tag - TlsConfig *tls.Config - WSGI bool + ConnectionConfig *connection.Config + ProxyConfig *ProxyConfig + BuildInfo *buildinfo.BuildInfo + ClientID string + CloseConnOnce *sync.Once // Used to close connectedSignal no more than once + EdgeAddrs []string + HAConnections int + IncidentLookup IncidentLookup + IsAutoupdated bool + IsFreeTunnel bool + LBPool string + Logger logger.Service + Observer *connection.Observer + ReportedVersion string + Retries uint + RunFromTerminal bool + TLSConfig *tls.Config - // feature-flag to use new edge reconnect tokens - UseReconnectToken bool - - NamedTunnel *NamedTunnelConfig - ReplaceExisting bool - TunnelEventChan chan<- ui.TunnelEvent + NamedTunnel *connection.NamedTunnelConfig + ClassicTunnel *connection.ClassicTunnelConfig + MuxerConfig *connection.MuxerConfig + TunnelEventChan chan ui.TunnelEvent IngressRules ingress.Ingress } -type dupConnRegisterTunnelError struct{} - -var errDuplicationConnection = &dupConnRegisterTunnelError{} - -func (e dupConnRegisterTunnelError) Error() string { - return "already connected to this server, trying another address" -} - type muxerShutdownError struct{} func (e muxerShutdownError) Error() string { @@ -125,18 +101,6 @@ func (e clientRegisterTunnelError) Error() string { return e.cause.Error() } -func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig { - return h2mux.MuxerConfig{ - Timeout: muxerTimeout, - Handler: handler, - IsClient: true, - HeartbeatInterval: c.HeartbeatInterval, - MaxHeartbeats: c.MaxHeartbeats, - Logger: c.TransportLogger, - CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality), - } -} - func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { policy := tunnelrpc.ExistingTunnelPolicy_balance if c.HAConnections <= 1 && c.LBPool == "" { @@ -148,12 +112,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch), ExistingTunnelPolicy: policy, PoolName: c.LBPool, - Tags: c.Tags, + Tags: c.ProxyConfig.Tags, ConnectionID: connectionID, OriginLocalIP: OriginLocalIP, IsAutoupdated: c.IsAutoupdated, RunFromTerminal: c.RunFromTerminal, - CompressionQuality: c.CompressionQuality, + CompressionQuality: uint64(c.MuxerConfig.CompressionSetting), UUID: uuid.String(), Features: c.SupportedFeatures(), } @@ -167,8 +131,8 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte return &tunnelpogs.ConnectionOptions{ Client: c.NamedTunnel.Client, OriginLocalIP: originIP, - ReplaceExisting: c.ReplaceExisting, - CompressionQuality: uint8(c.CompressionQuality), + ReplaceExisting: c.ConnectionConfig.ReplaceExisting, + CompressionQuality: uint8(c.MuxerConfig.CompressionSetting), NumPreviousAttempts: numPreviousAttempts, } } @@ -181,35 +145,6 @@ func (c *TunnelConfig) SupportedFeatures() []string { return features } -func (c *TunnelConfig) IsTrialTunnel() bool { - return c.Hostname == "" -} - -type NamedTunnelConfig struct { - Auth pogs.TunnelAuth - ID uuid.UUID - Client pogs.ClientInfo - Protocol Protocol -} - -type Protocol int64 - -const ( - h2muxProtocol Protocol = iota - http2Protocol -) - -func ParseProtocol(s string) (Protocol, bool) { - switch s { - case "h2mux": - return h2muxProtocol, true - case "http2": - return http2Protocol, true - default: - return 0, false - } -} - func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { s, err := NewSupervisor(config, cloudflaredID) if err != nil { @@ -225,11 +160,11 @@ func ServeTunnelLoop(ctx context.Context, connectionIndex uint8, connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, - bufferPool *buffer.Pool, reconnectCh chan ReconnectSignal, ) error { - config.Metrics.incrementHaConnections() - defer config.Metrics.decrementHaConnections() + haConnections.Inc() + defer haConnections.Dec() + backoff := BackoffHandler{MaxRetries: config.Retries} connectedFuse := h2mux.NewBooleanFuse() go func() { @@ -244,12 +179,10 @@ func ServeTunnelLoop(ctx context.Context, ctx, credentialManager, config, - config.Logger, addr, connectionIndex, connectedFuse, &backoff, cloudflaredUUID, - bufferPool, reconnectCh, ) if recoverable { @@ -257,7 +190,7 @@ func ServeTunnelLoop(ctx context.Context, if config.TunnelEventChan != nil { config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting} } - config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration) + config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err) backoff.Backoff(ctx) continue } @@ -270,13 +203,11 @@ func ServeTunnel( ctx context.Context, credentialManager *reconnectCredentialManager, config *TunnelConfig, - logger logger.Service, addr *net.TCPAddr, connectionIndex uint8, - connectedFuse *h2mux.BooleanFuse, + fuse *h2mux.BooleanFuse, backoff *BackoffHandler, cloudflaredUUID uuid.UUID, - bufferPool *buffer.Pool, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { // Treat panics as recoverable errors @@ -287,6 +218,7 @@ func ServeTunnel( if !ok { err = fmt.Errorf("ServeTunnel: %v", r) } + err = errors.Wrapf(err, "stack trace: %s", string(debug.Stack())) recoverable = true } }() @@ -298,203 +230,107 @@ func ServeTunnel( }() } - connectionTag := uint8ToString(connectionIndex) - - if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol { - return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh) - } - - // Returns error from parsing the origin URL or handshake errors - handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr) if err != nil { - switch err.(type) { - case connection.DialError: - logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err) - case h2mux.MuxerHandshakeError: - logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err) - default: - logger.Errorf("Connection %d failed: %s", connectionIndex, err) - return err, false - } return err, true } + connectedFuse := &connectedFuse{ + fuse: fuse, + backoff: backoff, + } + if config.NamedTunnel != nil && config.NamedTunnel.Protocol == connection.HTTP2 { + connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) + return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh) + } + return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh) +} + +func ServeH2mux( + ctx context.Context, + credentialManager *reconnectCredentialManager, + config *TunnelConfig, + edgeConn net.Conn, + connectionIndex uint8, + connectedFuse *connectedFuse, + cloudflaredUUID uuid.UUID, + reconnectCh chan ReconnectSignal, +) (err error, recoverable bool) { + // Returns error from parsing the origin URL or handshake errors + handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer) + if err != nil { + return err, recoverable + } errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() (err error) { - defer func() { - if err == nil { - connectedFuse.Fuse(true) - backoff.SetGracePeriod() - } - }() - - if config.UseReconnectToken && connectedFuse.Value() { - err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager) - if err == nil { - return nil - } - // log errors and proceed to RegisterTunnel - logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err) + if config.NamedTunnel != nil { + connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) + return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse) } - return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID) + registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) + return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) - errGroup.Go(func() error { - updateMetricsTickC := time.Tick(config.MetricsUpdateFreq) - for { - select { - case <-serveCtx.Done(): - // UnregisterTunnel blocks until the RPC call returns - if connectedFuse.Value() { - if config.NamedTunnel != nil { - _ = UnregisterConnection(ctx, handler.muxer, config) - } else { - _ = UnregisterTunnel(handler.muxer, config) - } - } - handler.muxer.Shutdown() - return nil - case <-updateMetricsTickC: - handler.UpdateMetrics(connectionTag) - } - } - }) - - errGroup.Go(func() error { - for { - select { - case reconnect := <-reconnectCh: - return &reconnect - case <-serveCtx.Done(): - return nil - } - } - }) - - errGroup.Go(func() error { - // All routines should stop when muxer finish serving. When muxer is shutdown - // gracefully, it doesn't return an error, so we need to return errMuxerShutdown - // here to notify other routines to stop - err := handler.muxer.Serve(serveCtx) - if err == nil { - return muxerShutdownError{} - } - return err - }) + errGroup.Go(listenReconnect(serveCtx, reconnectCh)) err = errGroup.Wait() if err != nil { switch err := err.(type) { - case *dupConnRegisterTunnelError: + case *connection.DupConnRegisterTunnelError: // don't retry this connection anymore, let supervisor pick new a address return err, false case *serverRegisterTunnelError: - logger.Errorf("Register tunnel error from server side: %s", err.cause) + config.Logger.Errorf("Register tunnel error from server side: %s", err.cause) // Don't send registration error return from server to Sentry. They are // logged on server side if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { - logger.Error(activeIncidentsMsg(incidents)) + config.Logger.Error(activeIncidentsMsg(incidents)) } return err.cause, !err.permanent case *clientRegisterTunnelError: - logger.Errorf("Register tunnel error on client side: %s", err.cause) + config.Logger.Errorf("Register tunnel error on client side: %s", err.cause) return err, true case *muxerShutdownError: - logger.Info("Muxer shutdown") + config.Logger.Info("Muxer shutdown") return err, true case *ReconnectSignal: - logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay) + config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay) err.DelayBeforeReconnect() return err, true default: if err == context.Canceled { - logger.Debugf("Serve tunnel error: %s", err) + config.Logger.Debugf("Serve tunnel error: %s", err) return err, false } - logger.Errorf("Serve tunnel error: %s", err) + config.Logger.Errorf("Serve tunnel error: %s", err) return err, true } } return nil, true } -func RegisterConnectionWithH2Mux( - ctx context.Context, - muxer *h2mux.Muxer, - config *TunnelConfig, - connectionIndex uint8, - originLocalAddr string, - numPreviousAttempts uint8, -) error { - const registerConnection = "registerConnection" - - config.TransportLogger.Debug("initiating RPC stream for RegisterConnection") - rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection) - if err != nil { - return err - } - defer rpcClient.Close() - - conn, err := rpcClient.RegisterConnection( - ctx, - config.NamedTunnel.Auth, - config.NamedTunnel.ID, - connectionIndex, - config.ConnectionOptions(originLocalAddr, numPreviousAttempts), - ) - if err != nil { - if err.Error() == DuplicateConnectionError { - config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc() - return errDuplicationConnection - } - config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc() - return serverRegistrationErrorFromRPC(err) - } - - config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc() - config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID) - - // If launch-ui flag is set, send connect msg - if config.TunnelEventChan != nil { - config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Connected, Location: conn.Location} - } - - return nil -} - -func ServeNamedTunnel( +func ServeHTTP2( ctx context.Context, config *TunnelConfig, + tlsServerConn net.Conn, + connOptions *tunnelpogs.ConnectionOptions, connIndex uint8, - addr *net.TCPAddr, - connectedFuse *h2mux.BooleanFuse, + connectedFuse connection.ConnectedFuse, reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { - tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) - if err != nil { - return err, true - } - - cfdServer, err := newHTTP2Server(config, connIndex, tlsServerConn.LocalAddr(), connectedFuse) + server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) if err != nil { return err, false } errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - cfdServer.serve(serveCtx, tlsServerConn) + server.Serve(serveCtx) return fmt.Errorf("Connection with edge closed") }) - errGroup.Go(func() error { - select { - case reconnect := <-reconnectCh: - return &reconnect - case <-serveCtx.Done(): - return nil - } - }) + errGroup.Go(listenReconnect(serveCtx, reconnectCh)) err = errGroup.Wait() if err != nil { @@ -503,229 +339,29 @@ func ServeNamedTunnel( return nil, false } -func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { - if retryable, ok := err.(*tunnelpogs.RetryableError); ok { - return &serverRegisterTunnelError{ - cause: retryable.Unwrap(), - permanent: false, +func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error { + return func() error { + select { + case reconnect := <-reconnectCh: + return &reconnect + case <-ctx.Done(): + return nil } } - return &serverRegisterTunnelError{ - cause: err, - permanent: true, - } } -func UnregisterConnection( - ctx context.Context, - muxer *h2mux.Muxer, - config *TunnelConfig, -) error { - config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection") - rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register) - if err != nil { - // RPC stream open error - return err - } - defer rpcClient.Close() - - return rpcClient.UnregisterConnection(ctx) +type connectedFuse struct { + fuse *h2mux.BooleanFuse + backoff *BackoffHandler } -func RegisterTunnel( - ctx context.Context, - credentialManager *reconnectCredentialManager, - muxer *h2mux.Muxer, - config *TunnelConfig, - logger logger.Service, - connectionID uint8, - originLocalIP string, - uuid uuid.UUID, -) error { - config.TransportLogger.Debug("initiating RPC stream to register") - if config.TunnelEventChan != nil { - config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel} - } - - rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register) - if err != nil { - return err - } - defer rpcClient.Close() - // Request server info without blocking tunnel registration; must use capnp library directly. - serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { - return nil - }) - LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan) - registration := rpcClient.RegisterTunnel( - ctx, - config.OriginCert, - config.Hostname, - config.RegistrationOptions(connectionID, originLocalIP, uuid), - ) - if registrationErr := registration.DeserializeError(); registrationErr != nil { - // RegisterTunnel RPC failure - return processRegisterTunnelError(registrationErr, config.Metrics, register) - } - - // Send free tunnel URL to UI - if config.TunnelEventChan != nil { - config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: registration.Url} - } - credentialManager.SetEventDigest(connectionID, registration.EventDigest) - return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager) +func (cf *connectedFuse) Connected() { + cf.fuse.Fuse(true) + cf.backoff.SetGracePeriod() } -func processRegistrationSuccess( - config *TunnelConfig, - logger logger.Service, - connectionID uint8, - registration *tunnelpogs.TunnelRegistration, - name rpcName, - credentialManager *reconnectCredentialManager, -) error { - for _, logLine := range registration.LogLines { - logger.Info(logLine) - } - - if registration.TunnelID != "" { - config.Metrics.tunnelsHA.AddTunnelID(connectionID, registration.TunnelID) - logger.Infof("Each HA connection's tunnel IDs: %v", config.Metrics.tunnelsHA.String()) - } - - // Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set) - if config.TunnelEventChan == nil { - if config.IsTrialTunnel() { - if registrationURL, err := url.Parse(registration.Url); err == nil { - for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) { - logger.Info(line) - } - } else { - logger.Error("Failed to connect tunnel, please try again.") - return fmt.Errorf("empty URL in response from Cloudflare edge") - } - } - } - - credentialManager.SetConnDigest(connectionID, registration.ConnDigest) - config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc() - - logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional") - config.Metrics.regSuccess.WithLabelValues(string(name)).Inc() - return nil -} - -func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error { - if err.Error() == DuplicateConnectionError { - metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc() - return errDuplicationConnection - } - metrics.regFail.WithLabelValues("server_error", string(name)).Inc() - return serverRegisterTunnelError{ - cause: err, - permanent: err.IsPermanent(), - } -} - -func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error { - config.TransportLogger.Debug("initiating RPC stream to unregister") - ctx := context.Background() - rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister) - if err != nil { - // RPC stream open error - return err - } - defer rpcClient.Close() - - // gracePeriod is encoded in int64 using capnproto - return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds()) -} - -func LogServerInfo( - promise tunnelrpc.ServerInfo_Promise, - connectionID uint8, - metrics *TunnelMetrics, - logger logger.Service, - tunnelEventChan chan<- ui.TunnelEvent, -) { - serverInfoMessage, err := promise.Struct() - if err != nil { - logger.Errorf("Failed to retrieve server information: %s", err) - return - } - serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage) - if err != nil { - logger.Errorf("Failed to retrieve server information: %s", err) - return - } - // If launch-ui flag is set, send connect msg - if tunnelEventChan != nil { - tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: serverInfo.LocationName} - } - logger.Infof("Connected to %s", serverInfo.LocationName) - metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) -} - -func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) { - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - req.Header.Set("Host", hostHeader) - req.Host = hostHeader - } - - dialler, ok := rule.Service.(websocket.Dialler) - if !ok { - return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service) - } - conn, response, err := websocket.ClientConnect(req, dialler) - if err != nil { - return nil, err - } - defer conn.Close() - err = wsResp.WriteRespHeaders(response) - if err != nil { - return nil, errors.Wrap(err, "Error writing response header") - } - // Copy to/from stream to the undelying connection. Use the underlying - // connection because cloudflared doesn't operate on the message themselves - websocket.Stream(conn.UnderlyingConn(), wsResp) - - return response, nil -} - -func uint8ToString(input uint8) string { - return strconv.FormatUint(uint64(input), 10) -} - -// Print out the given lines in a nice ASCII box. -func asciiBox(lines []string, padding int) (box []string) { - maxLen := maxLen(lines) - spacer := strings.Repeat(" ", padding) - - border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+" - - box = append(box, border) - for _, line := range lines { - box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|") - } - box = append(box, border) - return -} - -func maxLen(lines []string) int { - max := 0 - for _, line := range lines { - if len(line) > max { - max = len(line) - } - } - return max -} - -func trialZoneMsg(url string) []string { - return []string{ - "Your free tunnel has started! Visit it:", - " " + url, - } +func (cf *connectedFuse) IsConnected() bool { + return cf.fuse.Value() } func activeIncidentsMsg(incidents []Incident) string { @@ -741,26 +377,3 @@ func activeIncidentsMsg(incidents []Incident) string { return preamble + " " + strings.Join(incidentStrings, "; ") } - -func findCfRayHeader(h1 *http.Request) string { - return h1.Header.Get("Cf-Ray") -} - -func isLBProbeRequest(req *http.Request) bool { - return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) -} - -func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) { - openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout) - defer openStreamCancel() - stream, err := muxer.OpenRPCStream(openStreamCtx) - if err != nil { - return tunnelpogs.TunnelServer_PogsClient{}, err - } - rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger) - if err != nil { - // RPC stream open error - return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName) - } - return rpcClient, nil -} diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go index 357d009d..1a43298f 100644 --- a/tlsconfig/certreloader.go +++ b/tlsconfig/certreloader.go @@ -17,11 +17,6 @@ import ( const ( OriginCAPoolFlag = "origin-ca-pool" CaCertFlag = "cacert" - - // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge - edgeH2muxTLSServerName = "cftunnel.com" - // edgeH2TLSServerName is the server name to establish http2 connection with edge - edgeH2TLSServerName = "h2.cftunnel.com" ) // CertReloader can load and reload a TLS certificate from a particular filepath. @@ -123,16 +118,12 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) { return certPool, nil } -func CreateTunnelConfig(c *cli.Context, isNamedTunnel bool) (*tls.Config, error) { +func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) { var rootCAs []string if c.String(CaCertFlag) != "" { rootCAs = append(rootCAs, c.String(CaCertFlag)) } - serverName := edgeH2muxTLSServerName - if isNamedTunnel { - serverName = edgeH2TLSServerName - } userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName} tlsConfig, err := GetConfig(userConfig) if err != nil { diff --git a/tunnelrpc/pogs/connectionrpc.go b/tunnelrpc/pogs/connectionrpc.go index 2cd927e8..997a9bc3 100644 --- a/tunnelrpc/pogs/connectionrpc.go +++ b/tunnelrpc/pogs/connectionrpc.go @@ -93,6 +93,11 @@ type RegistrationServer_PogsClient struct { Conn *rpc.Conn } +func (c RegistrationServer_PogsClient) Close() error { + c.Client.Close() + return c.Conn.Close() +} + func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error { diff --git a/validation/validation.go b/validation/validation.go index 3007aea1..e22d206c 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -66,7 +66,15 @@ func ValidateHostname(hostname string) (string, error) { // but when it does not, the path is preserved: // ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/" // This is arguably a bug, but changing it might break some cloudflared users. -func ValidateUrl(originUrl string) (string, error) { +func ValidateUrl(originUrl string) (*url.URL, error) { + urlStr, err := validateUrlString(originUrl) + if err != nil { + return nil, err + } + return url.Parse(urlStr) +} + +func validateUrlString(originUrl string) (string, error) { if originUrl == "" { return "", fmt.Errorf("URL should not be empty") } @@ -157,12 +165,8 @@ func validateIP(scheme, host, port string) (string, error) { return fmt.Sprintf("%s://%s", scheme, host), nil } -func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error { - parsedURL, err := url.Parse(originURL) - if err != nil { - return err - } - +// originURL shouldn't be a pointer, because this function might change the scheme +func ValidateHTTPService(originURL url.URL, hostname string, transport http.RoundTripper) error { client := &http.Client{ Transport: transport, CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -171,7 +175,7 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round Timeout: validationTimeout, } - initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) + initialRequest, err := http.NewRequest("GET", originURL.String(), nil) if err != nil { return err } @@ -183,10 +187,10 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round } // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? - oldScheme := parsedURL.Scheme - parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) + oldScheme := originURL.Scheme + originURL.Scheme = toggleProtocol(originURL.Scheme) - secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil) + secondRequest, err := http.NewRequest("GET", originURL.String(), nil) if err != nil { return err } @@ -195,12 +199,12 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round if secondErr == nil { // Worked this time--advise the user to switch protocols resp.Body.Close() return errors.Errorf( - "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s", - parsedURL.Host, + "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v", + originURL.Host, oldScheme, - parsedURL.Scheme, + originURL.Scheme, initialErr, - parsedURL, + originURL, ) } @@ -224,12 +228,12 @@ type Access struct { } func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) { - domainURL, err := ValidateUrl(domain) + domainURL, err := validateUrlString(domain) if err != nil { return nil, err } - issuerURL, err := ValidateUrl(issuer) + issuerURL, err := validateUrlString(issuer) if err != nil { return nil, err } diff --git a/validation/validation_test.go b/validation/validation_test.go index b6ae8bf4..0745b085 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -101,7 +101,7 @@ func TestValidateUrl(t *testing.T) { for i, testCase := range testCases { validUrl, err := ValidateUrl(testCase.input) assert.NoError(t, err, "test case %v", i) - assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i) + assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i) } validUrl, err := ValidateUrl("") @@ -123,7 +123,7 @@ func TestToggleProtocol(t *testing.T) { // Happy path 1: originURL is HTTP, and HTTP connections work func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { - originURL := "http://127.0.0.1/" + originURL := mustParse(t, "http://127.0.0.1/") hostname := "example.com" assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { @@ -151,7 +151,7 @@ func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { // Happy path 2: originURL is HTTPS, and HTTPS connections work func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { - originURL := "https://127.0.0.1/" + originURL := mustParse(t, "https://127.0.0.1:1234/") hostname := "example.com" assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { @@ -179,7 +179,7 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { // Error path 1: originURL is HTTPS, but HTTP connections work func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { - originURL := "https://127.0.0.1:1234/" + originURL := mustParse(t, "https://127.0.0.1:1234/") hostname := "example.com" assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { @@ -207,10 +207,13 @@ func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { // Error path 2: originURL is HTTP, but HTTPS connections work func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { - originURL := "http://127.0.0.1:1234/" + originURLWithPort := url.URL{ + Scheme: "http", + Host: "127.0.0.1:1234", + } hostname := "example.com" - assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Equal(t, req.Host, hostname) if req.URL.Scheme == "http" { return nil, assert.AnError @@ -221,7 +224,7 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { panic("Shouldn't reach here") }))) - assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Equal(t, req.Host, hostname) if req.URL.Scheme == "http" { return nil, assert.AnError @@ -250,12 +253,14 @@ func TestValidateHTTPService_NoFollowRedirects(t *testing.T) { })) assert.NoError(t, err) defer redirectServer.Close() - assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport)) + redirectServerURL, err := url.Parse(redirectServer.URL) + assert.NoError(t, err) + assert.NoError(t, ValidateHTTPService(*redirectServerURL, hostname, redirectClient.Transport)) } // Ensure validation times out when origin URL is nonresponsive func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) { - originURL := "http://127.0.0.1/" + originURL := mustParse(t, "http://127.0.0.1/") hostname := "example.com" oldValidationTimeout := validationTimeout defer func() { @@ -371,3 +376,9 @@ func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *h return server, client, nil } + +func mustParse(t *testing.T, originURL string) url.URL { + parsedURL, err := url.Parse(originURL) + assert.NoError(t, err) + return *parsedURL +}