TUN-3462: Refactor cloudflared to separate origin from connection

This commit is contained in:
cthuang 2020-10-08 11:12:26 +01:00
parent a5a5b93b64
commit 9ac40dcf04
32 changed files with 2006 additions and 1339 deletions

View File

@ -2,7 +2,6 @@ package access
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/carrier"
@ -17,16 +16,11 @@ import (
// StartForwarder starts a client side websocket forward // StartForwarder starts a client side websocket forward
func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, logger logger.Service) error { func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, logger logger.Service) error {
validURLString, err := validation.ValidateUrl(forwarder.Listener) validURL, err := validation.ValidateUrl(forwarder.Listener)
if err != nil { if err != nil {
return errors.Wrap(err, "error validating origin URL") return errors.Wrap(err, "error validating origin URL")
} }
validURL, err := url.Parse(validURLString)
if err != nil {
return errors.Wrap(err, "error parsing origin URL")
}
// get the headers from the config file and add to the request // get the headers from the config file and add to the request
headers := make(http.Header) headers := make(http.Header)
if forwarder.TokenClientID != "" { if forwarder.TokenClientID != "" {
@ -106,12 +100,7 @@ func ssh(c *cli.Context) error {
wsConn := carrier.NewWSConnection(logger, false) wsConn := carrier.NewWSConnection(logger, false)
if c.NArg() > 0 || c.IsSet(sshURLFlag) { if c.NArg() > 0 || c.IsSet(sshURLFlag) {
localForwarder, err := config.ValidateUrl(c, true) forwarder, err := config.ValidateUrl(c, true)
if err != nil {
logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL")
}
forwarder, err := url.Parse(localForwarder)
if err != nil { if err != nil {
logger.Errorf("Error validating origin URL: %s", err) logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL") return errors.Wrap(err, "error validating origin URL")

View File

@ -2,6 +2,7 @@ package config
import ( import (
"fmt" "fmt"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -189,11 +190,11 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
// ValidateUrl will validate url flag correctness. It can be either from --url or argument // ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) { func ValidateUrl(c *cli.Context, allowFromArgs bool) (*url.URL, error) {
var url = c.String("url") var url = c.String("url")
if allowFromArgs && c.NArg() > 0 { if allowFromArgs && c.NArg() > 0 {
if c.IsSet("url") { if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
} }
url = c.Args().Get(0) url = c.Args().Get(0)
} }

View File

@ -18,6 +18,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/dbconnect" "github.com/cloudflare/cloudflared/dbconnect"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
@ -247,7 +248,7 @@ func StartServer(
version string, version string,
shutdownC, shutdownC,
graceShutdownC chan struct{}, graceShutdownC chan struct{},
namedTunnel *origin.NamedTunnelConfig, namedTunnel *connection.NamedTunnelConfig,
log logger.Service, log logger.Service,
isUIEnabled bool, isUIEnabled bool,
) error { ) error {
@ -366,7 +367,7 @@ func StartServer(
return errors.Wrap(err, "error setting up transport logger") return errors.Wrap(err, "error setting up transport logger")
} }
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel) tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled)
if err != nil { if err != nil {
return err return err
} }
@ -386,10 +387,6 @@ func StartServer(
}() }()
if isUIEnabled { if isUIEnabled {
const tunnelEventChanBufferSize = 16
tunnelEventChan := make(chan ui.TunnelEvent, tunnelEventChanBufferSize)
tunnelConfig.TunnelEventChan = tunnelEventChan
tunnelInfo := ui.NewUIModel( tunnelInfo := ui.NewUIModel(
version, version,
hostname, hostname,
@ -402,7 +399,7 @@ func StartServer(
if err != nil { if err != nil {
return err return err
} }
tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelEventChan) tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelConfig.TunnelEventChan)
} }
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log)
@ -986,7 +983,7 @@ func configureLoggingFlags(shouldHide bool) []cli.Flag {
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "transport-loglevel", Name: "transport-loglevel",
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
Value: "fatal", Value: "info",
Usage: "Transport logging level(previously called protocol logging level) {fatal, error, info, debug}", Usage: "Transport logging level(previously called protocol logging level) {fatal, error, info, debug}",
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL", "TUNNEL_TRANSPORT_LOGLEVEL"}, EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL", "TUNNEL_TRANSPORT_LOGLEVEL"},
Hidden: shouldHide, Hidden: shouldHide,

View File

@ -9,6 +9,9 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/origin"
@ -154,10 +157,10 @@ func prepareTunnelConfig(
version string, version string,
logger logger.Service, logger logger.Service,
transportLogger logger.Service, transportLogger logger.Service,
namedTunnel *origin.NamedTunnelConfig, namedTunnel *connection.NamedTunnelConfig,
uiIsEnabled bool,
) (*origin.TunnelConfig, error) { ) (*origin.TunnelConfig, error) {
isNamedTunnel := namedTunnel != nil isNamedTunnel := namedTunnel != nil
compatibilityMode := !isNamedTunnel
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
@ -189,10 +192,11 @@ func prepareTunnelConfig(
} }
} }
tunnelMetrics := origin.NewTunnelMetrics() var (
ingressRules ingress.Ingress
var ingressRules ingress.Ingress classicTunnel *connection.ClassicTunnelConfig
if namedTunnel != nil { )
if isNamedTunnel {
clientUUID, err := uuid.NewRandom() clientUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "can't generate clientUUID") return nil, errors.Wrap(err, "can't generate clientUUID")
@ -210,6 +214,13 @@ func prepareTunnelConfig(
if !ingressRules.IsEmpty() && c.IsSet("url") { if !ingressRules.IsEmpty() && c.IsSet("url") {
return nil, ingress.ErrURLIncompatibleWithIngress return nil, ingress.ErrURLIncompatibleWithIngress
} }
} else {
classicTunnel = &connection.ClassicTunnelConfig{
Hostname: hostname,
OriginCert: originCert,
// turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: !isNamedTunnel && c.Bool("use-reconnect-token"),
}
} }
// Convert single-origin configuration into multi-origin configuration. // Convert single-origin configuration into multi-origin configuration.
@ -220,43 +231,71 @@ func prepareTunnelConfig(
} }
} }
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel) protocol := determineProtocol(namedTunnel)
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName())
if err != nil { if err != nil {
logger.Errorf("unable to create TLS config to connect with edge: %s", err) logger.Errorf("unable to create TLS config to connect with edge: %s", err)
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
proxyConfig := &origin.ProxyConfig{
Client: httpTransport,
URL: originURL,
TLSConfig: httpTransport.TLSClientConfig,
HostHeader: c.String("http-host-header"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
Tags: tags,
}
originClient := origin.NewClient(proxyConfig, logger)
transportConfig := &connection.Config{
OriginClient: originClient,
GracePeriod: c.Duration("grace-period"),
ReplaceExisting: c.Bool("force"),
}
muxerConfig := &connection.MuxerConfig{
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
CompressionSetting: h2mux.CompressionSetting(c.Uint64("compression-quality")),
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
}
var tunnelEventChan chan ui.TunnelEvent
if uiIsEnabled {
tunnelEventChan = make(chan ui.TunnelEvent, 16)
}
return &origin.TunnelConfig{ return &origin.TunnelConfig{
ConnectionConfig: transportConfig,
ProxyConfig: proxyConfig,
BuildInfo: buildInfo, BuildInfo: buildInfo,
ClientID: clientID, ClientID: clientID,
CompressionQuality: c.Uint64("compression-quality"),
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
GracePeriod: c.Duration("grace-period"),
HAConnections: c.Int("ha-connections"), HAConnections: c.Int("ha-connections"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname,
IncidentLookup: origin.NewIncidentLookup(), IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel, IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"), LBPool: c.String("lb-pool"),
Logger: logger, Logger: logger,
TransportLogger: transportLogger, Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol),
MaxHeartbeats: c.Uint64("heartbeat-count"),
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
OriginCert: originCert,
ReportedVersion: version, ReportedVersion: version,
Retries: c.Uint("retries"), Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
Tags: tags, TLSConfig: toEdgeTLSConfig,
TlsConfig: toEdgeTLSConfig,
NamedTunnel: namedTunnel, NamedTunnel: namedTunnel,
ReplaceExisting: c.Bool("force"), ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig,
TunnelEventChan: tunnelEventChan,
IngressRules: ingressRules, IngressRules: ingressRules,
// turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"),
}, nil }, nil
} }
func isRunningFromTerminal() bool { func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd())) return terminal.IsTerminal(int(os.Stdout.Fd()))
} }
func determineProtocol(namedTunnel *connection.NamedTunnelConfig) connection.Protocol {
if namedTunnel != nil {
return namedTunnel.Protocol
}
return connection.H2mux
}

View File

@ -14,8 +14,8 @@ import (
"github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstore" "github.com/cloudflare/cloudflared/tunnelstore"
) )
@ -260,7 +260,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return err return err
} }
protocol, ok := origin.ParseProtocol(sc.c.String("protocol")) protocol, ok := connection.ParseProtocol(sc.c.String("protocol"))
if !ok { if !ok {
return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol) return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol)
} }
@ -269,7 +269,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
version, version,
shutdownC, shutdownC,
graceShutdownC, graceShutdownC,
&origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol}, &connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol},
sc.logger, sc.logger,
sc.isUIEnabled, sc.isUIEnabled,
) )

View File

@ -78,17 +78,20 @@ var (
Name: "credentials-file", Name: "credentials-file",
Aliases: []string{credFileFlagAlias}, Aliases: []string{credFileFlagAlias},
Usage: "File path of tunnel credentials", Usage: "File path of tunnel credentials",
EnvVars: []string{"TUNNEL_CRED_FILE"},
}) })
forceDeleteFlag = &cli.BoolFlag{ forceDeleteFlag = &cli.BoolFlag{
Name: "force", Name: "force",
Aliases: []string{"f"}, Aliases: []string{"f"},
Usage: "Allows you to delete a tunnel, even if it has active connections.", Usage: "Allows you to delete a tunnel, even if it has active connections.",
EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"},
} }
selectProtocolFlag = &cli.StringFlag{ selectProtocolFlag = &cli.StringFlag{
Name: "protocol", Name: "protocol",
Value: "h2mux", Value: "h2mux",
Aliases: []string{"p"}, Aliases: []string{"p"},
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol), Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol),
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
Hidden: true, Hidden: true,
} }
) )

91
connection/connection.go Normal file
View File

@ -0,0 +1,91 @@
package connection
import (
"io"
"net/http"
"strconv"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
)
const (
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
)
type Config struct {
OriginClient OriginClient
GracePeriod time.Duration
ReplaceExisting bool
}
type NamedTunnelConfig struct {
Auth pogs.TunnelAuth
ID uuid.UUID
Client pogs.ClientInfo
Protocol Protocol
}
type ClassicTunnelConfig struct {
Hostname string
OriginCert []byte
// feature-flag to use new edge reconnect tokens
UseReconnectToken bool
}
func (c *ClassicTunnelConfig) IsTrialZone() bool {
return c.Hostname == ""
}
type Protocol int64
const (
H2mux Protocol = iota
HTTP2
)
func ParseProtocol(s string) (Protocol, bool) {
switch s {
case "h2mux":
return H2mux, true
case "http2":
return HTTP2, true
default:
return 0, false
}
}
func (p Protocol) ServerName() string {
switch p {
case H2mux:
return edgeH2muxTLSServerName
case HTTP2:
return edgeH2TLSServerName
default:
return ""
}
}
type OriginClient interface {
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
}
type ResponseWriter interface {
WriteRespHeaders(*http.Response) error
WriteErrorResponse(error)
io.ReadWriter
}
type ConnectedFuse interface {
Connected()
IsConnected() bool
}
func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10)
}

76
connection/errors.go Normal file
View File

@ -0,0 +1,76 @@
package connection
import (
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/prometheus/client_golang/prometheus"
)
const (
DuplicateConnectionError = "EDUPCONN"
)
// RegisterTunnel error from client
type clientRegisterTunnelError struct {
cause error
}
func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
counter.WithLabelValues(cause.Error(), string(name)).Inc()
return clientRegisterTunnelError{cause: cause}
}
func (e clientRegisterTunnelError) Error() string {
return e.cause.Error()
}
type DupConnRegisterTunnelError struct{}
var errDuplicationConnection = &DupConnRegisterTunnelError{}
func (e DupConnRegisterTunnelError) Error() string {
return "already connected to this server, trying another address"
}
// RegisterTunnel error from server
type serverRegisterTunnelError struct {
cause error
permanent bool
}
func (e serverRegisterTunnelError) Error() string {
return e.cause.Error()
}
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
return &serverRegisterTunnelError{
cause: retryable.Unwrap(),
permanent: false,
}
}
return &serverRegisterTunnelError{
cause: err,
permanent: true,
}
}
type muxerShutdownError struct{}
func (e muxerShutdownError) Error() string {
return "muxer shutdown"
}
func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) bool {
switch err.(type) {
case edgediscovery.DialError:
observer.Errorf("Connection %d unable to dial edge: %s", connIndex, err)
case h2mux.MuxerHandshakeError:
observer.Errorf("Connection %d handshake with edge server failed: %s", connIndex, err)
default:
observer.Errorf("Connection %d failed: %s", connIndex, err)
return false
}
return true
}

216
connection/h2mux.go Normal file
View File

@ -0,0 +1,216 @@
package connection
import (
"context"
"net"
"net/http"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
)
const (
muxerTimeout = 5 * time.Second
openStreamTimeout = 30 * time.Second
)
type h2muxConnection struct {
config *Config
muxerConfig *MuxerConfig
originURL string
muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string
connIndexStr string
connIndex uint8
observer *Observer
}
type MuxerConfig struct {
HeartbeatInterval time.Duration
MaxHeartbeats uint64
CompressionSetting h2mux.CompressionSetting
MetricsUpdateFreq time.Duration
}
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.Service) *h2mux.MuxerConfig {
return &h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: h,
IsClient: true,
HeartbeatInterval: mc.HeartbeatInterval,
MaxHeartbeats: mc.MaxHeartbeats,
Logger: logger,
CompressionQuality: mc.CompressionSetting,
}
}
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection(ctx context.Context,
config *Config,
muxerConfig *MuxerConfig,
originURL string,
edgeConn net.Conn,
connIndex uint8,
observer *Observer,
) (*h2muxConnection, error, bool) {
h := &h2muxConnection{
config: config,
muxerConfig: muxerConfig,
originURL: originURL,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
observer: observer,
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer), h2mux.ActiveStreams)
if err != nil {
recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
return nil, err, recoverable
}
h.muxer = muxer
return h, nil, false
}
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() error {
stream, err := h.newRPCStream(serveCtx, register)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
defer rpcClient.close()
if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err
}
connectedFuse.Connected()
return nil
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, true)
return nil
})
return errGroup.Wait()
}
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Connected()
}
}()
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
if err == nil {
return nil
}
// log errors and proceed to RegisterTunnel
h.observer.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", h.connIndex, err)
}
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, false)
return nil
})
return errGroup.Wait()
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
// here to notify other routines to stop
err := h.muxer.Serve(ctx)
if err == nil {
return muxerShutdownError{}
}
return err
}
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
for {
select {
case <-ctx.Done():
// UnregisterTunnel blocks until the RPC call returns
if connectedFuse.IsConnected() {
h.unregister(isNamedTunnel)
}
h.muxer.Shutdown()
return
case <-updateMetricsTickC:
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
}
}
}
func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := h.muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return nil, err
}
return stream, nil
}
func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
respWriter := &h2muxRespWriter{stream}
req, reqErr := h.newRequest(stream)
if reqErr != nil {
respWriter.WriteErrorResponse(reqErr)
return reqErr
}
return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
}
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req, err := http.NewRequest("GET", h.originURL, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
return req, nil
}
type h2muxRespWriter struct {
*h2mux.MuxedStream
}
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
return rp.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
}
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
rp.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"},
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
})
rp.Write([]byte("502 Bad Gateway"))
}

253
connection/http2.go Normal file
View File

@ -0,0 +1,253 @@
package connection
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"golang.org/x/net/http2"
)
const (
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
websocketUpgrade = "websocket"
controlStreamUpgrade = "control-stream"
)
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
config *Config
originURL *url.URL
namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndexStr string
connIndex uint8
shutdownChan chan struct{}
connectedFuse ConnectedFuse
}
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) {
return &HTTP2Connection{
conn: conn,
server: &http2.Server{},
config: config,
originURL: originURL,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
shutdownChan: make(chan struct{}),
connectedFuse: connectedFuse,
}, nil
}
func (c *HTTP2Connection) Serve(ctx context.Context) {
go func() {
<-ctx.Done()
c.close()
}()
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
Context: ctx,
Handler: c,
})
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = c.originURL.Scheme
r.URL.Host = c.originURL.Host
respWriter := &http2RespWriter{
r: r.Body,
w: w,
}
if isControlStreamUpgrade(r) {
err := c.serveControlStream(r.Context(), respWriter)
if err != nil {
respWriter.WriteErrorResponse(err)
}
} else if isWebsocketUpgrade(r) {
wsRespWriter, err := newWSRespWriter(respWriter)
if err != nil {
respWriter.WriteErrorResponse(err)
return
}
stripWebsocketUpgradeHeader(r)
c.config.OriginClient.Proxy(wsRespWriter, r, true)
} else {
c.config.OriginClient.Proxy(respWriter, r, false)
}
}
func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error {
stream, err := newWSRespWriter(h2RespWriter)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, c.observer)
defer rpcClient.close()
if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
<-c.shutdownChan
c.gracefulShutdown(ctx, rpcClient)
close(c.shutdownChan)
return nil
}
func (c *HTTP2Connection) registerConnection(
ctx context.Context,
rpcClient tunnelpogs.RegistrationServer_PogsClient,
) error {
connDetail, err := rpcClient.RegisterConnection(
ctx,
c.namedTunnel.Auth,
c.namedTunnel.ID,
c.connIndex,
c.connOptions,
)
if err != nil {
c.observer.Errorf("Cannot register connection, err: %v", err)
return err
}
c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
return nil
}
func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) {
ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod)
defer cancel()
rpcClient.client.UnregisterConnection(ctx)
}
func (c *HTTP2Connection) close() {
// Send signal to control loop to start graceful shutdown
c.shutdownChan <- struct{}{}
// Wait for control loop to close channel
<-c.shutdownChan
c.conn.Close()
}
type http2RespWriter struct {
r io.Reader
w http.ResponseWriter
}
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
dest := rp.w.Header()
userHeaders := make(http.Header, len(resp.Header))
for header, values := range resp.Header {
// Since these are http2 headers, they're required to be lowercase
h2name := strings.ToLower(header)
for _, v := range values {
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,
// so it should be sent as an HTTP/2 response header.
dest.Add(h2name, v)
// Since these are http2 headers, they're required to be lowercase
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
// User headers, on the other hand, must all be serialized so that
// HTTP/2 header validation won't be applied to HTTP/1 header values
userHeaders.Add(h2name, v)
}
}
}
// Perform user header serialization and set them in the single header
dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
status := resp.StatusCode
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
if status == http.StatusSwitchingProtocols {
status = http.StatusOK
}
rp.w.WriteHeader(status)
return nil
}
func (rp *http2RespWriter) WriteErrorResponse(err error) {
jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
if err == nil {
rp.w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
}
rp.w.WriteHeader(http.StatusBadGateway)
}
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
return rp.r.Read(p)
}
func (wr *http2RespWriter) Write(p []byte) (n int, err error) {
return wr.w.Write(p)
}
type wsRespWriter struct {
h2 *http2RespWriter
flusher http.Flusher
}
func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) {
flusher, ok := h2.w.(http.Flusher)
if !ok {
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
}
return &wsRespWriter{
h2: h2,
flusher: flusher,
}, nil
}
func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) error {
err := rw.h2.WriteRespHeaders(resp)
if err != nil {
return err
}
rw.flusher.Flush()
return nil
}
func (rw *wsRespWriter) WriteErrorResponse(err error) {
rw.h2.WriteErrorResponse(err)
}
func (rw *wsRespWriter) Read(p []byte) (n int, err error) {
return rw.h2.Read(p)
}
func (rw *wsRespWriter) Write(p []byte) (n int, err error) {
n, err = rw.h2.Write(p)
if err != nil {
return
}
rw.flusher.Flush()
return
}
func (rw *wsRespWriter) Close() error {
return nil
}
func isControlStreamUpgrade(r *http.Request) bool {
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade
}
func isWebsocketUpgrade(r *http.Request) bool {
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
}
func stripWebsocketUpgradeHeader(r *http.Request) {
r.Header.Del(internalUpgradeHeader)
}

409
connection/metrics.go Normal file
View File

@ -0,0 +1,409 @@
package connection
import (
"sync"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/prometheus/client_golang/prometheus"
)
const (
MetricsNamespace = "cloudflared"
TunnelSubsystem = "tunnel"
muxerSubsystem = "muxer"
)
type muxerMetrics struct {
rtt *prometheus.GaugeVec
rttMin *prometheus.GaugeVec
rttMax *prometheus.GaugeVec
receiveWindowAve *prometheus.GaugeVec
sendWindowAve *prometheus.GaugeVec
receiveWindowMin *prometheus.GaugeVec
receiveWindowMax *prometheus.GaugeVec
sendWindowMin *prometheus.GaugeVec
sendWindowMax *prometheus.GaugeVec
inBoundRateCurr *prometheus.GaugeVec
inBoundRateMin *prometheus.GaugeVec
inBoundRateMax *prometheus.GaugeVec
outBoundRateCurr *prometheus.GaugeVec
outBoundRateMin *prometheus.GaugeVec
outBoundRateMax *prometheus.GaugeVec
compBytesBefore *prometheus.GaugeVec
compBytesAfter *prometheus.GaugeVec
compRateAve *prometheus.GaugeVec
}
type tunnelMetrics struct {
timerRetries prometheus.Gauge
serverLocations *prometheus.GaugeVec
// locationLock is a mutex for oldServerLocations
locationLock sync.Mutex
// oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string
regSuccess *prometheus.CounterVec
regFail *prometheus.CounterVec
rpcFail *prometheus.CounterVec
muxerMetrics *muxerMetrics
tunnelsHA tunnelsForHA
userHostnamesCounts *prometheus.CounterVec
}
func newMuxerMetrics() *muxerMetrics {
rtt := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt",
Help: "Round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rtt)
rttMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_min",
Help: "Shortest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMin)
rttMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_max",
Help: "Longest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMax)
receiveWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_ave",
Help: "Average receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowAve)
sendWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_ave",
Help: "Average send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowAve)
receiveWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_min",
Help: "Smallest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMin)
receiveWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_max",
Help: "Largest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMax)
sendWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_min",
Help: "Smallest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMin)
sendWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_max",
Help: "Largest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMax)
inBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_curr",
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateCurr)
inBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_min",
Help: "Minimum non-zero inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMin)
inBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_max",
Help: "Maximum inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMax)
outBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_curr",
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateCurr)
outBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_min",
Help: "Minimum non-zero outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMin)
outBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_max",
Help: "Maximum outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMax)
compBytesBefore := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_before",
Help: "Bytes sent via cross-stream compression, pre compression",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compBytesBefore)
compBytesAfter := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_after",
Help: "Bytes sent via cross-stream compression, post compression",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compBytesAfter)
compRateAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_rate_ave",
Help: "Average outbound cross-stream compression ratio",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compRateAve)
return &muxerMetrics{
rtt: rtt,
rttMin: rttMin,
rttMax: rttMax,
receiveWindowAve: receiveWindowAve,
sendWindowAve: sendWindowAve,
receiveWindowMin: receiveWindowMin,
receiveWindowMax: receiveWindowMax,
sendWindowMin: sendWindowMin,
sendWindowMax: sendWindowMax,
inBoundRateCurr: inBoundRateCurr,
inBoundRateMin: inBoundRateMin,
inBoundRateMax: inBoundRateMax,
outBoundRateCurr: outBoundRateCurr,
outBoundRateMin: outBoundRateMin,
outBoundRateMax: outBoundRateMax,
compBytesBefore: compBytesBefore,
compBytesAfter: compBytesAfter,
compRateAve: compRateAve,
}
}
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
}
func convertRTTMilliSec(t time.Duration) float64 {
return float64(t / time.Millisecond)
}
// Metrics that can be collected without asking the edge
func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "max_concurrent_requests_per_tunnel",
Help: "Largest number of concurrent requests proxied through each tunnel so far",
},
[]string{"connection_id"},
)
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "timer_retries",
Help: "Unacknowledged heart beats count",
})
prometheus.MustRegister(timerRetries)
serverLocations := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "server_locations",
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
},
[]string{"connection_id", "location"},
)
prometheus.MustRegister(serverLocations)
rpcFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "tunnel_rpc_fail",
Help: "Count of RPC connection errors by type",
},
[]string{"error", "rpcName"},
)
prometheus.MustRegister(rpcFail)
registerFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "tunnel_register_fail",
Help: "Count of tunnel registration errors by type",
},
[]string{"error", "rpcName"},
)
prometheus.MustRegister(registerFail)
userHostnamesCounts := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "user_hostnames_counts",
Help: "Which user hostnames cloudflared is serving",
},
[]string{"userHostname"},
)
prometheus.MustRegister(userHostnamesCounts)
registerSuccess := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: MetricsNamespace,
Subsystem: TunnelSubsystem,
Name: "tunnel_register_success",
Help: "Count of successful tunnel registrations",
},
[]string{"rpcName"},
)
prometheus.MustRegister(registerSuccess)
var muxerMetrics *muxerMetrics
if protocol == H2mux {
muxerMetrics = newMuxerMetrics()
}
return &tunnelMetrics{
timerRetries: timerRetries,
serverLocations: serverLocations,
oldServerLocations: make(map[string]string),
muxerMetrics: muxerMetrics,
tunnelsHA: NewTunnelsForHA(),
regSuccess: registerSuccess,
regFail: registerFail,
rpcFail: rpcFail,
userHostnamesCounts: userHostnamesCounts,
}
}
func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
t.muxerMetrics.update(connectionID, metrics)
}
func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
t.locationLock.Lock()
defer t.locationLock.Unlock()
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
return
} else if ok {
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
}
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
t.oldServerLocations[connectionID] = loc
}

99
connection/observer.go Normal file
View File

@ -0,0 +1,99 @@
package connection
import (
"fmt"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type Observer struct {
logger.Service
metrics *tunnelMetrics
tunnelEventChan chan<- ui.TunnelEvent
}
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer {
return &Observer{
logger,
newTunnelMetrics(protocol),
tunnelEventChan,
}
}
func (o *Observer) logServerInfo(connectionID uint8, location, msg string) {
// If launch-ui flag is set, send connect msg
if o.tunnelEventChan != nil {
o.tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: location}
}
o.Infof(msg)
o.metrics.registerServerLocation(uint8ToString(connectionID), location)
}
func (o *Observer) logTrialHostname(registration *tunnelpogs.TunnelRegistration) error {
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
if o.tunnelEventChan == nil {
if registrationURL, err := url.Parse(registration.Url); err == nil {
for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) {
o.Info(line)
}
} else {
o.Error("Failed to connect tunnel, please try again.")
return fmt.Errorf("empty URL in response from Cloudflare edge")
}
}
return nil
}
// Print out the given lines in a nice ASCII box.
func asciiBox(lines []string, padding int) (box []string) {
maxLen := maxLen(lines)
spacer := strings.Repeat(" ", padding)
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
box = append(box, border)
for _, line := range lines {
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
}
box = append(box, border)
return
}
func maxLen(lines []string) int {
max := 0
for _, line := range lines {
if len(line) > max {
max = len(line)
}
}
return max
}
func trialZoneMsg(url string) []string {
return []string{
"Your free tunnel has started! Visit it:",
" " + url,
}
}
func (o *Observer) sendRegisteringEvent() {
if o.tunnelEventChan != nil {
o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
}
}
func (o *Observer) sendConnectedEvent(connIndex uint8, location string) {
if o.tunnelEventChan != nil {
o.tunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Connected, Location: location}
}
}
func (o *Observer) sendURL(url string) {
if o.tunnelEventChan != nil {
o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: url}
}
}

View File

@ -0,0 +1,45 @@
package connection
import (
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
// can only be called once
var m = newTunnelMetrics(H2mux)
func TestRegisterServerLocation(t *testing.T) {
tunnels := 20
var wg sync.WaitGroup
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
id := strconv.Itoa(i)
m.registerServerLocation(id, "LHR")
wg.Done()
}(i)
}
wg.Wait()
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, "LHR", m.oldServerLocations[id])
}
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
id := strconv.Itoa(i)
m.registerServerLocation(id, "AUS")
wg.Done()
}(i)
}
wg.Wait()
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, "AUS", m.oldServerLocations[id])
}
}

View File

@ -2,40 +2,276 @@ package connection
import ( import (
"context" "context"
"fmt"
"io" "io"
rpc "zombiezen.com/go/capnproto2/rpc"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"zombiezen.com/go/capnproto2/rpc"
) )
// NewTunnelRPCClient creates and returns a new RPC client, which will communicate type tunnelServerClient struct {
// using a stream on the given muxer client tunnelpogs.TunnelServer_PogsClient
func NewTunnelRPCClient( transport rpc.Transport
}
// NewTunnelRPCClient creates and returns a new RPC client, which will communicate using a stream on the given muxer.
// This method is exported for supervisor to call Authenticate RPC
func NewTunnelServerClient(
ctx context.Context, ctx context.Context,
stream io.ReadWriteCloser, stream io.ReadWriteCloser,
logger logger.Service, logger logger.Service,
) (client tunnelpogs.TunnelServer_PogsClient, err error) { ) *tunnelServerClient {
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
conn := rpc.NewConn( conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), transport,
tunnelrpc.ConnLog(logger), tunnelrpc.ConnLog(logger),
) )
registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn} registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
client = tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn} return &tunnelServerClient{
return client, nil client: tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn},
transport: transport,
}
} }
func NewRegistrationRPCClient( func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
if err != nil {
return nil, err
}
return authResp.Outcome(), nil
}
func (tsc *tunnelServerClient) Close() {
// Closing the client will also close the connection
tsc.client.Close()
tsc.transport.Close()
}
type registrationServerClient struct {
client tunnelpogs.RegistrationServer_PogsClient
transport rpc.Transport
}
func newRegistrationRPCClient(
ctx context.Context, ctx context.Context,
stream io.ReadWriteCloser, stream io.ReadWriteCloser,
logger logger.Service, logger logger.Service,
) (client tunnelpogs.RegistrationServer_PogsClient, err error) { ) *registrationServerClient {
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
conn := rpc.NewConn( conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), transport,
tunnelrpc.ConnLog(logger), tunnelrpc.ConnLog(logger),
) )
client = tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn} return &registrationServerClient{
return client, nil client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
transport: transport,
}
}
func (rsc *registrationServerClient) close() {
// Closing the client will also close the connection
rsc.client.Close()
// Closing the transport also closes the stream
rsc.transport.Close()
}
type rpcName string
const (
register rpcName = "register"
reconnect rpcName = "reconnect"
unregister rpcName = "unregister"
authenticate rpcName = " authenticate"
)
func registerConnection(
ctx context.Context,
rpcClient *registrationServerClient,
config *NamedTunnelConfig,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
observer *Observer,
) error {
conn, err := rpcClient.client.RegisterConnection(
ctx,
config.Auth,
config.ID,
connIndex,
options,
)
if err != nil {
if err.Error() == DuplicateConnectionError {
observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
return errDuplicationConnection
}
observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
return serverRegistrationErrorFromRPC(err)
}
observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
observer.logServerInfo(connIndex, conn.Location, fmt.Sprintf("Connection %d registered with %s using ID %s", connIndex, conn.Location, conn.UUID))
observer.sendConnectedEvent(connIndex, conn.Location)
return nil
}
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
h.observer.sendRegisteringEvent()
stream, err := h.newRPCStream(ctx, register)
if err != nil {
return err
}
rpcClient := NewTunnelServerClient(ctx, stream, h.observer)
defer rpcClient.Close()
h.logServerInfo(ctx, rpcClient)
registration := rpcClient.client.RegisterTunnel(
ctx,
classicTunnel.OriginCert,
classicTunnel.Hostname,
registrationOptions,
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// RegisterTunnel RPC failure
return h.processRegisterTunnelError(registrationErr, register)
}
// Send free tunnel URL to UI
h.observer.sendURL(registration.Url)
credentialSetter.SetEventDigest(h.connIndex, registration.EventDigest)
return h.processRegistrationSuccess(registration, register, credentialSetter, classicTunnel)
}
type CredentialManager interface {
ReconnectToken() ([]byte, error)
EventDigest(connID uint8) ([]byte, error)
SetEventDigest(connID uint8, digest []byte)
ConnDigest(connID uint8) ([]byte, error)
SetConnDigest(connID uint8, digest []byte)
}
func (h *h2muxConnection) processRegistrationSuccess(
registration *tunnelpogs.TunnelRegistration,
name rpcName,
credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig,
) error {
for _, logLine := range registration.LogLines {
h.observer.Info(logLine)
}
if registration.TunnelID != "" {
h.observer.metrics.tunnelsHA.AddTunnelID(h.connIndex, registration.TunnelID)
h.observer.Infof("Each HA connection's tunnel IDs: %v", h.observer.metrics.tunnelsHA.String())
}
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
if classicTunnel.IsTrialZone() {
err := h.observer.logTrialHostname(registration)
if err != nil {
return err
}
}
credentialManager.SetConnDigest(h.connIndex, registration.ConnDigest)
h.observer.metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
h.observer.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
h.observer.metrics.regSuccess.WithLabelValues(string(name)).Inc()
return nil
}
func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, name rpcName) error {
if err.Error() == DuplicateConnectionError {
h.observer.metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
return errDuplicationConnection
}
h.observer.metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
return serverRegisterTunnelError{
cause: err,
permanent: err.IsPermanent(),
}
}
func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
token, err := credentialManager.ReconnectToken()
if err != nil {
return err
}
eventDigest, err := credentialManager.EventDigest(h.connIndex)
if err != nil {
return err
}
connDigest, err := credentialManager.ConnDigest(h.connIndex)
if err != nil {
return err
}
h.observer.Debug("initiating RPC stream to reconnect")
stream, err := h.newRPCStream(ctx, register)
if err != nil {
return err
}
rpcClient := NewTunnelServerClient(ctx, stream, h.observer)
defer rpcClient.Close()
h.logServerInfo(ctx, rpcClient)
registration := rpcClient.client.ReconnectTunnel(
ctx,
token,
eventDigest,
connDigest,
classicTunnel.Hostname,
registrationOptions,
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure
return h.processRegisterTunnelError(registrationErr, reconnect)
}
return h.processRegistrationSuccess(registration, reconnect, credentialManager, classicTunnel)
}
func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelServerClient) error {
// Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.client.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
serverInfoMessage, err := serverInfoPromise.Result().Struct()
if err != nil {
h.observer.Errorf("Failed to retrieve server information: %s", err)
return err
}
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
if err != nil {
h.observer.Errorf("Failed to retrieve server information: %s", err)
return err
}
h.observer.logServerInfo(h.connIndex, serverInfo.LocationName, fmt.Sprintf("Connnection %d connected to %s", h.connIndex, serverInfo.LocationName))
return nil
}
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
defer cancel()
stream, err := h.newRPCStream(unregisterCtx, register)
if err != nil {
return
}
if isNamedTunnel {
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer)
defer rpcClient.close()
rpcClient.client.UnregisterConnection(unregisterCtx)
} else {
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer)
defer rpcClient.Close()
// gracePeriod is encoded in int64 using capnproto
rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
}
} }

View File

@ -0,0 +1,50 @@
package connection
import (
"fmt"
"sync"
"github.com/prometheus/client_golang/prometheus"
)
// tunnelsForHA maps this cloudflared instance's HA connections to the tunnel IDs they serve.
type tunnelsForHA struct {
sync.Mutex
metrics *prometheus.GaugeVec
entries map[uint8]string
}
// NewTunnelsForHA initializes the Prometheus metrics etc for a tunnelsForHA.
func NewTunnelsForHA() tunnelsForHA {
metrics := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "tunnel_ids",
Help: "The ID of all tunnels (and their corresponding HA connection ID) running in this instance of cloudflared.",
},
[]string{"tunnel_id", "ha_conn_id"},
)
prometheus.MustRegister(metrics)
return tunnelsForHA{
metrics: metrics,
entries: make(map[uint8]string),
}
}
// Track a new tunnel ID, removing the disconnected tunnel (if any) and update metrics.
func (t *tunnelsForHA) AddTunnelID(haConn uint8, tunnelID string) {
t.Lock()
defer t.Unlock()
haStr := fmt.Sprintf("%v", haConn)
if oldTunnelID, ok := t.entries[haConn]; ok {
t.metrics.WithLabelValues(oldTunnelID, haStr).Dec()
}
t.entries[haConn] = tunnelID
t.metrics.WithLabelValues(tunnelID, haStr).Inc()
}
func (t *tunnelsForHA) String() string {
t.Lock()
defer t.Unlock()
return fmt.Sprintf("%v", t.entries)
}

View File

@ -1,4 +1,4 @@
package connection package edgediscovery
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package connection package edgediscovery
import ( import (
"fmt" "fmt"

View File

@ -7,6 +7,19 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
var (
ActiveStreams = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "cloudflared",
Subsystem: "tunnel",
Name: "active_streams",
Help: "Number of active streams created by all muxers.",
})
)
func init() {
prometheus.MustRegister(ActiveStreams)
}
// activeStreamMap is used to moderate access to active streams between the read and write // activeStreamMap is used to moderate access to active streams between the read and write
// threads, and deny access to new peer streams while shutting down. // threads, and deny access to new peer streams while shutting down.
type activeStreamMap struct { type activeStreamMap struct {

View File

@ -9,7 +9,7 @@ import (
func TestShutdown(t *testing.T) { func TestShutdown(t *testing.T) {
const numStreams = 1000 const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) m := newActiveStreamMap(true, ActiveStreams)
// Add all the streams // Add all the streams
{ {
@ -62,7 +62,7 @@ func TestShutdown(t *testing.T) {
func TestEmptyBeforeShutdown(t *testing.T) { func TestEmptyBeforeShutdown(t *testing.T) {
const numStreams = 1000 const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) m := newActiveStreamMap(true, ActiveStreams)
// Add all the streams // Add all the streams
{ {
@ -138,7 +138,7 @@ func (_ *noopReadyList) Signal(streamID uint32) {}
func TestAbort(t *testing.T) { func TestAbort(t *testing.T) {
const numStreams = 1000 const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) m := newActiveStreamMap(true, ActiveStreams)
var openedStreams sync.Map var openedStreams sync.Map

View File

@ -113,11 +113,11 @@ func (p *DefaultMuxerPair) Handshake(testName string) error {
defer cancel() defer cancel()
errGroup, _ := errgroup.WithContext(ctx) errGroup, _ := errgroup.WithContext(ctx)
errGroup.Go(func() (err error) { errGroup.Go(func() (err error) {
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, NewActiveStreamsMetrics(testName, "edge")) p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, ActiveStreams)
return errors.Wrap(err, "edge handshake failure") return errors.Wrap(err, "edge handshake failure")
}) })
errGroup.Go(func() (err error) { errGroup.Go(func() (err error) {
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, NewActiveStreamsMetrics(testName, "origin")) p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, ActiveStreams)
return errors.Wrap(err, "origin handshake failure") return errors.Wrap(err, "origin handshake failure")
}) })

View File

@ -6,7 +6,6 @@ import (
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/golang-collections/collections/queue" "github.com/golang-collections/collections/queue"
"github.com/prometheus/client_golang/prometheus"
) )
// data points used to compute average receive window and send window size // data points used to compute average receive window and send window size
@ -295,14 +294,3 @@ func (r *rate) get() (curr, min, max uint64) {
defer r.lock.RUnlock() defer r.lock.RUnlock()
return r.curr, r.min, r.max return r.curr, r.min, r.max
} }
func NewActiveStreamsMetrics(namespace, subsystem string) prometheus.Gauge {
activeStreams := prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "active_streams",
Help: "Number of active streams created by all muxers.",
})
prometheus.MustRegister(activeStreams)
return activeStreams
}

View File

@ -1,14 +0,0 @@
package origin
import (
"net"
)
// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called
type persistentConn struct {
net.Conn
}
func (pc *persistentConn) Close() error {
return nil
}

View File

@ -1,540 +1,63 @@
package origin package origin
import ( import (
"sync" "github.com/cloudflare/cloudflared/connection"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
const ( // Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem
metricsNamespace = "cloudflared" // (tunnel) as subsystem to keep them consistent with the previous qualifier.
tunnelSubsystem = "tunnel"
muxerSubsystem = "muxer"
)
type muxerMetrics struct { var (
rtt *prometheus.GaugeVec totalRequests = prometheus.NewCounter(
rttMin *prometheus.GaugeVec
rttMax *prometheus.GaugeVec
receiveWindowAve *prometheus.GaugeVec
sendWindowAve *prometheus.GaugeVec
receiveWindowMin *prometheus.GaugeVec
receiveWindowMax *prometheus.GaugeVec
sendWindowMin *prometheus.GaugeVec
sendWindowMax *prometheus.GaugeVec
inBoundRateCurr *prometheus.GaugeVec
inBoundRateMin *prometheus.GaugeVec
inBoundRateMax *prometheus.GaugeVec
outBoundRateCurr *prometheus.GaugeVec
outBoundRateMin *prometheus.GaugeVec
outBoundRateMax *prometheus.GaugeVec
compBytesBefore *prometheus.GaugeVec
compBytesAfter *prometheus.GaugeVec
compRateAve *prometheus.GaugeVec
}
type TunnelMetrics struct {
haConnections prometheus.Gauge
activeStreams prometheus.Gauge
totalRequests prometheus.Counter
requestsPerTunnel *prometheus.CounterVec
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
concurrentRequestsLock sync.Mutex
concurrentRequestsPerTunnel *prometheus.GaugeVec
// concurrentRequests records count of concurrent requests for each tunnel
concurrentRequests map[string]uint64
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
// concurrentRequests records max count of concurrent requests for each tunnel
maxConcurrentRequests map[string]uint64
timerRetries prometheus.Gauge
responseByCode *prometheus.CounterVec
responseCodePerTunnel *prometheus.CounterVec
serverLocations *prometheus.GaugeVec
// locationLock is a mutex for oldServerLocations
locationLock sync.Mutex
// oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string
regSuccess *prometheus.CounterVec
regFail *prometheus.CounterVec
rpcFail *prometheus.CounterVec
muxerMetrics *muxerMetrics
tunnelsHA tunnelsForHA
userHostnamesCounts *prometheus.CounterVec
}
func newMuxerMetrics() *muxerMetrics {
rtt := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt",
Help: "Round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rtt)
rttMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_min",
Help: "Shortest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMin)
rttMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "rtt_max",
Help: "Longest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMax)
receiveWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_ave",
Help: "Average receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowAve)
sendWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_ave",
Help: "Average send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowAve)
receiveWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_min",
Help: "Smallest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMin)
receiveWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "receive_window_max",
Help: "Largest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMax)
sendWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_min",
Help: "Smallest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMin)
sendWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "send_window_max",
Help: "Largest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMax)
inBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_curr",
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateCurr)
inBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_min",
Help: "Minimum non-zero inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMin)
inBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "inbound_bytes_per_sec_max",
Help: "Maximum inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMax)
outBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_curr",
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateCurr)
outBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_min",
Help: "Minimum non-zero outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMin)
outBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "outbound_bytes_per_sec_max",
Help: "Maximum outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMax)
compBytesBefore := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_before",
Help: "Bytes sent via cross-stream compression, pre compression",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compBytesBefore)
compBytesAfter := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_bytes_after",
Help: "Bytes sent via cross-stream compression, post compression",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compBytesAfter)
compRateAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: muxerSubsystem,
Name: "comp_rate_ave",
Help: "Average outbound cross-stream compression ratio",
},
[]string{"connection_id"},
)
prometheus.MustRegister(compRateAve)
return &muxerMetrics{
rtt: rtt,
rttMin: rttMin,
rttMax: rttMax,
receiveWindowAve: receiveWindowAve,
sendWindowAve: sendWindowAve,
receiveWindowMin: receiveWindowMin,
receiveWindowMax: receiveWindowMax,
sendWindowMin: sendWindowMin,
sendWindowMax: sendWindowMax,
inBoundRateCurr: inBoundRateCurr,
inBoundRateMin: inBoundRateMin,
inBoundRateMax: inBoundRateMax,
outBoundRateCurr: outBoundRateCurr,
outBoundRateMin: outBoundRateMin,
outBoundRateMax: outBoundRateMax,
compBytesBefore: compBytesBefore,
compBytesAfter: compBytesAfter,
compRateAve: compRateAve,
}
}
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
}
func convertRTTMilliSec(t time.Duration) float64 {
return float64(t / time.Millisecond)
}
// Metrics that can be collected without asking the edge
func NewTunnelMetrics() *TunnelMetrics {
haConnections := prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "ha_connections",
Help: "Number of active ha connections",
})
prometheus.MustRegister(haConnections)
activeStreams := h2mux.NewActiveStreamsMetrics(metricsNamespace, tunnelSubsystem)
totalRequests := prometheus.NewCounter(
prometheus.CounterOpts{ prometheus.CounterOpts{
Namespace: metricsNamespace, Namespace: connection.MetricsNamespace,
Subsystem: tunnelSubsystem, Subsystem: connection.TunnelSubsystem,
Name: "total_requests", Name: "total_requests",
Help: "Amount of requests proxied through all the tunnels", Help: "Amount of requests proxied through all the tunnels",
})
prometheus.MustRegister(totalRequests)
requestsPerTunnel := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "requests_per_tunnel",
Help: "Amount of requests proxied through each tunnel",
}, },
[]string{"connection_id"},
) )
prometheus.MustRegister(requestsPerTunnel) concurrentRequests = prometheus.NewGauge(
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: metricsNamespace, Namespace: connection.MetricsNamespace,
Subsystem: tunnelSubsystem, Subsystem: connection.TunnelSubsystem,
Name: "concurrent_requests_per_tunnel", Name: "concurrent_requests_per_tunnel",
Help: "Concurrent requests proxied through each tunnel", Help: "Concurrent requests proxied through each tunnel",
}, },
[]string{"connection_id"},
) )
prometheus.MustRegister(concurrentRequestsPerTunnel) responseByCode = prometheus.NewCounterVec(
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "max_concurrent_requests_per_tunnel",
Help: "Largest number of concurrent requests proxied through each tunnel so far",
},
[]string{"connection_id"},
)
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "timer_retries",
Help: "Unacknowledged heart beats count",
})
prometheus.MustRegister(timerRetries)
responseByCode := prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Namespace: metricsNamespace, Namespace: connection.MetricsNamespace,
Subsystem: tunnelSubsystem, Subsystem: connection.TunnelSubsystem,
Name: "response_by_code", Name: "response_by_code",
Help: "Count of responses by HTTP status code", Help: "Count of responses by HTTP status code",
}, },
[]string{"status_code"}, []string{"status_code"},
) )
prometheus.MustRegister(responseByCode) haConnections = prometheus.NewGauge(
responseCodePerTunnel := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "response_code_per_tunnel",
Help: "Count of responses by HTTP status code fore each tunnel",
},
[]string{"connection_id", "status_code"},
)
prometheus.MustRegister(responseCodePerTunnel)
serverLocations := prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: metricsNamespace, Namespace: connection.MetricsNamespace,
Subsystem: tunnelSubsystem, Subsystem: connection.TunnelSubsystem,
Name: "server_locations", Name: "ha_connections",
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.", Help: "Number of active ha connections",
}, },
[]string{"connection_id", "location"},
) )
prometheus.MustRegister(serverLocations)
rpcFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_rpc_fail",
Help: "Count of RPC connection errors by type",
},
[]string{"error", "rpcName"},
) )
prometheus.MustRegister(rpcFail)
registerFail := prometheus.NewCounterVec( func init() {
prometheus.CounterOpts{ prometheus.MustRegister(
Namespace: metricsNamespace, totalRequests,
Subsystem: tunnelSubsystem, concurrentRequests,
Name: "tunnel_register_fail", responseByCode,
Help: "Count of tunnel registration errors by type", haConnections,
},
[]string{"error", "rpcName"},
) )
prometheus.MustRegister(registerFail)
userHostnamesCounts := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "user_hostnames_counts",
Help: "Which user hostnames cloudflared is serving",
},
[]string{"userHostname"},
)
prometheus.MustRegister(userHostnamesCounts)
registerSuccess := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_register_success",
Help: "Count of successful tunnel registrations",
},
[]string{"rpcName"},
)
prometheus.MustRegister(registerSuccess)
return &TunnelMetrics{
haConnections: haConnections,
activeStreams: activeStreams,
totalRequests: totalRequests,
requestsPerTunnel: requestsPerTunnel,
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
concurrentRequests: make(map[string]uint64),
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
maxConcurrentRequests: make(map[string]uint64),
timerRetries: timerRetries,
responseByCode: responseByCode,
responseCodePerTunnel: responseCodePerTunnel,
serverLocations: serverLocations,
oldServerLocations: make(map[string]string),
muxerMetrics: newMuxerMetrics(),
tunnelsHA: NewTunnelsForHA(),
regSuccess: registerSuccess,
regFail: registerFail,
rpcFail: rpcFail,
userHostnamesCounts: userHostnamesCounts,
}
} }
func (t *TunnelMetrics) incrementHaConnections() { func incrementRequests() {
t.haConnections.Inc() totalRequests.Inc()
concurrentRequests.Inc()
} }
func (t *TunnelMetrics) decrementHaConnections() { func decrementConcurrentRequests() {
t.haConnections.Dec() concurrentRequests.Dec()
}
func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
t.muxerMetrics.update(connectionID, metrics)
}
func (t *TunnelMetrics) incrementRequests(connectionID string) {
t.concurrentRequestsLock.Lock()
var concurrentRequests uint64
var ok bool
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID]++
concurrentRequests++
} else {
t.concurrentRequests[connectionID] = 1
concurrentRequests = 1
}
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
t.maxConcurrentRequests[connectionID] = concurrentRequests
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
}
t.concurrentRequestsLock.Unlock()
t.totalRequests.Inc()
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
}
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsLock.Lock()
if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID]--
}
t.concurrentRequestsLock.Unlock()
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
}
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
t.responseByCode.WithLabelValues(code).Inc()
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
}
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
t.locationLock.Lock()
defer t.locationLock.Unlock()
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
return
} else if ok {
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
}
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
t.oldServerLocations[connectionID] = loc
} }

View File

@ -1,121 +0,0 @@
package origin
import (
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
// can only be called once
var m = NewTunnelMetrics()
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
routines := 20
var wg sync.WaitGroup
wg.Add(routines)
for i := 0; i < routines; i++ {
go func() {
m.incrementRequests("0")
wg.Done()
}()
}
wg.Wait()
assert.Len(t, m.concurrentRequests, 1)
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
assert.Len(t, m.maxConcurrentRequests, 1)
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
wg.Add(routines / 2)
for i := 0; i < routines/2; i++ {
go func() {
m.decrementConcurrentRequests("0")
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
}
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
m.concurrentRequests = make(map[string]uint64)
m.maxConcurrentRequests = make(map[string]uint64)
tunnels := 20
var wg sync.WaitGroup
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
for j := 0; j < i+1; j++ {
id := strconv.Itoa(i)
m.incrementRequests(id)
}
wg.Done()
}(i)
}
wg.Wait()
assert.Len(t, m.concurrentRequests, tunnels)
assert.Len(t, m.maxConcurrentRequests, tunnels)
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
}
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
for j := 0; j < i+1; j++ {
id := strconv.Itoa(i)
m.decrementConcurrentRequests(id)
}
wg.Done()
}(i)
}
wg.Wait()
assert.Len(t, m.concurrentRequests, tunnels)
assert.Len(t, m.maxConcurrentRequests, tunnels)
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, uint64(0), m.concurrentRequests[id])
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
}
}
func TestRegisterServerLocation(t *testing.T) {
tunnels := 20
var wg sync.WaitGroup
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
id := strconv.Itoa(i)
m.registerServerLocation(id, "LHR")
wg.Done()
}(i)
}
wg.Wait()
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, "LHR", m.oldServerLocations[id])
}
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
id := strconv.Itoa(i)
m.registerServerLocation(id, "AUS")
wg.Done()
}(i)
}
wg.Wait()
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, "AUS", m.oldServerLocations[id])
}
}

208
origin/proxy.go Normal file
View File

@ -0,0 +1,208 @@
package origin
import (
"bufio"
"crypto/tls"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
const (
TagHeaderNamePrefix = "Cf-Warp-Tag-"
)
type client struct {
config *ProxyConfig
logger logger.Service
bufferPool *buffer.Pool
}
func NewClient(config *ProxyConfig, logger logger.Service) connection.OriginClient {
return &client{
config: config,
logger: logger,
bufferPool: buffer.NewPool(512 * 1024),
}
}
type ProxyConfig struct {
Client http.RoundTripper
URL *url.URL
TLSConfig *tls.Config
HostHeader string
NoChunkedEncoding bool
Tags []tunnelpogs.Tag
}
func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
incrementRequests()
defer decrementConcurrentRequests()
cfRay := findCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
c.appendTagHeaders(req)
c.logRequest(req, cfRay, lbProbe)
var (
resp *http.Response
err error
)
if isWebsocket {
resp, err = c.proxyWebsocket(w, req)
} else {
resp, err = c.proxyHTTP(w, req)
}
if err != nil {
c.logger.Errorf("HTTP request error: %s", err)
responseByCode.WithLabelValues("502").Inc()
w.WriteErrorResponse(err)
return err
}
c.logResponseOk(resp, cfRay, lbProbe)
return nil
}
func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if c.config.NoChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
c.setHostHeader(req)
resp, err := c.config.Client.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
defer resp.Body.Close()
err = w.WriteRespHeaders(resp)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
if isEventStream(resp) {
//h.observer.Debug("Detected Server-Side Events from Origin")
c.writeEventStream(w, resp.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
buf := c.bufferPool.Get()
defer c.bufferPool.Put(buf)
io.CopyBuffer(w, resp.Body, buf)
}
return resp, nil
}
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
c.setHostHeader(req)
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
if err != nil {
return nil, err
}
defer conn.Close()
err = w.WriteRespHeaders(resp)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), w)
return resp, nil
}
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
break
}
w.Write(line)
}
}
func (c *client) setHostHeader(req *http.Request) {
if c.config.HostHeader != "" {
req.Header.Set("Host", c.config.HostHeader)
req.Host = c.config.HostHeader
}
}
func (c *client) appendTagHeaders(r *http.Request) {
for _, tag := range c.config.Tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
}
}
func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) {
if cfRay != "" {
c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else if lbProbe {
c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else {
c.logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
}
c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
} else {
c.logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
}
}
func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
responseByCode.WithLabelValues("200").Inc()
if cfRay != "" {
c.logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
} else if lbProbe {
c.logger.Debugf("Response to Load Balancer health check %s", r.Status)
} else {
c.logger.Infof("%s", r.Status)
}
c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
if contentLen := r.ContentLength; contentLen == -1 {
c.logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
} else {
c.logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
}
}
func findCfRayHeader(req *http.Request) string {
return req.Header.Get("Cf-Ray")
}
func isLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
}
func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10)
}
func isEventStream(response *http.Response) bool {
if response.Header.Get("content-type") == "text/event-stream" {
return true
}
return false
}

View File

@ -7,11 +7,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -138,52 +134,3 @@ func (cm *reconnectCredentialManager) RefreshAuth(
return nil, err return nil, err
} }
} }
func ReconnectTunnel(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger logger.Service,
connectionID uint8,
originLocalAddr string,
uuid uuid.UUID,
credentialManager *reconnectCredentialManager,
) error {
token, err := credentialManager.ReconnectToken()
if err != nil {
return err
}
eventDigest, err := credentialManager.EventDigest(connectionID)
if err != nil {
return err
}
connDigest, err := credentialManager.ConnDigest(connectionID)
if err != nil {
return err
}
config.TransportLogger.Debug("initiating RPC stream to reconnect")
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
if err != nil {
return err
}
defer rpcClient.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
registration := rpcClient.ReconnectTunnel(
ctx,
token,
eventDigest,
connDigest,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
}
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
}

View File

@ -8,7 +8,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
@ -56,8 +55,7 @@ type Supervisor struct {
logger logger.Service logger logger.Service
reconnectCredentialManager *reconnectCredentialManager reconnectCredentialManager *reconnectCredentialManager
useReconnectToken bool
bufferPool *buffer.Pool
} }
type resolveResult struct { type resolveResult struct {
@ -76,28 +74,33 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
err error err error
) )
if len(config.EdgeAddrs) > 0 { if len(config.EdgeAddrs) > 0 {
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs) edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
} else { } else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger) edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
useReconnectToken := false
if config.ClassicTunnel != nil {
useReconnectToken = config.ClassicTunnel.UseReconnectToken
}
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger, logger: config.Observer,
reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections), reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
bufferPool: buffer.NewPool(512 * 1024), useReconnectToken: useReconnectToken,
}, nil }, nil
} }
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error { func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.config.Logger logger := s.config.Observer
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil { if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err return err
} }
@ -110,7 +113,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken { if s.useReconnectToken {
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil { if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer refreshAuthBackoffTimer = timer
} else { } else {
@ -227,7 +230,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
for s.unusedIPs() { for s.unusedIPs() {
@ -239,7 +242,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
// try the next address if it was a dialError(network problem) or // try the next address if it was a dialError(network problem) or
// dupConnRegisterTunnelError // dupConnRegisterTunnelError
case connection.DialError, dupConnRegisterTunnelError: case edgediscovery.DialError, connection.DupConnRegisterTunnelError:
edgeErrors++ edgeErrors++
default: default:
return return
@ -250,7 +253,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
} }
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh)
} }
} }
@ -269,7 +272,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil { if err != nil {
return return
} }
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, reconnectCh)
} }
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
@ -301,7 +304,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
return nil, err return nil, err
} }
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -311,8 +314,8 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
// This callback is invoked by h2mux when the edge initiates a stream. // This callback is invoked by h2mux when the edge initiates a stream.
return nil // noop return nil // noop
}) })
muxerConfig := s.config.muxerConfig(handler) muxerConfig := s.config.MuxerConfig.H2MuxerConfig(handler, s.logger)
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams) muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig, h2mux.ActiveStreams)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -323,23 +326,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
<-muxer.Shutdown() <-muxer.Shutdown()
}() }()
rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate) stream, err := muxer.OpenRPCStream(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.logger)
defer rpcClient.Close() defer rpcClient.Close()
const arbitraryConnectionID = uint8(0) const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
authResponse, err := rpcClient.Authenticate( return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
ctx,
s.config.OriginCert,
s.config.Hostname,
registrationOptions,
)
if err != nil {
return nil, err
}
return authResponse.Outcome(), nil
} }

View File

@ -5,9 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"net/http" "runtime/debug"
"net/url"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -17,26 +15,22 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
) )
const ( const (
dialTimeout = 15 * time.Second dialTimeout = 15 * time.Second
openStreamTimeout = 30 * time.Second
muxerTimeout = 5 * time.Second muxerTimeout = 5 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
TagHeaderNamePrefix = "Cf-Warp-Tag-"
DuplicateConnectionError = "EDUPCONN" DuplicateConnectionError = "EDUPCONN"
FeatureSerializedHeaders = "serialized_headers" FeatureSerializedHeaders = "serialized_headers"
FeatureQuickReconnects = "quick_reconnects" FeatureQuickReconnects = "quick_reconnects"
@ -52,49 +46,31 @@ const (
) )
type TunnelConfig struct { type TunnelConfig struct {
ConnectionConfig *connection.Config
ProxyConfig *ProxyConfig
BuildInfo *buildinfo.BuildInfo BuildInfo *buildinfo.BuildInfo
ClientID string ClientID string
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
CompressionQuality uint64
EdgeAddrs []string EdgeAddrs []string
GracePeriod time.Duration
HAConnections int HAConnections int
HeartbeatInterval time.Duration
Hostname string
IncidentLookup IncidentLookup IncidentLookup IncidentLookup
IsAutoupdated bool IsAutoupdated bool
IsFreeTunnel bool IsFreeTunnel bool
LBPool string LBPool string
Logger logger.Service Logger logger.Service
TransportLogger logger.Service Observer *connection.Observer
MaxHeartbeats uint64
Metrics *TunnelMetrics
MetricsUpdateFreq time.Duration
OriginCert []byte
ReportedVersion string ReportedVersion string
Retries uint Retries uint
RunFromTerminal bool RunFromTerminal bool
Tags []tunnelpogs.Tag TLSConfig *tls.Config
TlsConfig *tls.Config
WSGI bool
// feature-flag to use new edge reconnect tokens NamedTunnel *connection.NamedTunnelConfig
UseReconnectToken bool ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
NamedTunnel *NamedTunnelConfig TunnelEventChan chan ui.TunnelEvent
ReplaceExisting bool
TunnelEventChan chan<- ui.TunnelEvent
IngressRules ingress.Ingress IngressRules ingress.Ingress
} }
type dupConnRegisterTunnelError struct{}
var errDuplicationConnection = &dupConnRegisterTunnelError{}
func (e dupConnRegisterTunnelError) Error() string {
return "already connected to this server, trying another address"
}
type muxerShutdownError struct{} type muxerShutdownError struct{}
func (e muxerShutdownError) Error() string { func (e muxerShutdownError) Error() string {
@ -125,18 +101,6 @@ func (e clientRegisterTunnelError) Error() string {
return e.cause.Error() return e.cause.Error()
} }
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
return h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: handler,
IsClient: true,
HeartbeatInterval: c.HeartbeatInterval,
MaxHeartbeats: c.MaxHeartbeats,
Logger: c.TransportLogger,
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
}
}
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_balance policy := tunnelrpc.ExistingTunnelPolicy_balance
if c.HAConnections <= 1 && c.LBPool == "" { if c.HAConnections <= 1 && c.LBPool == "" {
@ -148,12 +112,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch), OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
ExistingTunnelPolicy: policy, ExistingTunnelPolicy: policy,
PoolName: c.LBPool, PoolName: c.LBPool,
Tags: c.Tags, Tags: c.ProxyConfig.Tags,
ConnectionID: connectionID, ConnectionID: connectionID,
OriginLocalIP: OriginLocalIP, OriginLocalIP: OriginLocalIP,
IsAutoupdated: c.IsAutoupdated, IsAutoupdated: c.IsAutoupdated,
RunFromTerminal: c.RunFromTerminal, RunFromTerminal: c.RunFromTerminal,
CompressionQuality: c.CompressionQuality, CompressionQuality: uint64(c.MuxerConfig.CompressionSetting),
UUID: uuid.String(), UUID: uuid.String(),
Features: c.SupportedFeatures(), Features: c.SupportedFeatures(),
} }
@ -167,8 +131,8 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte
return &tunnelpogs.ConnectionOptions{ return &tunnelpogs.ConnectionOptions{
Client: c.NamedTunnel.Client, Client: c.NamedTunnel.Client,
OriginLocalIP: originIP, OriginLocalIP: originIP,
ReplaceExisting: c.ReplaceExisting, ReplaceExisting: c.ConnectionConfig.ReplaceExisting,
CompressionQuality: uint8(c.CompressionQuality), CompressionQuality: uint8(c.MuxerConfig.CompressionSetting),
NumPreviousAttempts: numPreviousAttempts, NumPreviousAttempts: numPreviousAttempts,
} }
} }
@ -181,35 +145,6 @@ func (c *TunnelConfig) SupportedFeatures() []string {
return features return features
} }
func (c *TunnelConfig) IsTrialTunnel() bool {
return c.Hostname == ""
}
type NamedTunnelConfig struct {
Auth pogs.TunnelAuth
ID uuid.UUID
Client pogs.ClientInfo
Protocol Protocol
}
type Protocol int64
const (
h2muxProtocol Protocol = iota
http2Protocol
)
func ParseProtocol(s string) (Protocol, bool) {
switch s {
case "h2mux":
return h2muxProtocol, true
case "http2":
return http2Protocol, true
default:
return 0, false
}
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error { func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
s, err := NewSupervisor(config, cloudflaredID) s, err := NewSupervisor(config, cloudflaredID)
if err != nil { if err != nil {
@ -225,11 +160,11 @@ func ServeTunnelLoop(ctx context.Context,
connectionIndex uint8, connectionIndex uint8,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
bufferPool *buffer.Pool,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) error { ) error {
config.Metrics.incrementHaConnections() haConnections.Inc()
defer config.Metrics.decrementHaConnections() defer haConnections.Dec()
backoff := BackoffHandler{MaxRetries: config.Retries} backoff := BackoffHandler{MaxRetries: config.Retries}
connectedFuse := h2mux.NewBooleanFuse() connectedFuse := h2mux.NewBooleanFuse()
go func() { go func() {
@ -244,12 +179,10 @@ func ServeTunnelLoop(ctx context.Context,
ctx, ctx,
credentialManager, credentialManager,
config, config,
config.Logger,
addr, connectionIndex, addr, connectionIndex,
connectedFuse, connectedFuse,
&backoff, &backoff,
cloudflaredUUID, cloudflaredUUID,
bufferPool,
reconnectCh, reconnectCh,
) )
if recoverable { if recoverable {
@ -257,7 +190,7 @@ func ServeTunnelLoop(ctx context.Context,
if config.TunnelEventChan != nil { if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting} config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
} }
config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration) config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
backoff.Backoff(ctx) backoff.Backoff(ctx)
continue continue
} }
@ -270,13 +203,11 @@ func ServeTunnel(
ctx context.Context, ctx context.Context,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
config *TunnelConfig, config *TunnelConfig,
logger logger.Service,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionIndex uint8, connectionIndex uint8,
connectedFuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *BackoffHandler,
cloudflaredUUID uuid.UUID, cloudflaredUUID uuid.UUID,
bufferPool *buffer.Pool,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
@ -287,6 +218,7 @@ func ServeTunnel(
if !ok { if !ok {
err = fmt.Errorf("ServeTunnel: %v", r) err = fmt.Errorf("ServeTunnel: %v", r)
} }
err = errors.Wrapf(err, "stack trace: %s", string(debug.Stack()))
recoverable = true recoverable = true
} }
}() }()
@ -298,203 +230,107 @@ func ServeTunnel(
}() }()
} }
connectionTag := uint8ToString(connectionIndex) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh)
}
// Returns error from parsing the origin URL or handshake errors
handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
if err != nil { if err != nil {
switch err.(type) {
case connection.DialError:
logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
case h2mux.MuxerHandshakeError:
logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
default:
logger.Errorf("Connection %d failed: %s", connectionIndex, err)
return err, false
}
return err, true return err, true
} }
connectedFuse := &connectedFuse{
fuse: fuse,
backoff: backoff,
}
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
}
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
}
func ServeH2mux(
ctx context.Context,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
edgeConn net.Conn,
connectionIndex uint8,
connectedFuse *connectedFuse,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
// Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
if err != nil {
return err, recoverable
}
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() (err error) { errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Fuse(true)
backoff.SetGracePeriod()
}
}()
if config.UseReconnectToken && connectedFuse.Value() {
err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
if err == nil {
return nil
}
// log errors and proceed to RegisterTunnel
logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err)
}
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
})
errGroup.Go(func() error {
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
for {
select {
case <-serveCtx.Done():
// UnregisterTunnel blocks until the RPC call returns
if connectedFuse.Value() {
if config.NamedTunnel != nil { if config.NamedTunnel != nil {
_ = UnregisterConnection(ctx, handler.muxer, config) connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
} else { return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
_ = UnregisterTunnel(handler.muxer, config)
}
}
handler.muxer.Shutdown()
return nil
case <-updateMetricsTickC:
handler.UpdateMetrics(connectionTag)
}
} }
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
}) })
errGroup.Go(func() error { errGroup.Go(listenReconnect(serveCtx, reconnectCh))
for {
select {
case reconnect := <-reconnectCh:
return &reconnect
case <-serveCtx.Done():
return nil
}
}
})
errGroup.Go(func() error {
// All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
// here to notify other routines to stop
err := handler.muxer.Serve(serveCtx)
if err == nil {
return muxerShutdownError{}
}
return err
})
err = errGroup.Wait() err = errGroup.Wait()
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case *dupConnRegisterTunnelError: case *connection.DupConnRegisterTunnelError:
// don't retry this connection anymore, let supervisor pick new a address // don't retry this connection anymore, let supervisor pick new a address
return err, false return err, false
case *serverRegisterTunnelError: case *serverRegisterTunnelError:
logger.Errorf("Register tunnel error from server side: %s", err.cause) config.Logger.Errorf("Register tunnel error from server side: %s", err.cause)
// Don't send registration error return from server to Sentry. They are // Don't send registration error return from server to Sentry. They are
// logged on server side // logged on server side
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
logger.Error(activeIncidentsMsg(incidents)) config.Logger.Error(activeIncidentsMsg(incidents))
} }
return err.cause, !err.permanent return err.cause, !err.permanent
case *clientRegisterTunnelError: case *clientRegisterTunnelError:
logger.Errorf("Register tunnel error on client side: %s", err.cause) config.Logger.Errorf("Register tunnel error on client side: %s", err.cause)
return err, true return err, true
case *muxerShutdownError: case *muxerShutdownError:
logger.Info("Muxer shutdown") config.Logger.Info("Muxer shutdown")
return err, true return err, true
case *ReconnectSignal: case *ReconnectSignal:
logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay) config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
err.DelayBeforeReconnect() err.DelayBeforeReconnect()
return err, true return err, true
default: default:
if err == context.Canceled { if err == context.Canceled {
logger.Debugf("Serve tunnel error: %s", err) config.Logger.Debugf("Serve tunnel error: %s", err)
return err, false return err, false
} }
logger.Errorf("Serve tunnel error: %s", err) config.Logger.Errorf("Serve tunnel error: %s", err)
return err, true return err, true
} }
} }
return nil, true return nil, true
} }
func RegisterConnectionWithH2Mux( func ServeHTTP2(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
connectionIndex uint8,
originLocalAddr string,
numPreviousAttempts uint8,
) error {
const registerConnection = "registerConnection"
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection)
if err != nil {
return err
}
defer rpcClient.Close()
conn, err := rpcClient.RegisterConnection(
ctx,
config.NamedTunnel.Auth,
config.NamedTunnel.ID,
connectionIndex,
config.ConnectionOptions(originLocalAddr, numPreviousAttempts),
)
if err != nil {
if err.Error() == DuplicateConnectionError {
config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
return errDuplicationConnection
}
config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
return serverRegistrationErrorFromRPC(err)
}
config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
// If launch-ui flag is set, send connect msg
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Connected, Location: conn.Location}
}
return nil
}
func ServeNamedTunnel(
ctx context.Context, ctx context.Context,
config *TunnelConfig, config *TunnelConfig,
tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
addr *net.TCPAddr, connectedFuse connection.ConnectedFuse,
connectedFuse *h2mux.BooleanFuse,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
if err != nil {
return err, true
}
cfdServer, err := newHTTP2Server(config, connIndex, tlsServerConn.LocalAddr(), connectedFuse)
if err != nil { if err != nil {
return err, false return err, false
} }
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
cfdServer.serve(serveCtx, tlsServerConn) server.Serve(serveCtx)
return fmt.Errorf("Connection with edge closed") return fmt.Errorf("Connection with edge closed")
}) })
errGroup.Go(func() error { errGroup.Go(listenReconnect(serveCtx, reconnectCh))
select {
case reconnect := <-reconnectCh:
return &reconnect
case <-serveCtx.Done():
return nil
}
})
err = errGroup.Wait() err = errGroup.Wait()
if err != nil { if err != nil {
@ -503,229 +339,29 @@ func ServeNamedTunnel(
return nil, false return nil, false
} }
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error {
if retryable, ok := err.(*tunnelpogs.RetryableError); ok { return func() error {
return &serverRegisterTunnelError{ select {
cause: retryable.Unwrap(), case reconnect := <-reconnectCh:
permanent: false, return &reconnect
} case <-ctx.Done():
}
return &serverRegisterTunnelError{
cause: err,
permanent: true,
}
}
func UnregisterConnection(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
) error {
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
if err != nil {
// RPC stream open error
return err
}
defer rpcClient.Close()
return rpcClient.UnregisterConnection(ctx)
}
func RegisterTunnel(
ctx context.Context,
credentialManager *reconnectCredentialManager,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger logger.Service,
connectionID uint8,
originLocalIP string,
uuid uuid.UUID,
) error {
config.TransportLogger.Debug("initiating RPC stream to register")
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
}
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
if err != nil {
return err
}
defer rpcClient.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
registration := rpcClient.RegisterTunnel(
ctx,
config.OriginCert,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalIP, uuid),
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// RegisterTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics, register)
}
// Send free tunnel URL to UI
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: registration.Url}
}
credentialManager.SetEventDigest(connectionID, registration.EventDigest)
return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager)
}
func processRegistrationSuccess(
config *TunnelConfig,
logger logger.Service,
connectionID uint8,
registration *tunnelpogs.TunnelRegistration,
name rpcName,
credentialManager *reconnectCredentialManager,
) error {
for _, logLine := range registration.LogLines {
logger.Info(logLine)
}
if registration.TunnelID != "" {
config.Metrics.tunnelsHA.AddTunnelID(connectionID, registration.TunnelID)
logger.Infof("Each HA connection's tunnel IDs: %v", config.Metrics.tunnelsHA.String())
}
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
if config.TunnelEventChan == nil {
if config.IsTrialTunnel() {
if registrationURL, err := url.Parse(registration.Url); err == nil {
for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) {
logger.Info(line)
}
} else {
logger.Error("Failed to connect tunnel, please try again.")
return fmt.Errorf("empty URL in response from Cloudflare edge")
}
}
}
credentialManager.SetConnDigest(connectionID, registration.ConnDigest)
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
config.Metrics.regSuccess.WithLabelValues(string(name)).Inc()
return nil return nil
} }
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error {
if err.Error() == DuplicateConnectionError {
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
return errDuplicationConnection
}
metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
return serverRegisterTunnelError{
cause: err,
permanent: err.IsPermanent(),
} }
} }
func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error { type connectedFuse struct {
config.TransportLogger.Debug("initiating RPC stream to unregister") fuse *h2mux.BooleanFuse
ctx := context.Background() backoff *BackoffHandler
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister)
if err != nil {
// RPC stream open error
return err
}
defer rpcClient.Close()
// gracePeriod is encoded in int64 using capnproto
return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds())
} }
func LogServerInfo( func (cf *connectedFuse) Connected() {
promise tunnelrpc.ServerInfo_Promise, cf.fuse.Fuse(true)
connectionID uint8, cf.backoff.SetGracePeriod()
metrics *TunnelMetrics,
logger logger.Service,
tunnelEventChan chan<- ui.TunnelEvent,
) {
serverInfoMessage, err := promise.Struct()
if err != nil {
logger.Errorf("Failed to retrieve server information: %s", err)
return
}
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
if err != nil {
logger.Errorf("Failed to retrieve server information: %s", err)
return
}
// If launch-ui flag is set, send connect msg
if tunnelEventChan != nil {
tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: serverInfo.LocationName}
}
logger.Infof("Connected to %s", serverInfo.LocationName)
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
} }
func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) { func (cf *connectedFuse) IsConnected() bool {
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { return cf.fuse.Value()
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
dialler, ok := rule.Service.(websocket.Dialler)
if !ok {
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
}
conn, response, err := websocket.ClientConnect(req, dialler)
if err != nil {
return nil, err
}
defer conn.Close()
err = wsResp.WriteRespHeaders(response)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), wsResp)
return response, nil
}
func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10)
}
// Print out the given lines in a nice ASCII box.
func asciiBox(lines []string, padding int) (box []string) {
maxLen := maxLen(lines)
spacer := strings.Repeat(" ", padding)
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
box = append(box, border)
for _, line := range lines {
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
}
box = append(box, border)
return
}
func maxLen(lines []string) int {
max := 0
for _, line := range lines {
if len(line) > max {
max = len(line)
}
}
return max
}
func trialZoneMsg(url string) []string {
return []string{
"Your free tunnel has started! Visit it:",
" " + url,
}
} }
func activeIncidentsMsg(incidents []Incident) string { func activeIncidentsMsg(incidents []Incident) string {
@ -741,26 +377,3 @@ func activeIncidentsMsg(incidents []Incident) string {
return preamble + " " + strings.Join(incidentStrings, "; ") return preamble + " " + strings.Join(incidentStrings, "; ")
} }
func findCfRayHeader(h1 *http.Request) string {
return h1.Header.Get("Cf-Ray")
}
func isLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
}
func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) {
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
defer openStreamCancel()
stream, err := muxer.OpenRPCStream(openStreamCtx)
if err != nil {
return tunnelpogs.TunnelServer_PogsClient{}, err
}
rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger)
if err != nil {
// RPC stream open error
return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName)
}
return rpcClient, nil
}

View File

@ -17,11 +17,6 @@ import (
const ( const (
OriginCAPoolFlag = "origin-ca-pool" OriginCAPoolFlag = "origin-ca-pool"
CaCertFlag = "cacert" CaCertFlag = "cacert"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
) )
// CertReloader can load and reload a TLS certificate from a particular filepath. // CertReloader can load and reload a TLS certificate from a particular filepath.
@ -123,16 +118,12 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
return certPool, nil return certPool, nil
} }
func CreateTunnelConfig(c *cli.Context, isNamedTunnel bool) (*tls.Config, error) { func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) {
var rootCAs []string var rootCAs []string
if c.String(CaCertFlag) != "" { if c.String(CaCertFlag) != "" {
rootCAs = append(rootCAs, c.String(CaCertFlag)) rootCAs = append(rootCAs, c.String(CaCertFlag))
} }
serverName := edgeH2muxTLSServerName
if isNamedTunnel {
serverName = edgeH2TLSServerName
}
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName} userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
tlsConfig, err := GetConfig(userConfig) tlsConfig, err := GetConfig(userConfig)
if err != nil { if err != nil {

View File

@ -93,6 +93,11 @@ type RegistrationServer_PogsClient struct {
Conn *rpc.Conn Conn *rpc.Conn
} }
func (c RegistrationServer_PogsClient) Close() error {
c.Client.Close()
return c.Conn.Close()
}
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) { func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := tunnelrpc.TunnelServer{Client: c.Client} client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error { promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {

View File

@ -66,7 +66,15 @@ func ValidateHostname(hostname string) (string, error) {
// but when it does not, the path is preserved: // but when it does not, the path is preserved:
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/" // ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
// This is arguably a bug, but changing it might break some cloudflared users. // This is arguably a bug, but changing it might break some cloudflared users.
func ValidateUrl(originUrl string) (string, error) { func ValidateUrl(originUrl string) (*url.URL, error) {
urlStr, err := validateUrlString(originUrl)
if err != nil {
return nil, err
}
return url.Parse(urlStr)
}
func validateUrlString(originUrl string) (string, error) {
if originUrl == "" { if originUrl == "" {
return "", fmt.Errorf("URL should not be empty") return "", fmt.Errorf("URL should not be empty")
} }
@ -157,12 +165,8 @@ func validateIP(scheme, host, port string) (string, error) {
return fmt.Sprintf("%s://%s", scheme, host), nil return fmt.Sprintf("%s://%s", scheme, host), nil
} }
func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error { // originURL shouldn't be a pointer, because this function might change the scheme
parsedURL, err := url.Parse(originURL) func ValidateHTTPService(originURL url.URL, hostname string, transport http.RoundTripper) error {
if err != nil {
return err
}
client := &http.Client{ client := &http.Client{
Transport: transport, Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
@ -171,7 +175,7 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
Timeout: validationTimeout, Timeout: validationTimeout,
} }
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) initialRequest, err := http.NewRequest("GET", originURL.String(), nil)
if err != nil { if err != nil {
return err return err
} }
@ -183,10 +187,10 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
} }
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := parsedURL.Scheme oldScheme := originURL.Scheme
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) originURL.Scheme = toggleProtocol(originURL.Scheme)
secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil) secondRequest, err := http.NewRequest("GET", originURL.String(), nil)
if err != nil { if err != nil {
return err return err
} }
@ -195,12 +199,12 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
if secondErr == nil { // Worked this time--advise the user to switch protocols if secondErr == nil { // Worked this time--advise the user to switch protocols
resp.Body.Close() resp.Body.Close()
return errors.Errorf( return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s", "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v",
parsedURL.Host, originURL.Host,
oldScheme, oldScheme,
parsedURL.Scheme, originURL.Scheme,
initialErr, initialErr,
parsedURL, originURL,
) )
} }
@ -224,12 +228,12 @@ type Access struct {
} }
func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) { func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
domainURL, err := ValidateUrl(domain) domainURL, err := validateUrlString(domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
issuerURL, err := ValidateUrl(issuer) issuerURL, err := validateUrlString(issuer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -101,7 +101,7 @@ func TestValidateUrl(t *testing.T) {
for i, testCase := range testCases { for i, testCase := range testCases {
validUrl, err := ValidateUrl(testCase.input) validUrl, err := ValidateUrl(testCase.input)
assert.NoError(t, err, "test case %v", i) assert.NoError(t, err, "test case %v", i)
assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i) assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
} }
validUrl, err := ValidateUrl("") validUrl, err := ValidateUrl("")
@ -123,7 +123,7 @@ func TestToggleProtocol(t *testing.T) {
// Happy path 1: originURL is HTTP, and HTTP connections work // Happy path 1: originURL is HTTP, and HTTP connections work
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
originURL := "http://127.0.0.1/" originURL := mustParse(t, "http://127.0.0.1/")
hostname := "example.com" hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@ -151,7 +151,7 @@ func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
// Happy path 2: originURL is HTTPS, and HTTPS connections work // Happy path 2: originURL is HTTPS, and HTTPS connections work
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
originURL := "https://127.0.0.1/" originURL := mustParse(t, "https://127.0.0.1:1234/")
hostname := "example.com" hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@ -179,7 +179,7 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
// Error path 1: originURL is HTTPS, but HTTP connections work // Error path 1: originURL is HTTPS, but HTTP connections work
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := "https://127.0.0.1:1234/" originURL := mustParse(t, "https://127.0.0.1:1234/")
hostname := "example.com" hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@ -207,10 +207,13 @@ func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
// Error path 2: originURL is HTTP, but HTTPS connections work // Error path 2: originURL is HTTP, but HTTPS connections work
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
originURL := "http://127.0.0.1:1234/" originURLWithPort := url.URL{
Scheme: "http",
Host: "127.0.0.1:1234",
}
hostname := "example.com" hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname) assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" { if req.URL.Scheme == "http" {
return nil, assert.AnError return nil, assert.AnError
@ -221,7 +224,7 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
panic("Shouldn't reach here") panic("Shouldn't reach here")
}))) })))
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname) assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" { if req.URL.Scheme == "http" {
return nil, assert.AnError return nil, assert.AnError
@ -250,12 +253,14 @@ func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer redirectServer.Close() defer redirectServer.Close()
assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport)) redirectServerURL, err := url.Parse(redirectServer.URL)
assert.NoError(t, err)
assert.NoError(t, ValidateHTTPService(*redirectServerURL, hostname, redirectClient.Transport))
} }
// Ensure validation times out when origin URL is nonresponsive // Ensure validation times out when origin URL is nonresponsive
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) { func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
originURL := "http://127.0.0.1/" originURL := mustParse(t, "http://127.0.0.1/")
hostname := "example.com" hostname := "example.com"
oldValidationTimeout := validationTimeout oldValidationTimeout := validationTimeout
defer func() { defer func() {
@ -371,3 +376,9 @@ func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *h
return server, client, nil return server, client, nil
} }
func mustParse(t *testing.T, originURL string) url.URL {
parsedURL, err := url.Parse(originURL)
assert.NoError(t, err)
return *parsedURL
}