TUN-2640: Users can configure per-origin config. Unify single-rule CLI
flow with multi-rule config file code.
This commit is contained in:
		
							parent
							
								
									ea71b78e6d
								
							
						
					
					
						commit
						e933ef9e1a
					
				|  | @ -34,7 +34,12 @@ var ( | |||
| 	ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories()) | ||||
| ) | ||||
| 
 | ||||
| const DefaultCredentialFile = "cert.pem" | ||||
| const ( | ||||
| 	DefaultCredentialFile = "cert.pem" | ||||
| 
 | ||||
| 	// BastionFlag is to enable bastion, or jump host, operation
 | ||||
| 	BastionFlag = "bastion" | ||||
| ) | ||||
| 
 | ||||
| // DefaultConfigDirectory returns the default directory of the config file
 | ||||
| func DefaultConfigDirectory() string { | ||||
|  | @ -200,11 +205,55 @@ type UnvalidatedIngressRule struct { | |||
| 	Hostname      string | ||||
| 	Path          string | ||||
| 	Service       string | ||||
| 	OriginRequest OriginRequestConfig `yaml:"originRequest"` | ||||
| } | ||||
| 
 | ||||
| // OriginRequestConfig is a set of optional fields that users may set to
 | ||||
| // customize how cloudflared sends requests to origin services. It is used to set
 | ||||
| // up general config that apply to all rules, and also, specific per-rule
 | ||||
| // config.
 | ||||
| // Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
 | ||||
| type OriginRequestConfig struct { | ||||
| 	// HTTP proxy timeout for establishing a new connection
 | ||||
| 	ConnectTimeout *time.Duration `yaml:"connectTimeout"` | ||||
| 	// HTTP proxy timeout for completing a TLS handshake
 | ||||
| 	TLSTimeout *time.Duration `yaml:"tlsTimeout"` | ||||
| 	// HTTP proxy TCP keepalive duration
 | ||||
| 	TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"` | ||||
| 	// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
 | ||||
| 	NoHappyEyeballs *bool `yaml:"noHappyEyeballs"` | ||||
| 	// HTTP proxy maximum keepalive connection pool size
 | ||||
| 	KeepAliveConnections *int `yaml:"keepAliveConnections"` | ||||
| 	// HTTP proxy timeout for closing an idle connection
 | ||||
| 	KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"` | ||||
| 	// Sets the HTTP Host header for the local webserver.
 | ||||
| 	HTTPHostHeader *string `yaml:"httpHostHeader"` | ||||
| 	// Hostname on the origin server certificate.
 | ||||
| 	OriginServerName *string `yaml:"originServerName"` | ||||
| 	// Path to the CA for the certificate of your origin.
 | ||||
| 	// This option should be used only if your certificate is not signed by Cloudflare.
 | ||||
| 	CAPool *string `yaml:"caPool"` | ||||
| 	// Disables TLS verification of the certificate presented by your origin.
 | ||||
| 	// Will allow any certificate from the origin to be accepted.
 | ||||
| 	// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
 | ||||
| 	NoTLSVerify *bool `yaml:"noTLSVerify"` | ||||
| 	// Disables chunked transfer encoding.
 | ||||
| 	// Useful if you are running a WSGI server.
 | ||||
| 	DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"` | ||||
| 	// Runs as jump host
 | ||||
| 	BastionMode *bool `yaml:"bastionMode"` | ||||
| 	// Listen address for the proxy.
 | ||||
| 	ProxyAddress *string `yaml:"proxyAddress"` | ||||
| 	// Listen port for the proxy.
 | ||||
| 	ProxyPort *uint `yaml:"proxyPort"` | ||||
| 	// Valid options are 'socks', 'ssh' or empty.
 | ||||
| 	ProxyType *string `yaml:"proxyType"` | ||||
| } | ||||
| 
 | ||||
| type Configuration struct { | ||||
| 	TunnelID      string `yaml:"tunnel"` | ||||
| 	Ingress       []UnvalidatedIngressRule | ||||
| 	OriginRequest OriginRequestConfig `yaml:"originRequest"` | ||||
| 	sourceFile    string | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,43 +5,32 @@ import ( | |||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"runtime" | ||||
| 	"runtime/trace" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/awsuploader" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/config" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/ui" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/updater" | ||||
| 	"github.com/cloudflare/cloudflared/dbconnect" | ||||
| 	"github.com/cloudflare/cloudflared/h2mux" | ||||
| 	"github.com/cloudflare/cloudflared/hello" | ||||
| 	"github.com/cloudflare/cloudflared/ingress" | ||||
| 	"github.com/cloudflare/cloudflared/logger" | ||||
| 	"github.com/cloudflare/cloudflared/metrics" | ||||
| 	"github.com/cloudflare/cloudflared/origin" | ||||
| 	"github.com/cloudflare/cloudflared/signal" | ||||
| 	"github.com/cloudflare/cloudflared/socks" | ||||
| 	"github.com/cloudflare/cloudflared/sshlog" | ||||
| 	"github.com/cloudflare/cloudflared/sshserver" | ||||
| 	"github.com/cloudflare/cloudflared/tlsconfig" | ||||
| 	"github.com/cloudflare/cloudflared/tunneldns" | ||||
| 	"github.com/cloudflare/cloudflared/tunnelstore" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 
 | ||||
| 	"github.com/coreos/go-systemd/daemon" | ||||
| 	"github.com/facebookgo/grace/gracenet" | ||||
| 	"github.com/getsentry/raven-go" | ||||
| 	"github.com/gliderlabs/ssh" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/mitchellh/go-homedir" | ||||
| 	"github.com/pkg/errors" | ||||
|  | @ -84,15 +73,6 @@ const ( | |||
| 	// hostKeyPath is the path of the dir to save SSH host keys too
 | ||||
| 	hostKeyPath = "host-key-path" | ||||
| 
 | ||||
| 	//sshServerFlag enables cloudflared ssh proxy server
 | ||||
| 	sshServerFlag = "ssh-server" | ||||
| 
 | ||||
| 	// socks5Flag is to enable the socks server to deframe
 | ||||
| 	socks5Flag = "socks5" | ||||
| 
 | ||||
| 	// bastionFlag is to enable bastion, or jump host, operation
 | ||||
| 	bastionFlag = "bastion" | ||||
| 
 | ||||
| 	// uiFlag is to enable launching cloudflared in interactive UI mode
 | ||||
| 	uiFlag = "ui" | ||||
| 
 | ||||
|  | @ -373,72 +353,6 @@ func StartServer( | |||
| 		return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0, log) | ||||
| 	} | ||||
| 
 | ||||
| 	if c.IsSet("hello-world") { | ||||
| 		log.Infof("hello-world set") | ||||
| 		helloListener, err := hello.CreateTLSListener("127.0.0.1:") | ||||
| 		if err != nil { | ||||
| 			log.Errorf("Cannot start Hello World Server: %s", err) | ||||
| 			return errors.Wrap(err, "Cannot start Hello World Server") | ||||
| 		} | ||||
| 		defer helloListener.Close() | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			_ = hello.StartHelloWorldServer(log, helloListener, shutdownC) | ||||
| 		}() | ||||
| 		forceSetFlag(c, "url", "https://"+helloListener.Addr().String()) | ||||
| 	} | ||||
| 
 | ||||
| 	if c.IsSet(sshServerFlag) { | ||||
| 		if runtime.GOOS != "darwin" && runtime.GOOS != "linux" { | ||||
| 			msg := fmt.Sprintf("--ssh-server is not supported on %s", runtime.GOOS) | ||||
| 			log.Error(msg) | ||||
| 			return errors.New(msg) | ||||
| 		} | ||||
| 
 | ||||
| 		log.Infof("ssh-server set") | ||||
| 
 | ||||
| 		logManager := sshlog.NewEmptyManager() | ||||
| 		if c.IsSet(bucketNameFlag) && c.IsSet(regionNameFlag) && c.IsSet(accessKeyIDFlag) && c.IsSet(secretIDFlag) { | ||||
| 			uploader, err := awsuploader.NewFileUploader(c.String(bucketNameFlag), c.String(regionNameFlag), | ||||
| 				c.String(accessKeyIDFlag), c.String(secretIDFlag), c.String(sessionTokenIDFlag), c.String(s3URLFlag)) | ||||
| 			if err != nil { | ||||
| 				msg := "Cannot create uploader for SSH Server" | ||||
| 				log.Errorf("%s: %s", msg, err) | ||||
| 				return errors.Wrap(err, msg) | ||||
| 			} | ||||
| 
 | ||||
| 			if err := os.MkdirAll(sshLogFileDirectory, 0700); err != nil { | ||||
| 				msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory) | ||||
| 				log.Errorf("%s: %s", msg, err) | ||||
| 				return errors.Wrap(err, msg) | ||||
| 			} | ||||
| 
 | ||||
| 			logManager = sshlog.New(sshLogFileDirectory) | ||||
| 
 | ||||
| 			uploadManager := awsuploader.NewDirectoryUploadManager(log, uploader, sshLogFileDirectory, 30*time.Minute, shutdownC) | ||||
| 			uploadManager.Start() | ||||
| 		} | ||||
| 
 | ||||
| 		localServerAddress := "127.0.0.1:" + c.String(sshPortFlag) | ||||
| 		server, err := sshserver.New(logManager, log, version, localServerAddress, c.String("hostname"), c.Path(hostKeyPath), shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) | ||||
| 		if err != nil { | ||||
| 			msg := "Cannot create new SSH Server" | ||||
| 			log.Errorf("%s: %s", msg, err) | ||||
| 			return errors.Wrap(err, msg) | ||||
| 		} | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			if err = server.Start(); err != nil && err != ssh.ErrServerClosed { | ||||
| 				log.Errorf("SSH server error: %s", err) | ||||
| 				// TODO: remove when declarative tunnels are implemented.
 | ||||
| 				close(shutdownC) | ||||
| 			} | ||||
| 		}() | ||||
| 		forceSetFlag(c, "url", "ssh://"+localServerAddress) | ||||
| 	} | ||||
| 
 | ||||
| 	url := c.String("url") | ||||
| 	hostname := c.String("hostname") | ||||
| 	if url == hostname && url != "" && hostname != "" { | ||||
|  | @ -447,42 +361,6 @@ func StartServer( | |||
| 		return fmt.Errorf(errText) | ||||
| 	} | ||||
| 
 | ||||
| 	if staticHost := hostnameFromURI(c.String("url")); isProxyDestinationConfigured(staticHost, c) { | ||||
| 		listener, err := net.Listen("tcp", net.JoinHostPort(c.String("proxy-address"), strconv.Itoa(c.Int("proxy-port")))) | ||||
| 		if err != nil { | ||||
| 			log.Errorf("Cannot start Websocket Proxy Server: %s", err) | ||||
| 			return errors.Wrap(err, "Cannot start Websocket Proxy Server") | ||||
| 		} | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			streamHandler := websocket.DefaultStreamHandler | ||||
| 			if c.IsSet(socks5Flag) { | ||||
| 				log.Info("SOCKS5 server started") | ||||
| 				streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) { | ||||
| 					dialer := socks.NewConnDialer(remoteConn) | ||||
| 					requestHandler := socks.NewRequestHandler(dialer) | ||||
| 					socksServer := socks.NewConnectionHandler(requestHandler) | ||||
| 
 | ||||
| 					socksServer.Serve(wsConn) | ||||
| 				} | ||||
| 			} else if c.IsSet(sshServerFlag) { | ||||
| 				streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, requestHeaders http.Header) { | ||||
| 					if finalDestination := requestHeaders.Get(h2mux.CFJumpDestinationHeader); finalDestination != "" { | ||||
| 						token := requestHeaders.Get(h2mux.CFAccessTokenHeader) | ||||
| 						if err := websocket.SendSSHPreamble(remoteConn, finalDestination, token); err != nil { | ||||
| 							log.Errorf("Failed to send SSH preamble: %s", err) | ||||
| 							return | ||||
| 						} | ||||
| 					} | ||||
| 					websocket.DefaultStreamHandler(wsConn, remoteConn, requestHeaders) | ||||
| 				} | ||||
| 			} | ||||
| 			errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler) | ||||
| 		}() | ||||
| 		forceSetFlag(c, "url", "http://"+listener.Addr().String()) | ||||
| 	} | ||||
| 
 | ||||
| 	transportLogger, err := createLogger(c, true, false) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "error setting up transport logger") | ||||
|  | @ -493,6 +371,8 @@ func StartServer( | |||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	tunnelConfig.IngressRules.StartOrigins(&wg, log, shutdownC, errC) | ||||
| 
 | ||||
| 	reconnectCh := make(chan origin.ReconnectSignal, 1) | ||||
| 	if c.IsSet("stdin-control") { | ||||
| 		log.Info("Enabling control through stdin") | ||||
|  | @ -514,7 +394,8 @@ func StartServer( | |||
| 			version, | ||||
| 			hostname, | ||||
| 			metricsListener.Addr().String(), | ||||
| 			tunnelConfig.OriginUrl, | ||||
| 			// TODO (TUN-3461): Update UI to show multiple origin URLs
 | ||||
| 			tunnelConfig.IngressRules.CatchAll().Service.Address(), | ||||
| 			tunnelConfig.HAConnections, | ||||
| 		) | ||||
| 		logLevels, err := logger.ParseLevelString(c.String("loglevel")) | ||||
|  | @ -559,11 +440,6 @@ func SetFlagsFromConfigFile(c *cli.Context) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // isProxyDestinationConfigured returns true if there is a static host set or if bastion mode is set.
 | ||||
| func isProxyDestinationConfigured(staticHost string, c *cli.Context) bool { | ||||
| 	return staticHost != "" || c.IsSet(bastionFlag) | ||||
| } | ||||
| 
 | ||||
| func waitToShutdown(wg *sync.WaitGroup, | ||||
| 	errC chan error, | ||||
| 	shutdownC, graceShutdownC chan struct{}, | ||||
|  | @ -910,67 +786,67 @@ func configureProxyFlags(shouldHide bool) []cli.Flag { | |||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||
| 			Name:    socks5Flag, | ||||
| 			Name:    ingress.Socks5Flag, | ||||
| 			Usage:   "specify if this tunnel is running as a SOCK5 Server", | ||||
| 			EnvVars: []string{"TUNNEL_SOCKS"}, | ||||
| 			Value:   false, | ||||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||
| 			Name:   "proxy-connect-timeout", | ||||
| 			Name:   ingress.ProxyConnectTimeoutFlag, | ||||
| 			Usage:  "HTTP proxy timeout for establishing a new connection", | ||||
| 			Value:  time.Second * 30, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||
| 			Name:   "proxy-tls-timeout", | ||||
| 			Name:   ingress.ProxyTLSTimeoutFlag, | ||||
| 			Usage:  "HTTP proxy timeout for completing a TLS handshake", | ||||
| 			Value:  time.Second * 10, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||
| 			Name:   "proxy-tcp-keepalive", | ||||
| 			Name:   ingress.ProxyTCPKeepAlive, | ||||
| 			Usage:  "HTTP proxy TCP keepalive duration", | ||||
| 			Value:  time.Second * 30, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||
| 			Name:   "proxy-no-happy-eyeballs", | ||||
| 			Name:   ingress.ProxyNoHappyEyeballsFlag, | ||||
| 			Usage:  "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback", | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewIntFlag(&cli.IntFlag{ | ||||
| 			Name:   "proxy-keepalive-connections", | ||||
| 			Name:   ingress.ProxyKeepAliveConnectionsFlag, | ||||
| 			Usage:  "HTTP proxy maximum keepalive connection pool size", | ||||
| 			Value:  100, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||
| 			Name:   "proxy-keepalive-timeout", | ||||
| 			Name:   ingress.ProxyKeepAliveTimeoutFlag, | ||||
| 			Usage:  "HTTP proxy timeout for closing an idle connection", | ||||
| 			Value:  time.Second * 90, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||
| 			Name:   "proxy-connection-timeout", | ||||
| 			Usage:  "HTTP proxy timeout for closing an idle connection", | ||||
| 			Usage:  "DEPRECATED. No longer has any effect.", | ||||
| 			Value:  time.Second * 90, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||
| 			Name:   "proxy-expect-continue-timeout", | ||||
| 			Usage:  "HTTP proxy timeout for closing an idle connection", | ||||
| 			Usage:  "DEPRECATED. No longer has any effect.", | ||||
| 			Value:  time.Second * 90, | ||||
| 			Hidden: shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||
| 			Name:    "http-host-header", | ||||
| 			Name:    ingress.HTTPHostHeaderFlag, | ||||
| 			Usage:   "Sets the HTTP Host header for the local webserver.", | ||||
| 			EnvVars: []string{"TUNNEL_HTTP_HOST_HEADER"}, | ||||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||
| 			Name:    "origin-server-name", | ||||
| 			Name:    ingress.OriginServerNameFlag, | ||||
| 			Usage:   "Hostname on the origin server certificate.", | ||||
| 			EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"}, | ||||
| 			Hidden:  shouldHide, | ||||
|  | @ -988,13 +864,13 @@ func configureProxyFlags(shouldHide bool) []cli.Flag { | |||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||
| 			Name:    "no-tls-verify", | ||||
| 			Name:    ingress.NoTLSVerifyFlag, | ||||
| 			Usage:   "Disables TLS verification of the certificate presented by your origin. Will allow any certificate from the origin to be accepted. Note: The connection from your machine to Cloudflare's Edge is still encrypted.", | ||||
| 			EnvVars: []string{"NO_TLS_VERIFY"}, | ||||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||
| 			Name:    "no-chunked-encoding", | ||||
| 			Name:    ingress.NoChunkedEncodingFlag, | ||||
| 			Usage:   "Disables chunked transfer encoding; useful if you are running a WSGI server.", | ||||
| 			EnvVars: []string{"TUNNEL_NO_CHUNKED_ENCODING"}, | ||||
| 			Hidden:  shouldHide, | ||||
|  | @ -1067,28 +943,28 @@ func sshFlags(shouldHide bool) []cli.Flag { | |||
| 			Hidden:  true, | ||||
| 		}), | ||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||
| 			Name:    sshServerFlag, | ||||
| 			Name:    ingress.SSHServerFlag, | ||||
| 			Value:   false, | ||||
| 			Usage:   "Run an SSH Server", | ||||
| 			EnvVars: []string{"TUNNEL_SSH_SERVER"}, | ||||
| 			Hidden:  true, // TODO: remove when feature is complete
 | ||||
| 		}), | ||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||
| 			Name:    bastionFlag, | ||||
| 			Name:    config.BastionFlag, | ||||
| 			Value:   false, | ||||
| 			Usage:   "Runs as jump host", | ||||
| 			EnvVars: []string{"TUNNEL_BASTION"}, | ||||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||
| 			Name:    "proxy-address", | ||||
| 			Name:    ingress.ProxyAddressFlag, | ||||
| 			Usage:   "Listen address for the proxy.", | ||||
| 			Value:   "127.0.0.1", | ||||
| 			EnvVars: []string{"TUNNEL_PROXY_ADDRESS"}, | ||||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewIntFlag(&cli.IntFlag{ | ||||
| 			Name:    "proxy-port", | ||||
| 			Name:    ingress.ProxyPortFlag, | ||||
| 			Usage:   "Listen port for the proxy.", | ||||
| 			Value:   0, | ||||
| 			EnvVars: []string{"TUNNEL_PROXY_PORT"}, | ||||
|  |  | |||
|  | @ -1,16 +1,11 @@ | |||
| package tunnel | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/config" | ||||
|  | @ -193,31 +188,7 @@ func prepareTunnelConfig( | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	originCertPool, err := tlsconfig.LoadOriginCA(c, logger) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error loading cert pool: %s", err) | ||||
| 		return nil, errors.Wrap(err, "Error loading cert pool") | ||||
| 	} | ||||
| 
 | ||||
| 	tunnelMetrics := origin.NewTunnelMetrics() | ||||
| 	httpTransport := &http.Transport{ | ||||
| 		Proxy:                 http.ProxyFromEnvironment, | ||||
| 		MaxIdleConns:          c.Int("proxy-keepalive-connections"), | ||||
| 		MaxIdleConnsPerHost:   c.Int("proxy-keepalive-connections"), | ||||
| 		IdleConnTimeout:       c.Duration("proxy-keepalive-timeout"), | ||||
| 		TLSHandshakeTimeout:   c.Duration("proxy-tls-timeout"), | ||||
| 		ExpectContinueTimeout: 1 * time.Second, | ||||
| 		TLSClientConfig:       &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")}, | ||||
| 	} | ||||
| 
 | ||||
| 	dialer := &net.Dialer{ | ||||
| 		Timeout:   c.Duration("proxy-connect-timeout"), | ||||
| 		KeepAlive: c.Duration("proxy-tcp-keepalive"), | ||||
| 	} | ||||
| 	if c.Bool("proxy-no-happy-eyeballs") { | ||||
| 		dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
 | ||||
| 	} | ||||
| 	dialContext := dialer.DialContext | ||||
| 
 | ||||
| 	var ingressRules ingress.Ingress | ||||
| 	if namedTunnel != nil { | ||||
|  | @ -231,7 +202,7 @@ func prepareTunnelConfig( | |||
| 			Version:  version, | ||||
| 			Arch:     fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch), | ||||
| 		} | ||||
| 		ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) | ||||
| 		ingressRules, err = ingress.ParseIngress(config.GetConfiguration(), logger) | ||||
| 		if err != nil && err != ingress.ErrNoIngressRules { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | @ -240,53 +211,11 @@ func prepareTunnelConfig( | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	var originURL string | ||||
| 	// Convert single-origin configuration into multi-origin configuration.
 | ||||
| 	if ingressRules.IsEmpty() { | ||||
| 		originURL, err = config.ValidateUrl(c, compatibilityMode) | ||||
| 		ingressRules, err = ingress.NewSingleOrigin(c, compatibilityMode, logger) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("Error validating origin URL: %s", err) | ||||
| 			return nil, errors.Wrap(err, "Error validating origin URL") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if c.IsSet("unix-socket") { | ||||
| 		unixSocket, err := config.ValidateUnixSocket(c) | ||||
| 		if err != nil { | ||||
| 			logger.Errorf("Error validating --unix-socket: %s", err) | ||||
| 			return nil, errors.Wrap(err, "Error validating --unix-socket") | ||||
| 		} | ||||
| 
 | ||||
| 		logger.Infof("Proxying tunnel requests to unix:%s", unixSocket) | ||||
| 		httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { | ||||
| 			// if --unix-socket specified, enforce network type "unix"
 | ||||
| 			return dialContext(ctx, "unix", unixSocket) | ||||
| 		} | ||||
| 	} else { | ||||
| 		logger.Infof("Proxying tunnel requests to %s", originURL) | ||||
| 		httpTransport.DialContext = dialContext | ||||
| 	} | ||||
| 
 | ||||
| 	if !c.IsSet("hello-world") && c.IsSet("origin-server-name") { | ||||
| 		httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") | ||||
| 	} | ||||
| 	// If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client.
 | ||||
| 	if !c.IsSet(bastionFlag) { | ||||
| 
 | ||||
| 		// List all origin URLs that require validation
 | ||||
| 		var originURLs []string | ||||
| 		if ingressRules.IsEmpty() { | ||||
| 			originURLs = append(originURLs, originURL) | ||||
| 		} else { | ||||
| 			for _, rule := range ingressRules.Rules { | ||||
| 				originURLs = append(originURLs, rule.Service.String()) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		// Validate each origin URL
 | ||||
| 		for _, u := range originURLs { | ||||
| 			if err = validation.ValidateHTTPService(u, hostname, httpTransport); err != nil { | ||||
| 				logger.Errorf("unable to connect to the origin: %s", err) | ||||
| 			} | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
|  | @ -298,15 +227,12 @@ func prepareTunnelConfig( | |||
| 	return &origin.TunnelConfig{ | ||||
| 		BuildInfo:          buildInfo, | ||||
| 		ClientID:           clientID, | ||||
| 		ClientTlsConfig:    httpTransport.TLSClientConfig, | ||||
| 		CompressionQuality: c.Uint64("compression-quality"), | ||||
| 		EdgeAddrs:          c.StringSlice("edge"), | ||||
| 		GracePeriod:        c.Duration("grace-period"), | ||||
| 		HAConnections:      c.Int("ha-connections"), | ||||
| 		HTTPTransport:      httpTransport, | ||||
| 		HeartbeatInterval:  c.Duration("heartbeat-interval"), | ||||
| 		Hostname:           hostname, | ||||
| 		HTTPHostHeader:     c.String("http-host-header"), | ||||
| 		IncidentLookup:     origin.NewIncidentLookup(), | ||||
| 		IsAutoupdated:      c.Bool("is-autoupdated"), | ||||
| 		IsFreeTunnel:       isFreeTunnel, | ||||
|  | @ -316,9 +242,7 @@ func prepareTunnelConfig( | |||
| 		MaxHeartbeats:      c.Uint64("heartbeat-count"), | ||||
| 		Metrics:            tunnelMetrics, | ||||
| 		MetricsUpdateFreq:  c.Duration("metrics-update-freq"), | ||||
| 		NoChunkedEncoding:  c.Bool("no-chunked-encoding"), | ||||
| 		OriginCert:         originCert, | ||||
| 		OriginUrl:          originURL, | ||||
| 		ReportedVersion:    version, | ||||
| 		Retries:            c.Uint("retries"), | ||||
| 		RunFromTerminal:    isRunningFromTerminal(), | ||||
|  |  | |||
|  | @ -71,7 +71,7 @@ func buildTestURLCommand() *cli.Command { | |||
| func validateIngressCommand(c *cli.Context) error { | ||||
| 	conf := config.GetConfiguration() | ||||
| 	fmt.Println("Validating rules from", conf.Source()) | ||||
| 	if _, err := ingress.ParseIngress(conf); err != nil { | ||||
| 	if _, err := ingress.ParseIngressDryRun(conf); err != nil { | ||||
| 		return errors.Wrap(err, "Validation failed") | ||||
| 	} | ||||
| 	if c.IsSet("url") { | ||||
|  | @ -98,12 +98,12 @@ func testURLCommand(c *cli.Context) error { | |||
| 
 | ||||
| 	conf := config.GetConfiguration() | ||||
| 	fmt.Println("Using rules from", conf.Source()) | ||||
| 	ing, err := ingress.ParseIngress(conf) | ||||
| 	ing, err := ingress.ParseIngressDryRun(conf) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Validation failed") | ||||
| 	} | ||||
| 
 | ||||
| 	i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path) | ||||
| 	_, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path) | ||||
| 	fmt.Printf("Matched rule #%d\n", i+1) | ||||
| 	fmt.Println(ing.Rules[i].MultiLineString()) | ||||
| 	return nil | ||||
|  |  | |||
|  | @ -1,14 +1,24 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/urfave/cli/v2" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/config" | ||||
| 	"github.com/cloudflare/cloudflared/logger" | ||||
| 	"github.com/cloudflare/cloudflared/tlsconfig" | ||||
| 	"github.com/cloudflare/cloudflared/validation" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
|  | @ -18,54 +28,93 @@ var ( | |||
| 	ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules") | ||||
| ) | ||||
| 
 | ||||
| // Each rule route traffic from a hostname/path on the public
 | ||||
| // internet to the service running on the given URL.
 | ||||
| type Rule struct { | ||||
| 	// Requests for this hostname will be proxied to this rule's service.
 | ||||
| 	Hostname string | ||||
| 
 | ||||
| 	// Path is an optional regex that can specify path-driven ingress rules.
 | ||||
| 	Path *regexp.Regexp | ||||
| 
 | ||||
| 	// A (probably local) address. Requests for a hostname which matches this
 | ||||
| 	// rule's hostname pattern will be proxied to the service running on this
 | ||||
| 	// address.
 | ||||
| 	Service *url.URL | ||||
| // Finalize the rules by adding missing struct fields and validating each origin.
 | ||||
| func (ing *Ingress) setHTTPTransport(logger logger.Service) error { | ||||
| 	for ruleNumber, rule := range ing.Rules { | ||||
| 		cfg := rule.Config | ||||
| 		originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil) | ||||
| 		if err != nil { | ||||
| 			return errors.Wrap(err, "Error loading cert pool") | ||||
| 		} | ||||
| 
 | ||||
| func (r Rule) MultiLineString() string { | ||||
| 	var out strings.Builder | ||||
| 	if r.Hostname != "" { | ||||
| 		out.WriteString("\thostname: ") | ||||
| 		out.WriteString(r.Hostname) | ||||
| 		out.WriteRune('\n') | ||||
| 		httpTransport := &http.Transport{ | ||||
| 			Proxy:                 http.ProxyFromEnvironment, | ||||
| 			MaxIdleConns:          cfg.KeepAliveConnections, | ||||
| 			MaxIdleConnsPerHost:   cfg.KeepAliveConnections, | ||||
| 			IdleConnTimeout:       cfg.KeepAliveTimeout, | ||||
| 			TLSHandshakeTimeout:   cfg.TLSTimeout, | ||||
| 			ExpectContinueTimeout: 1 * time.Second, | ||||
| 			TLSClientConfig:       &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify}, | ||||
| 		} | ||||
| 	if r.Path != nil { | ||||
| 		out.WriteString("\tpath: ") | ||||
| 		out.WriteString(r.Path.String()) | ||||
| 		out.WriteRune('\n') | ||||
| 	} | ||||
| 	out.WriteString("\tservice: ") | ||||
| 	out.WriteString(r.Service.String()) | ||||
| 	return out.String() | ||||
| 		if _, isHelloWorld := rule.Service.(*HelloWorld); !isHelloWorld && cfg.OriginServerName != "" { | ||||
| 			httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName | ||||
| 		} | ||||
| 
 | ||||
| func (r *Rule) Matches(hostname, path string) bool { | ||||
| 	hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname) | ||||
| 	pathMatch := r.Path == nil || r.Path.MatchString(path) | ||||
| 	return hostMatch && pathMatch | ||||
| 		dialer := &net.Dialer{ | ||||
| 			Timeout:   cfg.ConnectTimeout, | ||||
| 			KeepAlive: cfg.TCPKeepAlive, | ||||
| 		} | ||||
| 		if cfg.NoHappyEyeballs { | ||||
| 			dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
 | ||||
| 		} | ||||
| 
 | ||||
| 		// DialContext depends on which kind of origin is being used.
 | ||||
| 		dialContext := dialer.DialContext | ||||
| 		switch service := rule.Service.(type) { | ||||
| 
 | ||||
| 		// If this origin is a unix socket, enforce network type "unix".
 | ||||
| 		case UnixSocketPath: | ||||
| 			httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { | ||||
| 				return dialContext(ctx, "unix", service.Address()) | ||||
| 			} | ||||
| 		// Otherwise, use the regular network config.
 | ||||
| 		default: | ||||
| 			httpTransport.DialContext = dialContext | ||||
| 		} | ||||
| 
 | ||||
| 		ing.Rules[ruleNumber].HTTPTransport = httpTransport | ||||
| 		ing.Rules[ruleNumber].ClientTLSConfig = httpTransport.TLSClientConfig | ||||
| 	} | ||||
| 
 | ||||
| 	// Validate each origin
 | ||||
| 	for _, rule := range ing.Rules { | ||||
| 		// If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client.
 | ||||
| 		if rule.Config.BastionMode { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		// Unix sockets don't have validation
 | ||||
| 		if _, ok := rule.Service.(UnixSocketPath); ok { | ||||
| 			continue | ||||
| 		} | ||||
| 		switch service := rule.Service.(type) { | ||||
| 
 | ||||
| 		case UnixSocketPath: | ||||
| 			continue | ||||
| 
 | ||||
| 		case *HelloWorld: | ||||
| 			continue | ||||
| 
 | ||||
| 		default: | ||||
| 			if err := validation.ValidateHTTPService(service.Address(), rule.Hostname, rule.HTTPTransport); err != nil { | ||||
| 				logger.Errorf("unable to connect to the origin: %s", err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // FindMatchingRule returns the index of the Ingress Rule which matches the given
 | ||||
| // hostname and path. This function assumes the last rule matches everything,
 | ||||
| // which is the case if the rules were instantiated via the ingress#Validate method
 | ||||
| func (ing Ingress) FindMatchingRule(hostname, path string) int { | ||||
| func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) { | ||||
| 	for i, rule := range ing.Rules { | ||||
| 		if rule.Matches(hostname, path) { | ||||
| 			return i | ||||
| 			return &rule, i | ||||
| 		} | ||||
| 	} | ||||
| 	return len(ing.Rules) - 1 | ||||
| 	i := len(ing.Rules) - 1 | ||||
| 	return &ing.Rules[i], i | ||||
| } | ||||
| 
 | ||||
| func matchHost(ruleHost, reqHost string) bool { | ||||
|  | @ -84,6 +133,55 @@ func matchHost(ruleHost, reqHost string) bool { | |||
| // Ingress maps eyeball requests to origins.
 | ||||
| type Ingress struct { | ||||
| 	Rules    []Rule | ||||
| 	defaults OriginRequestConfig | ||||
| } | ||||
| 
 | ||||
| // NewSingleOrigin constructs an Ingress set with only one rule, constructed from
 | ||||
| // legacy CLI parameters like --url or --no-chunked-encoding.
 | ||||
| func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Service) (Ingress, error) { | ||||
| 
 | ||||
| 	service, err := parseSingleOriginService(c, compatibilityMode) | ||||
| 	if err != nil { | ||||
| 		return Ingress{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Construct an Ingress with the single rule.
 | ||||
| 	ing := Ingress{ | ||||
| 		Rules: []Rule{ | ||||
| 			{ | ||||
| 				Service: service, | ||||
| 			}, | ||||
| 		}, | ||||
| 		defaults: originRequestFromSingeRule(c), | ||||
| 	} | ||||
| 	err = ing.setHTTPTransport(logger) | ||||
| 	return ing, err | ||||
| } | ||||
| 
 | ||||
| // Get a single origin service from the CLI/config.
 | ||||
| func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) { | ||||
| 	if c.IsSet("hello-world") { | ||||
| 		return new(HelloWorld), nil | ||||
| 	} | ||||
| 	if c.IsSet("url") { | ||||
| 		originURLStr, err := config.ValidateUrl(c, compatibilityMode) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "Error validating origin URL") | ||||
| 		} | ||||
| 		originURL, err := url.Parse(originURLStr) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "couldn't parse origin URL") | ||||
| 		} | ||||
| 		return &URL{URL: originURL, RootURL: originURL}, nil | ||||
| 	} | ||||
| 	if c.IsSet("unix-socket") { | ||||
| 		unixSocket, err := config.ValidateUnixSocket(c) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "Error validating --unix-socket") | ||||
| 		} | ||||
| 		return UnixSocketPath(unixSocket), nil | ||||
| 	} | ||||
| 	return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket") | ||||
| } | ||||
| 
 | ||||
| // IsEmpty checks if there are any ingress rules.
 | ||||
|  | @ -91,19 +189,47 @@ func (ing Ingress) IsEmpty() bool { | |||
| 	return len(ing.Rules) == 0 | ||||
| } | ||||
| 
 | ||||
| func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { | ||||
| // StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World.
 | ||||
| func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error { | ||||
| 	for _, rule := range ing.Rules { | ||||
| 		if err := rule.Service.Start(wg, log, shutdownC, errC, rule.Config); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // CatchAll returns the catch-all rule (i.e. the last rule)
 | ||||
| func (ing Ingress) CatchAll() *Rule { | ||||
| 	return &ing.Rules[len(ing.Rules)-1] | ||||
| } | ||||
| 
 | ||||
| func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestConfig) (Ingress, error) { | ||||
| 	rules := make([]Rule, len(ingress)) | ||||
| 	for i, r := range ingress { | ||||
| 		service, err := url.Parse(r.Service) | ||||
| 		var service OriginService | ||||
| 
 | ||||
| 		if strings.HasPrefix(r.Service, "unix:") { | ||||
| 			// No validation necessary for unix socket filepath services
 | ||||
| 			service = UnixSocketPath(strings.TrimPrefix(r.Service, "unix:")) | ||||
| 		} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" { | ||||
| 			service = new(HelloWorld) | ||||
| 		} else { | ||||
| 			// Validate URL services
 | ||||
| 			u, err := url.Parse(r.Service) | ||||
| 			if err != nil { | ||||
| 				return Ingress{}, err | ||||
| 			} | ||||
| 		if service.Scheme == "" || service.Hostname() == "" { | ||||
| 
 | ||||
| 			if u.Scheme == "" || u.Hostname() == "" { | ||||
| 				return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service) | ||||
| 			} | ||||
| 
 | ||||
| 		if service.Path != "" { | ||||
| 			return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path.", r.Service) | ||||
| 			if u.Path != "" { | ||||
| 				return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service) | ||||
| 			} | ||||
| 			serviceURL := URL{URL: u} | ||||
| 			service = &serviceURL | ||||
| 		} | ||||
| 
 | ||||
| 		// Ensure that there are no wildcards anywhere except the first character
 | ||||
|  | @ -125,6 +251,7 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { | |||
| 
 | ||||
| 		var pathRegex *regexp.Regexp | ||||
| 		if r.Path != "" { | ||||
| 			var err error | ||||
| 			pathRegex, err = regexp.Compile(r.Path) | ||||
| 			if err != nil { | ||||
| 				return Ingress{}, errors.Wrapf(err, "Rule #%d has an invalid regex", i+1) | ||||
|  | @ -135,9 +262,10 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) { | |||
| 			Hostname: r.Hostname, | ||||
| 			Service:  service, | ||||
| 			Path:     pathRegex, | ||||
| 			Config:   SetConfig(defaults, r.OriginRequest), | ||||
| 		} | ||||
| 	} | ||||
| 	return Ingress{Rules: rules}, nil | ||||
| 	return Ingress{Rules: rules, defaults: defaults}, nil | ||||
| } | ||||
| 
 | ||||
| type errRuleShouldNotBeCatchAll struct { | ||||
|  | @ -151,9 +279,20 @@ func (e errRuleShouldNotBeCatchAll) Error() string { | |||
| 		"will never be triggered.", e.i+1, e.hostname) | ||||
| } | ||||
| 
 | ||||
| func ParseIngress(conf *config.Configuration) (Ingress, error) { | ||||
| // ParseIngress parses, validates and initializes HTTP transports to each origin.
 | ||||
| func ParseIngress(conf *config.Configuration, logger logger.Service) (Ingress, error) { | ||||
| 	ing, err := ParseIngressDryRun(conf) | ||||
| 	if err != nil { | ||||
| 		return Ingress{}, err | ||||
| 	} | ||||
| 	err = ing.setHTTPTransport(logger) | ||||
| 	return ing, err | ||||
| } | ||||
| 
 | ||||
| // ParseIngressDryRun parses ingress rules, but does not send HTTP requests to the origins.
 | ||||
| func ParseIngressDryRun(conf *config.Configuration) (Ingress, error) { | ||||
| 	if len(conf.Ingress) == 0 { | ||||
| 		return Ingress{}, ErrNoIngressRules | ||||
| 	} | ||||
| 	return validate(conf.Ingress) | ||||
| 	return validate(conf.Ingress, OriginRequestFromYAML(conf.OriginRequest)) | ||||
| } | ||||
|  |  | |||
|  | @ -2,7 +2,6 @@ package ingress | |||
| 
 | ||||
| import ( | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
|  | @ -12,16 +11,29 @@ import ( | |||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/config" | ||||
| ) | ||||
| 
 | ||||
| func TestParseUnixSocket(t *testing.T) { | ||||
| 	rawYAML := ` | ||||
| ingress: | ||||
| - service: unix:/tmp/echo.sock | ||||
| ` | ||||
| 	ing, err := ParseIngressDryRun(MustReadIngress(rawYAML)) | ||||
| 	require.NoError(t, err) | ||||
| 	_, ok := ing.Rules[0].Service.(UnixSocketPath) | ||||
| 	require.True(t, ok) | ||||
| } | ||||
| 
 | ||||
| func Test_parseIngress(t *testing.T) { | ||||
| 	localhost8000 := MustParseURL(t, "https://localhost:8000") | ||||
| 	localhost8001 := MustParseURL(t, "https://localhost:8001") | ||||
| 	defaultConfig := SetConfig(OriginRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{}) | ||||
| 	require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections) | ||||
| 	type args struct { | ||||
| 		rawYAML string | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		args    args | ||||
| 		want    Ingress | ||||
| 		want    []Rule | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
|  | @ -38,16 +50,18 @@ ingress: | |||
|  - hostname: "*" | ||||
|    service: https://localhost:8001
 | ||||
| `}, | ||||
| 			want: Ingress{Rules: []Rule{ | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Hostname: "tunnel1.example.com", | ||||
| 					Service:  localhost8000, | ||||
| 					Service:  &URL{URL: localhost8000}, | ||||
| 					Config:   defaultConfig, | ||||
| 				}, | ||||
| 				{ | ||||
| 					Hostname: "*", | ||||
| 					Service:  localhost8001, | ||||
| 					Service:  &URL{URL: localhost8001}, | ||||
| 					Config:   defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
| 			}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Extra keys", | ||||
|  | @ -57,12 +71,13 @@ ingress: | |||
|    service: https://localhost:8000
 | ||||
| extraKey: extraValue | ||||
| `}, | ||||
| 			want: Ingress{Rules: []Rule{ | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Hostname: "*", | ||||
| 					Service:  localhost8000, | ||||
| 					Service:  &URL{URL: localhost8000}, | ||||
| 					Config:   defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
| 			}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Hostname can be omitted", | ||||
|  | @ -70,11 +85,12 @@ extraKey: extraValue | |||
| ingress: | ||||
|  - service: https://localhost:8000
 | ||||
| `}, | ||||
| 			want: Ingress{Rules: []Rule{ | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Service: localhost8000, | ||||
| 					Service: &URL{URL: localhost8000}, | ||||
| 					Config:  defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
| 			}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Invalid service", | ||||
|  | @ -152,12 +168,12 @@ ingress: | |||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := ParseIngress(MustReadIngress(tt.args.rawYAML)) | ||||
| 			got, err := ParseIngressDryRun(MustReadIngress(tt.args.rawYAML)) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 				t.Errorf("ParseIngressDryRun() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 				return | ||||
| 			} | ||||
| 			assert.Equal(t, tt.want, got) | ||||
| 			assert.Equal(t, tt.want, got.Rules) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | @ -168,118 +184,6 @@ func MustParseURL(t *testing.T, rawURL string) *url.URL { | |||
| 	return u | ||||
| } | ||||
| 
 | ||||
| func Test_rule_matches(t *testing.T) { | ||||
| 	type fields struct { | ||||
| 		Hostname string | ||||
| 		Path     *regexp.Regexp | ||||
| 		Service  *url.URL | ||||
| 	} | ||||
| 	type args struct { | ||||
| 		requestURL *url.URL | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name   string | ||||
| 		fields fields | ||||
| 		args   args | ||||
| 		want   bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "Just hostname, pass", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Entire hostname is wildcard, should match everything", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just hostname, fail", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://foo.bar"), | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just wildcard hostname, pass", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://adam.example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just wildcard hostname, fail", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://tunnel.com"), | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just wildcard outside of subdomain in hostname, fail", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://www.example.com"), | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Wildcard over multiple subdomains", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://adam.chalmers.example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Hostname and path", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 				Path:     regexp.MustCompile("/static/.*\\.html"), | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://www.example.com/static/index.html"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			r := Rule{ | ||||
| 				Hostname: tt.fields.Hostname, | ||||
| 				Path:     tt.fields.Path, | ||||
| 				Service:  tt.fields.Service, | ||||
| 			} | ||||
| 			u := tt.args.requestURL | ||||
| 			if got := r.Matches(u.Hostname(), u.Path); got != tt.want { | ||||
| 				t.Errorf("rule.matches() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func BenchmarkFindMatch(b *testing.B) { | ||||
| 	rulesYAML := ` | ||||
| ingress: | ||||
|  | @ -291,7 +195,7 @@ ingress: | |||
|    service: https://localhost:8002
 | ||||
| ` | ||||
| 
 | ||||
| 	ing, err := ParseIngress(MustReadIngress(rulesYAML)) | ||||
| 	ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) | ||||
| 	if err != nil { | ||||
| 		b.Error(err) | ||||
| 	} | ||||
|  |  | |||
|  | @ -0,0 +1,331 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/config" | ||||
| 	"github.com/cloudflare/cloudflared/tlsconfig" | ||||
| 	"github.com/urfave/cli/v2" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	defaultConnectTimeout       = 30 * time.Second | ||||
| 	defaultTLSTimeout           = 10 * time.Second | ||||
| 	defaultTCPKeepAlive         = 30 * time.Second | ||||
| 	defaultKeepAliveConnections = 100 | ||||
| 	defaultKeepAliveTimeout     = 90 * time.Second | ||||
| 	defaultProxyAddress         = "127.0.0.1" | ||||
| 
 | ||||
| 	SSHServerFlag                 = "ssh-server" | ||||
| 	Socks5Flag                    = "socks5" | ||||
| 	ProxyConnectTimeoutFlag       = "proxy-connect-timeout" | ||||
| 	ProxyTLSTimeoutFlag           = "proxy-tls-timeout" | ||||
| 	ProxyTCPKeepAlive             = "proxy-tcp-keepalive" | ||||
| 	ProxyNoHappyEyeballsFlag      = "proxy-no-happy-eyeballs" | ||||
| 	ProxyKeepAliveConnectionsFlag = "proxy-keepalive-connections" | ||||
| 	ProxyKeepAliveTimeoutFlag     = "proxy-keepalive-timeout" | ||||
| 	HTTPHostHeaderFlag            = "http-host-header" | ||||
| 	OriginServerNameFlag          = "origin-server-name" | ||||
| 	NoTLSVerifyFlag               = "no-tls-verify" | ||||
| 	NoChunkedEncodingFlag         = "no-chunked-encoding" | ||||
| 	ProxyAddressFlag              = "proxy-address" | ||||
| 	ProxyPortFlag                 = "proxy-port" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	socksProxy = "socks" | ||||
| ) | ||||
| 
 | ||||
| func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { | ||||
| 	var connectTimeout time.Duration = defaultConnectTimeout | ||||
| 	var tlsTimeout time.Duration = defaultTLSTimeout | ||||
| 	var tcpKeepAlive time.Duration = defaultTCPKeepAlive | ||||
| 	var noHappyEyeballs bool | ||||
| 	var keepAliveConnections int = defaultKeepAliveConnections | ||||
| 	var keepAliveTimeout time.Duration = defaultKeepAliveTimeout | ||||
| 	var httpHostHeader string | ||||
| 	var originServerName string | ||||
| 	var caPool string | ||||
| 	var noTLSVerify bool | ||||
| 	var disableChunkedEncoding bool | ||||
| 	var bastionMode bool | ||||
| 	var proxyAddress string | ||||
| 	var proxyPort uint | ||||
| 	var proxyType string | ||||
| 	if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) { | ||||
| 		connectTimeout = c.Duration(flag) | ||||
| 	} | ||||
| 	if flag := ProxyTLSTimeoutFlag; c.IsSet(flag) { | ||||
| 		tlsTimeout = c.Duration(flag) | ||||
| 	} | ||||
| 	if flag := ProxyTCPKeepAlive; c.IsSet(flag) { | ||||
| 		tcpKeepAlive = c.Duration(flag) | ||||
| 	} | ||||
| 	if flag := ProxyNoHappyEyeballsFlag; c.IsSet(flag) { | ||||
| 		noHappyEyeballs = c.Bool(flag) | ||||
| 	} | ||||
| 	if flag := ProxyKeepAliveConnectionsFlag; c.IsSet(flag) { | ||||
| 		keepAliveConnections = c.Int(flag) | ||||
| 	} | ||||
| 	if flag := ProxyKeepAliveTimeoutFlag; c.IsSet(flag) { | ||||
| 		keepAliveTimeout = c.Duration(flag) | ||||
| 	} | ||||
| 	if flag := HTTPHostHeaderFlag; c.IsSet(flag) { | ||||
| 		httpHostHeader = c.String(flag) | ||||
| 	} | ||||
| 	if flag := OriginServerNameFlag; c.IsSet(flag) { | ||||
| 		originServerName = c.String(flag) | ||||
| 	} | ||||
| 	if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) { | ||||
| 		caPool = c.String(flag) | ||||
| 	} | ||||
| 	if flag := NoTLSVerifyFlag; c.IsSet(flag) { | ||||
| 		noTLSVerify = c.Bool(flag) | ||||
| 	} | ||||
| 	if flag := NoChunkedEncodingFlag; c.IsSet(flag) { | ||||
| 		disableChunkedEncoding = c.Bool(flag) | ||||
| 	} | ||||
| 	if flag := config.BastionFlag; c.IsSet(flag) { | ||||
| 		bastionMode = c.Bool(flag) | ||||
| 	} | ||||
| 	if flag := ProxyAddressFlag; c.IsSet(flag) { | ||||
| 		proxyAddress = c.String(flag) | ||||
| 	} | ||||
| 	if flag := ProxyPortFlag; c.IsSet(flag) { | ||||
| 		proxyPort = c.Uint(flag) | ||||
| 	} | ||||
| 	if c.IsSet(Socks5Flag) { | ||||
| 		proxyType = socksProxy | ||||
| 	} | ||||
| 	return OriginRequestConfig{ | ||||
| 		ConnectTimeout:         connectTimeout, | ||||
| 		TLSTimeout:             tlsTimeout, | ||||
| 		TCPKeepAlive:           tcpKeepAlive, | ||||
| 		NoHappyEyeballs:        noHappyEyeballs, | ||||
| 		KeepAliveConnections:   keepAliveConnections, | ||||
| 		KeepAliveTimeout:       keepAliveTimeout, | ||||
| 		HTTPHostHeader:         httpHostHeader, | ||||
| 		OriginServerName:       originServerName, | ||||
| 		CAPool:                 caPool, | ||||
| 		NoTLSVerify:            noTLSVerify, | ||||
| 		DisableChunkedEncoding: disableChunkedEncoding, | ||||
| 		BastionMode:            bastionMode, | ||||
| 		ProxyAddress:           proxyAddress, | ||||
| 		ProxyPort:              proxyPort, | ||||
| 		ProxyType:              proxyType, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func OriginRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig { | ||||
| 	out := OriginRequestConfig{ | ||||
| 		ConnectTimeout:       defaultConnectTimeout, | ||||
| 		TLSTimeout:           defaultTLSTimeout, | ||||
| 		TCPKeepAlive:         defaultTCPKeepAlive, | ||||
| 		KeepAliveConnections: defaultKeepAliveConnections, | ||||
| 		KeepAliveTimeout:     defaultKeepAliveTimeout, | ||||
| 		ProxyAddress:         defaultProxyAddress, | ||||
| 	} | ||||
| 	if y.ConnectTimeout != nil { | ||||
| 		out.ConnectTimeout = *y.ConnectTimeout | ||||
| 	} | ||||
| 	if y.TLSTimeout != nil { | ||||
| 		out.TLSTimeout = *y.TLSTimeout | ||||
| 	} | ||||
| 	if y.TCPKeepAlive != nil { | ||||
| 		out.TCPKeepAlive = *y.TCPKeepAlive | ||||
| 	} | ||||
| 	if y.NoHappyEyeballs != nil { | ||||
| 		out.NoHappyEyeballs = *y.NoHappyEyeballs | ||||
| 	} | ||||
| 	if y.KeepAliveConnections != nil { | ||||
| 		out.KeepAliveConnections = *y.KeepAliveConnections | ||||
| 	} | ||||
| 	if y.KeepAliveTimeout != nil { | ||||
| 		out.KeepAliveTimeout = *y.KeepAliveTimeout | ||||
| 	} | ||||
| 	if y.HTTPHostHeader != nil { | ||||
| 		out.HTTPHostHeader = *y.HTTPHostHeader | ||||
| 	} | ||||
| 	if y.OriginServerName != nil { | ||||
| 		out.OriginServerName = *y.OriginServerName | ||||
| 	} | ||||
| 	if y.CAPool != nil { | ||||
| 		out.CAPool = *y.CAPool | ||||
| 	} | ||||
| 	if y.NoTLSVerify != nil { | ||||
| 		out.NoTLSVerify = *y.NoTLSVerify | ||||
| 	} | ||||
| 	if y.DisableChunkedEncoding != nil { | ||||
| 		out.DisableChunkedEncoding = *y.DisableChunkedEncoding | ||||
| 	} | ||||
| 	if y.BastionMode != nil { | ||||
| 		out.BastionMode = *y.BastionMode | ||||
| 	} | ||||
| 	if y.ProxyAddress != nil { | ||||
| 		out.ProxyAddress = *y.ProxyAddress | ||||
| 	} | ||||
| 	if y.ProxyPort != nil { | ||||
| 		out.ProxyPort = *y.ProxyPort | ||||
| 	} | ||||
| 	if y.ProxyType != nil { | ||||
| 		out.ProxyType = *y.ProxyType | ||||
| 	} | ||||
| 	return out | ||||
| } | ||||
| 
 | ||||
| // OriginRequestConfig configures how Cloudflared sends requests to origin
 | ||||
| // services.
 | ||||
| // Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
 | ||||
| type OriginRequestConfig struct { | ||||
| 	// HTTP proxy timeout for establishing a new connection
 | ||||
| 	ConnectTimeout time.Duration `yaml:"connectTimeout"` | ||||
| 	// HTTP proxy timeout for completing a TLS handshake
 | ||||
| 	TLSTimeout time.Duration `yaml:"tlsTimeout"` | ||||
| 	// HTTP proxy TCP keepalive duration
 | ||||
| 	TCPKeepAlive time.Duration `yaml:"tcpKeepAlive"` | ||||
| 	// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
 | ||||
| 	NoHappyEyeballs bool `yaml:"noHappyEyeballs"` | ||||
| 	// HTTP proxy maximum keepalive connection pool size
 | ||||
| 	KeepAliveConnections int `yaml:"keepAliveConnections"` | ||||
| 	// HTTP proxy timeout for closing an idle connection
 | ||||
| 	KeepAliveTimeout time.Duration `yaml:"keepAliveTimeout"` | ||||
| 	// Sets the HTTP Host header for the local webserver.
 | ||||
| 	HTTPHostHeader string `yaml:"httpHostHeader"` | ||||
| 	// Hostname on the origin server certificate.
 | ||||
| 	OriginServerName string `yaml:"originServerName"` | ||||
| 	// Path to the CA for the certificate of your origin.
 | ||||
| 	// This option should be used only if your certificate is not signed by Cloudflare.
 | ||||
| 	CAPool string `yaml:"caPool"` | ||||
| 	// Disables TLS verification of the certificate presented by your origin.
 | ||||
| 	// Will allow any certificate from the origin to be accepted.
 | ||||
| 	// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
 | ||||
| 	NoTLSVerify bool `yaml:"noTLSVerify"` | ||||
| 	// Disables chunked transfer encoding.
 | ||||
| 	// Useful if you are running a WSGI server.
 | ||||
| 	DisableChunkedEncoding bool `yaml:"disableChunkedEncoding"` | ||||
| 	// Runs as jump host
 | ||||
| 	BastionMode bool `yaml:"bastionMode"` | ||||
| 	// Listen address for the proxy.
 | ||||
| 	ProxyAddress string `yaml:"proxyAddress"` | ||||
| 	// Listen port for the proxy.
 | ||||
| 	ProxyPort uint `yaml:"proxyPort"` | ||||
| 	// What sort of proxy should be started
 | ||||
| 	ProxyType string `yaml:"proxyType"` | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.ConnectTimeout; val != nil { | ||||
| 		defaults.ConnectTimeout = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setTLSTimeout(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.TLSTimeout; val != nil { | ||||
| 		defaults.TLSTimeout = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setNoHappyEyeballs(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.NoHappyEyeballs; val != nil { | ||||
| 		defaults.NoHappyEyeballs = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setKeepAliveConnections(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.KeepAliveConnections; val != nil { | ||||
| 		defaults.KeepAliveConnections = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setKeepAliveTimeout(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.KeepAliveTimeout; val != nil { | ||||
| 		defaults.KeepAliveTimeout = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setTCPKeepAlive(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.TCPKeepAlive; val != nil { | ||||
| 		defaults.TCPKeepAlive = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setHTTPHostHeader(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.HTTPHostHeader; val != nil { | ||||
| 		defaults.HTTPHostHeader = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setOriginServerName(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.OriginServerName; val != nil { | ||||
| 		defaults.OriginServerName = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.CAPool; val != nil { | ||||
| 		defaults.CAPool = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setNoTLSVerify(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.NoTLSVerify; val != nil { | ||||
| 		defaults.NoTLSVerify = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setDisableChunkedEncoding(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.DisableChunkedEncoding; val != nil { | ||||
| 		defaults.DisableChunkedEncoding = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setBastionMode(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.BastionMode; val != nil { | ||||
| 		defaults.BastionMode = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setProxyPort(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.ProxyPort; val != nil { | ||||
| 		defaults.ProxyPort = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setProxyAddress(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.ProxyAddress; val != nil { | ||||
| 		defaults.ProxyAddress = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequestConfig) { | ||||
| 	if val := overrides.ProxyType; val != nil { | ||||
| 		defaults.ProxyType = *val | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SetConfig gets config for the requests that cloudflared sends to origins.
 | ||||
| // Each field has a setter method which sets a value for the field by trying to find:
 | ||||
| //   1. The user config for this rule
 | ||||
| //   2. The user config for the overall ingress config
 | ||||
| //   3. Defaults chosen by the cloudflared team
 | ||||
| //   4. Golang zero values for that type
 | ||||
| // If an earlier option isn't set, it will try the next option down.
 | ||||
| func SetConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig { | ||||
| 	cfg := defaults | ||||
| 	cfg.setConnectTimeout(overrides) | ||||
| 	cfg.setTLSTimeout(overrides) | ||||
| 	cfg.setNoHappyEyeballs(overrides) | ||||
| 	cfg.setKeepAliveConnections(overrides) | ||||
| 	cfg.setKeepAliveTimeout(overrides) | ||||
| 	cfg.setTCPKeepAlive(overrides) | ||||
| 	cfg.setHTTPHostHeader(overrides) | ||||
| 	cfg.setOriginServerName(overrides) | ||||
| 	cfg.setCAPool(overrides) | ||||
| 	cfg.setNoTLSVerify(overrides) | ||||
| 	cfg.setDisableChunkedEncoding(overrides) | ||||
| 	cfg.setBastionMode(overrides) | ||||
| 	cfg.setProxyPort(overrides) | ||||
| 	cfg.setProxyAddress(overrides) | ||||
| 	cfg.setProxyType(overrides) | ||||
| 	return cfg | ||||
| } | ||||
|  | @ -0,0 +1,184 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/config" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"gopkg.in/yaml.v2" | ||||
| ) | ||||
| 
 | ||||
| // Ensure that the nullable config from `config` package and the
 | ||||
| // non-nullable config from `ingress` package have the same number of
 | ||||
| // fields.
 | ||||
| // This test ensures that programmers didn't add a new field to
 | ||||
| // one struct and forget to add it to the other ;)
 | ||||
| func TestCorrespondingFields(t *testing.T) { | ||||
| 	require.Equal( | ||||
| 		t, | ||||
| 		CountFields(t, config.OriginRequestConfig{}), | ||||
| 		CountFields(t, OriginRequestConfig{}), | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| func CountFields(t *testing.T, val interface{}) int { | ||||
| 	b, err := yaml.Marshal(val) | ||||
| 	require.NoError(t, err) | ||||
| 	m := make(map[string]interface{}, 0) | ||||
| 	err = yaml.Unmarshal(b, &m) | ||||
| 	require.NoError(t, err) | ||||
| 	return len(m) | ||||
| } | ||||
| 
 | ||||
| func TestOriginRequestConfigOverrides(t *testing.T) { | ||||
| 	rulesYAML := ` | ||||
| originRequest: | ||||
|   connectTimeout: 1m | ||||
|   tlsTimeout: 1s | ||||
|   noHappyEyeballs: true | ||||
|   tcpKeepAlive: 1s | ||||
|   keepAliveConnections: 1 | ||||
|   keepAliveTimeout: 1s | ||||
|   httpHostHeader: abc | ||||
|   originServerName: a1 | ||||
|   caPool: /tmp/path0 | ||||
|   noTLSVerify: true | ||||
|   disableChunkedEncoding: true | ||||
|   bastionMode: True | ||||
|   proxyAddress: 127.1.2.3 | ||||
|   proxyPort: 100 | ||||
|   proxyType: socks5 | ||||
| ingress: | ||||
| - hostname: tun.example.com | ||||
|   service: https://localhost:8000
 | ||||
| - hostname: "*" | ||||
|   service: https://localhost:8001
 | ||||
|   originRequest: | ||||
|     connectTimeout: 2m | ||||
|     tlsTimeout: 2s | ||||
|     noHappyEyeballs: false | ||||
|     tcpKeepAlive: 2s | ||||
|     keepAliveConnections: 2 | ||||
|     keepAliveTimeout: 2s | ||||
|     httpHostHeader: def | ||||
|     originServerName: b2 | ||||
|     caPool: /tmp/path1 | ||||
|     noTLSVerify: false | ||||
|     disableChunkedEncoding: false | ||||
|     bastionMode: false | ||||
|     proxyAddress: interface | ||||
|     proxyPort: 200 | ||||
|     proxyType: "" | ||||
| ` | ||||
| 	ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Rule 0 didn't override anything, so it inherits the user-specified
 | ||||
| 	// root-level configuration.
 | ||||
| 	actual0 := ing.Rules[0].Config | ||||
| 	expected0 := OriginRequestConfig{ | ||||
| 		ConnectTimeout:         1 * time.Minute, | ||||
| 		TLSTimeout:             1 * time.Second, | ||||
| 		NoHappyEyeballs:        true, | ||||
| 		TCPKeepAlive:           1 * time.Second, | ||||
| 		KeepAliveConnections:   1, | ||||
| 		KeepAliveTimeout:       1 * time.Second, | ||||
| 		HTTPHostHeader:         "abc", | ||||
| 		OriginServerName:       "a1", | ||||
| 		CAPool:                 "/tmp/path0", | ||||
| 		NoTLSVerify:            true, | ||||
| 		DisableChunkedEncoding: true, | ||||
| 		BastionMode:            true, | ||||
| 		ProxyAddress:           "127.1.2.3", | ||||
| 		ProxyPort:              uint(100), | ||||
| 		ProxyType:              "socks5", | ||||
| 	} | ||||
| 	require.Equal(t, expected0, actual0) | ||||
| 
 | ||||
| 	// Rule 1 overrode all the root-level config.
 | ||||
| 	actual1 := ing.Rules[1].Config | ||||
| 	expected1 := OriginRequestConfig{ | ||||
| 		ConnectTimeout:         2 * time.Minute, | ||||
| 		TLSTimeout:             2 * time.Second, | ||||
| 		NoHappyEyeballs:        false, | ||||
| 		TCPKeepAlive:           2 * time.Second, | ||||
| 		KeepAliveConnections:   2, | ||||
| 		KeepAliveTimeout:       2 * time.Second, | ||||
| 		HTTPHostHeader:         "def", | ||||
| 		OriginServerName:       "b2", | ||||
| 		CAPool:                 "/tmp/path1", | ||||
| 		NoTLSVerify:            false, | ||||
| 		DisableChunkedEncoding: false, | ||||
| 		BastionMode:            false, | ||||
| 		ProxyAddress:           "interface", | ||||
| 		ProxyPort:              uint(200), | ||||
| 		ProxyType:              "", | ||||
| 	} | ||||
| 	require.Equal(t, expected1, actual1) | ||||
| } | ||||
| 
 | ||||
| func TestOriginRequestConfigDefaults(t *testing.T) { | ||||
| 	rulesYAML := ` | ||||
| ingress: | ||||
| - hostname: tun.example.com | ||||
|   service: https://localhost:8000
 | ||||
| - hostname: "*" | ||||
|   service: https://localhost:8001
 | ||||
|   originRequest: | ||||
|     connectTimeout: 2m | ||||
|     tlsTimeout: 2s | ||||
|     noHappyEyeballs: false | ||||
|     tcpKeepAlive: 2s | ||||
|     keepAliveConnections: 2 | ||||
|     keepAliveTimeout: 2s | ||||
|     httpHostHeader: def | ||||
|     originServerName: b2 | ||||
|     caPool: /tmp/path1 | ||||
|     noTLSVerify: false | ||||
|     disableChunkedEncoding: false | ||||
|     bastionMode: false | ||||
|     proxyAddress: interface | ||||
|     proxyPort: 200 | ||||
|     proxyType: "" | ||||
| ` | ||||
| 	ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| 
 | ||||
| 	// Rule 0 didn't override anything, so it inherits the cloudflared defaults
 | ||||
| 	actual0 := ing.Rules[0].Config | ||||
| 	expected0 := OriginRequestConfig{ | ||||
| 		ConnectTimeout:       defaultConnectTimeout, | ||||
| 		TLSTimeout:           defaultTLSTimeout, | ||||
| 		TCPKeepAlive:         defaultTCPKeepAlive, | ||||
| 		KeepAliveConnections: defaultKeepAliveConnections, | ||||
| 		KeepAliveTimeout:     defaultKeepAliveTimeout, | ||||
| 		ProxyAddress:         defaultProxyAddress, | ||||
| 	} | ||||
| 	require.Equal(t, expected0, actual0) | ||||
| 
 | ||||
| 	// Rule 1 overrode all defaults.
 | ||||
| 	actual1 := ing.Rules[1].Config | ||||
| 	expected1 := OriginRequestConfig{ | ||||
| 		ConnectTimeout:         2 * time.Minute, | ||||
| 		TLSTimeout:             2 * time.Second, | ||||
| 		NoHappyEyeballs:        false, | ||||
| 		TCPKeepAlive:           2 * time.Second, | ||||
| 		KeepAliveConnections:   2, | ||||
| 		KeepAliveTimeout:       2 * time.Second, | ||||
| 		HTTPHostHeader:         "def", | ||||
| 		OriginServerName:       "b2", | ||||
| 		CAPool:                 "/tmp/path1", | ||||
| 		NoTLSVerify:            false, | ||||
| 		DisableChunkedEncoding: false, | ||||
| 		BastionMode:            false, | ||||
| 		ProxyAddress:           "interface", | ||||
| 		ProxyPort:              uint(200), | ||||
| 		ProxyType:              "", | ||||
| 	} | ||||
| 	require.Equal(t, expected1, actual1) | ||||
| } | ||||
|  | @ -0,0 +1,181 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/hello" | ||||
| 	"github.com/cloudflare/cloudflared/logger" | ||||
| 	"github.com/cloudflare/cloudflared/socks" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
| 
 | ||||
| // OriginService is something a tunnel can proxy traffic to.
 | ||||
| type OriginService interface { | ||||
| 	Address() string | ||||
| 	// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
 | ||||
| 	// If it's not managed by cloudflared, this is a no-op because the user is responsible for
 | ||||
| 	// starting the origin service.
 | ||||
| 	Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error | ||||
| 	String() string | ||||
| 	// RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply
 | ||||
| 	// this particular type of origin service's specific routing logic.
 | ||||
| 	RewriteOriginURL(*url.URL) | ||||
| } | ||||
| 
 | ||||
| // UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
 | ||||
| type UnixSocketPath string | ||||
| 
 | ||||
| func (o UnixSocketPath) Address() string { | ||||
| 	return string(o) | ||||
| } | ||||
| 
 | ||||
| func (o UnixSocketPath) String() string { | ||||
| 	return "unix socket: " + string(o) | ||||
| } | ||||
| 
 | ||||
| func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (o UnixSocketPath) RewriteOriginURL(u *url.URL) { | ||||
| 	// No changes necessary because the origin request URL isn't used.
 | ||||
| 	// Instead, HTTPTransport's dial is already configured to address the unix socket.
 | ||||
| } | ||||
| 
 | ||||
| // URL is an OriginService listening on a TCP address
 | ||||
| type URL struct { | ||||
| 	// The URL for the user's origin service
 | ||||
| 	RootURL *url.URL | ||||
| 	// The URL that cloudflared should send requests to.
 | ||||
| 	// If this origin requires starting a proxy, this is the proxy's address,
 | ||||
| 	// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
 | ||||
| 	URL *url.URL | ||||
| } | ||||
| 
 | ||||
| func (o *URL) Address() string { | ||||
| 	return o.URL.String() | ||||
| } | ||||
| 
 | ||||
| func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| 	staticHost := o.staticHost() | ||||
| 	if !originRequiresProxy(staticHost, cfg) { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Start a listener for the proxy
 | ||||
| 	proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort))) | ||||
| 	listener, err := net.Listen("tcp", proxyAddress) | ||||
| 	if err != nil { | ||||
| 		log.Errorf("Cannot start Websocket Proxy Server: %s", err) | ||||
| 		return errors.Wrap(err, "Cannot start Websocket Proxy Server") | ||||
| 	} | ||||
| 
 | ||||
| 	// Start the proxy itself
 | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		streamHandler := websocket.DefaultStreamHandler | ||||
| 		// This origin's config specifies what type of proxy to start.
 | ||||
| 		switch cfg.ProxyType { | ||||
| 		case socksProxy: | ||||
| 			log.Info("SOCKS5 server started") | ||||
| 			streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) { | ||||
| 				dialer := socks.NewConnDialer(remoteConn) | ||||
| 				requestHandler := socks.NewRequestHandler(dialer) | ||||
| 				socksServer := socks.NewConnectionHandler(requestHandler) | ||||
| 
 | ||||
| 				socksServer.Serve(wsConn) | ||||
| 			} | ||||
| 		case "": | ||||
| 			log.Debug("Not starting any websocket proxy") | ||||
| 		default: | ||||
| 			log.Errorf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy) | ||||
| 		} | ||||
| 
 | ||||
| 		errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler) | ||||
| 	}() | ||||
| 
 | ||||
| 	// Modify this origin, so that it no longer points at the origin service directly.
 | ||||
| 	// Instead, it points at the proxy to the origin service.
 | ||||
| 	newURL, err := url.Parse("http://" + listener.Addr().String()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	o.URL = newURL | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (o *URL) String() string { | ||||
| 	return o.Address() | ||||
| } | ||||
| 
 | ||||
| func (o *URL) RewriteOriginURL(u *url.URL) { | ||||
| 	u.Host = o.URL.Host | ||||
| 	u.Scheme = o.URL.Scheme | ||||
| } | ||||
| 
 | ||||
| func (o *URL) staticHost() string { | ||||
| 
 | ||||
| 	addPortIfMissing := func(uri *url.URL, port int) string { | ||||
| 		if uri.Port() != "" { | ||||
| 			return uri.Host | ||||
| 		} | ||||
| 		return fmt.Sprintf("%s:%d", uri.Hostname(), port) | ||||
| 	} | ||||
| 
 | ||||
| 	switch o.URL.Scheme { | ||||
| 	case "ssh": | ||||
| 		return addPortIfMissing(o.URL, 22) | ||||
| 	case "rdp": | ||||
| 		return addPortIfMissing(o.URL, 3389) | ||||
| 	case "smb": | ||||
| 		return addPortIfMissing(o.URL, 445) | ||||
| 	case "tcp": | ||||
| 		return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
 | ||||
| 	} | ||||
| 	return "" | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| // HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared.
 | ||||
| type HelloWorld struct { | ||||
| 	server net.Listener | ||||
| } | ||||
| 
 | ||||
| func (o *HelloWorld) Address() string { | ||||
| 	return o.server.Addr().String() | ||||
| } | ||||
| 
 | ||||
| func (o *HelloWorld) String() string { | ||||
| 	return "Hello World static HTML service" | ||||
| } | ||||
| 
 | ||||
| // Start starts a HelloWorld server and stores its address in the Service receiver.
 | ||||
| func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| 	helloListener, err := hello.CreateTLSListener("127.0.0.1:") | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Cannot start Hello World Server") | ||||
| 	} | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		_ = hello.StartHelloWorldServer(log, helloListener, shutdownC) | ||||
| 	}() | ||||
| 	o.server = helloListener | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (o *HelloWorld) RewriteOriginURL(u *url.URL) { | ||||
| 	u.Host = o.Address() | ||||
| 	u.Scheme = "https" | ||||
| } | ||||
| 
 | ||||
| func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool { | ||||
| 	return staticHost != "" || cfg.BastionMode | ||||
| } | ||||
|  | @ -0,0 +1,57 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // Rule routes traffic from a hostname/path on the public internet to the
 | ||||
| // service running on the given URL.
 | ||||
| type Rule struct { | ||||
| 	// Requests for this hostname will be proxied to this rule's service.
 | ||||
| 	Hostname string | ||||
| 
 | ||||
| 	// Path is an optional regex that can specify path-driven ingress rules.
 | ||||
| 	Path *regexp.Regexp | ||||
| 
 | ||||
| 	// A (probably local) address. Requests for a hostname which matches this
 | ||||
| 	// rule's hostname pattern will be proxied to the service running on this
 | ||||
| 	// address.
 | ||||
| 	Service OriginService | ||||
| 
 | ||||
| 	// Configure the request cloudflared sends to this specific origin.
 | ||||
| 	Config OriginRequestConfig | ||||
| 
 | ||||
| 	// Configures TLS for the cloudflared -> origin request
 | ||||
| 	ClientTLSConfig *tls.Config | ||||
| 	// Configures HTTP for the cloudflared -> origin request
 | ||||
| 	HTTPTransport http.RoundTripper | ||||
| } | ||||
| 
 | ||||
| // MultiLineString is for outputting rules in a human-friendly way when Cloudflared
 | ||||
| // is used as a CLI tool (not as a daemon).
 | ||||
| func (r Rule) MultiLineString() string { | ||||
| 	var out strings.Builder | ||||
| 	if r.Hostname != "" { | ||||
| 		out.WriteString("\thostname: ") | ||||
| 		out.WriteString(r.Hostname) | ||||
| 		out.WriteRune('\n') | ||||
| 	} | ||||
| 	if r.Path != nil { | ||||
| 		out.WriteString("\tpath: ") | ||||
| 		out.WriteString(r.Path.String()) | ||||
| 		out.WriteRune('\n') | ||||
| 	} | ||||
| 	out.WriteString("\tservice: ") | ||||
| 	out.WriteString(r.Service.String()) | ||||
| 	return out.String() | ||||
| } | ||||
| 
 | ||||
| // Matches checks if the rule matches a given hostname/path combination.
 | ||||
| func (r *Rule) Matches(hostname, path string) bool { | ||||
| 	hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname) | ||||
| 	pathMatch := r.Path == nil || r.Path.MatchString(path) | ||||
| 	return hostMatch && pathMatch | ||||
| } | ||||
|  | @ -0,0 +1,119 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"testing" | ||||
| ) | ||||
| 
 | ||||
| func Test_rule_matches(t *testing.T) { | ||||
| 	type fields struct { | ||||
| 		Hostname string | ||||
| 		Path     *regexp.Regexp | ||||
| 		Service  OriginService | ||||
| 	} | ||||
| 	type args struct { | ||||
| 		requestURL *url.URL | ||||
| 	} | ||||
| 	tests := []struct { | ||||
| 		name   string | ||||
| 		fields fields | ||||
| 		args   args | ||||
| 		want   bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "Just hostname, pass", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Entire hostname is wildcard, should match everything", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just hostname, fail", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://foo.bar"), | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just wildcard hostname, pass", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://adam.example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just wildcard hostname, fail", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://tunnel.com"), | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Just wildcard outside of subdomain in hostname, fail", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://www.example.com"), | ||||
| 			}, | ||||
| 			want: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Wildcard over multiple subdomains", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://adam.chalmers.example.com"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "Hostname and path", | ||||
| 			fields: fields{ | ||||
| 				Hostname: "*.example.com", | ||||
| 				Path:     regexp.MustCompile("/static/.*\\.html"), | ||||
| 			}, | ||||
| 			args: args{ | ||||
| 				requestURL: MustParseURL(t, "https://www.example.com/static/index.html"), | ||||
| 			}, | ||||
| 			want: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			r := Rule{ | ||||
| 				Hostname: tt.fields.Hostname, | ||||
| 				Path:     tt.fields.Path, | ||||
| 				Service:  tt.fields.Service, | ||||
| 			} | ||||
| 			u := tt.args.requestURL | ||||
| 			if got := r.Matches(u.Hostname(), u.Path); got != tt.want { | ||||
| 				t.Errorf("rule.matches() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | @ -30,7 +30,6 @@ import ( | |||
| 	"github.com/cloudflare/cloudflared/tunnelrpc" | ||||
| 	"github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||
| 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||
| 	"github.com/cloudflare/cloudflared/validation" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| ) | ||||
| 
 | ||||
|  | @ -57,16 +56,13 @@ const ( | |||
| type TunnelConfig struct { | ||||
| 	BuildInfo          *buildinfo.BuildInfo | ||||
| 	ClientID           string | ||||
| 	ClientTlsConfig    *tls.Config | ||||
| 	CloseConnOnce      *sync.Once // Used to close connectedSignal no more than once
 | ||||
| 	CompressionQuality uint64 | ||||
| 	EdgeAddrs          []string | ||||
| 	GracePeriod        time.Duration | ||||
| 	HAConnections      int | ||||
| 	HTTPTransport      http.RoundTripper | ||||
| 	HeartbeatInterval  time.Duration | ||||
| 	Hostname           string | ||||
| 	HTTPHostHeader     string | ||||
| 	IncidentLookup     IncidentLookup | ||||
| 	IsAutoupdated      bool | ||||
| 	IsFreeTunnel       bool | ||||
|  | @ -76,7 +72,6 @@ type TunnelConfig struct { | |||
| 	MaxHeartbeats      uint64 | ||||
| 	Metrics            *TunnelMetrics | ||||
| 	MetricsUpdateFreq  time.Duration | ||||
| 	NoChunkedEncoding  bool | ||||
| 	OriginCert         []byte | ||||
| 	ReportedVersion    string | ||||
| 	Retries            uint | ||||
|  | @ -84,8 +79,6 @@ type TunnelConfig struct { | |||
| 	Tags               []tunnelpogs.Tag | ||||
| 	TlsConfig          *tls.Config | ||||
| 	WSGI               bool | ||||
| 	// OriginUrl may not be used if a user specifies a unix socket.
 | ||||
| 	OriginUrl string | ||||
| 
 | ||||
| 	// feature-flag to use new edge reconnect tokens
 | ||||
| 	UseReconnectToken bool | ||||
|  | @ -618,18 +611,13 @@ func LogServerInfo( | |||
| } | ||||
| 
 | ||||
| type TunnelHandler struct { | ||||
| 	originUrl      string | ||||
| 	ingressRules ingress.Ingress | ||||
| 	httpHostHeader string | ||||
| 	muxer        *h2mux.Muxer | ||||
| 	httpClient     http.RoundTripper | ||||
| 	tlsConfig      *tls.Config | ||||
| 	tags         []tunnelpogs.Tag | ||||
| 	metrics      *TunnelMetrics | ||||
| 	// connectionID is only used by metrics, and prometheus requires labels to be string
 | ||||
| 	connectionID string | ||||
| 	logger       logger.Service | ||||
| 	noChunkedEncoding bool | ||||
| 
 | ||||
| 	bufferPool *buffer.Pool | ||||
| } | ||||
|  | @ -642,32 +630,14 @@ func NewTunnelHandler(ctx context.Context, | |||
| 	bufferPool *buffer.Pool, | ||||
| ) (*TunnelHandler, string, error) { | ||||
| 
 | ||||
| 	// Check single-origin config
 | ||||
| 	var originURL string | ||||
| 	var err error | ||||
| 	if config.IngressRules.IsEmpty() { | ||||
| 		originURL, err = validation.ValidateUrl(config.OriginUrl) | ||||
| 		if err != nil { | ||||
| 			return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	h := &TunnelHandler{ | ||||
| 		originUrl:         originURL, | ||||
| 		ingressRules: config.IngressRules, | ||||
| 		httpHostHeader:    config.HTTPHostHeader, | ||||
| 		httpClient:        config.HTTPTransport, | ||||
| 		tlsConfig:         config.ClientTlsConfig, | ||||
| 		tags:         config.Tags, | ||||
| 		metrics:      config.Metrics, | ||||
| 		connectionID: uint8ToString(connectionID), | ||||
| 		logger:       config.Logger, | ||||
| 		noChunkedEncoding: config.NoChunkedEncoding, | ||||
| 		bufferPool:   bufferPool, | ||||
| 	} | ||||
| 	if h.httpClient == nil { | ||||
| 		h.httpClient = http.DefaultTransport | ||||
| 	} | ||||
| 
 | ||||
| 	edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) | ||||
| 	if err != nil { | ||||
|  | @ -692,7 +662,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { | |||
| 	h.metrics.incrementRequests(h.connectionID) | ||||
| 	defer h.metrics.decrementConcurrentRequests(h.connectionID) | ||||
| 
 | ||||
| 	req, reqErr := h.createRequest(stream) | ||||
| 	req, rule, reqErr := h.createRequest(stream) | ||||
| 	if reqErr != nil { | ||||
| 		h.writeErrorResponse(stream, reqErr) | ||||
| 		return reqErr | ||||
|  | @ -705,9 +675,9 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { | |||
| 	var resp *http.Response | ||||
| 	var respErr error | ||||
| 	if websocket.IsWebSocketUpgrade(req) { | ||||
| 		resp, respErr = h.serveWebsocket(stream, req) | ||||
| 		resp, respErr = h.serveWebsocket(stream, req, rule) | ||||
| 	} else { | ||||
| 		resp, respErr = h.serveHTTP(stream, req) | ||||
| 		resp, respErr = h.serveHTTP(stream, req, rule) | ||||
| 	} | ||||
| 	if respErr != nil { | ||||
| 		h.writeErrorResponse(stream, respErr) | ||||
|  | @ -717,32 +687,28 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) { | ||||
| 	req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) | ||||
| func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) { | ||||
| 	req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream}) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") | ||||
| 		return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest") | ||||
| 	} | ||||
| 	err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "invalid request received") | ||||
| 		return nil, nil, errors.Wrap(err, "invalid request received") | ||||
| 	} | ||||
| 	h.AppendTagHeaders(req) | ||||
| 	if !h.ingressRules.IsEmpty() { | ||||
| 		ruleNumber := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) | ||||
| 		destination := h.ingressRules.Rules[ruleNumber].Service | ||||
| 		req.URL.Host = destination.Host | ||||
| 		req.URL.Scheme = destination.Scheme | ||||
| 	} | ||||
| 	return req, nil | ||||
| 	rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) | ||||
| 	rule.Service.RewriteOriginURL(req.URL) | ||||
| 	return req, rule, nil | ||||
| } | ||||
| 
 | ||||
| func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { | ||||
| 	if h.httpHostHeader != "" { | ||||
| 		req.Header.Set("Host", h.httpHostHeader) | ||||
| 		req.Host = h.httpHostHeader | ||||
| func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { | ||||
| 	if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { | ||||
| 		req.Header.Set("Host", hostHeader) | ||||
| 		req.Host = hostHeader | ||||
| 	} | ||||
| 
 | ||||
| 	conn, response, err := websocket.ClientConnect(req, h.tlsConfig) | ||||
| 	conn, response, err := websocket.ClientConnect(req, rule.ClientTLSConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -758,9 +724,9 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ | |||
| 	return response, nil | ||||
| } | ||||
| 
 | ||||
| func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { | ||||
| func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { | ||||
| 	// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
 | ||||
| 	if h.noChunkedEncoding { | ||||
| 	if rule.Config.DisableChunkedEncoding { | ||||
| 		req.TransferEncoding = []string{"gzip", "deflate"} | ||||
| 		cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) | ||||
| 		if err == nil { | ||||
|  | @ -771,12 +737,12 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) | |||
| 	// Request origin to keep connection alive to improve performance
 | ||||
| 	req.Header.Set("Connection", "keep-alive") | ||||
| 
 | ||||
| 	if h.httpHostHeader != "" { | ||||
| 		req.Header.Set("Host", h.httpHostHeader) | ||||
| 		req.Host = h.httpHostHeader | ||||
| 	if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { | ||||
| 		req.Header.Set("Host", hostHeader) | ||||
| 		req.Host = hostHeader | ||||
| 	} | ||||
| 
 | ||||
| 	response, err := h.httpClient.RoundTrip(req) | ||||
| 	response, err := rule.HTTPTransport.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error proxying request to origin") | ||||
| 	} | ||||
|  |  | |||
|  | @ -65,10 +65,9 @@ func (cr *CertReloader) LoadCert() error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func LoadOriginCA(c *cli.Context, logger logger.Service) (*x509.CertPool, error) { | ||||
| func LoadOriginCA(originCAPoolFilename string, logger logger.Service) (*x509.CertPool, error) { | ||||
| 	var originCustomCAPool []byte | ||||
| 
 | ||||
| 	originCAPoolFilename := c.String(OriginCAPoolFlag) | ||||
| 	if originCAPoolFilename != "" { | ||||
| 		var err error | ||||
| 		originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue