diff --git a/cmd/cloudflared/config/configuration.go b/cmd/cloudflared/config/configuration.go index 8233b160..58f48305 100644 --- a/cmd/cloudflared/config/configuration.go +++ b/cmd/cloudflared/config/configuration.go @@ -34,7 +34,12 @@ var ( ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories()) ) -const DefaultCredentialFile = "cert.pem" +const ( + DefaultCredentialFile = "cert.pem" + + // BastionFlag is to enable bastion, or jump host, operation + BastionFlag = "bastion" +) // DefaultConfigDirectory returns the default directory of the config file func DefaultConfigDirectory() string { @@ -197,15 +202,59 @@ func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) { } type UnvalidatedIngressRule struct { - Hostname string - Path string - Service string + Hostname string + Path string + Service string + OriginRequest OriginRequestConfig `yaml:"originRequest"` +} + +// OriginRequestConfig is a set of optional fields that users may set to +// customize how cloudflared sends requests to origin services. It is used to set +// up general config that apply to all rules, and also, specific per-rule +// config. +// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h". +type OriginRequestConfig struct { + // HTTP proxy timeout for establishing a new connection + ConnectTimeout *time.Duration `yaml:"connectTimeout"` + // HTTP proxy timeout for completing a TLS handshake + TLSTimeout *time.Duration `yaml:"tlsTimeout"` + // HTTP proxy TCP keepalive duration + TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"` + // HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback + NoHappyEyeballs *bool `yaml:"noHappyEyeballs"` + // HTTP proxy maximum keepalive connection pool size + KeepAliveConnections *int `yaml:"keepAliveConnections"` + // HTTP proxy timeout for closing an idle connection + KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"` + // Sets the HTTP Host header for the local webserver. + HTTPHostHeader *string `yaml:"httpHostHeader"` + // Hostname on the origin server certificate. + OriginServerName *string `yaml:"originServerName"` + // Path to the CA for the certificate of your origin. + // This option should be used only if your certificate is not signed by Cloudflare. + CAPool *string `yaml:"caPool"` + // Disables TLS verification of the certificate presented by your origin. + // Will allow any certificate from the origin to be accepted. + // Note: The connection from your machine to Cloudflare's Edge is still encrypted. + NoTLSVerify *bool `yaml:"noTLSVerify"` + // Disables chunked transfer encoding. + // Useful if you are running a WSGI server. + DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"` + // Runs as jump host + BastionMode *bool `yaml:"bastionMode"` + // Listen address for the proxy. + ProxyAddress *string `yaml:"proxyAddress"` + // Listen port for the proxy. + ProxyPort *uint `yaml:"proxyPort"` + // Valid options are 'socks', 'ssh' or empty. + ProxyType *string `yaml:"proxyType"` } type Configuration struct { - TunnelID string `yaml:"tunnel"` - Ingress []UnvalidatedIngressRule - sourceFile string + TunnelID string `yaml:"tunnel"` + Ingress []UnvalidatedIngressRule + OriginRequest OriginRequestConfig `yaml:"originRequest"` + sourceFile string } type configFileSettings struct { diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 0de66fcd..03232378 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -5,43 +5,32 @@ import ( "context" "fmt" "io/ioutil" - "net" - "net/http" "net/url" "os" "reflect" - "runtime" "runtime/trace" - "strconv" "strings" "sync" "time" - "github.com/cloudflare/cloudflared/awsuploader" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/dbconnect" - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/signal" - "github.com/cloudflare/cloudflared/socks" - "github.com/cloudflare/cloudflared/sshlog" - "github.com/cloudflare/cloudflared/sshserver" "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/tunnelstore" - "github.com/cloudflare/cloudflared/websocket" "github.com/coreos/go-systemd/daemon" "github.com/facebookgo/grace/gracenet" "github.com/getsentry/raven-go" - "github.com/gliderlabs/ssh" "github.com/google/uuid" "github.com/mitchellh/go-homedir" "github.com/pkg/errors" @@ -84,15 +73,6 @@ const ( // hostKeyPath is the path of the dir to save SSH host keys too hostKeyPath = "host-key-path" - //sshServerFlag enables cloudflared ssh proxy server - sshServerFlag = "ssh-server" - - // socks5Flag is to enable the socks server to deframe - socks5Flag = "socks5" - - // bastionFlag is to enable bastion, or jump host, operation - bastionFlag = "bastion" - // uiFlag is to enable launching cloudflared in interactive UI mode uiFlag = "ui" @@ -373,72 +353,6 @@ func StartServer( return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0, log) } - if c.IsSet("hello-world") { - log.Infof("hello-world set") - helloListener, err := hello.CreateTLSListener("127.0.0.1:") - if err != nil { - log.Errorf("Cannot start Hello World Server: %s", err) - return errors.Wrap(err, "Cannot start Hello World Server") - } - defer helloListener.Close() - wg.Add(1) - go func() { - defer wg.Done() - _ = hello.StartHelloWorldServer(log, helloListener, shutdownC) - }() - forceSetFlag(c, "url", "https://"+helloListener.Addr().String()) - } - - if c.IsSet(sshServerFlag) { - if runtime.GOOS != "darwin" && runtime.GOOS != "linux" { - msg := fmt.Sprintf("--ssh-server is not supported on %s", runtime.GOOS) - log.Error(msg) - return errors.New(msg) - } - - log.Infof("ssh-server set") - - logManager := sshlog.NewEmptyManager() - if c.IsSet(bucketNameFlag) && c.IsSet(regionNameFlag) && c.IsSet(accessKeyIDFlag) && c.IsSet(secretIDFlag) { - uploader, err := awsuploader.NewFileUploader(c.String(bucketNameFlag), c.String(regionNameFlag), - c.String(accessKeyIDFlag), c.String(secretIDFlag), c.String(sessionTokenIDFlag), c.String(s3URLFlag)) - if err != nil { - msg := "Cannot create uploader for SSH Server" - log.Errorf("%s: %s", msg, err) - return errors.Wrap(err, msg) - } - - if err := os.MkdirAll(sshLogFileDirectory, 0700); err != nil { - msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory) - log.Errorf("%s: %s", msg, err) - return errors.Wrap(err, msg) - } - - logManager = sshlog.New(sshLogFileDirectory) - - uploadManager := awsuploader.NewDirectoryUploadManager(log, uploader, sshLogFileDirectory, 30*time.Minute, shutdownC) - uploadManager.Start() - } - - localServerAddress := "127.0.0.1:" + c.String(sshPortFlag) - server, err := sshserver.New(logManager, log, version, localServerAddress, c.String("hostname"), c.Path(hostKeyPath), shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) - if err != nil { - msg := "Cannot create new SSH Server" - log.Errorf("%s: %s", msg, err) - return errors.Wrap(err, msg) - } - wg.Add(1) - go func() { - defer wg.Done() - if err = server.Start(); err != nil && err != ssh.ErrServerClosed { - log.Errorf("SSH server error: %s", err) - // TODO: remove when declarative tunnels are implemented. - close(shutdownC) - } - }() - forceSetFlag(c, "url", "ssh://"+localServerAddress) - } - url := c.String("url") hostname := c.String("hostname") if url == hostname && url != "" && hostname != "" { @@ -447,42 +361,6 @@ func StartServer( return fmt.Errorf(errText) } - if staticHost := hostnameFromURI(c.String("url")); isProxyDestinationConfigured(staticHost, c) { - listener, err := net.Listen("tcp", net.JoinHostPort(c.String("proxy-address"), strconv.Itoa(c.Int("proxy-port")))) - if err != nil { - log.Errorf("Cannot start Websocket Proxy Server: %s", err) - return errors.Wrap(err, "Cannot start Websocket Proxy Server") - } - wg.Add(1) - go func() { - defer wg.Done() - streamHandler := websocket.DefaultStreamHandler - if c.IsSet(socks5Flag) { - log.Info("SOCKS5 server started") - streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) { - dialer := socks.NewConnDialer(remoteConn) - requestHandler := socks.NewRequestHandler(dialer) - socksServer := socks.NewConnectionHandler(requestHandler) - - socksServer.Serve(wsConn) - } - } else if c.IsSet(sshServerFlag) { - streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, requestHeaders http.Header) { - if finalDestination := requestHeaders.Get(h2mux.CFJumpDestinationHeader); finalDestination != "" { - token := requestHeaders.Get(h2mux.CFAccessTokenHeader) - if err := websocket.SendSSHPreamble(remoteConn, finalDestination, token); err != nil { - log.Errorf("Failed to send SSH preamble: %s", err) - return - } - } - websocket.DefaultStreamHandler(wsConn, remoteConn, requestHeaders) - } - } - errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler) - }() - forceSetFlag(c, "url", "http://"+listener.Addr().String()) - } - transportLogger, err := createLogger(c, true, false) if err != nil { return errors.Wrap(err, "error setting up transport logger") @@ -493,6 +371,8 @@ func StartServer( return err } + tunnelConfig.IngressRules.StartOrigins(&wg, log, shutdownC, errC) + reconnectCh := make(chan origin.ReconnectSignal, 1) if c.IsSet("stdin-control") { log.Info("Enabling control through stdin") @@ -514,7 +394,8 @@ func StartServer( version, hostname, metricsListener.Addr().String(), - tunnelConfig.OriginUrl, + // TODO (TUN-3461): Update UI to show multiple origin URLs + tunnelConfig.IngressRules.CatchAll().Service.Address(), tunnelConfig.HAConnections, ) logLevels, err := logger.ParseLevelString(c.String("loglevel")) @@ -559,11 +440,6 @@ func SetFlagsFromConfigFile(c *cli.Context) error { return nil } -// isProxyDestinationConfigured returns true if there is a static host set or if bastion mode is set. -func isProxyDestinationConfigured(staticHost string, c *cli.Context) bool { - return staticHost != "" || c.IsSet(bastionFlag) -} - func waitToShutdown(wg *sync.WaitGroup, errC chan error, shutdownC, graceShutdownC chan struct{}, @@ -910,67 +786,67 @@ func configureProxyFlags(shouldHide bool) []cli.Flag { Hidden: shouldHide, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: socks5Flag, + Name: ingress.Socks5Flag, Usage: "specify if this tunnel is running as a SOCK5 Server", EnvVars: []string{"TUNNEL_SOCKS"}, Value: false, Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ - Name: "proxy-connect-timeout", + Name: ingress.ProxyConnectTimeoutFlag, Usage: "HTTP proxy timeout for establishing a new connection", Value: time.Second * 30, Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ - Name: "proxy-tls-timeout", + Name: ingress.ProxyTLSTimeoutFlag, Usage: "HTTP proxy timeout for completing a TLS handshake", Value: time.Second * 10, Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ - Name: "proxy-tcp-keepalive", + Name: ingress.ProxyTCPKeepAlive, Usage: "HTTP proxy TCP keepalive duration", Value: time.Second * 30, Hidden: shouldHide, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: "proxy-no-happy-eyeballs", + Name: ingress.ProxyNoHappyEyeballsFlag, Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback", Hidden: shouldHide, }), altsrc.NewIntFlag(&cli.IntFlag{ - Name: "proxy-keepalive-connections", + Name: ingress.ProxyKeepAliveConnectionsFlag, Usage: "HTTP proxy maximum keepalive connection pool size", Value: 100, Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ - Name: "proxy-keepalive-timeout", + Name: ingress.ProxyKeepAliveTimeoutFlag, Usage: "HTTP proxy timeout for closing an idle connection", Value: time.Second * 90, Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ Name: "proxy-connection-timeout", - Usage: "HTTP proxy timeout for closing an idle connection", + Usage: "DEPRECATED. No longer has any effect.", Value: time.Second * 90, Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ Name: "proxy-expect-continue-timeout", - Usage: "HTTP proxy timeout for closing an idle connection", + Usage: "DEPRECATED. No longer has any effect.", Value: time.Second * 90, Hidden: shouldHide, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "http-host-header", + Name: ingress.HTTPHostHeaderFlag, Usage: "Sets the HTTP Host header for the local webserver.", EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"}, Hidden: shouldHide, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "origin-server-name", + Name: ingress.OriginServerNameFlag, Usage: "Hostname on the origin server certificate.", EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"}, Hidden: shouldHide, @@ -988,13 +864,13 @@ func configureProxyFlags(shouldHide bool) []cli.Flag { Hidden: shouldHide, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: "no-tls-verify", + Name: ingress.NoTLSVerifyFlag, Usage: "Disables TLS verification of the certificate presented by your origin. Will allow any certificate from the origin to be accepted. Note: The connection from your machine to Cloudflare's Edge is still encrypted.", EnvVars: []string{"NO_TLS_VERIFY"}, Hidden: shouldHide, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: "no-chunked-encoding", + Name: ingress.NoChunkedEncodingFlag, Usage: "Disables chunked transfer encoding; useful if you are running a WSGI server.", EnvVars: []string{"TUNNEL_NO_CHUNKED_ENCODING"}, Hidden: shouldHide, @@ -1067,28 +943,28 @@ func sshFlags(shouldHide bool) []cli.Flag { Hidden: true, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: sshServerFlag, + Name: ingress.SSHServerFlag, Value: false, Usage: "Run an SSH Server", EnvVars: []string{"TUNNEL_SSH_SERVER"}, Hidden: true, // TODO: remove when feature is complete }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: bastionFlag, + Name: config.BastionFlag, Value: false, Usage: "Runs as jump host", EnvVars: []string{"TUNNEL_BASTION"}, Hidden: shouldHide, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "proxy-address", + Name: ingress.ProxyAddressFlag, Usage: "Listen address for the proxy.", Value: "127.0.0.1", EnvVars: []string{"TUNNEL_PROXY_ADDRESS"}, Hidden: shouldHide, }), altsrc.NewIntFlag(&cli.IntFlag{ - Name: "proxy-port", + Name: ingress.ProxyPortFlag, Usage: "Listen port for the proxy.", Value: 0, EnvVars: []string{"TUNNEL_PROXY_PORT"}, diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index e5a55e4b..7fb6c24f 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -1,16 +1,11 @@ package tunnel import ( - "context" - "crypto/tls" "fmt" "io/ioutil" - "net" - "net/http" "os" "path/filepath" "strings" - "time" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" @@ -193,31 +188,7 @@ func prepareTunnelConfig( } } - originCertPool, err := tlsconfig.LoadOriginCA(c, logger) - if err != nil { - logger.Errorf("Error loading cert pool: %s", err) - return nil, errors.Wrap(err, "Error loading cert pool") - } - tunnelMetrics := origin.NewTunnelMetrics() - httpTransport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - MaxIdleConns: c.Int("proxy-keepalive-connections"), - MaxIdleConnsPerHost: c.Int("proxy-keepalive-connections"), - IdleConnTimeout: c.Duration("proxy-keepalive-timeout"), - TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")}, - } - - dialer := &net.Dialer{ - Timeout: c.Duration("proxy-connect-timeout"), - KeepAlive: c.Duration("proxy-tcp-keepalive"), - } - if c.Bool("proxy-no-happy-eyeballs") { - dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs" - } - dialContext := dialer.DialContext var ingressRules ingress.Ingress if namedTunnel != nil { @@ -231,7 +202,7 @@ func prepareTunnelConfig( Version: version, Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch), } - ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) + ingressRules, err = ingress.ParseIngress(config.GetConfiguration(), logger) if err != nil && err != ingress.ErrNoIngressRules { return nil, err } @@ -240,53 +211,11 @@ func prepareTunnelConfig( } } - var originURL string + // Convert single-origin configuration into multi-origin configuration. if ingressRules.IsEmpty() { - originURL, err = config.ValidateUrl(c, compatibilityMode) + ingressRules, err = ingress.NewSingleOrigin(c, compatibilityMode, logger) if err != nil { - logger.Errorf("Error validating origin URL: %s", err) - return nil, errors.Wrap(err, "Error validating origin URL") - } - } - - if c.IsSet("unix-socket") { - unixSocket, err := config.ValidateUnixSocket(c) - if err != nil { - logger.Errorf("Error validating --unix-socket: %s", err) - return nil, errors.Wrap(err, "Error validating --unix-socket") - } - - logger.Infof("Proxying tunnel requests to unix:%s", unixSocket) - httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - // if --unix-socket specified, enforce network type "unix" - return dialContext(ctx, "unix", unixSocket) - } - } else { - logger.Infof("Proxying tunnel requests to %s", originURL) - httpTransport.DialContext = dialContext - } - - if !c.IsSet("hello-world") && c.IsSet("origin-server-name") { - httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") - } - // If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client. - if !c.IsSet(bastionFlag) { - - // List all origin URLs that require validation - var originURLs []string - if ingressRules.IsEmpty() { - originURLs = append(originURLs, originURL) - } else { - for _, rule := range ingressRules.Rules { - originURLs = append(originURLs, rule.Service.String()) - } - } - - // Validate each origin URL - for _, u := range originURLs { - if err = validation.ValidateHTTPService(u, hostname, httpTransport); err != nil { - logger.Errorf("unable to connect to the origin: %s", err) - } + return nil, err } } @@ -298,15 +227,12 @@ func prepareTunnelConfig( return &origin.TunnelConfig{ BuildInfo: buildInfo, ClientID: clientID, - ClientTlsConfig: httpTransport.TLSClientConfig, CompressionQuality: c.Uint64("compression-quality"), EdgeAddrs: c.StringSlice("edge"), GracePeriod: c.Duration("grace-period"), HAConnections: c.Int("ha-connections"), - HTTPTransport: httpTransport, HeartbeatInterval: c.Duration("heartbeat-interval"), Hostname: hostname, - HTTPHostHeader: c.String("http-host-header"), IncidentLookup: origin.NewIncidentLookup(), IsAutoupdated: c.Bool("is-autoupdated"), IsFreeTunnel: isFreeTunnel, @@ -316,9 +242,7 @@ func prepareTunnelConfig( MaxHeartbeats: c.Uint64("heartbeat-count"), Metrics: tunnelMetrics, MetricsUpdateFreq: c.Duration("metrics-update-freq"), - NoChunkedEncoding: c.Bool("no-chunked-encoding"), OriginCert: originCert, - OriginUrl: originURL, ReportedVersion: version, Retries: c.Uint("retries"), RunFromTerminal: isRunningFromTerminal(), diff --git a/cmd/cloudflared/tunnel/ingress_subcommands.go b/cmd/cloudflared/tunnel/ingress_subcommands.go index 20b97a29..ec03c2d0 100644 --- a/cmd/cloudflared/tunnel/ingress_subcommands.go +++ b/cmd/cloudflared/tunnel/ingress_subcommands.go @@ -71,7 +71,7 @@ func buildTestURLCommand() *cli.Command { func validateIngressCommand(c *cli.Context) error { conf := config.GetConfiguration() fmt.Println("Validating rules from", conf.Source()) - if _, err := ingress.ParseIngress(conf); err != nil { + if _, err := ingress.ParseIngressDryRun(conf); err != nil { return errors.Wrap(err, "Validation failed") } if c.IsSet("url") { @@ -98,12 +98,12 @@ func testURLCommand(c *cli.Context) error { conf := config.GetConfiguration() fmt.Println("Using rules from", conf.Source()) - ing, err := ingress.ParseIngress(conf) + ing, err := ingress.ParseIngressDryRun(conf) if err != nil { return errors.Wrap(err, "Validation failed") } - i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path) + _, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path) fmt.Printf("Matched rule #%d\n", i+1) fmt.Println(ing.Rules[i].MultiLineString()) return nil diff --git a/ingress/ingress.go b/ingress/ingress.go index 30e8b079..a1c16ec4 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -1,14 +1,24 @@ package ingress import ( + "context" + "crypto/tls" "fmt" + "net" + "net/http" "net/url" "regexp" "strings" + "sync" + "time" "github.com/pkg/errors" + "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tlsconfig" + "github.com/cloudflare/cloudflared/validation" ) var ( @@ -18,54 +28,93 @@ var ( ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules") ) -// Each rule route traffic from a hostname/path on the public -// internet to the service running on the given URL. -type Rule struct { - // Requests for this hostname will be proxied to this rule's service. - Hostname string +// Finalize the rules by adding missing struct fields and validating each origin. +func (ing *Ingress) setHTTPTransport(logger logger.Service) error { + for ruleNumber, rule := range ing.Rules { + cfg := rule.Config + originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil) + if err != nil { + return errors.Wrap(err, "Error loading cert pool") + } - // Path is an optional regex that can specify path-driven ingress rules. - Path *regexp.Regexp + httpTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: cfg.KeepAliveConnections, + MaxIdleConnsPerHost: cfg.KeepAliveConnections, + IdleConnTimeout: cfg.KeepAliveTimeout, + TLSHandshakeTimeout: cfg.TLSTimeout, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify}, + } + if _, isHelloWorld := rule.Service.(*HelloWorld); !isHelloWorld && cfg.OriginServerName != "" { + httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName + } - // A (probably local) address. Requests for a hostname which matches this - // rule's hostname pattern will be proxied to the service running on this - // address. - Service *url.URL -} + dialer := &net.Dialer{ + Timeout: cfg.ConnectTimeout, + KeepAlive: cfg.TCPKeepAlive, + } + if cfg.NoHappyEyeballs { + dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs" + } -func (r Rule) MultiLineString() string { - var out strings.Builder - if r.Hostname != "" { - out.WriteString("\thostname: ") - out.WriteString(r.Hostname) - out.WriteRune('\n') + // DialContext depends on which kind of origin is being used. + dialContext := dialer.DialContext + switch service := rule.Service.(type) { + + // If this origin is a unix socket, enforce network type "unix". + case UnixSocketPath: + httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialContext(ctx, "unix", service.Address()) + } + // Otherwise, use the regular network config. + default: + httpTransport.DialContext = dialContext + } + + ing.Rules[ruleNumber].HTTPTransport = httpTransport + ing.Rules[ruleNumber].ClientTLSConfig = httpTransport.TLSClientConfig } - if r.Path != nil { - out.WriteString("\tpath: ") - out.WriteString(r.Path.String()) - out.WriteRune('\n') - } - out.WriteString("\tservice: ") - out.WriteString(r.Service.String()) - return out.String() -} -func (r *Rule) Matches(hostname, path string) bool { - hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname) - pathMatch := r.Path == nil || r.Path.MatchString(path) - return hostMatch && pathMatch + // Validate each origin + for _, rule := range ing.Rules { + // If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client. + if rule.Config.BastionMode { + continue + } + + // Unix sockets don't have validation + if _, ok := rule.Service.(UnixSocketPath); ok { + continue + } + switch service := rule.Service.(type) { + + case UnixSocketPath: + continue + + case *HelloWorld: + continue + + default: + if err := validation.ValidateHTTPService(service.Address(), rule.Hostname, rule.HTTPTransport); err != nil { + logger.Errorf("unable to connect to the origin: %s", err) + } + } + } + return nil } // FindMatchingRule returns the index of the Ingress Rule which matches the given // hostname and path. This function assumes the last rule matches everything, // which is the case if the rules were instantiated via the ingress#Validate method -func (ing Ingress) FindMatchingRule(hostname, path string) int { +func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) { for i, rule := range ing.Rules { if rule.Matches(hostname, path) { - return i + return &rule, i } } - return len(ing.Rules) - 1 + i := len(ing.Rules) - 1 + return &ing.Rules[i], i } func matchHost(ruleHost, reqHost string) bool { @@ -83,7 +132,56 @@ func matchHost(ruleHost, reqHost string) bool { // Ingress maps eyeball requests to origins. type Ingress struct { - Rules []Rule + Rules []Rule + defaults OriginRequestConfig +} + +// NewSingleOrigin constructs an Ingress set with only one rule, constructed from +// legacy CLI parameters like --url or --no-chunked-encoding. +func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Service) (Ingress, error) { + + service, err := parseSingleOriginService(c, compatibilityMode) + if err != nil { + return Ingress{}, err + } + + // Construct an Ingress with the single rule. + ing := Ingress{ + Rules: []Rule{ + { + Service: service, + }, + }, + defaults: originRequestFromSingeRule(c), + } + err = ing.setHTTPTransport(logger) + return ing, err +} + +// Get a single origin service from the CLI/config. +func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) { + if c.IsSet("hello-world") { + return new(HelloWorld), nil + } + if c.IsSet("url") { + originURLStr, err := config.ValidateUrl(c, compatibilityMode) + if err != nil { + return nil, errors.Wrap(err, "Error validating origin URL") + } + originURL, err := url.Parse(originURLStr) + if err != nil { + return nil, errors.Wrap(err, "couldn't parse origin URL") + } + return &URL{URL: originURL, RootURL: originURL}, nil + } + if c.IsSet("unix-socket") { + unixSocket, err := config.ValidateUnixSocket(c) + if err != nil { + return nil, errors.Wrap(err, "Error validating --unix-socket") + } + return UnixSocketPath(unixSocket), nil + } + return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket") } // IsEmpty checks if there are any ingress rules. @@ -91,19 +189,47 @@ func (ing Ingress) IsEmpty() bool { return len(ing.Rules) == 0 } -func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { +// StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World. +func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error { + for _, rule := range ing.Rules { + if err := rule.Service.Start(wg, log, shutdownC, errC, rule.Config); err != nil { + return err + } + } + return nil +} + +// CatchAll returns the catch-all rule (i.e. the last rule) +func (ing Ingress) CatchAll() *Rule { + return &ing.Rules[len(ing.Rules)-1] +} + +func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestConfig) (Ingress, error) { rules := make([]Rule, len(ingress)) for i, r := range ingress { - service, err := url.Parse(r.Service) - if err != nil { - return Ingress{}, err - } - if service.Scheme == "" || service.Hostname() == "" { - return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service) - } + var service OriginService - if service.Path != "" { - return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path.", r.Service) + if strings.HasPrefix(r.Service, "unix:") { + // No validation necessary for unix socket filepath services + service = UnixSocketPath(strings.TrimPrefix(r.Service, "unix:")) + } else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" { + service = new(HelloWorld) + } else { + // Validate URL services + u, err := url.Parse(r.Service) + if err != nil { + return Ingress{}, err + } + + if u.Scheme == "" || u.Hostname() == "" { + return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service) + } + + if u.Path != "" { + return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service) + } + serviceURL := URL{URL: u} + service = &serviceURL } // Ensure that there are no wildcards anywhere except the first character @@ -125,6 +251,7 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { var pathRegex *regexp.Regexp if r.Path != "" { + var err error pathRegex, err = regexp.Compile(r.Path) if err != nil { return Ingress{}, errors.Wrapf(err, "Rule #%d has an invalid regex", i+1) @@ -135,9 +262,10 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { Hostname: r.Hostname, Service: service, Path: pathRegex, + Config: SetConfig(defaults, r.OriginRequest), } } - return Ingress{Rules: rules}, nil + return Ingress{Rules: rules, defaults: defaults}, nil } type errRuleShouldNotBeCatchAll struct { @@ -151,9 +279,20 @@ func (e errRuleShouldNotBeCatchAll) Error() string { "will never be triggered.", e.i+1, e.hostname) } -func ParseIngress(conf *config.Configuration) (Ingress, error) { +// ParseIngress parses, validates and initializes HTTP transports to each origin. +func ParseIngress(conf *config.Configuration, logger logger.Service) (Ingress, error) { + ing, err := ParseIngressDryRun(conf) + if err != nil { + return Ingress{}, err + } + err = ing.setHTTPTransport(logger) + return ing, err +} + +// ParseIngressDryRun parses ingress rules, but does not send HTTP requests to the origins. +func ParseIngressDryRun(conf *config.Configuration) (Ingress, error) { if len(conf.Ingress) == 0 { return Ingress{}, ErrNoIngressRules } - return validate(conf.Ingress) + return validate(conf.Ingress, OriginRequestFromYAML(conf.OriginRequest)) } diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index aa3482a2..5a426bf5 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -2,7 +2,6 @@ package ingress import ( "net/url" - "regexp" "testing" "github.com/stretchr/testify/assert" @@ -12,16 +11,29 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/config" ) +func TestParseUnixSocket(t *testing.T) { + rawYAML := ` +ingress: +- service: unix:/tmp/echo.sock +` + ing, err := ParseIngressDryRun(MustReadIngress(rawYAML)) + require.NoError(t, err) + _, ok := ing.Rules[0].Service.(UnixSocketPath) + require.True(t, ok) +} + func Test_parseIngress(t *testing.T) { localhost8000 := MustParseURL(t, "https://localhost:8000") localhost8001 := MustParseURL(t, "https://localhost:8001") + defaultConfig := SetConfig(OriginRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{}) + require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections) type args struct { rawYAML string } tests := []struct { name string args args - want Ingress + want []Rule wantErr bool }{ { @@ -38,16 +50,18 @@ ingress: - hostname: "*" service: https://localhost:8001 `}, - want: Ingress{Rules: []Rule{ + want: []Rule{ { Hostname: "tunnel1.example.com", - Service: localhost8000, + Service: &URL{URL: localhost8000}, + Config: defaultConfig, }, { Hostname: "*", - Service: localhost8001, + Service: &URL{URL: localhost8001}, + Config: defaultConfig, }, - }}, + }, }, { name: "Extra keys", @@ -57,12 +71,13 @@ ingress: service: https://localhost:8000 extraKey: extraValue `}, - want: Ingress{Rules: []Rule{ + want: []Rule{ { Hostname: "*", - Service: localhost8000, + Service: &URL{URL: localhost8000}, + Config: defaultConfig, }, - }}, + }, }, { name: "Hostname can be omitted", @@ -70,11 +85,12 @@ extraKey: extraValue ingress: - service: https://localhost:8000 `}, - want: Ingress{Rules: []Rule{ + want: []Rule{ { - Service: localhost8000, + Service: &URL{URL: localhost8000}, + Config: defaultConfig, }, - }}, + }, }, { name: "Invalid service", @@ -152,12 +168,12 @@ ingress: } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ParseIngress(MustReadIngress(tt.args.rawYAML)) + got, err := ParseIngressDryRun(MustReadIngress(tt.args.rawYAML)) if (err != nil) != tt.wantErr { - t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("ParseIngressDryRun() error = %v, wantErr %v", err, tt.wantErr) return } - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want, got.Rules) }) } } @@ -168,118 +184,6 @@ func MustParseURL(t *testing.T, rawURL string) *url.URL { return u } -func Test_rule_matches(t *testing.T) { - type fields struct { - Hostname string - Path *regexp.Regexp - Service *url.URL - } - type args struct { - requestURL *url.URL - } - tests := []struct { - name string - fields fields - args args - want bool - }{ - { - name: "Just hostname, pass", - fields: fields{ - Hostname: "example.com", - }, - args: args{ - requestURL: MustParseURL(t, "https://example.com"), - }, - want: true, - }, - { - name: "Entire hostname is wildcard, should match everything", - fields: fields{ - Hostname: "*", - }, - args: args{ - requestURL: MustParseURL(t, "https://example.com"), - }, - want: true, - }, - { - name: "Just hostname, fail", - fields: fields{ - Hostname: "example.com", - }, - args: args{ - requestURL: MustParseURL(t, "https://foo.bar"), - }, - want: false, - }, - { - name: "Just wildcard hostname, pass", - fields: fields{ - Hostname: "*.example.com", - }, - args: args{ - requestURL: MustParseURL(t, "https://adam.example.com"), - }, - want: true, - }, - { - name: "Just wildcard hostname, fail", - fields: fields{ - Hostname: "*.example.com", - }, - args: args{ - requestURL: MustParseURL(t, "https://tunnel.com"), - }, - want: false, - }, - { - name: "Just wildcard outside of subdomain in hostname, fail", - fields: fields{ - Hostname: "*example.com", - }, - args: args{ - requestURL: MustParseURL(t, "https://www.example.com"), - }, - want: false, - }, - { - name: "Wildcard over multiple subdomains", - fields: fields{ - Hostname: "*.example.com", - }, - args: args{ - requestURL: MustParseURL(t, "https://adam.chalmers.example.com"), - }, - want: true, - }, - { - name: "Hostname and path", - fields: fields{ - Hostname: "*.example.com", - Path: regexp.MustCompile("/static/.*\\.html"), - }, - args: args{ - requestURL: MustParseURL(t, "https://www.example.com/static/index.html"), - }, - want: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := Rule{ - Hostname: tt.fields.Hostname, - Path: tt.fields.Path, - Service: tt.fields.Service, - } - u := tt.args.requestURL - if got := r.Matches(u.Hostname(), u.Path); got != tt.want { - t.Errorf("rule.matches() = %v, want %v", got, tt.want) - } - }) - } -} - func BenchmarkFindMatch(b *testing.B) { rulesYAML := ` ingress: @@ -291,7 +195,7 @@ ingress: service: https://localhost:8002 ` - ing, err := ParseIngress(MustReadIngress(rulesYAML)) + ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) if err != nil { b.Error(err) } diff --git a/ingress/origin_request_config.go b/ingress/origin_request_config.go new file mode 100644 index 00000000..03e0fcab --- /dev/null +++ b/ingress/origin_request_config.go @@ -0,0 +1,331 @@ +package ingress + +import ( + "time" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/tlsconfig" + "github.com/urfave/cli/v2" +) + +const ( + defaultConnectTimeout = 30 * time.Second + defaultTLSTimeout = 10 * time.Second + defaultTCPKeepAlive = 30 * time.Second + defaultKeepAliveConnections = 100 + defaultKeepAliveTimeout = 90 * time.Second + defaultProxyAddress = "127.0.0.1" + + SSHServerFlag = "ssh-server" + Socks5Flag = "socks5" + ProxyConnectTimeoutFlag = "proxy-connect-timeout" + ProxyTLSTimeoutFlag = "proxy-tls-timeout" + ProxyTCPKeepAlive = "proxy-tcp-keepalive" + ProxyNoHappyEyeballsFlag = "proxy-no-happy-eyeballs" + ProxyKeepAliveConnectionsFlag = "proxy-keepalive-connections" + ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout" + HTTPHostHeaderFlag = "http-host-header" + OriginServerNameFlag = "origin-server-name" + NoTLSVerifyFlag = "no-tls-verify" + NoChunkedEncodingFlag = "no-chunked-encoding" + ProxyAddressFlag = "proxy-address" + ProxyPortFlag = "proxy-port" +) + +const ( + socksProxy = "socks" +) + +func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { + var connectTimeout time.Duration = defaultConnectTimeout + var tlsTimeout time.Duration = defaultTLSTimeout + var tcpKeepAlive time.Duration = defaultTCPKeepAlive + var noHappyEyeballs bool + var keepAliveConnections int = defaultKeepAliveConnections + var keepAliveTimeout time.Duration = defaultKeepAliveTimeout + var httpHostHeader string + var originServerName string + var caPool string + var noTLSVerify bool + var disableChunkedEncoding bool + var bastionMode bool + var proxyAddress string + var proxyPort uint + var proxyType string + if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) { + connectTimeout = c.Duration(flag) + } + if flag := ProxyTLSTimeoutFlag; c.IsSet(flag) { + tlsTimeout = c.Duration(flag) + } + if flag := ProxyTCPKeepAlive; c.IsSet(flag) { + tcpKeepAlive = c.Duration(flag) + } + if flag := ProxyNoHappyEyeballsFlag; c.IsSet(flag) { + noHappyEyeballs = c.Bool(flag) + } + if flag := ProxyKeepAliveConnectionsFlag; c.IsSet(flag) { + keepAliveConnections = c.Int(flag) + } + if flag := ProxyKeepAliveTimeoutFlag; c.IsSet(flag) { + keepAliveTimeout = c.Duration(flag) + } + if flag := HTTPHostHeaderFlag; c.IsSet(flag) { + httpHostHeader = c.String(flag) + } + if flag := OriginServerNameFlag; c.IsSet(flag) { + originServerName = c.String(flag) + } + if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) { + caPool = c.String(flag) + } + if flag := NoTLSVerifyFlag; c.IsSet(flag) { + noTLSVerify = c.Bool(flag) + } + if flag := NoChunkedEncodingFlag; c.IsSet(flag) { + disableChunkedEncoding = c.Bool(flag) + } + if flag := config.BastionFlag; c.IsSet(flag) { + bastionMode = c.Bool(flag) + } + if flag := ProxyAddressFlag; c.IsSet(flag) { + proxyAddress = c.String(flag) + } + if flag := ProxyPortFlag; c.IsSet(flag) { + proxyPort = c.Uint(flag) + } + if c.IsSet(Socks5Flag) { + proxyType = socksProxy + } + return OriginRequestConfig{ + ConnectTimeout: connectTimeout, + TLSTimeout: tlsTimeout, + TCPKeepAlive: tcpKeepAlive, + NoHappyEyeballs: noHappyEyeballs, + KeepAliveConnections: keepAliveConnections, + KeepAliveTimeout: keepAliveTimeout, + HTTPHostHeader: httpHostHeader, + OriginServerName: originServerName, + CAPool: caPool, + NoTLSVerify: noTLSVerify, + DisableChunkedEncoding: disableChunkedEncoding, + BastionMode: bastionMode, + ProxyAddress: proxyAddress, + ProxyPort: proxyPort, + ProxyType: proxyType, + } +} + +func OriginRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig { + out := OriginRequestConfig{ + ConnectTimeout: defaultConnectTimeout, + TLSTimeout: defaultTLSTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + KeepAliveConnections: defaultKeepAliveConnections, + KeepAliveTimeout: defaultKeepAliveTimeout, + ProxyAddress: defaultProxyAddress, + } + if y.ConnectTimeout != nil { + out.ConnectTimeout = *y.ConnectTimeout + } + if y.TLSTimeout != nil { + out.TLSTimeout = *y.TLSTimeout + } + if y.TCPKeepAlive != nil { + out.TCPKeepAlive = *y.TCPKeepAlive + } + if y.NoHappyEyeballs != nil { + out.NoHappyEyeballs = *y.NoHappyEyeballs + } + if y.KeepAliveConnections != nil { + out.KeepAliveConnections = *y.KeepAliveConnections + } + if y.KeepAliveTimeout != nil { + out.KeepAliveTimeout = *y.KeepAliveTimeout + } + if y.HTTPHostHeader != nil { + out.HTTPHostHeader = *y.HTTPHostHeader + } + if y.OriginServerName != nil { + out.OriginServerName = *y.OriginServerName + } + if y.CAPool != nil { + out.CAPool = *y.CAPool + } + if y.NoTLSVerify != nil { + out.NoTLSVerify = *y.NoTLSVerify + } + if y.DisableChunkedEncoding != nil { + out.DisableChunkedEncoding = *y.DisableChunkedEncoding + } + if y.BastionMode != nil { + out.BastionMode = *y.BastionMode + } + if y.ProxyAddress != nil { + out.ProxyAddress = *y.ProxyAddress + } + if y.ProxyPort != nil { + out.ProxyPort = *y.ProxyPort + } + if y.ProxyType != nil { + out.ProxyType = *y.ProxyType + } + return out +} + +// OriginRequestConfig configures how Cloudflared sends requests to origin +// services. +// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h". +type OriginRequestConfig struct { + // HTTP proxy timeout for establishing a new connection + ConnectTimeout time.Duration `yaml:"connectTimeout"` + // HTTP proxy timeout for completing a TLS handshake + TLSTimeout time.Duration `yaml:"tlsTimeout"` + // HTTP proxy TCP keepalive duration + TCPKeepAlive time.Duration `yaml:"tcpKeepAlive"` + // HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback + NoHappyEyeballs bool `yaml:"noHappyEyeballs"` + // HTTP proxy maximum keepalive connection pool size + KeepAliveConnections int `yaml:"keepAliveConnections"` + // HTTP proxy timeout for closing an idle connection + KeepAliveTimeout time.Duration `yaml:"keepAliveTimeout"` + // Sets the HTTP Host header for the local webserver. + HTTPHostHeader string `yaml:"httpHostHeader"` + // Hostname on the origin server certificate. + OriginServerName string `yaml:"originServerName"` + // Path to the CA for the certificate of your origin. + // This option should be used only if your certificate is not signed by Cloudflare. + CAPool string `yaml:"caPool"` + // Disables TLS verification of the certificate presented by your origin. + // Will allow any certificate from the origin to be accepted. + // Note: The connection from your machine to Cloudflare's Edge is still encrypted. + NoTLSVerify bool `yaml:"noTLSVerify"` + // Disables chunked transfer encoding. + // Useful if you are running a WSGI server. + DisableChunkedEncoding bool `yaml:"disableChunkedEncoding"` + // Runs as jump host + BastionMode bool `yaml:"bastionMode"` + // Listen address for the proxy. + ProxyAddress string `yaml:"proxyAddress"` + // Listen port for the proxy. + ProxyPort uint `yaml:"proxyPort"` + // What sort of proxy should be started + ProxyType string `yaml:"proxyType"` +} + +func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) { + if val := overrides.ConnectTimeout; val != nil { + defaults.ConnectTimeout = *val + } +} + +func (defaults *OriginRequestConfig) setTLSTimeout(overrides config.OriginRequestConfig) { + if val := overrides.TLSTimeout; val != nil { + defaults.TLSTimeout = *val + } +} + +func (defaults *OriginRequestConfig) setNoHappyEyeballs(overrides config.OriginRequestConfig) { + if val := overrides.NoHappyEyeballs; val != nil { + defaults.NoHappyEyeballs = *val + } +} + +func (defaults *OriginRequestConfig) setKeepAliveConnections(overrides config.OriginRequestConfig) { + if val := overrides.KeepAliveConnections; val != nil { + defaults.KeepAliveConnections = *val + } +} + +func (defaults *OriginRequestConfig) setKeepAliveTimeout(overrides config.OriginRequestConfig) { + if val := overrides.KeepAliveTimeout; val != nil { + defaults.KeepAliveTimeout = *val + } +} + +func (defaults *OriginRequestConfig) setTCPKeepAlive(overrides config.OriginRequestConfig) { + if val := overrides.TCPKeepAlive; val != nil { + defaults.TCPKeepAlive = *val + } +} + +func (defaults *OriginRequestConfig) setHTTPHostHeader(overrides config.OriginRequestConfig) { + if val := overrides.HTTPHostHeader; val != nil { + defaults.HTTPHostHeader = *val + } +} + +func (defaults *OriginRequestConfig) setOriginServerName(overrides config.OriginRequestConfig) { + if val := overrides.OriginServerName; val != nil { + defaults.OriginServerName = *val + } +} + +func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) { + if val := overrides.CAPool; val != nil { + defaults.CAPool = *val + } +} + +func (defaults *OriginRequestConfig) setNoTLSVerify(overrides config.OriginRequestConfig) { + if val := overrides.NoTLSVerify; val != nil { + defaults.NoTLSVerify = *val + } +} + +func (defaults *OriginRequestConfig) setDisableChunkedEncoding(overrides config.OriginRequestConfig) { + if val := overrides.DisableChunkedEncoding; val != nil { + defaults.DisableChunkedEncoding = *val + } +} + +func (defaults *OriginRequestConfig) setBastionMode(overrides config.OriginRequestConfig) { + if val := overrides.BastionMode; val != nil { + defaults.BastionMode = *val + } +} + +func (defaults *OriginRequestConfig) setProxyPort(overrides config.OriginRequestConfig) { + if val := overrides.ProxyPort; val != nil { + defaults.ProxyPort = *val + } +} + +func (defaults *OriginRequestConfig) setProxyAddress(overrides config.OriginRequestConfig) { + if val := overrides.ProxyAddress; val != nil { + defaults.ProxyAddress = *val + } +} + +func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequestConfig) { + if val := overrides.ProxyType; val != nil { + defaults.ProxyType = *val + } +} + +// SetConfig gets config for the requests that cloudflared sends to origins. +// Each field has a setter method which sets a value for the field by trying to find: +// 1. The user config for this rule +// 2. The user config for the overall ingress config +// 3. Defaults chosen by the cloudflared team +// 4. Golang zero values for that type +// If an earlier option isn't set, it will try the next option down. +func SetConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig { + cfg := defaults + cfg.setConnectTimeout(overrides) + cfg.setTLSTimeout(overrides) + cfg.setNoHappyEyeballs(overrides) + cfg.setKeepAliveConnections(overrides) + cfg.setKeepAliveTimeout(overrides) + cfg.setTCPKeepAlive(overrides) + cfg.setHTTPHostHeader(overrides) + cfg.setOriginServerName(overrides) + cfg.setCAPool(overrides) + cfg.setNoTLSVerify(overrides) + cfg.setDisableChunkedEncoding(overrides) + cfg.setBastionMode(overrides) + cfg.setProxyPort(overrides) + cfg.setProxyAddress(overrides) + cfg.setProxyType(overrides) + return cfg +} diff --git a/ingress/origin_request_config_test.go b/ingress/origin_request_config_test.go new file mode 100644 index 00000000..4b874bff --- /dev/null +++ b/ingress/origin_request_config_test.go @@ -0,0 +1,184 @@ +package ingress + +import ( + "testing" + "time" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +// Ensure that the nullable config from `config` package and the +// non-nullable config from `ingress` package have the same number of +// fields. +// This test ensures that programmers didn't add a new field to +// one struct and forget to add it to the other ;) +func TestCorrespondingFields(t *testing.T) { + require.Equal( + t, + CountFields(t, config.OriginRequestConfig{}), + CountFields(t, OriginRequestConfig{}), + ) +} + +func CountFields(t *testing.T, val interface{}) int { + b, err := yaml.Marshal(val) + require.NoError(t, err) + m := make(map[string]interface{}, 0) + err = yaml.Unmarshal(b, &m) + require.NoError(t, err) + return len(m) +} + +func TestOriginRequestConfigOverrides(t *testing.T) { + rulesYAML := ` +originRequest: + connectTimeout: 1m + tlsTimeout: 1s + noHappyEyeballs: true + tcpKeepAlive: 1s + keepAliveConnections: 1 + keepAliveTimeout: 1s + httpHostHeader: abc + originServerName: a1 + caPool: /tmp/path0 + noTLSVerify: true + disableChunkedEncoding: true + bastionMode: True + proxyAddress: 127.1.2.3 + proxyPort: 100 + proxyType: socks5 +ingress: +- hostname: tun.example.com + service: https://localhost:8000 +- hostname: "*" + service: https://localhost:8001 + originRequest: + connectTimeout: 2m + tlsTimeout: 2s + noHappyEyeballs: false + tcpKeepAlive: 2s + keepAliveConnections: 2 + keepAliveTimeout: 2s + httpHostHeader: def + originServerName: b2 + caPool: /tmp/path1 + noTLSVerify: false + disableChunkedEncoding: false + bastionMode: false + proxyAddress: interface + proxyPort: 200 + proxyType: "" +` + ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) + if err != nil { + t.Error(err) + } + + // Rule 0 didn't override anything, so it inherits the user-specified + // root-level configuration. + actual0 := ing.Rules[0].Config + expected0 := OriginRequestConfig{ + ConnectTimeout: 1 * time.Minute, + TLSTimeout: 1 * time.Second, + NoHappyEyeballs: true, + TCPKeepAlive: 1 * time.Second, + KeepAliveConnections: 1, + KeepAliveTimeout: 1 * time.Second, + HTTPHostHeader: "abc", + OriginServerName: "a1", + CAPool: "/tmp/path0", + NoTLSVerify: true, + DisableChunkedEncoding: true, + BastionMode: true, + ProxyAddress: "127.1.2.3", + ProxyPort: uint(100), + ProxyType: "socks5", + } + require.Equal(t, expected0, actual0) + + // Rule 1 overrode all the root-level config. + actual1 := ing.Rules[1].Config + expected1 := OriginRequestConfig{ + ConnectTimeout: 2 * time.Minute, + TLSTimeout: 2 * time.Second, + NoHappyEyeballs: false, + TCPKeepAlive: 2 * time.Second, + KeepAliveConnections: 2, + KeepAliveTimeout: 2 * time.Second, + HTTPHostHeader: "def", + OriginServerName: "b2", + CAPool: "/tmp/path1", + NoTLSVerify: false, + DisableChunkedEncoding: false, + BastionMode: false, + ProxyAddress: "interface", + ProxyPort: uint(200), + ProxyType: "", + } + require.Equal(t, expected1, actual1) +} + +func TestOriginRequestConfigDefaults(t *testing.T) { + rulesYAML := ` +ingress: +- hostname: tun.example.com + service: https://localhost:8000 +- hostname: "*" + service: https://localhost:8001 + originRequest: + connectTimeout: 2m + tlsTimeout: 2s + noHappyEyeballs: false + tcpKeepAlive: 2s + keepAliveConnections: 2 + keepAliveTimeout: 2s + httpHostHeader: def + originServerName: b2 + caPool: /tmp/path1 + noTLSVerify: false + disableChunkedEncoding: false + bastionMode: false + proxyAddress: interface + proxyPort: 200 + proxyType: "" +` + ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) + if err != nil { + t.Error(err) + } + + // Rule 0 didn't override anything, so it inherits the cloudflared defaults + actual0 := ing.Rules[0].Config + expected0 := OriginRequestConfig{ + ConnectTimeout: defaultConnectTimeout, + TLSTimeout: defaultTLSTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + KeepAliveConnections: defaultKeepAliveConnections, + KeepAliveTimeout: defaultKeepAliveTimeout, + ProxyAddress: defaultProxyAddress, + } + require.Equal(t, expected0, actual0) + + // Rule 1 overrode all defaults. + actual1 := ing.Rules[1].Config + expected1 := OriginRequestConfig{ + ConnectTimeout: 2 * time.Minute, + TLSTimeout: 2 * time.Second, + NoHappyEyeballs: false, + TCPKeepAlive: 2 * time.Second, + KeepAliveConnections: 2, + KeepAliveTimeout: 2 * time.Second, + HTTPHostHeader: "def", + OriginServerName: "b2", + CAPool: "/tmp/path1", + NoTLSVerify: false, + DisableChunkedEncoding: false, + BastionMode: false, + ProxyAddress: "interface", + ProxyPort: uint(200), + ProxyType: "", + } + require.Equal(t, expected1, actual1) +} diff --git a/ingress/origin_service.go b/ingress/origin_service.go new file mode 100644 index 00000000..89b4a57e --- /dev/null +++ b/ingress/origin_service.go @@ -0,0 +1,181 @@ +package ingress + +import ( + "fmt" + "net" + "net/http" + "net/url" + "strconv" + "sync" + + "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/socks" + "github.com/cloudflare/cloudflared/websocket" + "github.com/pkg/errors" +) + +// OriginService is something a tunnel can proxy traffic to. +type OriginService interface { + Address() string + // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. + // If it's not managed by cloudflared, this is a no-op because the user is responsible for + // starting the origin service. + Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error + String() string + // RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply + // this particular type of origin service's specific routing logic. + RewriteOriginURL(*url.URL) +} + +// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP) +type UnixSocketPath string + +func (o UnixSocketPath) Address() string { + return string(o) +} + +func (o UnixSocketPath) String() string { + return "unix socket: " + string(o) +} + +func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + return nil +} + +func (o UnixSocketPath) RewriteOriginURL(u *url.URL) { + // No changes necessary because the origin request URL isn't used. + // Instead, HTTPTransport's dial is already configured to address the unix socket. +} + +// URL is an OriginService listening on a TCP address +type URL struct { + // The URL for the user's origin service + RootURL *url.URL + // The URL that cloudflared should send requests to. + // If this origin requires starting a proxy, this is the proxy's address, + // and that proxy points to RootURL. Otherwise, this is equal to RootURL. + URL *url.URL +} + +func (o *URL) Address() string { + return o.URL.String() +} + +func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + staticHost := o.staticHost() + if !originRequiresProxy(staticHost, cfg) { + return nil + } + + // Start a listener for the proxy + proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort))) + listener, err := net.Listen("tcp", proxyAddress) + if err != nil { + log.Errorf("Cannot start Websocket Proxy Server: %s", err) + return errors.Wrap(err, "Cannot start Websocket Proxy Server") + } + + // Start the proxy itself + wg.Add(1) + go func() { + defer wg.Done() + streamHandler := websocket.DefaultStreamHandler + // This origin's config specifies what type of proxy to start. + switch cfg.ProxyType { + case socksProxy: + log.Info("SOCKS5 server started") + streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) { + dialer := socks.NewConnDialer(remoteConn) + requestHandler := socks.NewRequestHandler(dialer) + socksServer := socks.NewConnectionHandler(requestHandler) + + socksServer.Serve(wsConn) + } + case "": + log.Debug("Not starting any websocket proxy") + default: + log.Errorf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy) + } + + errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler) + }() + + // Modify this origin, so that it no longer points at the origin service directly. + // Instead, it points at the proxy to the origin service. + newURL, err := url.Parse("http://" + listener.Addr().String()) + if err != nil { + return err + } + o.URL = newURL + return nil +} + +func (o *URL) String() string { + return o.Address() +} + +func (o *URL) RewriteOriginURL(u *url.URL) { + u.Host = o.URL.Host + u.Scheme = o.URL.Scheme +} + +func (o *URL) staticHost() string { + + addPortIfMissing := func(uri *url.URL, port int) string { + if uri.Port() != "" { + return uri.Host + } + return fmt.Sprintf("%s:%d", uri.Hostname(), port) + } + + switch o.URL.Scheme { + case "ssh": + return addPortIfMissing(o.URL, 22) + case "rdp": + return addPortIfMissing(o.URL, 3389) + case "smb": + return addPortIfMissing(o.URL, 445) + case "tcp": + return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case + } + return "" + +} + +// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared. +type HelloWorld struct { + server net.Listener +} + +func (o *HelloWorld) Address() string { + return o.server.Addr().String() +} + +func (o *HelloWorld) String() string { + return "Hello World static HTML service" +} + +// Start starts a HelloWorld server and stores its address in the Service receiver. +func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + helloListener, err := hello.CreateTLSListener("127.0.0.1:") + if err != nil { + return errors.Wrap(err, "Cannot start Hello World Server") + } + wg.Add(1) + go func() { + defer wg.Done() + _ = hello.StartHelloWorldServer(log, helloListener, shutdownC) + }() + o.server = helloListener + return nil +} + +func (o *HelloWorld) RewriteOriginURL(u *url.URL) { + u.Host = o.Address() + u.Scheme = "https" +} + +func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool { + return staticHost != "" || cfg.BastionMode +} diff --git a/ingress/rule.go b/ingress/rule.go new file mode 100644 index 00000000..c47b8bb9 --- /dev/null +++ b/ingress/rule.go @@ -0,0 +1,57 @@ +package ingress + +import ( + "crypto/tls" + "net/http" + "regexp" + "strings" +) + +// Rule routes traffic from a hostname/path on the public internet to the +// service running on the given URL. +type Rule struct { + // Requests for this hostname will be proxied to this rule's service. + Hostname string + + // Path is an optional regex that can specify path-driven ingress rules. + Path *regexp.Regexp + + // A (probably local) address. Requests for a hostname which matches this + // rule's hostname pattern will be proxied to the service running on this + // address. + Service OriginService + + // Configure the request cloudflared sends to this specific origin. + Config OriginRequestConfig + + // Configures TLS for the cloudflared -> origin request + ClientTLSConfig *tls.Config + // Configures HTTP for the cloudflared -> origin request + HTTPTransport http.RoundTripper +} + +// MultiLineString is for outputting rules in a human-friendly way when Cloudflared +// is used as a CLI tool (not as a daemon). +func (r Rule) MultiLineString() string { + var out strings.Builder + if r.Hostname != "" { + out.WriteString("\thostname: ") + out.WriteString(r.Hostname) + out.WriteRune('\n') + } + if r.Path != nil { + out.WriteString("\tpath: ") + out.WriteString(r.Path.String()) + out.WriteRune('\n') + } + out.WriteString("\tservice: ") + out.WriteString(r.Service.String()) + return out.String() +} + +// Matches checks if the rule matches a given hostname/path combination. +func (r *Rule) Matches(hostname, path string) bool { + hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname) + pathMatch := r.Path == nil || r.Path.MatchString(path) + return hostMatch && pathMatch +} diff --git a/ingress/rule_test.go b/ingress/rule_test.go new file mode 100644 index 00000000..ef908090 --- /dev/null +++ b/ingress/rule_test.go @@ -0,0 +1,119 @@ +package ingress + +import ( + "net/url" + "regexp" + "testing" +) + +func Test_rule_matches(t *testing.T) { + type fields struct { + Hostname string + Path *regexp.Regexp + Service OriginService + } + type args struct { + requestURL *url.URL + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "Just hostname, pass", + fields: fields{ + Hostname: "example.com", + }, + args: args{ + requestURL: MustParseURL(t, "https://example.com"), + }, + want: true, + }, + { + name: "Entire hostname is wildcard, should match everything", + fields: fields{ + Hostname: "*", + }, + args: args{ + requestURL: MustParseURL(t, "https://example.com"), + }, + want: true, + }, + { + name: "Just hostname, fail", + fields: fields{ + Hostname: "example.com", + }, + args: args{ + requestURL: MustParseURL(t, "https://foo.bar"), + }, + want: false, + }, + { + name: "Just wildcard hostname, pass", + fields: fields{ + Hostname: "*.example.com", + }, + args: args{ + requestURL: MustParseURL(t, "https://adam.example.com"), + }, + want: true, + }, + { + name: "Just wildcard hostname, fail", + fields: fields{ + Hostname: "*.example.com", + }, + args: args{ + requestURL: MustParseURL(t, "https://tunnel.com"), + }, + want: false, + }, + { + name: "Just wildcard outside of subdomain in hostname, fail", + fields: fields{ + Hostname: "*example.com", + }, + args: args{ + requestURL: MustParseURL(t, "https://www.example.com"), + }, + want: false, + }, + { + name: "Wildcard over multiple subdomains", + fields: fields{ + Hostname: "*.example.com", + }, + args: args{ + requestURL: MustParseURL(t, "https://adam.chalmers.example.com"), + }, + want: true, + }, + { + name: "Hostname and path", + fields: fields{ + Hostname: "*.example.com", + Path: regexp.MustCompile("/static/.*\\.html"), + }, + args: args{ + requestURL: MustParseURL(t, "https://www.example.com/static/index.html"), + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Rule{ + Hostname: tt.fields.Hostname, + Path: tt.fields.Path, + Service: tt.fields.Service, + } + u := tt.args.requestURL + if got := r.Matches(u.Hostname(), u.Path); got != tt.want { + t.Errorf("rule.matches() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go index f300bba6..7655ff10 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -30,7 +30,6 @@ import ( "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/websocket" ) @@ -57,16 +56,13 @@ const ( type TunnelConfig struct { BuildInfo *buildinfo.BuildInfo ClientID string - ClientTlsConfig *tls.Config CloseConnOnce *sync.Once // Used to close connectedSignal no more than once CompressionQuality uint64 EdgeAddrs []string GracePeriod time.Duration HAConnections int - HTTPTransport http.RoundTripper HeartbeatInterval time.Duration Hostname string - HTTPHostHeader string IncidentLookup IncidentLookup IsAutoupdated bool IsFreeTunnel bool @@ -76,7 +72,6 @@ type TunnelConfig struct { MaxHeartbeats uint64 Metrics *TunnelMetrics MetricsUpdateFreq time.Duration - NoChunkedEncoding bool OriginCert []byte ReportedVersion string Retries uint @@ -84,8 +79,6 @@ type TunnelConfig struct { Tags []tunnelpogs.Tag TlsConfig *tls.Config WSGI bool - // OriginUrl may not be used if a user specifies a unix socket. - OriginUrl string // feature-flag to use new edge reconnect tokens UseReconnectToken bool @@ -618,18 +611,13 @@ func LogServerInfo( } type TunnelHandler struct { - originUrl string - ingressRules ingress.Ingress - httpHostHeader string - muxer *h2mux.Muxer - httpClient http.RoundTripper - tlsConfig *tls.Config - tags []tunnelpogs.Tag - metrics *TunnelMetrics + ingressRules ingress.Ingress + muxer *h2mux.Muxer + tags []tunnelpogs.Tag + metrics *TunnelMetrics // connectionID is only used by metrics, and prometheus requires labels to be string - connectionID string - logger logger.Service - noChunkedEncoding bool + connectionID string + logger logger.Service bufferPool *buffer.Pool } @@ -642,31 +630,13 @@ func NewTunnelHandler(ctx context.Context, bufferPool *buffer.Pool, ) (*TunnelHandler, string, error) { - // Check single-origin config - var originURL string - var err error - if config.IngressRules.IsEmpty() { - originURL, err = validation.ValidateUrl(config.OriginUrl) - if err != nil { - return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL) - } - } - h := &TunnelHandler{ - originUrl: originURL, - ingressRules: config.IngressRules, - httpHostHeader: config.HTTPHostHeader, - httpClient: config.HTTPTransport, - tlsConfig: config.ClientTlsConfig, - tags: config.Tags, - metrics: config.Metrics, - connectionID: uint8ToString(connectionID), - logger: config.Logger, - noChunkedEncoding: config.NoChunkedEncoding, - bufferPool: bufferPool, - } - if h.httpClient == nil { - h.httpClient = http.DefaultTransport + ingressRules: config.IngressRules, + tags: config.Tags, + metrics: config.Metrics, + connectionID: uint8ToString(connectionID), + logger: config.Logger, + bufferPool: bufferPool, } edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) @@ -692,7 +662,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { h.metrics.incrementRequests(h.connectionID) defer h.metrics.decrementConcurrentRequests(h.connectionID) - req, reqErr := h.createRequest(stream) + req, rule, reqErr := h.createRequest(stream) if reqErr != nil { h.writeErrorResponse(stream, reqErr) return reqErr @@ -705,9 +675,9 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { var resp *http.Response var respErr error if websocket.IsWebSocketUpgrade(req) { - resp, respErr = h.serveWebsocket(stream, req) + resp, respErr = h.serveWebsocket(stream, req, rule) } else { - resp, respErr = h.serveHTTP(stream, req) + resp, respErr = h.serveHTTP(stream, req, rule) } if respErr != nil { h.writeErrorResponse(stream, respErr) @@ -717,32 +687,28 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { return nil } -func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) { - req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) +func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) { + req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { - return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") + return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest") } err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { - return nil, errors.Wrap(err, "invalid request received") + return nil, nil, errors.Wrap(err, "invalid request received") } h.AppendTagHeaders(req) - if !h.ingressRules.IsEmpty() { - ruleNumber := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) - destination := h.ingressRules.Rules[ruleNumber].Service - req.URL.Host = destination.Host - req.URL.Scheme = destination.Scheme - } - return req, nil + rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) + rule.Service.RewriteOriginURL(req.URL) + return req, rule, nil } -func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { - if h.httpHostHeader != "" { - req.Header.Set("Host", h.httpHostHeader) - req.Host = h.httpHostHeader +func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + req.Header.Set("Host", hostHeader) + req.Host = hostHeader } - conn, response, err := websocket.ClientConnect(req, h.tlsConfig) + conn, response, err := websocket.ClientConnect(req, rule.ClientTLSConfig) if err != nil { return nil, err } @@ -758,9 +724,9 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ return response, nil } -func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { +func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if h.noChunkedEncoding { + if rule.Config.DisableChunkedEncoding { req.TransferEncoding = []string{"gzip", "deflate"} cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) if err == nil { @@ -771,12 +737,12 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) // Request origin to keep connection alive to improve performance req.Header.Set("Connection", "keep-alive") - if h.httpHostHeader != "" { - req.Header.Set("Host", h.httpHostHeader) - req.Host = h.httpHostHeader + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + req.Header.Set("Host", hostHeader) + req.Host = hostHeader } - response, err := h.httpClient.RoundTrip(req) + response, err := rule.HTTPTransport.RoundTrip(req) if err != nil { return nil, errors.Wrap(err, "Error proxying request to origin") } diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go index 5830a4a4..041392bc 100644 --- a/tlsconfig/certreloader.go +++ b/tlsconfig/certreloader.go @@ -65,10 +65,9 @@ func (cr *CertReloader) LoadCert() error { return nil } -func LoadOriginCA(c *cli.Context, logger logger.Service) (*x509.CertPool, error) { +func LoadOriginCA(originCAPoolFilename string, logger logger.Service) (*x509.CertPool, error) { var originCustomCAPool []byte - originCAPoolFilename := c.String(OriginCAPoolFlag) if originCAPoolFilename != "" { var err error originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)