TUN-3462: Refactor cloudflared to separate origin from connection
This commit is contained in:
parent
a5a5b93b64
commit
9ac40dcf04
|
@ -2,7 +2,6 @@ package access
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
|
@ -17,16 +16,11 @@ import (
|
||||||
|
|
||||||
// StartForwarder starts a client side websocket forward
|
// StartForwarder starts a client side websocket forward
|
||||||
func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, logger logger.Service) error {
|
func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, logger logger.Service) error {
|
||||||
validURLString, err := validation.ValidateUrl(forwarder.Listener)
|
validURL, err := validation.ValidateUrl(forwarder.Listener)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error validating origin URL")
|
return errors.Wrap(err, "error validating origin URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
validURL, err := url.Parse(validURLString)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "error parsing origin URL")
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the headers from the config file and add to the request
|
// get the headers from the config file and add to the request
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
if forwarder.TokenClientID != "" {
|
if forwarder.TokenClientID != "" {
|
||||||
|
@ -106,12 +100,7 @@ func ssh(c *cli.Context) error {
|
||||||
wsConn := carrier.NewWSConnection(logger, false)
|
wsConn := carrier.NewWSConnection(logger, false)
|
||||||
|
|
||||||
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
|
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
|
||||||
localForwarder, err := config.ValidateUrl(c, true)
|
forwarder, err := config.ValidateUrl(c, true)
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error validating origin URL: %s", err)
|
|
||||||
return errors.Wrap(err, "error validating origin URL")
|
|
||||||
}
|
|
||||||
forwarder, err := url.Parse(localForwarder)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error validating origin URL: %s", err)
|
logger.Errorf("Error validating origin URL: %s", err)
|
||||||
return errors.Wrap(err, "error validating origin URL")
|
return errors.Wrap(err, "error validating origin URL")
|
||||||
|
|
|
@ -2,6 +2,7 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -189,11 +190,11 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
|
||||||
|
|
||||||
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
|
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
|
||||||
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
|
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
|
||||||
func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) {
|
func ValidateUrl(c *cli.Context, allowFromArgs bool) (*url.URL, error) {
|
||||||
var url = c.String("url")
|
var url = c.String("url")
|
||||||
if allowFromArgs && c.NArg() > 0 {
|
if allowFromArgs && c.NArg() > 0 {
|
||||||
if c.IsSet("url") {
|
if c.IsSet("url") {
|
||||||
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
||||||
}
|
}
|
||||||
url = c.Args().Get(0)
|
url = c.Args().Get(0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/dbconnect"
|
"github.com/cloudflare/cloudflared/dbconnect"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
@ -247,7 +248,7 @@ func StartServer(
|
||||||
version string,
|
version string,
|
||||||
shutdownC,
|
shutdownC,
|
||||||
graceShutdownC chan struct{},
|
graceShutdownC chan struct{},
|
||||||
namedTunnel *origin.NamedTunnelConfig,
|
namedTunnel *connection.NamedTunnelConfig,
|
||||||
log logger.Service,
|
log logger.Service,
|
||||||
isUIEnabled bool,
|
isUIEnabled bool,
|
||||||
) error {
|
) error {
|
||||||
|
@ -366,7 +367,7 @@ func StartServer(
|
||||||
return errors.Wrap(err, "error setting up transport logger")
|
return errors.Wrap(err, "error setting up transport logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel)
|
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -386,10 +387,6 @@ func StartServer(
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if isUIEnabled {
|
if isUIEnabled {
|
||||||
const tunnelEventChanBufferSize = 16
|
|
||||||
tunnelEventChan := make(chan ui.TunnelEvent, tunnelEventChanBufferSize)
|
|
||||||
tunnelConfig.TunnelEventChan = tunnelEventChan
|
|
||||||
|
|
||||||
tunnelInfo := ui.NewUIModel(
|
tunnelInfo := ui.NewUIModel(
|
||||||
version,
|
version,
|
||||||
hostname,
|
hostname,
|
||||||
|
@ -402,7 +399,7 @@ func StartServer(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelEventChan)
|
tunnelInfo.LaunchUI(ctx, log, logLevels, tunnelConfig.TunnelEventChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log)
|
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), log)
|
||||||
|
@ -986,7 +983,7 @@ func configureLoggingFlags(shouldHide bool) []cli.Flag {
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "transport-loglevel",
|
Name: "transport-loglevel",
|
||||||
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
|
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
|
||||||
Value: "fatal",
|
Value: "info",
|
||||||
Usage: "Transport logging level(previously called protocol logging level) {fatal, error, info, debug}",
|
Usage: "Transport logging level(previously called protocol logging level) {fatal, error, info, debug}",
|
||||||
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL", "TUNNEL_TRANSPORT_LOGLEVEL"},
|
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL", "TUNNEL_TRANSPORT_LOGLEVEL"},
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
|
|
|
@ -9,6 +9,9 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/origin"
|
"github.com/cloudflare/cloudflared/origin"
|
||||||
|
@ -154,10 +157,10 @@ func prepareTunnelConfig(
|
||||||
version string,
|
version string,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
transportLogger logger.Service,
|
transportLogger logger.Service,
|
||||||
namedTunnel *origin.NamedTunnelConfig,
|
namedTunnel *connection.NamedTunnelConfig,
|
||||||
|
uiIsEnabled bool,
|
||||||
) (*origin.TunnelConfig, error) {
|
) (*origin.TunnelConfig, error) {
|
||||||
isNamedTunnel := namedTunnel != nil
|
isNamedTunnel := namedTunnel != nil
|
||||||
compatibilityMode := !isNamedTunnel
|
|
||||||
|
|
||||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -189,10 +192,11 @@ func prepareTunnelConfig(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelMetrics := origin.NewTunnelMetrics()
|
var (
|
||||||
|
ingressRules ingress.Ingress
|
||||||
var ingressRules ingress.Ingress
|
classicTunnel *connection.ClassicTunnelConfig
|
||||||
if namedTunnel != nil {
|
)
|
||||||
|
if isNamedTunnel {
|
||||||
clientUUID, err := uuid.NewRandom()
|
clientUUID, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "can't generate clientUUID")
|
return nil, errors.Wrap(err, "can't generate clientUUID")
|
||||||
|
@ -210,6 +214,13 @@ func prepareTunnelConfig(
|
||||||
if !ingressRules.IsEmpty() && c.IsSet("url") {
|
if !ingressRules.IsEmpty() && c.IsSet("url") {
|
||||||
return nil, ingress.ErrURLIncompatibleWithIngress
|
return nil, ingress.ErrURLIncompatibleWithIngress
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
classicTunnel = &connection.ClassicTunnelConfig{
|
||||||
|
Hostname: hostname,
|
||||||
|
OriginCert: originCert,
|
||||||
|
// turn off use of reconnect token and auth refresh when using named tunnels
|
||||||
|
UseReconnectToken: !isNamedTunnel && c.Bool("use-reconnect-token"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert single-origin configuration into multi-origin configuration.
|
// Convert single-origin configuration into multi-origin configuration.
|
||||||
|
@ -220,43 +231,71 @@ func prepareTunnelConfig(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel)
|
protocol := determineProtocol(namedTunnel)
|
||||||
|
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, protocol.ServerName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("unable to create TLS config to connect with edge: %s", err)
|
logger.Errorf("unable to create TLS config to connect with edge: %s", err)
|
||||||
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proxyConfig := &origin.ProxyConfig{
|
||||||
|
Client: httpTransport,
|
||||||
|
URL: originURL,
|
||||||
|
TLSConfig: httpTransport.TLSClientConfig,
|
||||||
|
HostHeader: c.String("http-host-header"),
|
||||||
|
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
|
||||||
|
Tags: tags,
|
||||||
|
}
|
||||||
|
originClient := origin.NewClient(proxyConfig, logger)
|
||||||
|
transportConfig := &connection.Config{
|
||||||
|
OriginClient: originClient,
|
||||||
|
GracePeriod: c.Duration("grace-period"),
|
||||||
|
ReplaceExisting: c.Bool("force"),
|
||||||
|
}
|
||||||
|
muxerConfig := &connection.MuxerConfig{
|
||||||
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
|
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||||
|
CompressionSetting: h2mux.CompressionSetting(c.Uint64("compression-quality")),
|
||||||
|
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
||||||
|
}
|
||||||
|
|
||||||
|
var tunnelEventChan chan ui.TunnelEvent
|
||||||
|
if uiIsEnabled {
|
||||||
|
tunnelEventChan = make(chan ui.TunnelEvent, 16)
|
||||||
|
}
|
||||||
|
|
||||||
return &origin.TunnelConfig{
|
return &origin.TunnelConfig{
|
||||||
|
ConnectionConfig: transportConfig,
|
||||||
|
ProxyConfig: proxyConfig,
|
||||||
BuildInfo: buildInfo,
|
BuildInfo: buildInfo,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
CompressionQuality: c.Uint64("compression-quality"),
|
|
||||||
EdgeAddrs: c.StringSlice("edge"),
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
GracePeriod: c.Duration("grace-period"),
|
|
||||||
HAConnections: c.Int("ha-connections"),
|
HAConnections: c.Int("ha-connections"),
|
||||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
|
||||||
Hostname: hostname,
|
|
||||||
IncidentLookup: origin.NewIncidentLookup(),
|
IncidentLookup: origin.NewIncidentLookup(),
|
||||||
IsAutoupdated: c.Bool("is-autoupdated"),
|
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||||
IsFreeTunnel: isFreeTunnel,
|
IsFreeTunnel: isFreeTunnel,
|
||||||
LBPool: c.String("lb-pool"),
|
LBPool: c.String("lb-pool"),
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
TransportLogger: transportLogger,
|
Observer: connection.NewObserver(transportLogger, tunnelEventChan, protocol),
|
||||||
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
|
||||||
Metrics: tunnelMetrics,
|
|
||||||
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
|
||||||
OriginCert: originCert,
|
|
||||||
ReportedVersion: version,
|
ReportedVersion: version,
|
||||||
Retries: c.Uint("retries"),
|
Retries: c.Uint("retries"),
|
||||||
RunFromTerminal: isRunningFromTerminal(),
|
RunFromTerminal: isRunningFromTerminal(),
|
||||||
Tags: tags,
|
TLSConfig: toEdgeTLSConfig,
|
||||||
TlsConfig: toEdgeTLSConfig,
|
|
||||||
NamedTunnel: namedTunnel,
|
NamedTunnel: namedTunnel,
|
||||||
ReplaceExisting: c.Bool("force"),
|
ClassicTunnel: classicTunnel,
|
||||||
|
MuxerConfig: muxerConfig,
|
||||||
|
TunnelEventChan: tunnelEventChan,
|
||||||
IngressRules: ingressRules,
|
IngressRules: ingressRules,
|
||||||
// turn off use of reconnect token and auth refresh when using named tunnels
|
|
||||||
UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRunningFromTerminal() bool {
|
func isRunningFromTerminal() bool {
|
||||||
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func determineProtocol(namedTunnel *connection.NamedTunnelConfig) connection.Protocol {
|
||||||
|
if namedTunnel != nil {
|
||||||
|
return namedTunnel.Protocol
|
||||||
|
}
|
||||||
|
return connection.H2mux
|
||||||
|
}
|
||||||
|
|
|
@ -14,8 +14,8 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/certutil"
|
"github.com/cloudflare/cloudflared/certutil"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/origin"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
)
|
)
|
||||||
|
@ -260,7 +260,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, ok := origin.ParseProtocol(sc.c.String("protocol"))
|
protocol, ok := connection.ParseProtocol(sc.c.String("protocol"))
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol)
|
return fmt.Errorf("%s is not valid protocol. %s", sc.c.String("protocol"), availableProtocol)
|
||||||
}
|
}
|
||||||
|
@ -269,7 +269,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
||||||
version,
|
version,
|
||||||
shutdownC,
|
shutdownC,
|
||||||
graceShutdownC,
|
graceShutdownC,
|
||||||
&origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol},
|
&connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID, Protocol: protocol},
|
||||||
sc.logger,
|
sc.logger,
|
||||||
sc.isUIEnabled,
|
sc.isUIEnabled,
|
||||||
)
|
)
|
||||||
|
|
|
@ -78,17 +78,20 @@ var (
|
||||||
Name: "credentials-file",
|
Name: "credentials-file",
|
||||||
Aliases: []string{credFileFlagAlias},
|
Aliases: []string{credFileFlagAlias},
|
||||||
Usage: "File path of tunnel credentials",
|
Usage: "File path of tunnel credentials",
|
||||||
|
EnvVars: []string{"TUNNEL_CRED_FILE"},
|
||||||
})
|
})
|
||||||
forceDeleteFlag = &cli.BoolFlag{
|
forceDeleteFlag = &cli.BoolFlag{
|
||||||
Name: "force",
|
Name: "force",
|
||||||
Aliases: []string{"f"},
|
Aliases: []string{"f"},
|
||||||
Usage: "Allows you to delete a tunnel, even if it has active connections.",
|
Usage: "Allows you to delete a tunnel, even if it has active connections.",
|
||||||
|
EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"},
|
||||||
}
|
}
|
||||||
selectProtocolFlag = &cli.StringFlag{
|
selectProtocolFlag = &cli.StringFlag{
|
||||||
Name: "protocol",
|
Name: "protocol",
|
||||||
Value: "h2mux",
|
Value: "h2mux",
|
||||||
Aliases: []string{"p"},
|
Aliases: []string{"p"},
|
||||||
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol),
|
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", availableProtocol),
|
||||||
|
EnvVars: []string{"TUNNEL_TRANSPORT_PROTOCOL"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,91 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
||||||
|
edgeH2muxTLSServerName = "cftunnel.com"
|
||||||
|
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||||
|
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||||
|
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
OriginClient OriginClient
|
||||||
|
GracePeriod time.Duration
|
||||||
|
ReplaceExisting bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type NamedTunnelConfig struct {
|
||||||
|
Auth pogs.TunnelAuth
|
||||||
|
ID uuid.UUID
|
||||||
|
Client pogs.ClientInfo
|
||||||
|
Protocol Protocol
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClassicTunnelConfig struct {
|
||||||
|
Hostname string
|
||||||
|
OriginCert []byte
|
||||||
|
// feature-flag to use new edge reconnect tokens
|
||||||
|
UseReconnectToken bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClassicTunnelConfig) IsTrialZone() bool {
|
||||||
|
return c.Hostname == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type Protocol int64
|
||||||
|
|
||||||
|
const (
|
||||||
|
H2mux Protocol = iota
|
||||||
|
HTTP2
|
||||||
|
)
|
||||||
|
|
||||||
|
func ParseProtocol(s string) (Protocol, bool) {
|
||||||
|
switch s {
|
||||||
|
case "h2mux":
|
||||||
|
return H2mux, true
|
||||||
|
case "http2":
|
||||||
|
return HTTP2, true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Protocol) ServerName() string {
|
||||||
|
switch p {
|
||||||
|
case H2mux:
|
||||||
|
return edgeH2muxTLSServerName
|
||||||
|
case HTTP2:
|
||||||
|
return edgeH2TLSServerName
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OriginClient interface {
|
||||||
|
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseWriter interface {
|
||||||
|
WriteRespHeaders(*http.Response) error
|
||||||
|
WriteErrorResponse(error)
|
||||||
|
io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectedFuse interface {
|
||||||
|
Connected()
|
||||||
|
IsConnected() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func uint8ToString(input uint8) string {
|
||||||
|
return strconv.FormatUint(uint64(input), 10)
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DuplicateConnectionError = "EDUPCONN"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterTunnel error from client
|
||||||
|
type clientRegisterTunnelError struct {
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
|
||||||
|
counter.WithLabelValues(cause.Error(), string(name)).Inc()
|
||||||
|
return clientRegisterTunnelError{cause: cause}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e clientRegisterTunnelError) Error() string {
|
||||||
|
return e.cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
type DupConnRegisterTunnelError struct{}
|
||||||
|
|
||||||
|
var errDuplicationConnection = &DupConnRegisterTunnelError{}
|
||||||
|
|
||||||
|
func (e DupConnRegisterTunnelError) Error() string {
|
||||||
|
return "already connected to this server, trying another address"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTunnel error from server
|
||||||
|
type serverRegisterTunnelError struct {
|
||||||
|
cause error
|
||||||
|
permanent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e serverRegisterTunnelError) Error() string {
|
||||||
|
return e.cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
|
||||||
|
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
|
||||||
|
return &serverRegisterTunnelError{
|
||||||
|
cause: retryable.Unwrap(),
|
||||||
|
permanent: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &serverRegisterTunnelError{
|
||||||
|
cause: err,
|
||||||
|
permanent: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type muxerShutdownError struct{}
|
||||||
|
|
||||||
|
func (e muxerShutdownError) Error() string {
|
||||||
|
return "muxer shutdown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) bool {
|
||||||
|
switch err.(type) {
|
||||||
|
case edgediscovery.DialError:
|
||||||
|
observer.Errorf("Connection %d unable to dial edge: %s", connIndex, err)
|
||||||
|
case h2mux.MuxerHandshakeError:
|
||||||
|
observer.Errorf("Connection %d handshake with edge server failed: %s", connIndex, err)
|
||||||
|
default:
|
||||||
|
observer.Errorf("Connection %d failed: %s", connIndex, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,216 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
muxerTimeout = 5 * time.Second
|
||||||
|
openStreamTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type h2muxConnection struct {
|
||||||
|
config *Config
|
||||||
|
muxerConfig *MuxerConfig
|
||||||
|
originURL string
|
||||||
|
muxer *h2mux.Muxer
|
||||||
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||||
|
connIndexStr string
|
||||||
|
connIndex uint8
|
||||||
|
|
||||||
|
observer *Observer
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxerConfig struct {
|
||||||
|
HeartbeatInterval time.Duration
|
||||||
|
MaxHeartbeats uint64
|
||||||
|
CompressionSetting h2mux.CompressionSetting
|
||||||
|
MetricsUpdateFreq time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.Service) *h2mux.MuxerConfig {
|
||||||
|
return &h2mux.MuxerConfig{
|
||||||
|
Timeout: muxerTimeout,
|
||||||
|
Handler: h,
|
||||||
|
IsClient: true,
|
||||||
|
HeartbeatInterval: mc.HeartbeatInterval,
|
||||||
|
MaxHeartbeats: mc.MaxHeartbeats,
|
||||||
|
Logger: logger,
|
||||||
|
CompressionQuality: mc.CompressionSetting,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
||||||
|
func NewH2muxConnection(ctx context.Context,
|
||||||
|
config *Config,
|
||||||
|
muxerConfig *MuxerConfig,
|
||||||
|
originURL string,
|
||||||
|
edgeConn net.Conn,
|
||||||
|
connIndex uint8,
|
||||||
|
observer *Observer,
|
||||||
|
) (*h2muxConnection, error, bool) {
|
||||||
|
h := &h2muxConnection{
|
||||||
|
config: config,
|
||||||
|
muxerConfig: muxerConfig,
|
||||||
|
originURL: originURL,
|
||||||
|
connIndexStr: uint8ToString(connIndex),
|
||||||
|
connIndex: connIndex,
|
||||||
|
observer: observer,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Establish a muxed connection with the edge
|
||||||
|
// Client mux handshake with agent server
|
||||||
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer), h2mux.ActiveStreams)
|
||||||
|
if err != nil {
|
||||||
|
recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
|
||||||
|
return nil, err, recoverable
|
||||||
|
}
|
||||||
|
h.muxer = muxer
|
||||||
|
return h, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, credentialManager CredentialManager, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
|
||||||
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
return h.serveMuxer(serveCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
stream, err := h.newRPCStream(serveCtx, register)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
|
||||||
|
defer rpcClient.close()
|
||||||
|
|
||||||
|
if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
connectedFuse.Connected()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
h.controlLoop(serveCtx, connectedFuse, true)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return errGroup.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
|
||||||
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
return h.serveMuxer(serveCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
connectedFuse.Connected()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
|
||||||
|
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// log errors and proceed to RegisterTunnel
|
||||||
|
h.observer.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", h.connIndex, err)
|
||||||
|
}
|
||||||
|
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
h.controlLoop(serveCtx, connectedFuse, false)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return errGroup.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
|
||||||
|
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||||
|
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||||
|
// here to notify other routines to stop
|
||||||
|
err := h.muxer.Serve(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return muxerShutdownError{}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
|
||||||
|
updateMetricsTickC := time.Tick(h.muxerConfig.MetricsUpdateFreq)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// UnregisterTunnel blocks until the RPC call returns
|
||||||
|
if connectedFuse.IsConnected() {
|
||||||
|
h.unregister(isNamedTunnel)
|
||||||
|
}
|
||||||
|
h.muxer.Shutdown()
|
||||||
|
return
|
||||||
|
case <-updateMetricsTickC:
|
||||||
|
h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
|
||||||
|
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||||
|
defer openStreamCancel()
|
||||||
|
stream, err := h.muxer.OpenRPCStream(openStreamCtx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
|
respWriter := &h2muxRespWriter{stream}
|
||||||
|
|
||||||
|
req, reqErr := h.newRequest(stream)
|
||||||
|
if reqErr != nil {
|
||||||
|
respWriter.WriteErrorResponse(reqErr)
|
||||||
|
return reqErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||||
|
req, err := http.NewRequest("GET", h.originURL, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||||
|
}
|
||||||
|
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "invalid request received")
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type h2muxRespWriter struct {
|
||||||
|
*h2mux.MuxedStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
return rp.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
|
||||||
|
rp.WriteHeaders([]h2mux.Header{
|
||||||
|
{Name: ":status", Value: "502"},
|
||||||
|
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
|
||||||
|
})
|
||||||
|
rp.Write([]byte("502 Bad Gateway"))
|
||||||
|
}
|
|
@ -0,0 +1,253 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
||||||
|
websocketUpgrade = "websocket"
|
||||||
|
controlStreamUpgrade = "control-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTP2Connection struct {
|
||||||
|
conn net.Conn
|
||||||
|
server *http2.Server
|
||||||
|
config *Config
|
||||||
|
originURL *url.URL
|
||||||
|
namedTunnel *NamedTunnelConfig
|
||||||
|
connOptions *tunnelpogs.ConnectionOptions
|
||||||
|
observer *Observer
|
||||||
|
connIndexStr string
|
||||||
|
connIndex uint8
|
||||||
|
shutdownChan chan struct{}
|
||||||
|
connectedFuse ConnectedFuse
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) {
|
||||||
|
return &HTTP2Connection{
|
||||||
|
conn: conn,
|
||||||
|
server: &http2.Server{},
|
||||||
|
config: config,
|
||||||
|
originURL: originURL,
|
||||||
|
namedTunnel: namedTunnelConfig,
|
||||||
|
connOptions: connOptions,
|
||||||
|
observer: observer,
|
||||||
|
connIndexStr: uint8ToString(connIndex),
|
||||||
|
connIndex: connIndex,
|
||||||
|
shutdownChan: make(chan struct{}),
|
||||||
|
connectedFuse: connectedFuse,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTP2Connection) Serve(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
c.close()
|
||||||
|
}()
|
||||||
|
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: c,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.URL.Scheme = c.originURL.Scheme
|
||||||
|
r.URL.Host = c.originURL.Host
|
||||||
|
|
||||||
|
respWriter := &http2RespWriter{
|
||||||
|
r: r.Body,
|
||||||
|
w: w,
|
||||||
|
}
|
||||||
|
if isControlStreamUpgrade(r) {
|
||||||
|
err := c.serveControlStream(r.Context(), respWriter)
|
||||||
|
if err != nil {
|
||||||
|
respWriter.WriteErrorResponse(err)
|
||||||
|
}
|
||||||
|
} else if isWebsocketUpgrade(r) {
|
||||||
|
wsRespWriter, err := newWSRespWriter(respWriter)
|
||||||
|
if err != nil {
|
||||||
|
respWriter.WriteErrorResponse(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stripWebsocketUpgradeHeader(r)
|
||||||
|
c.config.OriginClient.Proxy(wsRespWriter, r, true)
|
||||||
|
} else {
|
||||||
|
c.config.OriginClient.Proxy(respWriter, r, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error {
|
||||||
|
stream, err := newWSRespWriter(h2RespWriter)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rpcClient := newRegistrationRPCClient(ctx, stream, c.observer)
|
||||||
|
defer rpcClient.close()
|
||||||
|
|
||||||
|
if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.connectedFuse.Connected()
|
||||||
|
|
||||||
|
<-c.shutdownChan
|
||||||
|
c.gracefulShutdown(ctx, rpcClient)
|
||||||
|
close(c.shutdownChan)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTP2Connection) registerConnection(
|
||||||
|
ctx context.Context,
|
||||||
|
rpcClient tunnelpogs.RegistrationServer_PogsClient,
|
||||||
|
) error {
|
||||||
|
connDetail, err := rpcClient.RegisterConnection(
|
||||||
|
ctx,
|
||||||
|
c.namedTunnel.Auth,
|
||||||
|
c.namedTunnel.ID,
|
||||||
|
c.connIndex,
|
||||||
|
c.connOptions,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
c.observer.Errorf("Cannot register connection, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod)
|
||||||
|
defer cancel()
|
||||||
|
rpcClient.client.UnregisterConnection(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTP2Connection) close() {
|
||||||
|
// Send signal to control loop to start graceful shutdown
|
||||||
|
c.shutdownChan <- struct{}{}
|
||||||
|
// Wait for control loop to close channel
|
||||||
|
<-c.shutdownChan
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type http2RespWriter struct {
|
||||||
|
r io.Reader
|
||||||
|
w http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
dest := rp.w.Header()
|
||||||
|
userHeaders := make(http.Header, len(resp.Header))
|
||||||
|
for header, values := range resp.Header {
|
||||||
|
// Since these are http2 headers, they're required to be lowercase
|
||||||
|
h2name := strings.ToLower(header)
|
||||||
|
for _, v := range values {
|
||||||
|
if h2name == "content-length" {
|
||||||
|
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||||
|
// so it should be sent as an HTTP/2 response header.
|
||||||
|
dest.Add(h2name, v)
|
||||||
|
// Since these are http2 headers, they're required to be lowercase
|
||||||
|
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
|
||||||
|
// User headers, on the other hand, must all be serialized so that
|
||||||
|
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||||
|
userHeaders.Add(h2name, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform user header serialization and set them in the single header
|
||||||
|
dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
||||||
|
status := resp.StatusCode
|
||||||
|
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
||||||
|
if status == http.StatusSwitchingProtocols {
|
||||||
|
status = http.StatusOK
|
||||||
|
}
|
||||||
|
rp.w.WriteHeader(status)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *http2RespWriter) WriteErrorResponse(err error) {
|
||||||
|
jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared})
|
||||||
|
if err == nil {
|
||||||
|
rp.w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader))
|
||||||
|
}
|
||||||
|
rp.w.WriteHeader(http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
||||||
|
return rp.r.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wr *http2RespWriter) Write(p []byte) (n int, err error) {
|
||||||
|
return wr.w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsRespWriter struct {
|
||||||
|
h2 *http2RespWriter
|
||||||
|
flusher http.Flusher
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) {
|
||||||
|
flusher, ok := h2.w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
|
||||||
|
}
|
||||||
|
return &wsRespWriter{
|
||||||
|
h2: h2,
|
||||||
|
flusher: flusher,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
err := rw.h2.WriteRespHeaders(resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rw.flusher.Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *wsRespWriter) WriteErrorResponse(err error) {
|
||||||
|
rw.h2.WriteErrorResponse(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *wsRespWriter) Read(p []byte) (n int, err error) {
|
||||||
|
return rw.h2.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *wsRespWriter) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = rw.h2.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.flusher.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *wsRespWriter) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isControlStreamUpgrade(r *http.Request) bool {
|
||||||
|
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWebsocketUpgrade(r *http.Request) bool {
|
||||||
|
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripWebsocketUpgradeHeader(r *http.Request) {
|
||||||
|
r.Header.Del(internalUpgradeHeader)
|
||||||
|
}
|
|
@ -0,0 +1,409 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MetricsNamespace = "cloudflared"
|
||||||
|
TunnelSubsystem = "tunnel"
|
||||||
|
muxerSubsystem = "muxer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type muxerMetrics struct {
|
||||||
|
rtt *prometheus.GaugeVec
|
||||||
|
rttMin *prometheus.GaugeVec
|
||||||
|
rttMax *prometheus.GaugeVec
|
||||||
|
receiveWindowAve *prometheus.GaugeVec
|
||||||
|
sendWindowAve *prometheus.GaugeVec
|
||||||
|
receiveWindowMin *prometheus.GaugeVec
|
||||||
|
receiveWindowMax *prometheus.GaugeVec
|
||||||
|
sendWindowMin *prometheus.GaugeVec
|
||||||
|
sendWindowMax *prometheus.GaugeVec
|
||||||
|
inBoundRateCurr *prometheus.GaugeVec
|
||||||
|
inBoundRateMin *prometheus.GaugeVec
|
||||||
|
inBoundRateMax *prometheus.GaugeVec
|
||||||
|
outBoundRateCurr *prometheus.GaugeVec
|
||||||
|
outBoundRateMin *prometheus.GaugeVec
|
||||||
|
outBoundRateMax *prometheus.GaugeVec
|
||||||
|
compBytesBefore *prometheus.GaugeVec
|
||||||
|
compBytesAfter *prometheus.GaugeVec
|
||||||
|
compRateAve *prometheus.GaugeVec
|
||||||
|
}
|
||||||
|
|
||||||
|
type tunnelMetrics struct {
|
||||||
|
timerRetries prometheus.Gauge
|
||||||
|
serverLocations *prometheus.GaugeVec
|
||||||
|
// locationLock is a mutex for oldServerLocations
|
||||||
|
locationLock sync.Mutex
|
||||||
|
// oldServerLocations stores the last server the tunnel was connected to
|
||||||
|
oldServerLocations map[string]string
|
||||||
|
|
||||||
|
regSuccess *prometheus.CounterVec
|
||||||
|
regFail *prometheus.CounterVec
|
||||||
|
rpcFail *prometheus.CounterVec
|
||||||
|
|
||||||
|
muxerMetrics *muxerMetrics
|
||||||
|
tunnelsHA tunnelsForHA
|
||||||
|
userHostnamesCounts *prometheus.CounterVec
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMuxerMetrics() *muxerMetrics {
|
||||||
|
rtt := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "rtt",
|
||||||
|
Help: "Round-trip time in millisecond",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rtt)
|
||||||
|
|
||||||
|
rttMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "rtt_min",
|
||||||
|
Help: "Shortest round-trip time in millisecond",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rttMin)
|
||||||
|
|
||||||
|
rttMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "rtt_max",
|
||||||
|
Help: "Longest round-trip time in millisecond",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rttMax)
|
||||||
|
|
||||||
|
receiveWindowAve := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "receive_window_ave",
|
||||||
|
Help: "Average receive window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(receiveWindowAve)
|
||||||
|
|
||||||
|
sendWindowAve := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "send_window_ave",
|
||||||
|
Help: "Average send window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(sendWindowAve)
|
||||||
|
|
||||||
|
receiveWindowMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "receive_window_min",
|
||||||
|
Help: "Smallest receive window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(receiveWindowMin)
|
||||||
|
|
||||||
|
receiveWindowMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "receive_window_max",
|
||||||
|
Help: "Largest receive window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(receiveWindowMax)
|
||||||
|
|
||||||
|
sendWindowMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "send_window_min",
|
||||||
|
Help: "Smallest send window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(sendWindowMin)
|
||||||
|
|
||||||
|
sendWindowMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "send_window_max",
|
||||||
|
Help: "Largest send window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(sendWindowMax)
|
||||||
|
|
||||||
|
inBoundRateCurr := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "inbound_bytes_per_sec_curr",
|
||||||
|
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(inBoundRateCurr)
|
||||||
|
|
||||||
|
inBoundRateMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "inbound_bytes_per_sec_min",
|
||||||
|
Help: "Minimum non-zero inbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(inBoundRateMin)
|
||||||
|
|
||||||
|
inBoundRateMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "inbound_bytes_per_sec_max",
|
||||||
|
Help: "Maximum inbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(inBoundRateMax)
|
||||||
|
|
||||||
|
outBoundRateCurr := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "outbound_bytes_per_sec_curr",
|
||||||
|
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(outBoundRateCurr)
|
||||||
|
|
||||||
|
outBoundRateMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "outbound_bytes_per_sec_min",
|
||||||
|
Help: "Minimum non-zero outbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(outBoundRateMin)
|
||||||
|
|
||||||
|
outBoundRateMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "outbound_bytes_per_sec_max",
|
||||||
|
Help: "Maximum outbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(outBoundRateMax)
|
||||||
|
|
||||||
|
compBytesBefore := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "comp_bytes_before",
|
||||||
|
Help: "Bytes sent via cross-stream compression, pre compression",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(compBytesBefore)
|
||||||
|
|
||||||
|
compBytesAfter := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "comp_bytes_after",
|
||||||
|
Help: "Bytes sent via cross-stream compression, post compression",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(compBytesAfter)
|
||||||
|
|
||||||
|
compRateAve := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: muxerSubsystem,
|
||||||
|
Name: "comp_rate_ave",
|
||||||
|
Help: "Average outbound cross-stream compression ratio",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(compRateAve)
|
||||||
|
|
||||||
|
return &muxerMetrics{
|
||||||
|
rtt: rtt,
|
||||||
|
rttMin: rttMin,
|
||||||
|
rttMax: rttMax,
|
||||||
|
receiveWindowAve: receiveWindowAve,
|
||||||
|
sendWindowAve: sendWindowAve,
|
||||||
|
receiveWindowMin: receiveWindowMin,
|
||||||
|
receiveWindowMax: receiveWindowMax,
|
||||||
|
sendWindowMin: sendWindowMin,
|
||||||
|
sendWindowMax: sendWindowMax,
|
||||||
|
inBoundRateCurr: inBoundRateCurr,
|
||||||
|
inBoundRateMin: inBoundRateMin,
|
||||||
|
inBoundRateMax: inBoundRateMax,
|
||||||
|
outBoundRateCurr: outBoundRateCurr,
|
||||||
|
outBoundRateMin: outBoundRateMin,
|
||||||
|
outBoundRateMax: outBoundRateMax,
|
||||||
|
compBytesBefore: compBytesBefore,
|
||||||
|
compBytesAfter: compBytesAfter,
|
||||||
|
compRateAve: compRateAve,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||||
|
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
|
||||||
|
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
|
||||||
|
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
|
||||||
|
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
|
||||||
|
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
|
||||||
|
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
|
||||||
|
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
|
||||||
|
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
|
||||||
|
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
|
||||||
|
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
|
||||||
|
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
|
||||||
|
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
|
||||||
|
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
|
||||||
|
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
|
||||||
|
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
|
||||||
|
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
|
||||||
|
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
|
||||||
|
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertRTTMilliSec(t time.Duration) float64 {
|
||||||
|
return float64(t / time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metrics that can be collected without asking the edge
|
||||||
|
func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
|
||||||
|
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "max_concurrent_requests_per_tunnel",
|
||||||
|
Help: "Largest number of concurrent requests proxied through each tunnel so far",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
||||||
|
|
||||||
|
timerRetries := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "timer_retries",
|
||||||
|
Help: "Unacknowledged heart beats count",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(timerRetries)
|
||||||
|
|
||||||
|
serverLocations := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "server_locations",
|
||||||
|
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
|
||||||
|
},
|
||||||
|
[]string{"connection_id", "location"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(serverLocations)
|
||||||
|
|
||||||
|
rpcFail := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "tunnel_rpc_fail",
|
||||||
|
Help: "Count of RPC connection errors by type",
|
||||||
|
},
|
||||||
|
[]string{"error", "rpcName"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rpcFail)
|
||||||
|
|
||||||
|
registerFail := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "tunnel_register_fail",
|
||||||
|
Help: "Count of tunnel registration errors by type",
|
||||||
|
},
|
||||||
|
[]string{"error", "rpcName"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(registerFail)
|
||||||
|
|
||||||
|
userHostnamesCounts := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "user_hostnames_counts",
|
||||||
|
Help: "Which user hostnames cloudflared is serving",
|
||||||
|
},
|
||||||
|
[]string{"userHostname"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(userHostnamesCounts)
|
||||||
|
|
||||||
|
registerSuccess := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Namespace: MetricsNamespace,
|
||||||
|
Subsystem: TunnelSubsystem,
|
||||||
|
Name: "tunnel_register_success",
|
||||||
|
Help: "Count of successful tunnel registrations",
|
||||||
|
},
|
||||||
|
[]string{"rpcName"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(registerSuccess)
|
||||||
|
var muxerMetrics *muxerMetrics
|
||||||
|
if protocol == H2mux {
|
||||||
|
muxerMetrics = newMuxerMetrics()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tunnelMetrics{
|
||||||
|
timerRetries: timerRetries,
|
||||||
|
serverLocations: serverLocations,
|
||||||
|
oldServerLocations: make(map[string]string),
|
||||||
|
muxerMetrics: muxerMetrics,
|
||||||
|
tunnelsHA: NewTunnelsForHA(),
|
||||||
|
regSuccess: registerSuccess,
|
||||||
|
regFail: registerFail,
|
||||||
|
rpcFail: rpcFail,
|
||||||
|
userHostnamesCounts: userHostnamesCounts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||||
|
t.muxerMetrics.update(connectionID, metrics)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||||
|
t.locationLock.Lock()
|
||||||
|
defer t.locationLock.Unlock()
|
||||||
|
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
||||||
|
return
|
||||||
|
} else if ok {
|
||||||
|
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
|
||||||
|
}
|
||||||
|
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
|
||||||
|
t.oldServerLocations[connectionID] = loc
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Observer struct {
|
||||||
|
logger.Service
|
||||||
|
metrics *tunnelMetrics
|
||||||
|
tunnelEventChan chan<- ui.TunnelEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer {
|
||||||
|
return &Observer{
|
||||||
|
logger,
|
||||||
|
newTunnelMetrics(protocol),
|
||||||
|
tunnelEventChan,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Observer) logServerInfo(connectionID uint8, location, msg string) {
|
||||||
|
// If launch-ui flag is set, send connect msg
|
||||||
|
if o.tunnelEventChan != nil {
|
||||||
|
o.tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: location}
|
||||||
|
}
|
||||||
|
o.Infof(msg)
|
||||||
|
o.metrics.registerServerLocation(uint8ToString(connectionID), location)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Observer) logTrialHostname(registration *tunnelpogs.TunnelRegistration) error {
|
||||||
|
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
|
||||||
|
if o.tunnelEventChan == nil {
|
||||||
|
if registrationURL, err := url.Parse(registration.Url); err == nil {
|
||||||
|
for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) {
|
||||||
|
o.Info(line)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
o.Error("Failed to connect tunnel, please try again.")
|
||||||
|
return fmt.Errorf("empty URL in response from Cloudflare edge")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out the given lines in a nice ASCII box.
|
||||||
|
func asciiBox(lines []string, padding int) (box []string) {
|
||||||
|
maxLen := maxLen(lines)
|
||||||
|
spacer := strings.Repeat(" ", padding)
|
||||||
|
|
||||||
|
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
|
||||||
|
|
||||||
|
box = append(box, border)
|
||||||
|
for _, line := range lines {
|
||||||
|
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
|
||||||
|
}
|
||||||
|
box = append(box, border)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxLen(lines []string) int {
|
||||||
|
max := 0
|
||||||
|
for _, line := range lines {
|
||||||
|
if len(line) > max {
|
||||||
|
max = len(line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
|
func trialZoneMsg(url string) []string {
|
||||||
|
return []string{
|
||||||
|
"Your free tunnel has started! Visit it:",
|
||||||
|
" " + url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Observer) sendRegisteringEvent() {
|
||||||
|
if o.tunnelEventChan != nil {
|
||||||
|
o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Observer) sendConnectedEvent(connIndex uint8, location string) {
|
||||||
|
if o.tunnelEventChan != nil {
|
||||||
|
o.tunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Connected, Location: location}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Observer) sendURL(url string) {
|
||||||
|
if o.tunnelEventChan != nil {
|
||||||
|
o.tunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: url}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// can only be called once
|
||||||
|
var m = newTunnelMetrics(H2mux)
|
||||||
|
|
||||||
|
func TestRegisterServerLocation(t *testing.T) {
|
||||||
|
tunnels := 20
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
m.registerServerLocation(id, "LHR")
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
assert.Equal(t, "LHR", m.oldServerLocations[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
m.registerServerLocation(id, "AUS")
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
assert.Equal(t, "AUS", m.oldServerLocations[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -2,40 +2,276 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewTunnelRPCClient creates and returns a new RPC client, which will communicate
|
type tunnelServerClient struct {
|
||||||
// using a stream on the given muxer
|
client tunnelpogs.TunnelServer_PogsClient
|
||||||
func NewTunnelRPCClient(
|
transport rpc.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelRPCClient creates and returns a new RPC client, which will communicate using a stream on the given muxer.
|
||||||
|
// This method is exported for supervisor to call Authenticate RPC
|
||||||
|
func NewTunnelServerClient(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
stream io.ReadWriteCloser,
|
stream io.ReadWriteCloser,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
|
) *tunnelServerClient {
|
||||||
|
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
||||||
conn := rpc.NewConn(
|
conn := rpc.NewConn(
|
||||||
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
transport,
|
||||||
tunnelrpc.ConnLog(logger),
|
tunnelrpc.ConnLog(logger),
|
||||||
)
|
)
|
||||||
registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
||||||
client = tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn}
|
return &tunnelServerClient{
|
||||||
return client, nil
|
client: tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn},
|
||||||
|
transport: transport,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRegistrationRPCClient(
|
func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authResp.Outcome(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tsc *tunnelServerClient) Close() {
|
||||||
|
// Closing the client will also close the connection
|
||||||
|
tsc.client.Close()
|
||||||
|
tsc.transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type registrationServerClient struct {
|
||||||
|
client tunnelpogs.RegistrationServer_PogsClient
|
||||||
|
transport rpc.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRegistrationRPCClient(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
stream io.ReadWriteCloser,
|
stream io.ReadWriteCloser,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
) (client tunnelpogs.RegistrationServer_PogsClient, err error) {
|
) *registrationServerClient {
|
||||||
|
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
||||||
conn := rpc.NewConn(
|
conn := rpc.NewConn(
|
||||||
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
transport,
|
||||||
tunnelrpc.ConnLog(logger),
|
tunnelrpc.ConnLog(logger),
|
||||||
)
|
)
|
||||||
client = tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
|
return ®istrationServerClient{
|
||||||
return client, nil
|
client: tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
|
||||||
|
transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rsc *registrationServerClient) close() {
|
||||||
|
// Closing the client will also close the connection
|
||||||
|
rsc.client.Close()
|
||||||
|
// Closing the transport also closes the stream
|
||||||
|
rsc.transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcName string
|
||||||
|
|
||||||
|
const (
|
||||||
|
register rpcName = "register"
|
||||||
|
reconnect rpcName = "reconnect"
|
||||||
|
unregister rpcName = "unregister"
|
||||||
|
authenticate rpcName = " authenticate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func registerConnection(
|
||||||
|
ctx context.Context,
|
||||||
|
rpcClient *registrationServerClient,
|
||||||
|
config *NamedTunnelConfig,
|
||||||
|
options *tunnelpogs.ConnectionOptions,
|
||||||
|
connIndex uint8,
|
||||||
|
observer *Observer,
|
||||||
|
) error {
|
||||||
|
conn, err := rpcClient.client.RegisterConnection(
|
||||||
|
ctx,
|
||||||
|
config.Auth,
|
||||||
|
config.ID,
|
||||||
|
connIndex,
|
||||||
|
options,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == DuplicateConnectionError {
|
||||||
|
observer.metrics.regFail.WithLabelValues("dup_edge_conn", "registerConnection").Inc()
|
||||||
|
return errDuplicationConnection
|
||||||
|
}
|
||||||
|
observer.metrics.regFail.WithLabelValues("server_error", "registerConnection").Inc()
|
||||||
|
return serverRegistrationErrorFromRPC(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
|
||||||
|
|
||||||
|
observer.logServerInfo(connIndex, conn.Location, fmt.Sprintf("Connection %d registered with %s using ID %s", connIndex, conn.Location, conn.UUID))
|
||||||
|
observer.sendConnectedEvent(connIndex, conn.Location)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
|
||||||
|
h.observer.sendRegisteringEvent()
|
||||||
|
|
||||||
|
stream, err := h.newRPCStream(ctx, register)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rpcClient := NewTunnelServerClient(ctx, stream, h.observer)
|
||||||
|
defer rpcClient.Close()
|
||||||
|
|
||||||
|
h.logServerInfo(ctx, rpcClient)
|
||||||
|
registration := rpcClient.client.RegisterTunnel(
|
||||||
|
ctx,
|
||||||
|
classicTunnel.OriginCert,
|
||||||
|
classicTunnel.Hostname,
|
||||||
|
registrationOptions,
|
||||||
|
)
|
||||||
|
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||||
|
// RegisterTunnel RPC failure
|
||||||
|
return h.processRegisterTunnelError(registrationErr, register)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send free tunnel URL to UI
|
||||||
|
h.observer.sendURL(registration.Url)
|
||||||
|
credentialSetter.SetEventDigest(h.connIndex, registration.EventDigest)
|
||||||
|
return h.processRegistrationSuccess(registration, register, credentialSetter, classicTunnel)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CredentialManager interface {
|
||||||
|
ReconnectToken() ([]byte, error)
|
||||||
|
EventDigest(connID uint8) ([]byte, error)
|
||||||
|
SetEventDigest(connID uint8, digest []byte)
|
||||||
|
ConnDigest(connID uint8) ([]byte, error)
|
||||||
|
SetConnDigest(connID uint8, digest []byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) processRegistrationSuccess(
|
||||||
|
registration *tunnelpogs.TunnelRegistration,
|
||||||
|
name rpcName,
|
||||||
|
credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig,
|
||||||
|
) error {
|
||||||
|
for _, logLine := range registration.LogLines {
|
||||||
|
h.observer.Info(logLine)
|
||||||
|
}
|
||||||
|
|
||||||
|
if registration.TunnelID != "" {
|
||||||
|
h.observer.metrics.tunnelsHA.AddTunnelID(h.connIndex, registration.TunnelID)
|
||||||
|
h.observer.Infof("Each HA connection's tunnel IDs: %v", h.observer.metrics.tunnelsHA.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
|
||||||
|
if classicTunnel.IsTrialZone() {
|
||||||
|
err := h.observer.logTrialHostname(registration)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
credentialManager.SetConnDigest(h.connIndex, registration.ConnDigest)
|
||||||
|
h.observer.metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
||||||
|
|
||||||
|
h.observer.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
||||||
|
h.observer.metrics.regSuccess.WithLabelValues(string(name)).Inc()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, name rpcName) error {
|
||||||
|
if err.Error() == DuplicateConnectionError {
|
||||||
|
h.observer.metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
|
||||||
|
return errDuplicationConnection
|
||||||
|
}
|
||||||
|
h.observer.metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
|
||||||
|
return serverRegisterTunnelError{
|
||||||
|
cause: err,
|
||||||
|
permanent: err.IsPermanent(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
|
||||||
|
token, err := credentialManager.ReconnectToken()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
eventDigest, err := credentialManager.EventDigest(h.connIndex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
connDigest, err := credentialManager.ConnDigest(h.connIndex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
h.observer.Debug("initiating RPC stream to reconnect")
|
||||||
|
stream, err := h.newRPCStream(ctx, register)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rpcClient := NewTunnelServerClient(ctx, stream, h.observer)
|
||||||
|
defer rpcClient.Close()
|
||||||
|
|
||||||
|
h.logServerInfo(ctx, rpcClient)
|
||||||
|
registration := rpcClient.client.ReconnectTunnel(
|
||||||
|
ctx,
|
||||||
|
token,
|
||||||
|
eventDigest,
|
||||||
|
connDigest,
|
||||||
|
classicTunnel.Hostname,
|
||||||
|
registrationOptions,
|
||||||
|
)
|
||||||
|
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||||
|
// ReconnectTunnel RPC failure
|
||||||
|
return h.processRegisterTunnelError(registrationErr, reconnect)
|
||||||
|
}
|
||||||
|
return h.processRegistrationSuccess(registration, reconnect, credentialManager, classicTunnel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelServerClient) error {
|
||||||
|
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||||
|
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.client.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
serverInfoMessage, err := serverInfoPromise.Result().Struct()
|
||||||
|
if err != nil {
|
||||||
|
h.observer.Errorf("Failed to retrieve server information: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
||||||
|
if err != nil {
|
||||||
|
h.observer.Errorf("Failed to retrieve server information: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.observer.logServerInfo(h.connIndex, serverInfo.LocationName, fmt.Sprintf("Connnection %d connected to %s", h.connIndex, serverInfo.LocationName))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
|
||||||
|
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stream, err := h.newRPCStream(unregisterCtx, register)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNamedTunnel {
|
||||||
|
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer)
|
||||||
|
defer rpcClient.close()
|
||||||
|
|
||||||
|
rpcClient.client.UnregisterConnection(unregisterCtx)
|
||||||
|
} else {
|
||||||
|
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer)
|
||||||
|
defer rpcClient.Close()
|
||||||
|
|
||||||
|
// gracePeriod is encoded in int64 using capnproto
|
||||||
|
rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tunnelsForHA maps this cloudflared instance's HA connections to the tunnel IDs they serve.
|
||||||
|
type tunnelsForHA struct {
|
||||||
|
sync.Mutex
|
||||||
|
metrics *prometheus.GaugeVec
|
||||||
|
entries map[uint8]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTunnelsForHA initializes the Prometheus metrics etc for a tunnelsForHA.
|
||||||
|
func NewTunnelsForHA() tunnelsForHA {
|
||||||
|
metrics := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "tunnel_ids",
|
||||||
|
Help: "The ID of all tunnels (and their corresponding HA connection ID) running in this instance of cloudflared.",
|
||||||
|
},
|
||||||
|
[]string{"tunnel_id", "ha_conn_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(metrics)
|
||||||
|
|
||||||
|
return tunnelsForHA{
|
||||||
|
metrics: metrics,
|
||||||
|
entries: make(map[uint8]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track a new tunnel ID, removing the disconnected tunnel (if any) and update metrics.
|
||||||
|
func (t *tunnelsForHA) AddTunnelID(haConn uint8, tunnelID string) {
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
haStr := fmt.Sprintf("%v", haConn)
|
||||||
|
if oldTunnelID, ok := t.entries[haConn]; ok {
|
||||||
|
t.metrics.WithLabelValues(oldTunnelID, haStr).Dec()
|
||||||
|
}
|
||||||
|
t.entries[haConn] = tunnelID
|
||||||
|
t.metrics.WithLabelValues(tunnelID, haStr).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tunnelsForHA) String() string {
|
||||||
|
t.Lock()
|
||||||
|
defer t.Unlock()
|
||||||
|
return fmt.Sprintf("%v", t.entries)
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package connection
|
package edgediscovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
|
@ -1,4 +1,4 @@
|
||||||
package connection
|
package edgediscovery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
|
@ -7,6 +7,19 @@ import (
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ActiveStreams = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: "cloudflared",
|
||||||
|
Subsystem: "tunnel",
|
||||||
|
Name: "active_streams",
|
||||||
|
Help: "Number of active streams created by all muxers.",
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
prometheus.MustRegister(ActiveStreams)
|
||||||
|
}
|
||||||
|
|
||||||
// activeStreamMap is used to moderate access to active streams between the read and write
|
// activeStreamMap is used to moderate access to active streams between the read and write
|
||||||
// threads, and deny access to new peer streams while shutting down.
|
// threads, and deny access to new peer streams while shutting down.
|
||||||
type activeStreamMap struct {
|
type activeStreamMap struct {
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
func TestShutdown(t *testing.T) {
|
func TestShutdown(t *testing.T) {
|
||||||
const numStreams = 1000
|
const numStreams = 1000
|
||||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
m := newActiveStreamMap(true, ActiveStreams)
|
||||||
|
|
||||||
// Add all the streams
|
// Add all the streams
|
||||||
{
|
{
|
||||||
|
@ -62,7 +62,7 @@ func TestShutdown(t *testing.T) {
|
||||||
|
|
||||||
func TestEmptyBeforeShutdown(t *testing.T) {
|
func TestEmptyBeforeShutdown(t *testing.T) {
|
||||||
const numStreams = 1000
|
const numStreams = 1000
|
||||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
m := newActiveStreamMap(true, ActiveStreams)
|
||||||
|
|
||||||
// Add all the streams
|
// Add all the streams
|
||||||
{
|
{
|
||||||
|
@ -138,7 +138,7 @@ func (_ *noopReadyList) Signal(streamID uint32) {}
|
||||||
|
|
||||||
func TestAbort(t *testing.T) {
|
func TestAbort(t *testing.T) {
|
||||||
const numStreams = 1000
|
const numStreams = 1000
|
||||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
m := newActiveStreamMap(true, ActiveStreams)
|
||||||
|
|
||||||
var openedStreams sync.Map
|
var openedStreams sync.Map
|
||||||
|
|
||||||
|
|
|
@ -113,11 +113,11 @@ func (p *DefaultMuxerPair) Handshake(testName string) error {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
errGroup, _ := errgroup.WithContext(ctx)
|
errGroup, _ := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() (err error) {
|
errGroup.Go(func() (err error) {
|
||||||
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, NewActiveStreamsMetrics(testName, "edge"))
|
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, ActiveStreams)
|
||||||
return errors.Wrap(err, "edge handshake failure")
|
return errors.Wrap(err, "edge handshake failure")
|
||||||
})
|
})
|
||||||
errGroup.Go(func() (err error) {
|
errGroup.Go(func() (err error) {
|
||||||
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, NewActiveStreamsMetrics(testName, "origin"))
|
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, ActiveStreams)
|
||||||
return errors.Wrap(err, "origin handshake failure")
|
return errors.Wrap(err, "origin handshake failure")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/golang-collections/collections/queue"
|
"github.com/golang-collections/collections/queue"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// data points used to compute average receive window and send window size
|
// data points used to compute average receive window and send window size
|
||||||
|
@ -295,14 +294,3 @@ func (r *rate) get() (curr, min, max uint64) {
|
||||||
defer r.lock.RUnlock()
|
defer r.lock.RUnlock()
|
||||||
return r.curr, r.min, r.max
|
return r.curr, r.min, r.max
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewActiveStreamsMetrics(namespace, subsystem string) prometheus.Gauge {
|
|
||||||
activeStreams := prometheus.NewGauge(prometheus.GaugeOpts{
|
|
||||||
Namespace: namespace,
|
|
||||||
Subsystem: subsystem,
|
|
||||||
Name: "active_streams",
|
|
||||||
Help: "Number of active streams created by all muxers.",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(activeStreams)
|
|
||||||
return activeStreams
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
package origin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called
|
|
||||||
type persistentConn struct {
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pc *persistentConn) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,540 +1,63 @@
|
||||||
package origin
|
package origin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem
|
||||||
metricsNamespace = "cloudflared"
|
// (tunnel) as subsystem to keep them consistent with the previous qualifier.
|
||||||
tunnelSubsystem = "tunnel"
|
|
||||||
muxerSubsystem = "muxer"
|
|
||||||
)
|
|
||||||
|
|
||||||
type muxerMetrics struct {
|
var (
|
||||||
rtt *prometheus.GaugeVec
|
totalRequests = prometheus.NewCounter(
|
||||||
rttMin *prometheus.GaugeVec
|
|
||||||
rttMax *prometheus.GaugeVec
|
|
||||||
receiveWindowAve *prometheus.GaugeVec
|
|
||||||
sendWindowAve *prometheus.GaugeVec
|
|
||||||
receiveWindowMin *prometheus.GaugeVec
|
|
||||||
receiveWindowMax *prometheus.GaugeVec
|
|
||||||
sendWindowMin *prometheus.GaugeVec
|
|
||||||
sendWindowMax *prometheus.GaugeVec
|
|
||||||
inBoundRateCurr *prometheus.GaugeVec
|
|
||||||
inBoundRateMin *prometheus.GaugeVec
|
|
||||||
inBoundRateMax *prometheus.GaugeVec
|
|
||||||
outBoundRateCurr *prometheus.GaugeVec
|
|
||||||
outBoundRateMin *prometheus.GaugeVec
|
|
||||||
outBoundRateMax *prometheus.GaugeVec
|
|
||||||
compBytesBefore *prometheus.GaugeVec
|
|
||||||
compBytesAfter *prometheus.GaugeVec
|
|
||||||
compRateAve *prometheus.GaugeVec
|
|
||||||
}
|
|
||||||
|
|
||||||
type TunnelMetrics struct {
|
|
||||||
haConnections prometheus.Gauge
|
|
||||||
activeStreams prometheus.Gauge
|
|
||||||
totalRequests prometheus.Counter
|
|
||||||
requestsPerTunnel *prometheus.CounterVec
|
|
||||||
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
|
||||||
concurrentRequestsLock sync.Mutex
|
|
||||||
concurrentRequestsPerTunnel *prometheus.GaugeVec
|
|
||||||
// concurrentRequests records count of concurrent requests for each tunnel
|
|
||||||
concurrentRequests map[string]uint64
|
|
||||||
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
|
||||||
// concurrentRequests records max count of concurrent requests for each tunnel
|
|
||||||
maxConcurrentRequests map[string]uint64
|
|
||||||
timerRetries prometheus.Gauge
|
|
||||||
responseByCode *prometheus.CounterVec
|
|
||||||
responseCodePerTunnel *prometheus.CounterVec
|
|
||||||
serverLocations *prometheus.GaugeVec
|
|
||||||
// locationLock is a mutex for oldServerLocations
|
|
||||||
locationLock sync.Mutex
|
|
||||||
// oldServerLocations stores the last server the tunnel was connected to
|
|
||||||
oldServerLocations map[string]string
|
|
||||||
|
|
||||||
regSuccess *prometheus.CounterVec
|
|
||||||
regFail *prometheus.CounterVec
|
|
||||||
rpcFail *prometheus.CounterVec
|
|
||||||
|
|
||||||
muxerMetrics *muxerMetrics
|
|
||||||
tunnelsHA tunnelsForHA
|
|
||||||
userHostnamesCounts *prometheus.CounterVec
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMuxerMetrics() *muxerMetrics {
|
|
||||||
rtt := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "rtt",
|
|
||||||
Help: "Round-trip time in millisecond",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(rtt)
|
|
||||||
|
|
||||||
rttMin := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "rtt_min",
|
|
||||||
Help: "Shortest round-trip time in millisecond",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(rttMin)
|
|
||||||
|
|
||||||
rttMax := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "rtt_max",
|
|
||||||
Help: "Longest round-trip time in millisecond",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(rttMax)
|
|
||||||
|
|
||||||
receiveWindowAve := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "receive_window_ave",
|
|
||||||
Help: "Average receive window size in bytes",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(receiveWindowAve)
|
|
||||||
|
|
||||||
sendWindowAve := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "send_window_ave",
|
|
||||||
Help: "Average send window size in bytes",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(sendWindowAve)
|
|
||||||
|
|
||||||
receiveWindowMin := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "receive_window_min",
|
|
||||||
Help: "Smallest receive window size in bytes",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(receiveWindowMin)
|
|
||||||
|
|
||||||
receiveWindowMax := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "receive_window_max",
|
|
||||||
Help: "Largest receive window size in bytes",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(receiveWindowMax)
|
|
||||||
|
|
||||||
sendWindowMin := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "send_window_min",
|
|
||||||
Help: "Smallest send window size in bytes",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(sendWindowMin)
|
|
||||||
|
|
||||||
sendWindowMax := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "send_window_max",
|
|
||||||
Help: "Largest send window size in bytes",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(sendWindowMax)
|
|
||||||
|
|
||||||
inBoundRateCurr := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "inbound_bytes_per_sec_curr",
|
|
||||||
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(inBoundRateCurr)
|
|
||||||
|
|
||||||
inBoundRateMin := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "inbound_bytes_per_sec_min",
|
|
||||||
Help: "Minimum non-zero inbounding bytes per second",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(inBoundRateMin)
|
|
||||||
|
|
||||||
inBoundRateMax := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "inbound_bytes_per_sec_max",
|
|
||||||
Help: "Maximum inbounding bytes per second",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(inBoundRateMax)
|
|
||||||
|
|
||||||
outBoundRateCurr := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "outbound_bytes_per_sec_curr",
|
|
||||||
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(outBoundRateCurr)
|
|
||||||
|
|
||||||
outBoundRateMin := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "outbound_bytes_per_sec_min",
|
|
||||||
Help: "Minimum non-zero outbounding bytes per second",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(outBoundRateMin)
|
|
||||||
|
|
||||||
outBoundRateMax := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "outbound_bytes_per_sec_max",
|
|
||||||
Help: "Maximum outbounding bytes per second",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(outBoundRateMax)
|
|
||||||
|
|
||||||
compBytesBefore := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "comp_bytes_before",
|
|
||||||
Help: "Bytes sent via cross-stream compression, pre compression",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(compBytesBefore)
|
|
||||||
|
|
||||||
compBytesAfter := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "comp_bytes_after",
|
|
||||||
Help: "Bytes sent via cross-stream compression, post compression",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(compBytesAfter)
|
|
||||||
|
|
||||||
compRateAve := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: muxerSubsystem,
|
|
||||||
Name: "comp_rate_ave",
|
|
||||||
Help: "Average outbound cross-stream compression ratio",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(compRateAve)
|
|
||||||
|
|
||||||
return &muxerMetrics{
|
|
||||||
rtt: rtt,
|
|
||||||
rttMin: rttMin,
|
|
||||||
rttMax: rttMax,
|
|
||||||
receiveWindowAve: receiveWindowAve,
|
|
||||||
sendWindowAve: sendWindowAve,
|
|
||||||
receiveWindowMin: receiveWindowMin,
|
|
||||||
receiveWindowMax: receiveWindowMax,
|
|
||||||
sendWindowMin: sendWindowMin,
|
|
||||||
sendWindowMax: sendWindowMax,
|
|
||||||
inBoundRateCurr: inBoundRateCurr,
|
|
||||||
inBoundRateMin: inBoundRateMin,
|
|
||||||
inBoundRateMax: inBoundRateMax,
|
|
||||||
outBoundRateCurr: outBoundRateCurr,
|
|
||||||
outBoundRateMin: outBoundRateMin,
|
|
||||||
outBoundRateMax: outBoundRateMax,
|
|
||||||
compBytesBefore: compBytesBefore,
|
|
||||||
compBytesAfter: compBytesAfter,
|
|
||||||
compRateAve: compRateAve,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
|
|
||||||
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
|
|
||||||
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
|
|
||||||
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
|
|
||||||
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
|
|
||||||
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
|
|
||||||
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
|
|
||||||
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
|
|
||||||
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
|
|
||||||
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
|
|
||||||
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
|
|
||||||
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
|
|
||||||
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
|
|
||||||
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
|
|
||||||
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
|
|
||||||
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
|
|
||||||
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
|
|
||||||
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
|
|
||||||
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertRTTMilliSec(t time.Duration) float64 {
|
|
||||||
return float64(t / time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Metrics that can be collected without asking the edge
|
|
||||||
func NewTunnelMetrics() *TunnelMetrics {
|
|
||||||
haConnections := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "ha_connections",
|
|
||||||
Help: "Number of active ha connections",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(haConnections)
|
|
||||||
|
|
||||||
activeStreams := h2mux.NewActiveStreamsMetrics(metricsNamespace, tunnelSubsystem)
|
|
||||||
|
|
||||||
totalRequests := prometheus.NewCounter(
|
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Namespace: metricsNamespace,
|
Namespace: connection.MetricsNamespace,
|
||||||
Subsystem: tunnelSubsystem,
|
Subsystem: connection.TunnelSubsystem,
|
||||||
Name: "total_requests",
|
Name: "total_requests",
|
||||||
Help: "Amount of requests proxied through all the tunnels",
|
Help: "Amount of requests proxied through all the tunnels",
|
||||||
})
|
|
||||||
prometheus.MustRegister(totalRequests)
|
|
||||||
|
|
||||||
requestsPerTunnel := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "requests_per_tunnel",
|
|
||||||
Help: "Amount of requests proxied through each tunnel",
|
|
||||||
},
|
},
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(requestsPerTunnel)
|
concurrentRequests = prometheus.NewGauge(
|
||||||
|
|
||||||
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Namespace: metricsNamespace,
|
Namespace: connection.MetricsNamespace,
|
||||||
Subsystem: tunnelSubsystem,
|
Subsystem: connection.TunnelSubsystem,
|
||||||
Name: "concurrent_requests_per_tunnel",
|
Name: "concurrent_requests_per_tunnel",
|
||||||
Help: "Concurrent requests proxied through each tunnel",
|
Help: "Concurrent requests proxied through each tunnel",
|
||||||
},
|
},
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(concurrentRequestsPerTunnel)
|
responseByCode = prometheus.NewCounterVec(
|
||||||
|
|
||||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "max_concurrent_requests_per_tunnel",
|
|
||||||
Help: "Largest number of concurrent requests proxied through each tunnel so far",
|
|
||||||
},
|
|
||||||
[]string{"connection_id"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
|
||||||
|
|
||||||
timerRetries := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "timer_retries",
|
|
||||||
Help: "Unacknowledged heart beats count",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(timerRetries)
|
|
||||||
|
|
||||||
responseByCode := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Namespace: metricsNamespace,
|
Namespace: connection.MetricsNamespace,
|
||||||
Subsystem: tunnelSubsystem,
|
Subsystem: connection.TunnelSubsystem,
|
||||||
Name: "response_by_code",
|
Name: "response_by_code",
|
||||||
Help: "Count of responses by HTTP status code",
|
Help: "Count of responses by HTTP status code",
|
||||||
},
|
},
|
||||||
[]string{"status_code"},
|
[]string{"status_code"},
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(responseByCode)
|
haConnections = prometheus.NewGauge(
|
||||||
|
|
||||||
responseCodePerTunnel := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "response_code_per_tunnel",
|
|
||||||
Help: "Count of responses by HTTP status code fore each tunnel",
|
|
||||||
},
|
|
||||||
[]string{"connection_id", "status_code"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(responseCodePerTunnel)
|
|
||||||
|
|
||||||
serverLocations := prometheus.NewGaugeVec(
|
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Namespace: metricsNamespace,
|
Namespace: connection.MetricsNamespace,
|
||||||
Subsystem: tunnelSubsystem,
|
Subsystem: connection.TunnelSubsystem,
|
||||||
Name: "server_locations",
|
Name: "ha_connections",
|
||||||
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
|
Help: "Number of active ha connections",
|
||||||
},
|
},
|
||||||
[]string{"connection_id", "location"},
|
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(serverLocations)
|
|
||||||
|
|
||||||
rpcFail := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "tunnel_rpc_fail",
|
|
||||||
Help: "Count of RPC connection errors by type",
|
|
||||||
},
|
|
||||||
[]string{"error", "rpcName"},
|
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(rpcFail)
|
|
||||||
|
|
||||||
registerFail := prometheus.NewCounterVec(
|
func init() {
|
||||||
prometheus.CounterOpts{
|
prometheus.MustRegister(
|
||||||
Namespace: metricsNamespace,
|
totalRequests,
|
||||||
Subsystem: tunnelSubsystem,
|
concurrentRequests,
|
||||||
Name: "tunnel_register_fail",
|
responseByCode,
|
||||||
Help: "Count of tunnel registration errors by type",
|
haConnections,
|
||||||
},
|
|
||||||
[]string{"error", "rpcName"},
|
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(registerFail)
|
|
||||||
|
|
||||||
userHostnamesCounts := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "user_hostnames_counts",
|
|
||||||
Help: "Which user hostnames cloudflared is serving",
|
|
||||||
},
|
|
||||||
[]string{"userHostname"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(userHostnamesCounts)
|
|
||||||
|
|
||||||
registerSuccess := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "tunnel_register_success",
|
|
||||||
Help: "Count of successful tunnel registrations",
|
|
||||||
},
|
|
||||||
[]string{"rpcName"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(registerSuccess)
|
|
||||||
|
|
||||||
return &TunnelMetrics{
|
|
||||||
haConnections: haConnections,
|
|
||||||
activeStreams: activeStreams,
|
|
||||||
totalRequests: totalRequests,
|
|
||||||
requestsPerTunnel: requestsPerTunnel,
|
|
||||||
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
|
||||||
concurrentRequests: make(map[string]uint64),
|
|
||||||
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
|
||||||
maxConcurrentRequests: make(map[string]uint64),
|
|
||||||
timerRetries: timerRetries,
|
|
||||||
responseByCode: responseByCode,
|
|
||||||
responseCodePerTunnel: responseCodePerTunnel,
|
|
||||||
serverLocations: serverLocations,
|
|
||||||
oldServerLocations: make(map[string]string),
|
|
||||||
muxerMetrics: newMuxerMetrics(),
|
|
||||||
tunnelsHA: NewTunnelsForHA(),
|
|
||||||
regSuccess: registerSuccess,
|
|
||||||
regFail: registerFail,
|
|
||||||
rpcFail: rpcFail,
|
|
||||||
userHostnamesCounts: userHostnamesCounts,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) incrementHaConnections() {
|
func incrementRequests() {
|
||||||
t.haConnections.Inc()
|
totalRequests.Inc()
|
||||||
|
concurrentRequests.Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) decrementHaConnections() {
|
func decrementConcurrentRequests() {
|
||||||
t.haConnections.Dec()
|
concurrentRequests.Dec()
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
|
|
||||||
t.muxerMetrics.update(connectionID, metrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
|
||||||
t.concurrentRequestsLock.Lock()
|
|
||||||
var concurrentRequests uint64
|
|
||||||
var ok bool
|
|
||||||
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
|
|
||||||
t.concurrentRequests[connectionID]++
|
|
||||||
concurrentRequests++
|
|
||||||
} else {
|
|
||||||
t.concurrentRequests[connectionID] = 1
|
|
||||||
concurrentRequests = 1
|
|
||||||
}
|
|
||||||
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
|
|
||||||
t.maxConcurrentRequests[connectionID] = concurrentRequests
|
|
||||||
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
|
|
||||||
}
|
|
||||||
t.concurrentRequestsLock.Unlock()
|
|
||||||
|
|
||||||
t.totalRequests.Inc()
|
|
||||||
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
|
|
||||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
|
||||||
t.concurrentRequestsLock.Lock()
|
|
||||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
|
||||||
t.concurrentRequests[connectionID]--
|
|
||||||
}
|
|
||||||
t.concurrentRequestsLock.Unlock()
|
|
||||||
|
|
||||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
|
|
||||||
t.responseByCode.WithLabelValues(code).Inc()
|
|
||||||
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
|
|
||||||
t.locationLock.Lock()
|
|
||||||
defer t.locationLock.Unlock()
|
|
||||||
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
|
||||||
return
|
|
||||||
} else if ok {
|
|
||||||
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
|
|
||||||
}
|
|
||||||
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
|
|
||||||
t.oldServerLocations[connectionID] = loc
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,121 +0,0 @@
|
||||||
package origin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
// can only be called once
|
|
||||||
var m = NewTunnelMetrics()
|
|
||||||
|
|
||||||
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
|
|
||||||
routines := 20
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(routines)
|
|
||||||
for i := 0; i < routines; i++ {
|
|
||||||
go func() {
|
|
||||||
m.incrementRequests("0")
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
assert.Len(t, m.concurrentRequests, 1)
|
|
||||||
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
|
|
||||||
assert.Len(t, m.maxConcurrentRequests, 1)
|
|
||||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
|
||||||
|
|
||||||
wg.Add(routines / 2)
|
|
||||||
for i := 0; i < routines/2; i++ {
|
|
||||||
go func() {
|
|
||||||
m.decrementConcurrentRequests("0")
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
|
|
||||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
|
|
||||||
m.concurrentRequests = make(map[string]uint64)
|
|
||||||
m.maxConcurrentRequests = make(map[string]uint64)
|
|
||||||
tunnels := 20
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(tunnels)
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
go func(i int) {
|
|
||||||
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
|
|
||||||
for j := 0; j < i+1; j++ {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
m.incrementRequests(id)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
assert.Len(t, m.concurrentRequests, tunnels)
|
|
||||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
|
|
||||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(tunnels)
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
go func(i int) {
|
|
||||||
for j := 0; j < i+1; j++ {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
m.decrementConcurrentRequests(id)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
assert.Len(t, m.concurrentRequests, tunnels)
|
|
||||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
assert.Equal(t, uint64(0), m.concurrentRequests[id])
|
|
||||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegisterServerLocation(t *testing.T) {
|
|
||||||
tunnels := 20
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(tunnels)
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
go func(i int) {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
m.registerServerLocation(id, "LHR")
|
|
||||||
wg.Done()
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
assert.Equal(t, "LHR", m.oldServerLocations[id])
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(tunnels)
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
go func(i int) {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
m.registerServerLocation(id, "AUS")
|
|
||||||
wg.Done()
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
for i := 0; i < tunnels; i++ {
|
|
||||||
id := strconv.Itoa(i)
|
|
||||||
assert.Equal(t, "AUS", m.oldServerLocations[id])
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/buffer"
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||||
|
)
|
||||||
|
|
||||||
|
type client struct {
|
||||||
|
config *ProxyConfig
|
||||||
|
logger logger.Service
|
||||||
|
bufferPool *buffer.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(config *ProxyConfig, logger logger.Service) connection.OriginClient {
|
||||||
|
return &client{
|
||||||
|
config: config,
|
||||||
|
logger: logger,
|
||||||
|
bufferPool: buffer.NewPool(512 * 1024),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxyConfig struct {
|
||||||
|
Client http.RoundTripper
|
||||||
|
URL *url.URL
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
HostHeader string
|
||||||
|
NoChunkedEncoding bool
|
||||||
|
Tags []tunnelpogs.Tag
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
|
||||||
|
incrementRequests()
|
||||||
|
defer decrementConcurrentRequests()
|
||||||
|
|
||||||
|
cfRay := findCfRayHeader(req)
|
||||||
|
lbProbe := isLBProbeRequest(req)
|
||||||
|
|
||||||
|
c.appendTagHeaders(req)
|
||||||
|
c.logRequest(req, cfRay, lbProbe)
|
||||||
|
var (
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if isWebsocket {
|
||||||
|
resp, err = c.proxyWebsocket(w, req)
|
||||||
|
} else {
|
||||||
|
resp, err = c.proxyHTTP(w, req)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Errorf("HTTP request error: %s", err)
|
||||||
|
responseByCode.WithLabelValues("502").Inc()
|
||||||
|
w.WriteErrorResponse(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.logResponseOk(resp, cfRay, lbProbe)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
||||||
|
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||||
|
if c.config.NoChunkedEncoding {
|
||||||
|
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||||
|
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||||
|
if err == nil {
|
||||||
|
req.ContentLength = int64(cLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request origin to keep connection alive to improve performance
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
c.setHostHeader(req)
|
||||||
|
|
||||||
|
resp, err := c.config.Client.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
err = w.WriteRespHeaders(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
|
}
|
||||||
|
if isEventStream(resp) {
|
||||||
|
//h.observer.Debug("Detected Server-Side Events from Origin")
|
||||||
|
c.writeEventStream(w, resp.Body)
|
||||||
|
} else {
|
||||||
|
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||||
|
// compression generates dictionary on first write
|
||||||
|
buf := c.bufferPool.Get()
|
||||||
|
defer c.bufferPool.Put(buf)
|
||||||
|
io.CopyBuffer(w, resp.Body, buf)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
||||||
|
c.setHostHeader(req)
|
||||||
|
|
||||||
|
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
err = w.WriteRespHeaders(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
|
}
|
||||||
|
// Copy to/from stream to the undelying connection. Use the underlying
|
||||||
|
// connection because cloudflared doesn't operate on the message themselves
|
||||||
|
websocket.Stream(conn.UnderlyingConn(), w)
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||||
|
reader := bufio.NewReader(respBody)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
w.Write(line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) setHostHeader(req *http.Request) {
|
||||||
|
if c.config.HostHeader != "" {
|
||||||
|
req.Header.Set("Host", c.config.HostHeader)
|
||||||
|
req.Host = c.config.HostHeader
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) appendTagHeaders(r *http.Request) {
|
||||||
|
for _, tag := range c.config.Tags {
|
||||||
|
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) {
|
||||||
|
if cfRay != "" {
|
||||||
|
c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||||
|
} else if lbProbe {
|
||||||
|
c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||||
|
} else {
|
||||||
|
c.logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
|
||||||
|
}
|
||||||
|
c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
|
||||||
|
|
||||||
|
if contentLen := r.ContentLength; contentLen == -1 {
|
||||||
|
c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
|
||||||
|
} else {
|
||||||
|
c.logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
|
||||||
|
responseByCode.WithLabelValues("200").Inc()
|
||||||
|
if cfRay != "" {
|
||||||
|
c.logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
|
||||||
|
} else if lbProbe {
|
||||||
|
c.logger.Debugf("Response to Load Balancer health check %s", r.Status)
|
||||||
|
} else {
|
||||||
|
c.logger.Infof("%s", r.Status)
|
||||||
|
}
|
||||||
|
c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
|
||||||
|
|
||||||
|
if contentLen := r.ContentLength; contentLen == -1 {
|
||||||
|
c.logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
|
||||||
|
} else {
|
||||||
|
c.logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findCfRayHeader(req *http.Request) string {
|
||||||
|
return req.Header.Get("Cf-Ray")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLBProbeRequest(req *http.Request) bool {
|
||||||
|
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uint8ToString(input uint8) string {
|
||||||
|
return strconv.FormatUint(uint64(input), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isEventStream(response *http.Response) bool {
|
||||||
|
if response.Header.Get("content-type") == "text/event-stream" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -7,11 +7,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -138,52 +134,3 @@ func (cm *reconnectCredentialManager) RefreshAuth(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReconnectTunnel(
|
|
||||||
ctx context.Context,
|
|
||||||
muxer *h2mux.Muxer,
|
|
||||||
config *TunnelConfig,
|
|
||||||
logger logger.Service,
|
|
||||||
connectionID uint8,
|
|
||||||
originLocalAddr string,
|
|
||||||
uuid uuid.UUID,
|
|
||||||
credentialManager *reconnectCredentialManager,
|
|
||||||
) error {
|
|
||||||
token, err := credentialManager.ReconnectToken()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
eventDigest, err := credentialManager.EventDigest(connectionID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
connDigest, err := credentialManager.ConnDigest(connectionID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
config.TransportLogger.Debug("initiating RPC stream to reconnect")
|
|
||||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rpcClient.Close()
|
|
||||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
|
||||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
|
|
||||||
registration := rpcClient.ReconnectTunnel(
|
|
||||||
ctx,
|
|
||||||
token,
|
|
||||||
eventDigest,
|
|
||||||
connDigest,
|
|
||||||
config.Hostname,
|
|
||||||
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
|
|
||||||
)
|
|
||||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
|
||||||
// ReconnectTunnel RPC failure
|
|
||||||
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
|
|
||||||
}
|
|
||||||
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
|
|
||||||
}
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/buffer"
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
@ -56,8 +55,7 @@ type Supervisor struct {
|
||||||
logger logger.Service
|
logger logger.Service
|
||||||
|
|
||||||
reconnectCredentialManager *reconnectCredentialManager
|
reconnectCredentialManager *reconnectCredentialManager
|
||||||
|
useReconnectToken bool
|
||||||
bufferPool *buffer.Pool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type resolveResult struct {
|
type resolveResult struct {
|
||||||
|
@ -76,28 +74,33 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if len(config.EdgeAddrs) > 0 {
|
if len(config.EdgeAddrs) > 0 {
|
||||||
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
|
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
|
||||||
} else {
|
} else {
|
||||||
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
|
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
useReconnectToken := false
|
||||||
|
if config.ClassicTunnel != nil {
|
||||||
|
useReconnectToken = config.ClassicTunnel.UseReconnectToken
|
||||||
|
}
|
||||||
|
|
||||||
return &Supervisor{
|
return &Supervisor{
|
||||||
cloudflaredUUID: cloudflaredUUID,
|
cloudflaredUUID: cloudflaredUUID,
|
||||||
config: config,
|
config: config,
|
||||||
edgeIPs: edgeIPs,
|
edgeIPs: edgeIPs,
|
||||||
tunnelErrors: make(chan tunnelError),
|
tunnelErrors: make(chan tunnelError),
|
||||||
tunnelsConnecting: map[int]chan struct{}{},
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
logger: config.Logger,
|
logger: config.Observer,
|
||||||
reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
|
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
||||||
bufferPool: buffer.NewPool(512 * 1024),
|
useReconnectToken: useReconnectToken,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||||
logger := s.config.Logger
|
logger := s.config.Observer
|
||||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -110,7 +113,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
||||||
var refreshAuthBackoffTimer <-chan time.Time
|
var refreshAuthBackoffTimer <-chan time.Time
|
||||||
|
|
||||||
if s.config.UseReconnectToken {
|
if s.useReconnectToken {
|
||||||
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||||
refreshAuthBackoffTimer = timer
|
refreshAuthBackoffTimer = timer
|
||||||
} else {
|
} else {
|
||||||
|
@ -227,7 +230,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh)
|
||||||
// If the first tunnel disconnects, keep restarting it.
|
// If the first tunnel disconnects, keep restarting it.
|
||||||
edgeErrors := 0
|
edgeErrors := 0
|
||||||
for s.unusedIPs() {
|
for s.unusedIPs() {
|
||||||
|
@ -239,7 +242,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
// try the next address if it was a dialError(network problem) or
|
// try the next address if it was a dialError(network problem) or
|
||||||
// dupConnRegisterTunnelError
|
// dupConnRegisterTunnelError
|
||||||
case connection.DialError, dupConnRegisterTunnelError:
|
case edgediscovery.DialError, connection.DupConnRegisterTunnelError:
|
||||||
edgeErrors++
|
edgeErrors++
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
|
@ -250,7 +253,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,7 +272,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, reconnectCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||||
|
@ -301,7 +304,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -311,8 +314,8 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||||
// This callback is invoked by h2mux when the edge initiates a stream.
|
// This callback is invoked by h2mux when the edge initiates a stream.
|
||||||
return nil // noop
|
return nil // noop
|
||||||
})
|
})
|
||||||
muxerConfig := s.config.muxerConfig(handler)
|
muxerConfig := s.config.MuxerConfig.H2MuxerConfig(handler, s.logger)
|
||||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig, h2mux.ActiveStreams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -323,23 +326,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||||
<-muxer.Shutdown()
|
<-muxer.Shutdown()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate)
|
stream, err := muxer.OpenRPCStream(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.logger)
|
||||||
defer rpcClient.Close()
|
defer rpcClient.Close()
|
||||||
|
|
||||||
const arbitraryConnectionID = uint8(0)
|
const arbitraryConnectionID = uint8(0)
|
||||||
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
||||||
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||||
authResponse, err := rpcClient.Authenticate(
|
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
|
||||||
ctx,
|
|
||||||
s.config.OriginCert,
|
|
||||||
s.config.Hostname,
|
|
||||||
registrationOptions,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return authResponse.Outcome(), nil
|
|
||||||
}
|
}
|
||||||
|
|
549
origin/tunnel.go
549
origin/tunnel.go
|
@ -5,9 +5,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"runtime/debug"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -17,26 +15,22 @@ import (
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/buffer"
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
openStreamTimeout = 30 * time.Second
|
|
||||||
muxerTimeout = 5 * time.Second
|
muxerTimeout = 5 * time.Second
|
||||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
|
||||||
DuplicateConnectionError = "EDUPCONN"
|
DuplicateConnectionError = "EDUPCONN"
|
||||||
FeatureSerializedHeaders = "serialized_headers"
|
FeatureSerializedHeaders = "serialized_headers"
|
||||||
FeatureQuickReconnects = "quick_reconnects"
|
FeatureQuickReconnects = "quick_reconnects"
|
||||||
|
@ -52,49 +46,31 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelConfig struct {
|
type TunnelConfig struct {
|
||||||
|
ConnectionConfig *connection.Config
|
||||||
|
ProxyConfig *ProxyConfig
|
||||||
BuildInfo *buildinfo.BuildInfo
|
BuildInfo *buildinfo.BuildInfo
|
||||||
ClientID string
|
ClientID string
|
||||||
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
||||||
CompressionQuality uint64
|
|
||||||
EdgeAddrs []string
|
EdgeAddrs []string
|
||||||
GracePeriod time.Duration
|
|
||||||
HAConnections int
|
HAConnections int
|
||||||
HeartbeatInterval time.Duration
|
|
||||||
Hostname string
|
|
||||||
IncidentLookup IncidentLookup
|
IncidentLookup IncidentLookup
|
||||||
IsAutoupdated bool
|
IsAutoupdated bool
|
||||||
IsFreeTunnel bool
|
IsFreeTunnel bool
|
||||||
LBPool string
|
LBPool string
|
||||||
Logger logger.Service
|
Logger logger.Service
|
||||||
TransportLogger logger.Service
|
Observer *connection.Observer
|
||||||
MaxHeartbeats uint64
|
|
||||||
Metrics *TunnelMetrics
|
|
||||||
MetricsUpdateFreq time.Duration
|
|
||||||
OriginCert []byte
|
|
||||||
ReportedVersion string
|
ReportedVersion string
|
||||||
Retries uint
|
Retries uint
|
||||||
RunFromTerminal bool
|
RunFromTerminal bool
|
||||||
Tags []tunnelpogs.Tag
|
TLSConfig *tls.Config
|
||||||
TlsConfig *tls.Config
|
|
||||||
WSGI bool
|
|
||||||
|
|
||||||
// feature-flag to use new edge reconnect tokens
|
NamedTunnel *connection.NamedTunnelConfig
|
||||||
UseReconnectToken bool
|
ClassicTunnel *connection.ClassicTunnelConfig
|
||||||
|
MuxerConfig *connection.MuxerConfig
|
||||||
NamedTunnel *NamedTunnelConfig
|
TunnelEventChan chan ui.TunnelEvent
|
||||||
ReplaceExisting bool
|
|
||||||
TunnelEventChan chan<- ui.TunnelEvent
|
|
||||||
IngressRules ingress.Ingress
|
IngressRules ingress.Ingress
|
||||||
}
|
}
|
||||||
|
|
||||||
type dupConnRegisterTunnelError struct{}
|
|
||||||
|
|
||||||
var errDuplicationConnection = &dupConnRegisterTunnelError{}
|
|
||||||
|
|
||||||
func (e dupConnRegisterTunnelError) Error() string {
|
|
||||||
return "already connected to this server, trying another address"
|
|
||||||
}
|
|
||||||
|
|
||||||
type muxerShutdownError struct{}
|
type muxerShutdownError struct{}
|
||||||
|
|
||||||
func (e muxerShutdownError) Error() string {
|
func (e muxerShutdownError) Error() string {
|
||||||
|
@ -125,18 +101,6 @@ func (e clientRegisterTunnelError) Error() string {
|
||||||
return e.cause.Error()
|
return e.cause.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
|
|
||||||
return h2mux.MuxerConfig{
|
|
||||||
Timeout: muxerTimeout,
|
|
||||||
Handler: handler,
|
|
||||||
IsClient: true,
|
|
||||||
HeartbeatInterval: c.HeartbeatInterval,
|
|
||||||
MaxHeartbeats: c.MaxHeartbeats,
|
|
||||||
Logger: c.TransportLogger,
|
|
||||||
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
||||||
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
||||||
if c.HAConnections <= 1 && c.LBPool == "" {
|
if c.HAConnections <= 1 && c.LBPool == "" {
|
||||||
|
@ -148,12 +112,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
|
OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
|
||||||
ExistingTunnelPolicy: policy,
|
ExistingTunnelPolicy: policy,
|
||||||
PoolName: c.LBPool,
|
PoolName: c.LBPool,
|
||||||
Tags: c.Tags,
|
Tags: c.ProxyConfig.Tags,
|
||||||
ConnectionID: connectionID,
|
ConnectionID: connectionID,
|
||||||
OriginLocalIP: OriginLocalIP,
|
OriginLocalIP: OriginLocalIP,
|
||||||
IsAutoupdated: c.IsAutoupdated,
|
IsAutoupdated: c.IsAutoupdated,
|
||||||
RunFromTerminal: c.RunFromTerminal,
|
RunFromTerminal: c.RunFromTerminal,
|
||||||
CompressionQuality: c.CompressionQuality,
|
CompressionQuality: uint64(c.MuxerConfig.CompressionSetting),
|
||||||
UUID: uuid.String(),
|
UUID: uuid.String(),
|
||||||
Features: c.SupportedFeatures(),
|
Features: c.SupportedFeatures(),
|
||||||
}
|
}
|
||||||
|
@ -167,8 +131,8 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte
|
||||||
return &tunnelpogs.ConnectionOptions{
|
return &tunnelpogs.ConnectionOptions{
|
||||||
Client: c.NamedTunnel.Client,
|
Client: c.NamedTunnel.Client,
|
||||||
OriginLocalIP: originIP,
|
OriginLocalIP: originIP,
|
||||||
ReplaceExisting: c.ReplaceExisting,
|
ReplaceExisting: c.ConnectionConfig.ReplaceExisting,
|
||||||
CompressionQuality: uint8(c.CompressionQuality),
|
CompressionQuality: uint8(c.MuxerConfig.CompressionSetting),
|
||||||
NumPreviousAttempts: numPreviousAttempts,
|
NumPreviousAttempts: numPreviousAttempts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,35 +145,6 @@ func (c *TunnelConfig) SupportedFeatures() []string {
|
||||||
return features
|
return features
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TunnelConfig) IsTrialTunnel() bool {
|
|
||||||
return c.Hostname == ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type NamedTunnelConfig struct {
|
|
||||||
Auth pogs.TunnelAuth
|
|
||||||
ID uuid.UUID
|
|
||||||
Client pogs.ClientInfo
|
|
||||||
Protocol Protocol
|
|
||||||
}
|
|
||||||
|
|
||||||
type Protocol int64
|
|
||||||
|
|
||||||
const (
|
|
||||||
h2muxProtocol Protocol = iota
|
|
||||||
http2Protocol
|
|
||||||
)
|
|
||||||
|
|
||||||
func ParseProtocol(s string) (Protocol, bool) {
|
|
||||||
switch s {
|
|
||||||
case "h2mux":
|
|
||||||
return h2muxProtocol, true
|
|
||||||
case "http2":
|
|
||||||
return http2Protocol, true
|
|
||||||
default:
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
||||||
s, err := NewSupervisor(config, cloudflaredID)
|
s, err := NewSupervisor(config, cloudflaredID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -225,11 +160,11 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
connectionIndex uint8,
|
connectionIndex uint8,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
bufferPool *buffer.Pool,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
) error {
|
) error {
|
||||||
config.Metrics.incrementHaConnections()
|
haConnections.Inc()
|
||||||
defer config.Metrics.decrementHaConnections()
|
defer haConnections.Dec()
|
||||||
|
|
||||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||||
connectedFuse := h2mux.NewBooleanFuse()
|
connectedFuse := h2mux.NewBooleanFuse()
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -244,12 +179,10 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
ctx,
|
ctx,
|
||||||
credentialManager,
|
credentialManager,
|
||||||
config,
|
config,
|
||||||
config.Logger,
|
|
||||||
addr, connectionIndex,
|
addr, connectionIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
&backoff,
|
&backoff,
|
||||||
cloudflaredUUID,
|
cloudflaredUUID,
|
||||||
bufferPool,
|
|
||||||
reconnectCh,
|
reconnectCh,
|
||||||
)
|
)
|
||||||
if recoverable {
|
if recoverable {
|
||||||
|
@ -257,7 +190,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
if config.TunnelEventChan != nil {
|
if config.TunnelEventChan != nil {
|
||||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
|
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
|
||||||
}
|
}
|
||||||
config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration)
|
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
|
||||||
backoff.Backoff(ctx)
|
backoff.Backoff(ctx)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -270,13 +203,11 @@ func ServeTunnel(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
credentialManager *reconnectCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
logger logger.Service,
|
|
||||||
addr *net.TCPAddr,
|
addr *net.TCPAddr,
|
||||||
connectionIndex uint8,
|
connectionIndex uint8,
|
||||||
connectedFuse *h2mux.BooleanFuse,
|
fuse *h2mux.BooleanFuse,
|
||||||
backoff *BackoffHandler,
|
backoff *BackoffHandler,
|
||||||
cloudflaredUUID uuid.UUID,
|
cloudflaredUUID uuid.UUID,
|
||||||
bufferPool *buffer.Pool,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
|
@ -287,6 +218,7 @@ func ServeTunnel(
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("ServeTunnel: %v", r)
|
err = fmt.Errorf("ServeTunnel: %v", r)
|
||||||
}
|
}
|
||||||
|
err = errors.Wrapf(err, "stack trace: %s", string(debug.Stack()))
|
||||||
recoverable = true
|
recoverable = true
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -298,203 +230,107 @@ func ServeTunnel(
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
connectionTag := uint8ToString(connectionIndex)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
|
||||||
|
|
||||||
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
|
|
||||||
return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
|
||||||
handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
|
||||||
case connection.DialError:
|
|
||||||
logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
|
|
||||||
case h2mux.MuxerHandshakeError:
|
|
||||||
logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
|
|
||||||
default:
|
|
||||||
logger.Errorf("Connection %d failed: %s", connectionIndex, err)
|
|
||||||
return err, false
|
|
||||||
}
|
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
|
connectedFuse := &connectedFuse{
|
||||||
|
fuse: fuse,
|
||||||
|
backoff: backoff,
|
||||||
|
}
|
||||||
|
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == connection.HTTP2 {
|
||||||
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
||||||
|
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
|
||||||
|
}
|
||||||
|
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServeH2mux(
|
||||||
|
ctx context.Context,
|
||||||
|
credentialManager *reconnectCredentialManager,
|
||||||
|
config *TunnelConfig,
|
||||||
|
edgeConn net.Conn,
|
||||||
|
connectionIndex uint8,
|
||||||
|
connectedFuse *connectedFuse,
|
||||||
|
cloudflaredUUID uuid.UUID,
|
||||||
|
reconnectCh chan ReconnectSignal,
|
||||||
|
) (err error, recoverable bool) {
|
||||||
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
|
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
|
||||||
|
if err != nil {
|
||||||
|
return err, recoverable
|
||||||
|
}
|
||||||
|
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
errGroup.Go(func() (err error) {
|
errGroup.Go(func() (err error) {
|
||||||
defer func() {
|
|
||||||
if err == nil {
|
|
||||||
connectedFuse.Fuse(true)
|
|
||||||
backoff.SetGracePeriod()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if config.UseReconnectToken && connectedFuse.Value() {
|
|
||||||
err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// log errors and proceed to RegisterTunnel
|
|
||||||
logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err)
|
|
||||||
}
|
|
||||||
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
|
|
||||||
})
|
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
|
||||||
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-serveCtx.Done():
|
|
||||||
// UnregisterTunnel blocks until the RPC call returns
|
|
||||||
if connectedFuse.Value() {
|
|
||||||
if config.NamedTunnel != nil {
|
if config.NamedTunnel != nil {
|
||||||
_ = UnregisterConnection(ctx, handler.muxer, config)
|
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
||||||
} else {
|
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
||||||
_ = UnregisterTunnel(handler.muxer, config)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
handler.muxer.Shutdown()
|
|
||||||
return nil
|
|
||||||
case <-updateMetricsTickC:
|
|
||||||
handler.UpdateMetrics(connectionTag)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||||
|
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case reconnect := <-reconnectCh:
|
|
||||||
return &reconnect
|
|
||||||
case <-serveCtx.Done():
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
|
||||||
// All routines should stop when muxer finish serving. When muxer is shutdown
|
|
||||||
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
|
||||||
// here to notify other routines to stop
|
|
||||||
err := handler.muxer.Serve(serveCtx)
|
|
||||||
if err == nil {
|
|
||||||
return muxerShutdownError{}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
err = errGroup.Wait()
|
err = errGroup.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err := err.(type) {
|
switch err := err.(type) {
|
||||||
case *dupConnRegisterTunnelError:
|
case *connection.DupConnRegisterTunnelError:
|
||||||
// don't retry this connection anymore, let supervisor pick new a address
|
// don't retry this connection anymore, let supervisor pick new a address
|
||||||
return err, false
|
return err, false
|
||||||
case *serverRegisterTunnelError:
|
case *serverRegisterTunnelError:
|
||||||
logger.Errorf("Register tunnel error from server side: %s", err.cause)
|
config.Logger.Errorf("Register tunnel error from server side: %s", err.cause)
|
||||||
// Don't send registration error return from server to Sentry. They are
|
// Don't send registration error return from server to Sentry. They are
|
||||||
// logged on server side
|
// logged on server side
|
||||||
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
||||||
logger.Error(activeIncidentsMsg(incidents))
|
config.Logger.Error(activeIncidentsMsg(incidents))
|
||||||
}
|
}
|
||||||
return err.cause, !err.permanent
|
return err.cause, !err.permanent
|
||||||
case *clientRegisterTunnelError:
|
case *clientRegisterTunnelError:
|
||||||
logger.Errorf("Register tunnel error on client side: %s", err.cause)
|
config.Logger.Errorf("Register tunnel error on client side: %s", err.cause)
|
||||||
return err, true
|
return err, true
|
||||||
case *muxerShutdownError:
|
case *muxerShutdownError:
|
||||||
logger.Info("Muxer shutdown")
|
config.Logger.Info("Muxer shutdown")
|
||||||
return err, true
|
return err, true
|
||||||
case *ReconnectSignal:
|
case *ReconnectSignal:
|
||||||
logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
|
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
|
||||||
err.DelayBeforeReconnect()
|
err.DelayBeforeReconnect()
|
||||||
return err, true
|
return err, true
|
||||||
default:
|
default:
|
||||||
if err == context.Canceled {
|
if err == context.Canceled {
|
||||||
logger.Debugf("Serve tunnel error: %s", err)
|
config.Logger.Debugf("Serve tunnel error: %s", err)
|
||||||
return err, false
|
return err, false
|
||||||
}
|
}
|
||||||
logger.Errorf("Serve tunnel error: %s", err)
|
config.Logger.Errorf("Serve tunnel error: %s", err)
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterConnectionWithH2Mux(
|
func ServeHTTP2(
|
||||||
ctx context.Context,
|
|
||||||
muxer *h2mux.Muxer,
|
|
||||||
config *TunnelConfig,
|
|
||||||
connectionIndex uint8,
|
|
||||||
originLocalAddr string,
|
|
||||||
numPreviousAttempts uint8,
|
|
||||||
) error {
|
|
||||||
const registerConnection = "registerConnection"
|
|
||||||
|
|
||||||
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
|
|
||||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rpcClient.Close()
|
|
||||||
|
|
||||||
conn, err := rpcClient.RegisterConnection(
|
|
||||||
ctx,
|
|
||||||
config.NamedTunnel.Auth,
|
|
||||||
config.NamedTunnel.ID,
|
|
||||||
connectionIndex,
|
|
||||||
config.ConnectionOptions(originLocalAddr, numPreviousAttempts),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if err.Error() == DuplicateConnectionError {
|
|
||||||
config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
|
|
||||||
return errDuplicationConnection
|
|
||||||
}
|
|
||||||
config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
|
|
||||||
return serverRegistrationErrorFromRPC(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
|
|
||||||
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
|
|
||||||
|
|
||||||
// If launch-ui flag is set, send connect msg
|
|
||||||
if config.TunnelEventChan != nil {
|
|
||||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Connected, Location: conn.Location}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ServeNamedTunnel(
|
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
|
tlsServerConn net.Conn,
|
||||||
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
addr *net.TCPAddr,
|
connectedFuse connection.ConnectedFuse,
|
||||||
connectedFuse *h2mux.BooleanFuse,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
reconnectCh chan ReconnectSignal,
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
|
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
|
||||||
if err != nil {
|
|
||||||
return err, true
|
|
||||||
}
|
|
||||||
|
|
||||||
cfdServer, err := newHTTP2Server(config, connIndex, tlsServerConn.LocalAddr(), connectedFuse)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, false
|
return err, false
|
||||||
}
|
}
|
||||||
|
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
cfdServer.serve(serveCtx, tlsServerConn)
|
server.Serve(serveCtx)
|
||||||
return fmt.Errorf("Connection with edge closed")
|
return fmt.Errorf("Connection with edge closed")
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||||
select {
|
|
||||||
case reconnect := <-reconnectCh:
|
|
||||||
return &reconnect
|
|
||||||
case <-serveCtx.Done():
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
err = errGroup.Wait()
|
err = errGroup.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -503,229 +339,29 @@ func ServeNamedTunnel(
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
|
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error {
|
||||||
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
|
return func() error {
|
||||||
return &serverRegisterTunnelError{
|
select {
|
||||||
cause: retryable.Unwrap(),
|
case reconnect := <-reconnectCh:
|
||||||
permanent: false,
|
return &reconnect
|
||||||
}
|
case <-ctx.Done():
|
||||||
}
|
|
||||||
return &serverRegisterTunnelError{
|
|
||||||
cause: err,
|
|
||||||
permanent: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UnregisterConnection(
|
|
||||||
ctx context.Context,
|
|
||||||
muxer *h2mux.Muxer,
|
|
||||||
config *TunnelConfig,
|
|
||||||
) error {
|
|
||||||
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
|
|
||||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
|
|
||||||
if err != nil {
|
|
||||||
// RPC stream open error
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rpcClient.Close()
|
|
||||||
|
|
||||||
return rpcClient.UnregisterConnection(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterTunnel(
|
|
||||||
ctx context.Context,
|
|
||||||
credentialManager *reconnectCredentialManager,
|
|
||||||
muxer *h2mux.Muxer,
|
|
||||||
config *TunnelConfig,
|
|
||||||
logger logger.Service,
|
|
||||||
connectionID uint8,
|
|
||||||
originLocalIP string,
|
|
||||||
uuid uuid.UUID,
|
|
||||||
) error {
|
|
||||||
config.TransportLogger.Debug("initiating RPC stream to register")
|
|
||||||
if config.TunnelEventChan != nil {
|
|
||||||
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
|
|
||||||
}
|
|
||||||
|
|
||||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rpcClient.Close()
|
|
||||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
|
||||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
|
|
||||||
registration := rpcClient.RegisterTunnel(
|
|
||||||
ctx,
|
|
||||||
config.OriginCert,
|
|
||||||
config.Hostname,
|
|
||||||
config.RegistrationOptions(connectionID, originLocalIP, uuid),
|
|
||||||
)
|
|
||||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
|
||||||
// RegisterTunnel RPC failure
|
|
||||||
return processRegisterTunnelError(registrationErr, config.Metrics, register)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send free tunnel URL to UI
|
|
||||||
if config.TunnelEventChan != nil {
|
|
||||||
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: registration.Url}
|
|
||||||
}
|
|
||||||
credentialManager.SetEventDigest(connectionID, registration.EventDigest)
|
|
||||||
return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
func processRegistrationSuccess(
|
|
||||||
config *TunnelConfig,
|
|
||||||
logger logger.Service,
|
|
||||||
connectionID uint8,
|
|
||||||
registration *tunnelpogs.TunnelRegistration,
|
|
||||||
name rpcName,
|
|
||||||
credentialManager *reconnectCredentialManager,
|
|
||||||
) error {
|
|
||||||
for _, logLine := range registration.LogLines {
|
|
||||||
logger.Info(logLine)
|
|
||||||
}
|
|
||||||
|
|
||||||
if registration.TunnelID != "" {
|
|
||||||
config.Metrics.tunnelsHA.AddTunnelID(connectionID, registration.TunnelID)
|
|
||||||
logger.Infof("Each HA connection's tunnel IDs: %v", config.Metrics.tunnelsHA.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
|
|
||||||
if config.TunnelEventChan == nil {
|
|
||||||
if config.IsTrialTunnel() {
|
|
||||||
if registrationURL, err := url.Parse(registration.Url); err == nil {
|
|
||||||
for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) {
|
|
||||||
logger.Info(line)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Error("Failed to connect tunnel, please try again.")
|
|
||||||
return fmt.Errorf("empty URL in response from Cloudflare edge")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
credentialManager.SetConnDigest(connectionID, registration.ConnDigest)
|
|
||||||
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
|
||||||
|
|
||||||
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
|
||||||
config.Metrics.regSuccess.WithLabelValues(string(name)).Inc()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error {
|
|
||||||
if err.Error() == DuplicateConnectionError {
|
|
||||||
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
|
|
||||||
return errDuplicationConnection
|
|
||||||
}
|
|
||||||
metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
|
|
||||||
return serverRegisterTunnelError{
|
|
||||||
cause: err,
|
|
||||||
permanent: err.IsPermanent(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error {
|
type connectedFuse struct {
|
||||||
config.TransportLogger.Debug("initiating RPC stream to unregister")
|
fuse *h2mux.BooleanFuse
|
||||||
ctx := context.Background()
|
backoff *BackoffHandler
|
||||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister)
|
|
||||||
if err != nil {
|
|
||||||
// RPC stream open error
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rpcClient.Close()
|
|
||||||
|
|
||||||
// gracePeriod is encoded in int64 using capnproto
|
|
||||||
return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogServerInfo(
|
func (cf *connectedFuse) Connected() {
|
||||||
promise tunnelrpc.ServerInfo_Promise,
|
cf.fuse.Fuse(true)
|
||||||
connectionID uint8,
|
cf.backoff.SetGracePeriod()
|
||||||
metrics *TunnelMetrics,
|
|
||||||
logger logger.Service,
|
|
||||||
tunnelEventChan chan<- ui.TunnelEvent,
|
|
||||||
) {
|
|
||||||
serverInfoMessage, err := promise.Struct()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to retrieve server information: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to retrieve server information: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// If launch-ui flag is set, send connect msg
|
|
||||||
if tunnelEventChan != nil {
|
|
||||||
tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: serverInfo.LocationName}
|
|
||||||
}
|
|
||||||
logger.Infof("Connected to %s", serverInfo.LocationName)
|
|
||||||
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
func (cf *connectedFuse) IsConnected() bool {
|
||||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
return cf.fuse.Value()
|
||||||
req.Header.Set("Host", hostHeader)
|
|
||||||
req.Host = hostHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
dialler, ok := rule.Service.(websocket.Dialler)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
|
|
||||||
}
|
|
||||||
conn, response, err := websocket.ClientConnect(req, dialler)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
err = wsResp.WriteRespHeaders(response)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "Error writing response header")
|
|
||||||
}
|
|
||||||
// Copy to/from stream to the undelying connection. Use the underlying
|
|
||||||
// connection because cloudflared doesn't operate on the message themselves
|
|
||||||
websocket.Stream(conn.UnderlyingConn(), wsResp)
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func uint8ToString(input uint8) string {
|
|
||||||
return strconv.FormatUint(uint64(input), 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print out the given lines in a nice ASCII box.
|
|
||||||
func asciiBox(lines []string, padding int) (box []string) {
|
|
||||||
maxLen := maxLen(lines)
|
|
||||||
spacer := strings.Repeat(" ", padding)
|
|
||||||
|
|
||||||
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
|
|
||||||
|
|
||||||
box = append(box, border)
|
|
||||||
for _, line := range lines {
|
|
||||||
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
|
|
||||||
}
|
|
||||||
box = append(box, border)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func maxLen(lines []string) int {
|
|
||||||
max := 0
|
|
||||||
for _, line := range lines {
|
|
||||||
if len(line) > max {
|
|
||||||
max = len(line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max
|
|
||||||
}
|
|
||||||
|
|
||||||
func trialZoneMsg(url string) []string {
|
|
||||||
return []string{
|
|
||||||
"Your free tunnel has started! Visit it:",
|
|
||||||
" " + url,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func activeIncidentsMsg(incidents []Incident) string {
|
func activeIncidentsMsg(incidents []Incident) string {
|
||||||
|
@ -741,26 +377,3 @@ func activeIncidentsMsg(incidents []Incident) string {
|
||||||
return preamble + " " + strings.Join(incidentStrings, "; ")
|
return preamble + " " + strings.Join(incidentStrings, "; ")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findCfRayHeader(h1 *http.Request) string {
|
|
||||||
return h1.Header.Get("Cf-Ray")
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLBProbeRequest(req *http.Request) bool {
|
|
||||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) {
|
|
||||||
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
|
||||||
defer openStreamCancel()
|
|
||||||
stream, err := muxer.OpenRPCStream(openStreamCtx)
|
|
||||||
if err != nil {
|
|
||||||
return tunnelpogs.TunnelServer_PogsClient{}, err
|
|
||||||
}
|
|
||||||
rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger)
|
|
||||||
if err != nil {
|
|
||||||
// RPC stream open error
|
|
||||||
return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName)
|
|
||||||
}
|
|
||||||
return rpcClient, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -17,11 +17,6 @@ import (
|
||||||
const (
|
const (
|
||||||
OriginCAPoolFlag = "origin-ca-pool"
|
OriginCAPoolFlag = "origin-ca-pool"
|
||||||
CaCertFlag = "cacert"
|
CaCertFlag = "cacert"
|
||||||
|
|
||||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
|
||||||
edgeH2muxTLSServerName = "cftunnel.com"
|
|
||||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
|
||||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
||||||
|
@ -123,16 +118,12 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) {
|
||||||
return certPool, nil
|
return certPool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTunnelConfig(c *cli.Context, isNamedTunnel bool) (*tls.Config, error) {
|
func CreateTunnelConfig(c *cli.Context, serverName string) (*tls.Config, error) {
|
||||||
var rootCAs []string
|
var rootCAs []string
|
||||||
if c.String(CaCertFlag) != "" {
|
if c.String(CaCertFlag) != "" {
|
||||||
rootCAs = append(rootCAs, c.String(CaCertFlag))
|
rootCAs = append(rootCAs, c.String(CaCertFlag))
|
||||||
}
|
}
|
||||||
|
|
||||||
serverName := edgeH2muxTLSServerName
|
|
||||||
if isNamedTunnel {
|
|
||||||
serverName = edgeH2TLSServerName
|
|
||||||
}
|
|
||||||
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
|
userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName}
|
||||||
tlsConfig, err := GetConfig(userConfig)
|
tlsConfig, err := GetConfig(userConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -93,6 +93,11 @@ type RegistrationServer_PogsClient struct {
|
||||||
Conn *rpc.Conn
|
Conn *rpc.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c RegistrationServer_PogsClient) Close() error {
|
||||||
|
c.Client.Close()
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||||
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
||||||
|
|
|
@ -66,7 +66,15 @@ func ValidateHostname(hostname string) (string, error) {
|
||||||
// but when it does not, the path is preserved:
|
// but when it does not, the path is preserved:
|
||||||
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
|
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
|
||||||
// This is arguably a bug, but changing it might break some cloudflared users.
|
// This is arguably a bug, but changing it might break some cloudflared users.
|
||||||
func ValidateUrl(originUrl string) (string, error) {
|
func ValidateUrl(originUrl string) (*url.URL, error) {
|
||||||
|
urlStr, err := validateUrlString(originUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return url.Parse(urlStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateUrlString(originUrl string) (string, error) {
|
||||||
if originUrl == "" {
|
if originUrl == "" {
|
||||||
return "", fmt.Errorf("URL should not be empty")
|
return "", fmt.Errorf("URL should not be empty")
|
||||||
}
|
}
|
||||||
|
@ -157,12 +165,8 @@ func validateIP(scheme, host, port string) (string, error) {
|
||||||
return fmt.Sprintf("%s://%s", scheme, host), nil
|
return fmt.Sprintf("%s://%s", scheme, host), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
|
// originURL shouldn't be a pointer, because this function might change the scheme
|
||||||
parsedURL, err := url.Parse(originURL)
|
func ValidateHTTPService(originURL url.URL, hostname string, transport http.RoundTripper) error {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
@ -171,7 +175,7 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
||||||
Timeout: validationTimeout,
|
Timeout: validationTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
initialRequest, err := http.NewRequest("GET", originURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -183,10 +187,10 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
||||||
oldScheme := parsedURL.Scheme
|
oldScheme := originURL.Scheme
|
||||||
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
|
originURL.Scheme = toggleProtocol(originURL.Scheme)
|
||||||
|
|
||||||
secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
secondRequest, err := http.NewRequest("GET", originURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -195,12 +199,12 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
||||||
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return errors.Errorf(
|
return errors.Errorf(
|
||||||
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
|
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v",
|
||||||
parsedURL.Host,
|
originURL.Host,
|
||||||
oldScheme,
|
oldScheme,
|
||||||
parsedURL.Scheme,
|
originURL.Scheme,
|
||||||
initialErr,
|
initialErr,
|
||||||
parsedURL,
|
originURL,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,12 +228,12 @@ type Access struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
|
func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
|
||||||
domainURL, err := ValidateUrl(domain)
|
domainURL, err := validateUrlString(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
issuerURL, err := ValidateUrl(issuer)
|
issuerURL, err := validateUrlString(issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -101,7 +101,7 @@ func TestValidateUrl(t *testing.T) {
|
||||||
for i, testCase := range testCases {
|
for i, testCase := range testCases {
|
||||||
validUrl, err := ValidateUrl(testCase.input)
|
validUrl, err := ValidateUrl(testCase.input)
|
||||||
assert.NoError(t, err, "test case %v", i)
|
assert.NoError(t, err, "test case %v", i)
|
||||||
assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i)
|
assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
validUrl, err := ValidateUrl("")
|
validUrl, err := ValidateUrl("")
|
||||||
|
@ -123,7 +123,7 @@ func TestToggleProtocol(t *testing.T) {
|
||||||
|
|
||||||
// Happy path 1: originURL is HTTP, and HTTP connections work
|
// Happy path 1: originURL is HTTP, and HTTP connections work
|
||||||
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
||||||
originURL := "http://127.0.0.1/"
|
originURL := mustParse(t, "http://127.0.0.1/")
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -151,7 +151,7 @@ func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
||||||
|
|
||||||
// Happy path 2: originURL is HTTPS, and HTTPS connections work
|
// Happy path 2: originURL is HTTPS, and HTTPS connections work
|
||||||
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||||
originURL := "https://127.0.0.1/"
|
originURL := mustParse(t, "https://127.0.0.1:1234/")
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -179,7 +179,7 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||||
|
|
||||||
// Error path 1: originURL is HTTPS, but HTTP connections work
|
// Error path 1: originURL is HTTPS, but HTTP connections work
|
||||||
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
||||||
originURL := "https://127.0.0.1:1234/"
|
originURL := mustParse(t, "https://127.0.0.1:1234/")
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -207,10 +207,13 @@ func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
||||||
|
|
||||||
// Error path 2: originURL is HTTP, but HTTPS connections work
|
// Error path 2: originURL is HTTP, but HTTPS connections work
|
||||||
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||||
originURL := "http://127.0.0.1:1234/"
|
originURLWithPort := url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "127.0.0.1:1234",
|
||||||
|
}
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
assert.Equal(t, req.Host, hostname)
|
assert.Equal(t, req.Host, hostname)
|
||||||
if req.URL.Scheme == "http" {
|
if req.URL.Scheme == "http" {
|
||||||
return nil, assert.AnError
|
return nil, assert.AnError
|
||||||
|
@ -221,7 +224,7 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||||
panic("Shouldn't reach here")
|
panic("Shouldn't reach here")
|
||||||
})))
|
})))
|
||||||
|
|
||||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
assert.Equal(t, req.Host, hostname)
|
assert.Equal(t, req.Host, hostname)
|
||||||
if req.URL.Scheme == "http" {
|
if req.URL.Scheme == "http" {
|
||||||
return nil, assert.AnError
|
return nil, assert.AnError
|
||||||
|
@ -250,12 +253,14 @@ func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer redirectServer.Close()
|
defer redirectServer.Close()
|
||||||
assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
|
redirectServerURL, err := url.Parse(redirectServer.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NoError(t, ValidateHTTPService(*redirectServerURL, hostname, redirectClient.Transport))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure validation times out when origin URL is nonresponsive
|
// Ensure validation times out when origin URL is nonresponsive
|
||||||
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
|
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
|
||||||
originURL := "http://127.0.0.1/"
|
originURL := mustParse(t, "http://127.0.0.1/")
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
oldValidationTimeout := validationTimeout
|
oldValidationTimeout := validationTimeout
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -371,3 +376,9 @@ func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *h
|
||||||
|
|
||||||
return server, client, nil
|
return server, client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustParse(t *testing.T, originURL string) url.URL {
|
||||||
|
parsedURL, err := url.Parse(originURL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
return *parsedURL
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue