TUN-2640: Users can configure per-origin config. Unify single-rule CLI

flow with multi-rule config file code.
This commit is contained in:
Adam Chalmers 2020-10-15 16:41:03 -05:00
parent ea71b78e6d
commit e933ef9e1a
13 changed files with 1210 additions and 481 deletions

View File

@ -34,7 +34,12 @@ var (
ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories()) ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories())
) )
const DefaultCredentialFile = "cert.pem" const (
DefaultCredentialFile = "cert.pem"
// BastionFlag is to enable bastion, or jump host, operation
BastionFlag = "bastion"
)
// DefaultConfigDirectory returns the default directory of the config file // DefaultConfigDirectory returns the default directory of the config file
func DefaultConfigDirectory() string { func DefaultConfigDirectory() string {
@ -200,11 +205,55 @@ type UnvalidatedIngressRule struct {
Hostname string Hostname string
Path string Path string
Service string Service string
OriginRequest OriginRequestConfig `yaml:"originRequest"`
}
// OriginRequestConfig is a set of optional fields that users may set to
// customize how cloudflared sends requests to origin services. It is used to set
// up general config that apply to all rules, and also, specific per-rule
// config.
// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
type OriginRequestConfig struct {
// HTTP proxy timeout for establishing a new connection
ConnectTimeout *time.Duration `yaml:"connectTimeout"`
// HTTP proxy timeout for completing a TLS handshake
TLSTimeout *time.Duration `yaml:"tlsTimeout"`
// HTTP proxy TCP keepalive duration
TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"`
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
NoHappyEyeballs *bool `yaml:"noHappyEyeballs"`
// HTTP proxy maximum keepalive connection pool size
KeepAliveConnections *int `yaml:"keepAliveConnections"`
// HTTP proxy timeout for closing an idle connection
KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"`
// Sets the HTTP Host header for the local webserver.
HTTPHostHeader *string `yaml:"httpHostHeader"`
// Hostname on the origin server certificate.
OriginServerName *string `yaml:"originServerName"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool *string `yaml:"caPool"`
// Disables TLS verification of the certificate presented by your origin.
// Will allow any certificate from the origin to be accepted.
// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
NoTLSVerify *bool `yaml:"noTLSVerify"`
// Disables chunked transfer encoding.
// Useful if you are running a WSGI server.
DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"`
// Runs as jump host
BastionMode *bool `yaml:"bastionMode"`
// Listen address for the proxy.
ProxyAddress *string `yaml:"proxyAddress"`
// Listen port for the proxy.
ProxyPort *uint `yaml:"proxyPort"`
// Valid options are 'socks', 'ssh' or empty.
ProxyType *string `yaml:"proxyType"`
} }
type Configuration struct { type Configuration struct {
TunnelID string `yaml:"tunnel"` TunnelID string `yaml:"tunnel"`
Ingress []UnvalidatedIngressRule Ingress []UnvalidatedIngressRule
OriginRequest OriginRequestConfig `yaml:"originRequest"`
sourceFile string sourceFile string
} }

View File

@ -5,43 +5,32 @@ import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http"
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"runtime"
"runtime/trace" "runtime/trace"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/cloudflare/cloudflared/awsuploader"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/dbconnect" "github.com/cloudflare/cloudflared/dbconnect"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/cloudflare/cloudflared/sshserver"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelstore" "github.com/cloudflare/cloudflared/tunnelstore"
"github.com/cloudflare/cloudflared/websocket"
"github.com/coreos/go-systemd/daemon" "github.com/coreos/go-systemd/daemon"
"github.com/facebookgo/grace/gracenet" "github.com/facebookgo/grace/gracenet"
"github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
"github.com/gliderlabs/ssh"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -84,15 +73,6 @@ const (
// hostKeyPath is the path of the dir to save SSH host keys too // hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path" hostKeyPath = "host-key-path"
//sshServerFlag enables cloudflared ssh proxy server
sshServerFlag = "ssh-server"
// socks5Flag is to enable the socks server to deframe
socks5Flag = "socks5"
// bastionFlag is to enable bastion, or jump host, operation
bastionFlag = "bastion"
// uiFlag is to enable launching cloudflared in interactive UI mode // uiFlag is to enable launching cloudflared in interactive UI mode
uiFlag = "ui" uiFlag = "ui"
@ -373,72 +353,6 @@ func StartServer(
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0, log) return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0, log)
} }
if c.IsSet("hello-world") {
log.Infof("hello-world set")
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
log.Errorf("Cannot start Hello World Server: %s", err)
return errors.Wrap(err, "Cannot start Hello World Server")
}
defer helloListener.Close()
wg.Add(1)
go func() {
defer wg.Done()
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
}()
forceSetFlag(c, "url", "https://"+helloListener.Addr().String())
}
if c.IsSet(sshServerFlag) {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
msg := fmt.Sprintf("--ssh-server is not supported on %s", runtime.GOOS)
log.Error(msg)
return errors.New(msg)
}
log.Infof("ssh-server set")
logManager := sshlog.NewEmptyManager()
if c.IsSet(bucketNameFlag) && c.IsSet(regionNameFlag) && c.IsSet(accessKeyIDFlag) && c.IsSet(secretIDFlag) {
uploader, err := awsuploader.NewFileUploader(c.String(bucketNameFlag), c.String(regionNameFlag),
c.String(accessKeyIDFlag), c.String(secretIDFlag), c.String(sessionTokenIDFlag), c.String(s3URLFlag))
if err != nil {
msg := "Cannot create uploader for SSH Server"
log.Errorf("%s: %s", msg, err)
return errors.Wrap(err, msg)
}
if err := os.MkdirAll(sshLogFileDirectory, 0700); err != nil {
msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory)
log.Errorf("%s: %s", msg, err)
return errors.Wrap(err, msg)
}
logManager = sshlog.New(sshLogFileDirectory)
uploadManager := awsuploader.NewDirectoryUploadManager(log, uploader, sshLogFileDirectory, 30*time.Minute, shutdownC)
uploadManager.Start()
}
localServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logManager, log, version, localServerAddress, c.String("hostname"), c.Path(hostKeyPath), shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
if err != nil {
msg := "Cannot create new SSH Server"
log.Errorf("%s: %s", msg, err)
return errors.Wrap(err, msg)
}
wg.Add(1)
go func() {
defer wg.Done()
if err = server.Start(); err != nil && err != ssh.ErrServerClosed {
log.Errorf("SSH server error: %s", err)
// TODO: remove when declarative tunnels are implemented.
close(shutdownC)
}
}()
forceSetFlag(c, "url", "ssh://"+localServerAddress)
}
url := c.String("url") url := c.String("url")
hostname := c.String("hostname") hostname := c.String("hostname")
if url == hostname && url != "" && hostname != "" { if url == hostname && url != "" && hostname != "" {
@ -447,42 +361,6 @@ func StartServer(
return fmt.Errorf(errText) return fmt.Errorf(errText)
} }
if staticHost := hostnameFromURI(c.String("url")); isProxyDestinationConfigured(staticHost, c) {
listener, err := net.Listen("tcp", net.JoinHostPort(c.String("proxy-address"), strconv.Itoa(c.Int("proxy-port"))))
if err != nil {
log.Errorf("Cannot start Websocket Proxy Server: %s", err)
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
}
wg.Add(1)
go func() {
defer wg.Done()
streamHandler := websocket.DefaultStreamHandler
if c.IsSet(socks5Flag) {
log.Info("SOCKS5 server started")
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
dialer := socks.NewConnDialer(remoteConn)
requestHandler := socks.NewRequestHandler(dialer)
socksServer := socks.NewConnectionHandler(requestHandler)
socksServer.Serve(wsConn)
}
} else if c.IsSet(sshServerFlag) {
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, requestHeaders http.Header) {
if finalDestination := requestHeaders.Get(h2mux.CFJumpDestinationHeader); finalDestination != "" {
token := requestHeaders.Get(h2mux.CFAccessTokenHeader)
if err := websocket.SendSSHPreamble(remoteConn, finalDestination, token); err != nil {
log.Errorf("Failed to send SSH preamble: %s", err)
return
}
}
websocket.DefaultStreamHandler(wsConn, remoteConn, requestHeaders)
}
}
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
}()
forceSetFlag(c, "url", "http://"+listener.Addr().String())
}
transportLogger, err := createLogger(c, true, false) transportLogger, err := createLogger(c, true, false)
if err != nil { if err != nil {
return errors.Wrap(err, "error setting up transport logger") return errors.Wrap(err, "error setting up transport logger")
@ -493,6 +371,8 @@ func StartServer(
return err return err
} }
tunnelConfig.IngressRules.StartOrigins(&wg, log, shutdownC, errC)
reconnectCh := make(chan origin.ReconnectSignal, 1) reconnectCh := make(chan origin.ReconnectSignal, 1)
if c.IsSet("stdin-control") { if c.IsSet("stdin-control") {
log.Info("Enabling control through stdin") log.Info("Enabling control through stdin")
@ -514,7 +394,8 @@ func StartServer(
version, version,
hostname, hostname,
metricsListener.Addr().String(), metricsListener.Addr().String(),
tunnelConfig.OriginUrl, // TODO (TUN-3461): Update UI to show multiple origin URLs
tunnelConfig.IngressRules.CatchAll().Service.Address(),
tunnelConfig.HAConnections, tunnelConfig.HAConnections,
) )
logLevels, err := logger.ParseLevelString(c.String("loglevel")) logLevels, err := logger.ParseLevelString(c.String("loglevel"))
@ -559,11 +440,6 @@ func SetFlagsFromConfigFile(c *cli.Context) error {
return nil return nil
} }
// isProxyDestinationConfigured returns true if there is a static host set or if bastion mode is set.
func isProxyDestinationConfigured(staticHost string, c *cli.Context) bool {
return staticHost != "" || c.IsSet(bastionFlag)
}
func waitToShutdown(wg *sync.WaitGroup, func waitToShutdown(wg *sync.WaitGroup,
errC chan error, errC chan error,
shutdownC, graceShutdownC chan struct{}, shutdownC, graceShutdownC chan struct{},
@ -910,67 +786,67 @@ func configureProxyFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: socks5Flag, Name: ingress.Socks5Flag,
Usage: "specify if this tunnel is running as a SOCK5 Server", Usage: "specify if this tunnel is running as a SOCK5 Server",
EnvVars: []string{"TUNNEL_SOCKS"}, EnvVars: []string{"TUNNEL_SOCKS"},
Value: false, Value: false,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-connect-timeout", Name: ingress.ProxyConnectTimeoutFlag,
Usage: "HTTP proxy timeout for establishing a new connection", Usage: "HTTP proxy timeout for establishing a new connection",
Value: time.Second * 30, Value: time.Second * 30,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-tls-timeout", Name: ingress.ProxyTLSTimeoutFlag,
Usage: "HTTP proxy timeout for completing a TLS handshake", Usage: "HTTP proxy timeout for completing a TLS handshake",
Value: time.Second * 10, Value: time.Second * 10,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-tcp-keepalive", Name: ingress.ProxyTCPKeepAlive,
Usage: "HTTP proxy TCP keepalive duration", Usage: "HTTP proxy TCP keepalive duration",
Value: time.Second * 30, Value: time.Second * 30,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "proxy-no-happy-eyeballs", Name: ingress.ProxyNoHappyEyeballsFlag,
Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback", Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: "proxy-keepalive-connections", Name: ingress.ProxyKeepAliveConnectionsFlag,
Usage: "HTTP proxy maximum keepalive connection pool size", Usage: "HTTP proxy maximum keepalive connection pool size",
Value: 100, Value: 100,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-keepalive-timeout", Name: ingress.ProxyKeepAliveTimeoutFlag,
Usage: "HTTP proxy timeout for closing an idle connection", Usage: "HTTP proxy timeout for closing an idle connection",
Value: time.Second * 90, Value: time.Second * 90,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-connection-timeout", Name: "proxy-connection-timeout",
Usage: "HTTP proxy timeout for closing an idle connection", Usage: "DEPRECATED. No longer has any effect.",
Value: time.Second * 90, Value: time.Second * 90,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-expect-continue-timeout", Name: "proxy-expect-continue-timeout",
Usage: "HTTP proxy timeout for closing an idle connection", Usage: "DEPRECATED. No longer has any effect.",
Value: time.Second * 90, Value: time.Second * 90,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "http-host-header", Name: ingress.HTTPHostHeaderFlag,
Usage: "Sets the HTTP Host header for the local webserver.", Usage: "Sets the HTTP Host header for the local webserver.",
EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"}, EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-server-name", Name: ingress.OriginServerNameFlag,
Usage: "Hostname on the origin server certificate.", Usage: "Hostname on the origin server certificate.",
EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"}, EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
Hidden: shouldHide, Hidden: shouldHide,
@ -988,13 +864,13 @@ func configureProxyFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-tls-verify", Name: ingress.NoTLSVerifyFlag,
Usage: "Disables TLS verification of the certificate presented by your origin. Will allow any certificate from the origin to be accepted. Note: The connection from your machine to Cloudflare's Edge is still encrypted.", Usage: "Disables TLS verification of the certificate presented by your origin. Will allow any certificate from the origin to be accepted. Note: The connection from your machine to Cloudflare's Edge is still encrypted.",
EnvVars: []string{"NO_TLS_VERIFY"}, EnvVars: []string{"NO_TLS_VERIFY"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-chunked-encoding", Name: ingress.NoChunkedEncodingFlag,
Usage: "Disables chunked transfer encoding; useful if you are running a WSGI server.", Usage: "Disables chunked transfer encoding; useful if you are running a WSGI server.",
EnvVars: []string{"TUNNEL_NO_CHUNKED_ENCODING"}, EnvVars: []string{"TUNNEL_NO_CHUNKED_ENCODING"},
Hidden: shouldHide, Hidden: shouldHide,
@ -1067,28 +943,28 @@ func sshFlags(shouldHide bool) []cli.Flag {
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: sshServerFlag, Name: ingress.SSHServerFlag,
Value: false, Value: false,
Usage: "Run an SSH Server", Usage: "Run an SSH Server",
EnvVars: []string{"TUNNEL_SSH_SERVER"}, EnvVars: []string{"TUNNEL_SSH_SERVER"},
Hidden: true, // TODO: remove when feature is complete Hidden: true, // TODO: remove when feature is complete
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: bastionFlag, Name: config.BastionFlag,
Value: false, Value: false,
Usage: "Runs as jump host", Usage: "Runs as jump host",
EnvVars: []string{"TUNNEL_BASTION"}, EnvVars: []string{"TUNNEL_BASTION"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "proxy-address", Name: ingress.ProxyAddressFlag,
Usage: "Listen address for the proxy.", Usage: "Listen address for the proxy.",
Value: "127.0.0.1", Value: "127.0.0.1",
EnvVars: []string{"TUNNEL_PROXY_ADDRESS"}, EnvVars: []string{"TUNNEL_PROXY_ADDRESS"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: "proxy-port", Name: ingress.ProxyPortFlag,
Usage: "Listen port for the proxy.", Usage: "Listen port for the proxy.",
Value: 0, Value: 0,
EnvVars: []string{"TUNNEL_PROXY_PORT"}, EnvVars: []string{"TUNNEL_PROXY_PORT"},

View File

@ -1,16 +1,11 @@
package tunnel package tunnel
import ( import (
"context"
"crypto/tls"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
@ -193,31 +188,7 @@ func prepareTunnelConfig(
} }
} }
originCertPool, err := tlsconfig.LoadOriginCA(c, logger)
if err != nil {
logger.Errorf("Error loading cert pool: %s", err)
return nil, errors.Wrap(err, "Error loading cert pool")
}
tunnelMetrics := origin.NewTunnelMetrics() tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: c.Int("proxy-keepalive-connections"),
MaxIdleConnsPerHost: c.Int("proxy-keepalive-connections"),
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
}
dialer := &net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
}
if c.Bool("proxy-no-happy-eyeballs") {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
dialContext := dialer.DialContext
var ingressRules ingress.Ingress var ingressRules ingress.Ingress
if namedTunnel != nil { if namedTunnel != nil {
@ -231,7 +202,7 @@ func prepareTunnelConfig(
Version: version, Version: version,
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch), Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
} }
ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) ingressRules, err = ingress.ParseIngress(config.GetConfiguration(), logger)
if err != nil && err != ingress.ErrNoIngressRules { if err != nil && err != ingress.ErrNoIngressRules {
return nil, err return nil, err
} }
@ -240,53 +211,11 @@ func prepareTunnelConfig(
} }
} }
var originURL string // Convert single-origin configuration into multi-origin configuration.
if ingressRules.IsEmpty() { if ingressRules.IsEmpty() {
originURL, err = config.ValidateUrl(c, compatibilityMode) ingressRules, err = ingress.NewSingleOrigin(c, compatibilityMode, logger)
if err != nil { if err != nil {
logger.Errorf("Error validating origin URL: %s", err) return nil, err
return nil, errors.Wrap(err, "Error validating origin URL")
}
}
if c.IsSet("unix-socket") {
unixSocket, err := config.ValidateUnixSocket(c)
if err != nil {
logger.Errorf("Error validating --unix-socket: %s", err)
return nil, errors.Wrap(err, "Error validating --unix-socket")
}
logger.Infof("Proxying tunnel requests to unix:%s", unixSocket)
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
// if --unix-socket specified, enforce network type "unix"
return dialContext(ctx, "unix", unixSocket)
}
} else {
logger.Infof("Proxying tunnel requests to %s", originURL)
httpTransport.DialContext = dialContext
}
if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
}
// If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client.
if !c.IsSet(bastionFlag) {
// List all origin URLs that require validation
var originURLs []string
if ingressRules.IsEmpty() {
originURLs = append(originURLs, originURL)
} else {
for _, rule := range ingressRules.Rules {
originURLs = append(originURLs, rule.Service.String())
}
}
// Validate each origin URL
for _, u := range originURLs {
if err = validation.ValidateHTTPService(u, hostname, httpTransport); err != nil {
logger.Errorf("unable to connect to the origin: %s", err)
}
} }
} }
@ -298,15 +227,12 @@ func prepareTunnelConfig(
return &origin.TunnelConfig{ return &origin.TunnelConfig{
BuildInfo: buildInfo, BuildInfo: buildInfo,
ClientID: clientID, ClientID: clientID,
ClientTlsConfig: httpTransport.TLSClientConfig,
CompressionQuality: c.Uint64("compression-quality"), CompressionQuality: c.Uint64("compression-quality"),
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
GracePeriod: c.Duration("grace-period"), GracePeriod: c.Duration("grace-period"),
HAConnections: c.Int("ha-connections"), HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
HeartbeatInterval: c.Duration("heartbeat-interval"), HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname, Hostname: hostname,
HTTPHostHeader: c.String("http-host-header"),
IncidentLookup: origin.NewIncidentLookup(), IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel, IsFreeTunnel: isFreeTunnel,
@ -316,9 +242,7 @@ func prepareTunnelConfig(
MaxHeartbeats: c.Uint64("heartbeat-count"), MaxHeartbeats: c.Uint64("heartbeat-count"),
Metrics: tunnelMetrics, Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"), MetricsUpdateFreq: c.Duration("metrics-update-freq"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
OriginCert: originCert, OriginCert: originCert,
OriginUrl: originURL,
ReportedVersion: version, ReportedVersion: version,
Retries: c.Uint("retries"), Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),

View File

@ -71,7 +71,7 @@ func buildTestURLCommand() *cli.Command {
func validateIngressCommand(c *cli.Context) error { func validateIngressCommand(c *cli.Context) error {
conf := config.GetConfiguration() conf := config.GetConfiguration()
fmt.Println("Validating rules from", conf.Source()) fmt.Println("Validating rules from", conf.Source())
if _, err := ingress.ParseIngress(conf); err != nil { if _, err := ingress.ParseIngressDryRun(conf); err != nil {
return errors.Wrap(err, "Validation failed") return errors.Wrap(err, "Validation failed")
} }
if c.IsSet("url") { if c.IsSet("url") {
@ -98,12 +98,12 @@ func testURLCommand(c *cli.Context) error {
conf := config.GetConfiguration() conf := config.GetConfiguration()
fmt.Println("Using rules from", conf.Source()) fmt.Println("Using rules from", conf.Source())
ing, err := ingress.ParseIngress(conf) ing, err := ingress.ParseIngressDryRun(conf)
if err != nil { if err != nil {
return errors.Wrap(err, "Validation failed") return errors.Wrap(err, "Validation failed")
} }
i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path) _, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path)
fmt.Printf("Matched rule #%d\n", i+1) fmt.Printf("Matched rule #%d\n", i+1)
fmt.Println(ing.Rules[i].MultiLineString()) fmt.Println(ing.Rules[i].MultiLineString())
return nil return nil

View File

@ -1,14 +1,24 @@
package ingress package ingress
import ( import (
"context"
"crypto/tls"
"fmt" "fmt"
"net"
"net/http"
"net/url" "net/url"
"regexp" "regexp"
"strings" "strings"
"sync"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/validation"
) )
var ( var (
@ -18,54 +28,93 @@ var (
ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules") ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules")
) )
// Each rule route traffic from a hostname/path on the public // Finalize the rules by adding missing struct fields and validating each origin.
// internet to the service running on the given URL. func (ing *Ingress) setHTTPTransport(logger logger.Service) error {
type Rule struct { for ruleNumber, rule := range ing.Rules {
// Requests for this hostname will be proxied to this rule's service. cfg := rule.Config
Hostname string originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil)
if err != nil {
// Path is an optional regex that can specify path-driven ingress rules. return errors.Wrap(err, "Error loading cert pool")
Path *regexp.Regexp
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service *url.URL
}
func (r Rule) MultiLineString() string {
var out strings.Builder
if r.Hostname != "" {
out.WriteString("\thostname: ")
out.WriteString(r.Hostname)
out.WriteRune('\n')
} }
if r.Path != nil {
out.WriteString("\tpath: ")
out.WriteString(r.Path.String())
out.WriteRune('\n')
}
out.WriteString("\tservice: ")
out.WriteString(r.Service.String())
return out.String()
}
func (r *Rule) Matches(hostname, path string) bool { httpTransport := &http.Transport{
hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname) Proxy: http.ProxyFromEnvironment,
pathMatch := r.Path == nil || r.Path.MatchString(path) MaxIdleConns: cfg.KeepAliveConnections,
return hostMatch && pathMatch MaxIdleConnsPerHost: cfg.KeepAliveConnections,
IdleConnTimeout: cfg.KeepAliveTimeout,
TLSHandshakeTimeout: cfg.TLSTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
}
if _, isHelloWorld := rule.Service.(*HelloWorld); !isHelloWorld && cfg.OriginServerName != "" {
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
}
dialer := &net.Dialer{
Timeout: cfg.ConnectTimeout,
KeepAlive: cfg.TCPKeepAlive,
}
if cfg.NoHappyEyeballs {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
// DialContext depends on which kind of origin is being used.
dialContext := dialer.DialContext
switch service := rule.Service.(type) {
// If this origin is a unix socket, enforce network type "unix".
case UnixSocketPath:
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialContext(ctx, "unix", service.Address())
}
// Otherwise, use the regular network config.
default:
httpTransport.DialContext = dialContext
}
ing.Rules[ruleNumber].HTTPTransport = httpTransport
ing.Rules[ruleNumber].ClientTLSConfig = httpTransport.TLSClientConfig
}
// Validate each origin
for _, rule := range ing.Rules {
// If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client.
if rule.Config.BastionMode {
continue
}
// Unix sockets don't have validation
if _, ok := rule.Service.(UnixSocketPath); ok {
continue
}
switch service := rule.Service.(type) {
case UnixSocketPath:
continue
case *HelloWorld:
continue
default:
if err := validation.ValidateHTTPService(service.Address(), rule.Hostname, rule.HTTPTransport); err != nil {
logger.Errorf("unable to connect to the origin: %s", err)
}
}
}
return nil
} }
// FindMatchingRule returns the index of the Ingress Rule which matches the given // FindMatchingRule returns the index of the Ingress Rule which matches the given
// hostname and path. This function assumes the last rule matches everything, // hostname and path. This function assumes the last rule matches everything,
// which is the case if the rules were instantiated via the ingress#Validate method // which is the case if the rules were instantiated via the ingress#Validate method
func (ing Ingress) FindMatchingRule(hostname, path string) int { func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) {
for i, rule := range ing.Rules { for i, rule := range ing.Rules {
if rule.Matches(hostname, path) { if rule.Matches(hostname, path) {
return i return &rule, i
} }
} }
return len(ing.Rules) - 1 i := len(ing.Rules) - 1
return &ing.Rules[i], i
} }
func matchHost(ruleHost, reqHost string) bool { func matchHost(ruleHost, reqHost string) bool {
@ -84,6 +133,55 @@ func matchHost(ruleHost, reqHost string) bool {
// Ingress maps eyeball requests to origins. // Ingress maps eyeball requests to origins.
type Ingress struct { type Ingress struct {
Rules []Rule Rules []Rule
defaults OriginRequestConfig
}
// NewSingleOrigin constructs an Ingress set with only one rule, constructed from
// legacy CLI parameters like --url or --no-chunked-encoding.
func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Service) (Ingress, error) {
service, err := parseSingleOriginService(c, compatibilityMode)
if err != nil {
return Ingress{}, err
}
// Construct an Ingress with the single rule.
ing := Ingress{
Rules: []Rule{
{
Service: service,
},
},
defaults: originRequestFromSingeRule(c),
}
err = ing.setHTTPTransport(logger)
return ing, err
}
// Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) {
if c.IsSet("hello-world") {
return new(HelloWorld), nil
}
if c.IsSet("url") {
originURLStr, err := config.ValidateUrl(c, compatibilityMode)
if err != nil {
return nil, errors.Wrap(err, "Error validating origin URL")
}
originURL, err := url.Parse(originURLStr)
if err != nil {
return nil, errors.Wrap(err, "couldn't parse origin URL")
}
return &URL{URL: originURL, RootURL: originURL}, nil
}
if c.IsSet("unix-socket") {
unixSocket, err := config.ValidateUnixSocket(c)
if err != nil {
return nil, errors.Wrap(err, "Error validating --unix-socket")
}
return UnixSocketPath(unixSocket), nil
}
return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket")
} }
// IsEmpty checks if there are any ingress rules. // IsEmpty checks if there are any ingress rules.
@ -91,19 +189,47 @@ func (ing Ingress) IsEmpty() bool {
return len(ing.Rules) == 0 return len(ing.Rules) == 0
} }
func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { // StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World.
func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error {
for _, rule := range ing.Rules {
if err := rule.Service.Start(wg, log, shutdownC, errC, rule.Config); err != nil {
return err
}
}
return nil
}
// CatchAll returns the catch-all rule (i.e. the last rule)
func (ing Ingress) CatchAll() *Rule {
return &ing.Rules[len(ing.Rules)-1]
}
func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestConfig) (Ingress, error) {
rules := make([]Rule, len(ingress)) rules := make([]Rule, len(ingress))
for i, r := range ingress { for i, r := range ingress {
service, err := url.Parse(r.Service) var service OriginService
if strings.HasPrefix(r.Service, "unix:") {
// No validation necessary for unix socket filepath services
service = UnixSocketPath(strings.TrimPrefix(r.Service, "unix:"))
} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" {
service = new(HelloWorld)
} else {
// Validate URL services
u, err := url.Parse(r.Service)
if err != nil { if err != nil {
return Ingress{}, err return Ingress{}, err
} }
if service.Scheme == "" || service.Hostname() == "" {
if u.Scheme == "" || u.Hostname() == "" {
return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service) return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service)
} }
if service.Path != "" { if u.Path != "" {
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path.", r.Service) return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
}
serviceURL := URL{URL: u}
service = &serviceURL
} }
// Ensure that there are no wildcards anywhere except the first character // Ensure that there are no wildcards anywhere except the first character
@ -125,6 +251,7 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) {
var pathRegex *regexp.Regexp var pathRegex *regexp.Regexp
if r.Path != "" { if r.Path != "" {
var err error
pathRegex, err = regexp.Compile(r.Path) pathRegex, err = regexp.Compile(r.Path)
if err != nil { if err != nil {
return Ingress{}, errors.Wrapf(err, "Rule #%d has an invalid regex", i+1) return Ingress{}, errors.Wrapf(err, "Rule #%d has an invalid regex", i+1)
@ -135,9 +262,10 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) {
Hostname: r.Hostname, Hostname: r.Hostname,
Service: service, Service: service,
Path: pathRegex, Path: pathRegex,
Config: SetConfig(defaults, r.OriginRequest),
} }
} }
return Ingress{Rules: rules}, nil return Ingress{Rules: rules, defaults: defaults}, nil
} }
type errRuleShouldNotBeCatchAll struct { type errRuleShouldNotBeCatchAll struct {
@ -151,9 +279,20 @@ func (e errRuleShouldNotBeCatchAll) Error() string {
"will never be triggered.", e.i+1, e.hostname) "will never be triggered.", e.i+1, e.hostname)
} }
func ParseIngress(conf *config.Configuration) (Ingress, error) { // ParseIngress parses, validates and initializes HTTP transports to each origin.
func ParseIngress(conf *config.Configuration, logger logger.Service) (Ingress, error) {
ing, err := ParseIngressDryRun(conf)
if err != nil {
return Ingress{}, err
}
err = ing.setHTTPTransport(logger)
return ing, err
}
// ParseIngressDryRun parses ingress rules, but does not send HTTP requests to the origins.
func ParseIngressDryRun(conf *config.Configuration) (Ingress, error) {
if len(conf.Ingress) == 0 { if len(conf.Ingress) == 0 {
return Ingress{}, ErrNoIngressRules return Ingress{}, ErrNoIngressRules
} }
return validate(conf.Ingress) return validate(conf.Ingress, OriginRequestFromYAML(conf.OriginRequest))
} }

View File

@ -2,7 +2,6 @@ package ingress
import ( import (
"net/url" "net/url"
"regexp"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -12,16 +11,29 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/config"
) )
func TestParseUnixSocket(t *testing.T) {
rawYAML := `
ingress:
- service: unix:/tmp/echo.sock
`
ing, err := ParseIngressDryRun(MustReadIngress(rawYAML))
require.NoError(t, err)
_, ok := ing.Rules[0].Service.(UnixSocketPath)
require.True(t, ok)
}
func Test_parseIngress(t *testing.T) { func Test_parseIngress(t *testing.T) {
localhost8000 := MustParseURL(t, "https://localhost:8000") localhost8000 := MustParseURL(t, "https://localhost:8000")
localhost8001 := MustParseURL(t, "https://localhost:8001") localhost8001 := MustParseURL(t, "https://localhost:8001")
defaultConfig := SetConfig(OriginRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{})
require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections)
type args struct { type args struct {
rawYAML string rawYAML string
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want Ingress want []Rule
wantErr bool wantErr bool
}{ }{
{ {
@ -38,16 +50,18 @@ ingress:
- hostname: "*" - hostname: "*"
service: https://localhost:8001 service: https://localhost:8001
`}, `},
want: Ingress{Rules: []Rule{ want: []Rule{
{ {
Hostname: "tunnel1.example.com", Hostname: "tunnel1.example.com",
Service: localhost8000, Service: &URL{URL: localhost8000},
Config: defaultConfig,
}, },
{ {
Hostname: "*", Hostname: "*",
Service: localhost8001, Service: &URL{URL: localhost8001},
Config: defaultConfig,
},
}, },
}},
}, },
{ {
name: "Extra keys", name: "Extra keys",
@ -57,12 +71,13 @@ ingress:
service: https://localhost:8000 service: https://localhost:8000
extraKey: extraValue extraKey: extraValue
`}, `},
want: Ingress{Rules: []Rule{ want: []Rule{
{ {
Hostname: "*", Hostname: "*",
Service: localhost8000, Service: &URL{URL: localhost8000},
Config: defaultConfig,
},
}, },
}},
}, },
{ {
name: "Hostname can be omitted", name: "Hostname can be omitted",
@ -70,11 +85,12 @@ extraKey: extraValue
ingress: ingress:
- service: https://localhost:8000 - service: https://localhost:8000
`}, `},
want: Ingress{Rules: []Rule{ want: []Rule{
{ {
Service: localhost8000, Service: &URL{URL: localhost8000},
Config: defaultConfig,
},
}, },
}},
}, },
{ {
name: "Invalid service", name: "Invalid service",
@ -152,12 +168,12 @@ ingress:
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := ParseIngress(MustReadIngress(tt.args.rawYAML)) got, err := ParseIngressDryRun(MustReadIngress(tt.args.rawYAML))
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ParseIngressDryRun() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got.Rules)
}) })
} }
} }
@ -168,118 +184,6 @@ func MustParseURL(t *testing.T, rawURL string) *url.URL {
return u return u
} }
func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service *url.URL
}
type args struct {
requestURL *url.URL
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "Just hostname, pass",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Entire hostname is wildcard, should match everything",
fields: fields{
Hostname: "*",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Just hostname, fail",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://foo.bar"),
},
want: false,
},
{
name: "Just wildcard hostname, pass",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.example.com"),
},
want: true,
},
{
name: "Just wildcard hostname, fail",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://tunnel.com"),
},
want: false,
},
{
name: "Just wildcard outside of subdomain in hostname, fail",
fields: fields{
Hostname: "*example.com",
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com"),
},
want: false,
},
{
name: "Wildcard over multiple subdomains",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.chalmers.example.com"),
},
want: true,
},
{
name: "Hostname and path",
fields: fields{
Hostname: "*.example.com",
Path: regexp.MustCompile("/static/.*\\.html"),
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com/static/index.html"),
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := Rule{
Hostname: tt.fields.Hostname,
Path: tt.fields.Path,
Service: tt.fields.Service,
}
u := tt.args.requestURL
if got := r.Matches(u.Hostname(), u.Path); got != tt.want {
t.Errorf("rule.matches() = %v, want %v", got, tt.want)
}
})
}
}
func BenchmarkFindMatch(b *testing.B) { func BenchmarkFindMatch(b *testing.B) {
rulesYAML := ` rulesYAML := `
ingress: ingress:
@ -291,7 +195,7 @@ ingress:
service: https://localhost:8002 service: https://localhost:8002
` `
ing, err := ParseIngress(MustReadIngress(rulesYAML)) ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }

View File

@ -0,0 +1,331 @@
package ingress
import (
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/urfave/cli/v2"
)
const (
defaultConnectTimeout = 30 * time.Second
defaultTLSTimeout = 10 * time.Second
defaultTCPKeepAlive = 30 * time.Second
defaultKeepAliveConnections = 100
defaultKeepAliveTimeout = 90 * time.Second
defaultProxyAddress = "127.0.0.1"
SSHServerFlag = "ssh-server"
Socks5Flag = "socks5"
ProxyConnectTimeoutFlag = "proxy-connect-timeout"
ProxyTLSTimeoutFlag = "proxy-tls-timeout"
ProxyTCPKeepAlive = "proxy-tcp-keepalive"
ProxyNoHappyEyeballsFlag = "proxy-no-happy-eyeballs"
ProxyKeepAliveConnectionsFlag = "proxy-keepalive-connections"
ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout"
HTTPHostHeaderFlag = "http-host-header"
OriginServerNameFlag = "origin-server-name"
NoTLSVerifyFlag = "no-tls-verify"
NoChunkedEncodingFlag = "no-chunked-encoding"
ProxyAddressFlag = "proxy-address"
ProxyPortFlag = "proxy-port"
)
const (
socksProxy = "socks"
)
func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig {
var connectTimeout time.Duration = defaultConnectTimeout
var tlsTimeout time.Duration = defaultTLSTimeout
var tcpKeepAlive time.Duration = defaultTCPKeepAlive
var noHappyEyeballs bool
var keepAliveConnections int = defaultKeepAliveConnections
var keepAliveTimeout time.Duration = defaultKeepAliveTimeout
var httpHostHeader string
var originServerName string
var caPool string
var noTLSVerify bool
var disableChunkedEncoding bool
var bastionMode bool
var proxyAddress string
var proxyPort uint
var proxyType string
if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
connectTimeout = c.Duration(flag)
}
if flag := ProxyTLSTimeoutFlag; c.IsSet(flag) {
tlsTimeout = c.Duration(flag)
}
if flag := ProxyTCPKeepAlive; c.IsSet(flag) {
tcpKeepAlive = c.Duration(flag)
}
if flag := ProxyNoHappyEyeballsFlag; c.IsSet(flag) {
noHappyEyeballs = c.Bool(flag)
}
if flag := ProxyKeepAliveConnectionsFlag; c.IsSet(flag) {
keepAliveConnections = c.Int(flag)
}
if flag := ProxyKeepAliveTimeoutFlag; c.IsSet(flag) {
keepAliveTimeout = c.Duration(flag)
}
if flag := HTTPHostHeaderFlag; c.IsSet(flag) {
httpHostHeader = c.String(flag)
}
if flag := OriginServerNameFlag; c.IsSet(flag) {
originServerName = c.String(flag)
}
if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) {
caPool = c.String(flag)
}
if flag := NoTLSVerifyFlag; c.IsSet(flag) {
noTLSVerify = c.Bool(flag)
}
if flag := NoChunkedEncodingFlag; c.IsSet(flag) {
disableChunkedEncoding = c.Bool(flag)
}
if flag := config.BastionFlag; c.IsSet(flag) {
bastionMode = c.Bool(flag)
}
if flag := ProxyAddressFlag; c.IsSet(flag) {
proxyAddress = c.String(flag)
}
if flag := ProxyPortFlag; c.IsSet(flag) {
proxyPort = c.Uint(flag)
}
if c.IsSet(Socks5Flag) {
proxyType = socksProxy
}
return OriginRequestConfig{
ConnectTimeout: connectTimeout,
TLSTimeout: tlsTimeout,
TCPKeepAlive: tcpKeepAlive,
NoHappyEyeballs: noHappyEyeballs,
KeepAliveConnections: keepAliveConnections,
KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: httpHostHeader,
OriginServerName: originServerName,
CAPool: caPool,
NoTLSVerify: noTLSVerify,
DisableChunkedEncoding: disableChunkedEncoding,
BastionMode: bastionMode,
ProxyAddress: proxyAddress,
ProxyPort: proxyPort,
ProxyType: proxyType,
}
}
func OriginRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig {
out := OriginRequestConfig{
ConnectTimeout: defaultConnectTimeout,
TLSTimeout: defaultTLSTimeout,
TCPKeepAlive: defaultTCPKeepAlive,
KeepAliveConnections: defaultKeepAliveConnections,
KeepAliveTimeout: defaultKeepAliveTimeout,
ProxyAddress: defaultProxyAddress,
}
if y.ConnectTimeout != nil {
out.ConnectTimeout = *y.ConnectTimeout
}
if y.TLSTimeout != nil {
out.TLSTimeout = *y.TLSTimeout
}
if y.TCPKeepAlive != nil {
out.TCPKeepAlive = *y.TCPKeepAlive
}
if y.NoHappyEyeballs != nil {
out.NoHappyEyeballs = *y.NoHappyEyeballs
}
if y.KeepAliveConnections != nil {
out.KeepAliveConnections = *y.KeepAliveConnections
}
if y.KeepAliveTimeout != nil {
out.KeepAliveTimeout = *y.KeepAliveTimeout
}
if y.HTTPHostHeader != nil {
out.HTTPHostHeader = *y.HTTPHostHeader
}
if y.OriginServerName != nil {
out.OriginServerName = *y.OriginServerName
}
if y.CAPool != nil {
out.CAPool = *y.CAPool
}
if y.NoTLSVerify != nil {
out.NoTLSVerify = *y.NoTLSVerify
}
if y.DisableChunkedEncoding != nil {
out.DisableChunkedEncoding = *y.DisableChunkedEncoding
}
if y.BastionMode != nil {
out.BastionMode = *y.BastionMode
}
if y.ProxyAddress != nil {
out.ProxyAddress = *y.ProxyAddress
}
if y.ProxyPort != nil {
out.ProxyPort = *y.ProxyPort
}
if y.ProxyType != nil {
out.ProxyType = *y.ProxyType
}
return out
}
// OriginRequestConfig configures how Cloudflared sends requests to origin
// services.
// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
type OriginRequestConfig struct {
// HTTP proxy timeout for establishing a new connection
ConnectTimeout time.Duration `yaml:"connectTimeout"`
// HTTP proxy timeout for completing a TLS handshake
TLSTimeout time.Duration `yaml:"tlsTimeout"`
// HTTP proxy TCP keepalive duration
TCPKeepAlive time.Duration `yaml:"tcpKeepAlive"`
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
NoHappyEyeballs bool `yaml:"noHappyEyeballs"`
// HTTP proxy maximum keepalive connection pool size
KeepAliveConnections int `yaml:"keepAliveConnections"`
// HTTP proxy timeout for closing an idle connection
KeepAliveTimeout time.Duration `yaml:"keepAliveTimeout"`
// Sets the HTTP Host header for the local webserver.
HTTPHostHeader string `yaml:"httpHostHeader"`
// Hostname on the origin server certificate.
OriginServerName string `yaml:"originServerName"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool string `yaml:"caPool"`
// Disables TLS verification of the certificate presented by your origin.
// Will allow any certificate from the origin to be accepted.
// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
NoTLSVerify bool `yaml:"noTLSVerify"`
// Disables chunked transfer encoding.
// Useful if you are running a WSGI server.
DisableChunkedEncoding bool `yaml:"disableChunkedEncoding"`
// Runs as jump host
BastionMode bool `yaml:"bastionMode"`
// Listen address for the proxy.
ProxyAddress string `yaml:"proxyAddress"`
// Listen port for the proxy.
ProxyPort uint `yaml:"proxyPort"`
// What sort of proxy should be started
ProxyType string `yaml:"proxyType"`
}
func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) {
if val := overrides.ConnectTimeout; val != nil {
defaults.ConnectTimeout = *val
}
}
func (defaults *OriginRequestConfig) setTLSTimeout(overrides config.OriginRequestConfig) {
if val := overrides.TLSTimeout; val != nil {
defaults.TLSTimeout = *val
}
}
func (defaults *OriginRequestConfig) setNoHappyEyeballs(overrides config.OriginRequestConfig) {
if val := overrides.NoHappyEyeballs; val != nil {
defaults.NoHappyEyeballs = *val
}
}
func (defaults *OriginRequestConfig) setKeepAliveConnections(overrides config.OriginRequestConfig) {
if val := overrides.KeepAliveConnections; val != nil {
defaults.KeepAliveConnections = *val
}
}
func (defaults *OriginRequestConfig) setKeepAliveTimeout(overrides config.OriginRequestConfig) {
if val := overrides.KeepAliveTimeout; val != nil {
defaults.KeepAliveTimeout = *val
}
}
func (defaults *OriginRequestConfig) setTCPKeepAlive(overrides config.OriginRequestConfig) {
if val := overrides.TCPKeepAlive; val != nil {
defaults.TCPKeepAlive = *val
}
}
func (defaults *OriginRequestConfig) setHTTPHostHeader(overrides config.OriginRequestConfig) {
if val := overrides.HTTPHostHeader; val != nil {
defaults.HTTPHostHeader = *val
}
}
func (defaults *OriginRequestConfig) setOriginServerName(overrides config.OriginRequestConfig) {
if val := overrides.OriginServerName; val != nil {
defaults.OriginServerName = *val
}
}
func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) {
if val := overrides.CAPool; val != nil {
defaults.CAPool = *val
}
}
func (defaults *OriginRequestConfig) setNoTLSVerify(overrides config.OriginRequestConfig) {
if val := overrides.NoTLSVerify; val != nil {
defaults.NoTLSVerify = *val
}
}
func (defaults *OriginRequestConfig) setDisableChunkedEncoding(overrides config.OriginRequestConfig) {
if val := overrides.DisableChunkedEncoding; val != nil {
defaults.DisableChunkedEncoding = *val
}
}
func (defaults *OriginRequestConfig) setBastionMode(overrides config.OriginRequestConfig) {
if val := overrides.BastionMode; val != nil {
defaults.BastionMode = *val
}
}
func (defaults *OriginRequestConfig) setProxyPort(overrides config.OriginRequestConfig) {
if val := overrides.ProxyPort; val != nil {
defaults.ProxyPort = *val
}
}
func (defaults *OriginRequestConfig) setProxyAddress(overrides config.OriginRequestConfig) {
if val := overrides.ProxyAddress; val != nil {
defaults.ProxyAddress = *val
}
}
func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequestConfig) {
if val := overrides.ProxyType; val != nil {
defaults.ProxyType = *val
}
}
// SetConfig gets config for the requests that cloudflared sends to origins.
// Each field has a setter method which sets a value for the field by trying to find:
// 1. The user config for this rule
// 2. The user config for the overall ingress config
// 3. Defaults chosen by the cloudflared team
// 4. Golang zero values for that type
// If an earlier option isn't set, it will try the next option down.
func SetConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig {
cfg := defaults
cfg.setConnectTimeout(overrides)
cfg.setTLSTimeout(overrides)
cfg.setNoHappyEyeballs(overrides)
cfg.setKeepAliveConnections(overrides)
cfg.setKeepAliveTimeout(overrides)
cfg.setTCPKeepAlive(overrides)
cfg.setHTTPHostHeader(overrides)
cfg.setOriginServerName(overrides)
cfg.setCAPool(overrides)
cfg.setNoTLSVerify(overrides)
cfg.setDisableChunkedEncoding(overrides)
cfg.setBastionMode(overrides)
cfg.setProxyPort(overrides)
cfg.setProxyAddress(overrides)
cfg.setProxyType(overrides)
return cfg
}

View File

@ -0,0 +1,184 @@
package ingress
import (
"testing"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)
// Ensure that the nullable config from `config` package and the
// non-nullable config from `ingress` package have the same number of
// fields.
// This test ensures that programmers didn't add a new field to
// one struct and forget to add it to the other ;)
func TestCorrespondingFields(t *testing.T) {
require.Equal(
t,
CountFields(t, config.OriginRequestConfig{}),
CountFields(t, OriginRequestConfig{}),
)
}
func CountFields(t *testing.T, val interface{}) int {
b, err := yaml.Marshal(val)
require.NoError(t, err)
m := make(map[string]interface{}, 0)
err = yaml.Unmarshal(b, &m)
require.NoError(t, err)
return len(m)
}
func TestOriginRequestConfigOverrides(t *testing.T) {
rulesYAML := `
originRequest:
connectTimeout: 1m
tlsTimeout: 1s
noHappyEyeballs: true
tcpKeepAlive: 1s
keepAliveConnections: 1
keepAliveTimeout: 1s
httpHostHeader: abc
originServerName: a1
caPool: /tmp/path0
noTLSVerify: true
disableChunkedEncoding: true
bastionMode: True
proxyAddress: 127.1.2.3
proxyPort: 100
proxyType: socks5
ingress:
- hostname: tun.example.com
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
originRequest:
connectTimeout: 2m
tlsTimeout: 2s
noHappyEyeballs: false
tcpKeepAlive: 2s
keepAliveConnections: 2
keepAliveTimeout: 2s
httpHostHeader: def
originServerName: b2
caPool: /tmp/path1
noTLSVerify: false
disableChunkedEncoding: false
bastionMode: false
proxyAddress: interface
proxyPort: 200
proxyType: ""
`
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
if err != nil {
t.Error(err)
}
// Rule 0 didn't override anything, so it inherits the user-specified
// root-level configuration.
actual0 := ing.Rules[0].Config
expected0 := OriginRequestConfig{
ConnectTimeout: 1 * time.Minute,
TLSTimeout: 1 * time.Second,
NoHappyEyeballs: true,
TCPKeepAlive: 1 * time.Second,
KeepAliveConnections: 1,
KeepAliveTimeout: 1 * time.Second,
HTTPHostHeader: "abc",
OriginServerName: "a1",
CAPool: "/tmp/path0",
NoTLSVerify: true,
DisableChunkedEncoding: true,
BastionMode: true,
ProxyAddress: "127.1.2.3",
ProxyPort: uint(100),
ProxyType: "socks5",
}
require.Equal(t, expected0, actual0)
// Rule 1 overrode all the root-level config.
actual1 := ing.Rules[1].Config
expected1 := OriginRequestConfig{
ConnectTimeout: 2 * time.Minute,
TLSTimeout: 2 * time.Second,
NoHappyEyeballs: false,
TCPKeepAlive: 2 * time.Second,
KeepAliveConnections: 2,
KeepAliveTimeout: 2 * time.Second,
HTTPHostHeader: "def",
OriginServerName: "b2",
CAPool: "/tmp/path1",
NoTLSVerify: false,
DisableChunkedEncoding: false,
BastionMode: false,
ProxyAddress: "interface",
ProxyPort: uint(200),
ProxyType: "",
}
require.Equal(t, expected1, actual1)
}
func TestOriginRequestConfigDefaults(t *testing.T) {
rulesYAML := `
ingress:
- hostname: tun.example.com
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
originRequest:
connectTimeout: 2m
tlsTimeout: 2s
noHappyEyeballs: false
tcpKeepAlive: 2s
keepAliveConnections: 2
keepAliveTimeout: 2s
httpHostHeader: def
originServerName: b2
caPool: /tmp/path1
noTLSVerify: false
disableChunkedEncoding: false
bastionMode: false
proxyAddress: interface
proxyPort: 200
proxyType: ""
`
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
if err != nil {
t.Error(err)
}
// Rule 0 didn't override anything, so it inherits the cloudflared defaults
actual0 := ing.Rules[0].Config
expected0 := OriginRequestConfig{
ConnectTimeout: defaultConnectTimeout,
TLSTimeout: defaultTLSTimeout,
TCPKeepAlive: defaultTCPKeepAlive,
KeepAliveConnections: defaultKeepAliveConnections,
KeepAliveTimeout: defaultKeepAliveTimeout,
ProxyAddress: defaultProxyAddress,
}
require.Equal(t, expected0, actual0)
// Rule 1 overrode all defaults.
actual1 := ing.Rules[1].Config
expected1 := OriginRequestConfig{
ConnectTimeout: 2 * time.Minute,
TLSTimeout: 2 * time.Second,
NoHappyEyeballs: false,
TCPKeepAlive: 2 * time.Second,
KeepAliveConnections: 2,
KeepAliveTimeout: 2 * time.Second,
HTTPHostHeader: "def",
OriginServerName: "b2",
CAPool: "/tmp/path1",
NoTLSVerify: false,
DisableChunkedEncoding: false,
BastionMode: false,
ProxyAddress: "interface",
ProxyPort: uint(200),
ProxyType: "",
}
require.Equal(t, expected1, actual1)
}

181
ingress/origin_service.go Normal file
View File

@ -0,0 +1,181 @@
package ingress
import (
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
Address() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
// starting the origin service.
Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
String() string
// RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply
// this particular type of origin service's specific routing logic.
RewriteOriginURL(*url.URL)
}
// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
type UnixSocketPath string
func (o UnixSocketPath) Address() string {
return string(o)
}
func (o UnixSocketPath) String() string {
return "unix socket: " + string(o)
}
func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
return nil
}
func (o UnixSocketPath) RewriteOriginURL(u *url.URL) {
// No changes necessary because the origin request URL isn't used.
// Instead, HTTPTransport's dial is already configured to address the unix socket.
}
// URL is an OriginService listening on a TCP address
type URL struct {
// The URL for the user's origin service
RootURL *url.URL
// The URL that cloudflared should send requests to.
// If this origin requires starting a proxy, this is the proxy's address,
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
URL *url.URL
}
func (o *URL) Address() string {
return o.URL.String()
}
func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
staticHost := o.staticHost()
if !originRequiresProxy(staticHost, cfg) {
return nil
}
// Start a listener for the proxy
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
listener, err := net.Listen("tcp", proxyAddress)
if err != nil {
log.Errorf("Cannot start Websocket Proxy Server: %s", err)
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
}
// Start the proxy itself
wg.Add(1)
go func() {
defer wg.Done()
streamHandler := websocket.DefaultStreamHandler
// This origin's config specifies what type of proxy to start.
switch cfg.ProxyType {
case socksProxy:
log.Info("SOCKS5 server started")
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
dialer := socks.NewConnDialer(remoteConn)
requestHandler := socks.NewRequestHandler(dialer)
socksServer := socks.NewConnectionHandler(requestHandler)
socksServer.Serve(wsConn)
}
case "":
log.Debug("Not starting any websocket proxy")
default:
log.Errorf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
}
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
}()
// Modify this origin, so that it no longer points at the origin service directly.
// Instead, it points at the proxy to the origin service.
newURL, err := url.Parse("http://" + listener.Addr().String())
if err != nil {
return err
}
o.URL = newURL
return nil
}
func (o *URL) String() string {
return o.Address()
}
func (o *URL) RewriteOriginURL(u *url.URL) {
u.Host = o.URL.Host
u.Scheme = o.URL.Scheme
}
func (o *URL) staticHost() string {
addPortIfMissing := func(uri *url.URL, port int) string {
if uri.Port() != "" {
return uri.Host
}
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
switch o.URL.Scheme {
case "ssh":
return addPortIfMissing(o.URL, 22)
case "rdp":
return addPortIfMissing(o.URL, 3389)
case "smb":
return addPortIfMissing(o.URL, 445)
case "tcp":
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
}
return ""
}
// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared.
type HelloWorld struct {
server net.Listener
}
func (o *HelloWorld) Address() string {
return o.server.Addr().String()
}
func (o *HelloWorld) String() string {
return "Hello World static HTML service"
}
// Start starts a HelloWorld server and stores its address in the Service receiver.
func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server")
}
wg.Add(1)
go func() {
defer wg.Done()
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
}()
o.server = helloListener
return nil
}
func (o *HelloWorld) RewriteOriginURL(u *url.URL) {
u.Host = o.Address()
u.Scheme = "https"
}
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
return staticHost != "" || cfg.BastionMode
}

57
ingress/rule.go Normal file
View File

@ -0,0 +1,57 @@
package ingress
import (
"crypto/tls"
"net/http"
"regexp"
"strings"
)
// Rule routes traffic from a hostname/path on the public internet to the
// service running on the given URL.
type Rule struct {
// Requests for this hostname will be proxied to this rule's service.
Hostname string
// Path is an optional regex that can specify path-driven ingress rules.
Path *regexp.Regexp
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service OriginService
// Configure the request cloudflared sends to this specific origin.
Config OriginRequestConfig
// Configures TLS for the cloudflared -> origin request
ClientTLSConfig *tls.Config
// Configures HTTP for the cloudflared -> origin request
HTTPTransport http.RoundTripper
}
// MultiLineString is for outputting rules in a human-friendly way when Cloudflared
// is used as a CLI tool (not as a daemon).
func (r Rule) MultiLineString() string {
var out strings.Builder
if r.Hostname != "" {
out.WriteString("\thostname: ")
out.WriteString(r.Hostname)
out.WriteRune('\n')
}
if r.Path != nil {
out.WriteString("\tpath: ")
out.WriteString(r.Path.String())
out.WriteRune('\n')
}
out.WriteString("\tservice: ")
out.WriteString(r.Service.String())
return out.String()
}
// Matches checks if the rule matches a given hostname/path combination.
func (r *Rule) Matches(hostname, path string) bool {
hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname)
pathMatch := r.Path == nil || r.Path.MatchString(path)
return hostMatch && pathMatch
}

119
ingress/rule_test.go Normal file
View File

@ -0,0 +1,119 @@
package ingress
import (
"net/url"
"regexp"
"testing"
)
func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service OriginService
}
type args struct {
requestURL *url.URL
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "Just hostname, pass",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Entire hostname is wildcard, should match everything",
fields: fields{
Hostname: "*",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Just hostname, fail",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://foo.bar"),
},
want: false,
},
{
name: "Just wildcard hostname, pass",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.example.com"),
},
want: true,
},
{
name: "Just wildcard hostname, fail",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://tunnel.com"),
},
want: false,
},
{
name: "Just wildcard outside of subdomain in hostname, fail",
fields: fields{
Hostname: "*example.com",
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com"),
},
want: false,
},
{
name: "Wildcard over multiple subdomains",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.chalmers.example.com"),
},
want: true,
},
{
name: "Hostname and path",
fields: fields{
Hostname: "*.example.com",
Path: regexp.MustCompile("/static/.*\\.html"),
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com/static/index.html"),
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := Rule{
Hostname: tt.fields.Hostname,
Path: tt.fields.Path,
Service: tt.fields.Service,
}
u := tt.args.requestURL
if got := r.Matches(u.Hostname(), u.Path); got != tt.want {
t.Errorf("rule.matches() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -30,7 +30,6 @@ import (
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
) )
@ -57,16 +56,13 @@ const (
type TunnelConfig struct { type TunnelConfig struct {
BuildInfo *buildinfo.BuildInfo BuildInfo *buildinfo.BuildInfo
ClientID string ClientID string
ClientTlsConfig *tls.Config
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
CompressionQuality uint64 CompressionQuality uint64
EdgeAddrs []string EdgeAddrs []string
GracePeriod time.Duration GracePeriod time.Duration
HAConnections int HAConnections int
HTTPTransport http.RoundTripper
HeartbeatInterval time.Duration HeartbeatInterval time.Duration
Hostname string Hostname string
HTTPHostHeader string
IncidentLookup IncidentLookup IncidentLookup IncidentLookup
IsAutoupdated bool IsAutoupdated bool
IsFreeTunnel bool IsFreeTunnel bool
@ -76,7 +72,6 @@ type TunnelConfig struct {
MaxHeartbeats uint64 MaxHeartbeats uint64
Metrics *TunnelMetrics Metrics *TunnelMetrics
MetricsUpdateFreq time.Duration MetricsUpdateFreq time.Duration
NoChunkedEncoding bool
OriginCert []byte OriginCert []byte
ReportedVersion string ReportedVersion string
Retries uint Retries uint
@ -84,8 +79,6 @@ type TunnelConfig struct {
Tags []tunnelpogs.Tag Tags []tunnelpogs.Tag
TlsConfig *tls.Config TlsConfig *tls.Config
WSGI bool WSGI bool
// OriginUrl may not be used if a user specifies a unix socket.
OriginUrl string
// feature-flag to use new edge reconnect tokens // feature-flag to use new edge reconnect tokens
UseReconnectToken bool UseReconnectToken bool
@ -618,18 +611,13 @@ func LogServerInfo(
} }
type TunnelHandler struct { type TunnelHandler struct {
originUrl string
ingressRules ingress.Ingress ingressRules ingress.Ingress
httpHostHeader string
muxer *h2mux.Muxer muxer *h2mux.Muxer
httpClient http.RoundTripper
tlsConfig *tls.Config
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
metrics *TunnelMetrics metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string connectionID string
logger logger.Service logger logger.Service
noChunkedEncoding bool
bufferPool *buffer.Pool bufferPool *buffer.Pool
} }
@ -642,32 +630,14 @@ func NewTunnelHandler(ctx context.Context,
bufferPool *buffer.Pool, bufferPool *buffer.Pool,
) (*TunnelHandler, string, error) { ) (*TunnelHandler, string, error) {
// Check single-origin config
var originURL string
var err error
if config.IngressRules.IsEmpty() {
originURL, err = validation.ValidateUrl(config.OriginUrl)
if err != nil {
return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL)
}
}
h := &TunnelHandler{ h := &TunnelHandler{
originUrl: originURL,
ingressRules: config.IngressRules, ingressRules: config.IngressRules,
httpHostHeader: config.HTTPHostHeader,
httpClient: config.HTTPTransport,
tlsConfig: config.ClientTlsConfig,
tags: config.Tags, tags: config.Tags,
metrics: config.Metrics, metrics: config.Metrics,
connectionID: uint8ToString(connectionID), connectionID: uint8ToString(connectionID),
logger: config.Logger, logger: config.Logger,
noChunkedEncoding: config.NoChunkedEncoding,
bufferPool: bufferPool, bufferPool: bufferPool,
} }
if h.httpClient == nil {
h.httpClient = http.DefaultTransport
}
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
if err != nil { if err != nil {
@ -692,7 +662,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID) h.metrics.incrementRequests(h.connectionID)
defer h.metrics.decrementConcurrentRequests(h.connectionID) defer h.metrics.decrementConcurrentRequests(h.connectionID)
req, reqErr := h.createRequest(stream) req, rule, reqErr := h.createRequest(stream)
if reqErr != nil { if reqErr != nil {
h.writeErrorResponse(stream, reqErr) h.writeErrorResponse(stream, reqErr)
return reqErr return reqErr
@ -705,9 +675,9 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
var resp *http.Response var resp *http.Response
var respErr error var respErr error
if websocket.IsWebSocketUpgrade(req) { if websocket.IsWebSocketUpgrade(req) {
resp, respErr = h.serveWebsocket(stream, req) resp, respErr = h.serveWebsocket(stream, req, rule)
} else { } else {
resp, respErr = h.serveHTTP(stream, req) resp, respErr = h.serveHTTP(stream, req, rule)
} }
if respErr != nil { if respErr != nil {
h.writeErrorResponse(stream, respErr) h.writeErrorResponse(stream, respErr)
@ -717,32 +687,28 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
return nil return nil
} }
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) { func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) {
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
} }
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "invalid request received") return nil, nil, errors.Wrap(err, "invalid request received")
} }
h.AppendTagHeaders(req) h.AppendTagHeaders(req)
if !h.ingressRules.IsEmpty() { rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
ruleNumber := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) rule.Service.RewriteOriginURL(req.URL)
destination := h.ingressRules.Rules[ruleNumber].Service return req, rule, nil
req.URL.Host = destination.Host
req.URL.Scheme = destination.Scheme
}
return req, nil
} }
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
if h.httpHostHeader != "" { if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", h.httpHostHeader) req.Header.Set("Host", hostHeader)
req.Host = h.httpHostHeader req.Host = hostHeader
} }
conn, response, err := websocket.ClientConnect(req, h.tlsConfig) conn, response, err := websocket.ClientConnect(req, rule.ClientTLSConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -758,9 +724,9 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ
return response, nil return response, nil
} }
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if h.noChunkedEncoding { if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"} req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil { if err == nil {
@ -771,12 +737,12 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request)
// Request origin to keep connection alive to improve performance // Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive") req.Header.Set("Connection", "keep-alive")
if h.httpHostHeader != "" { if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", h.httpHostHeader) req.Header.Set("Host", hostHeader)
req.Host = h.httpHostHeader req.Host = hostHeader
} }
response, err := h.httpClient.RoundTrip(req) response, err := rule.HTTPTransport.RoundTrip(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin") return nil, errors.Wrap(err, "Error proxying request to origin")
} }

View File

@ -65,10 +65,9 @@ func (cr *CertReloader) LoadCert() error {
return nil return nil
} }
func LoadOriginCA(c *cli.Context, logger logger.Service) (*x509.CertPool, error) { func LoadOriginCA(originCAPoolFilename string, logger logger.Service) (*x509.CertPool, error) {
var originCustomCAPool []byte var originCustomCAPool []byte
originCAPoolFilename := c.String(OriginCAPoolFlag)
if originCAPoolFilename != "" { if originCAPoolFilename != "" {
var err error var err error
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename) originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)