From d01770107e5001176a1cd658896acfc9bfbd2834 Mon Sep 17 00:00:00 2001 From: Adam Chalmers Date: Fri, 30 Oct 2020 16:37:40 -0500 Subject: [PATCH] TUN-3492: Refactor OriginService, shrink its interface --- cmd/cloudflared/tunnel/cmd.go | 2 +- cmd/cloudflared/tunnel/configuration.go | 2 +- cmd/cloudflared/tunnel/ingress_subcommands.go | 4 +- ingress/ingress.go | 121 ++---------- ingress/ingress_test.go | 45 +++-- ingress/origin_request_config.go | 4 +- ingress/origin_request_config_test.go | 4 +- ingress/origin_service.go | 176 +++++++++++++----- ingress/rule.go | 7 - origin/tunnel.go | 9 +- websocket/websocket.go | 22 ++- websocket/websocket_test.go | 3 +- 12 files changed, 214 insertions(+), 185 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 03232378..434bc9d9 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -395,7 +395,7 @@ func StartServer( hostname, metricsListener.Addr().String(), // TODO (TUN-3461): Update UI to show multiple origin URLs - tunnelConfig.IngressRules.CatchAll().Service.Address(), + tunnelConfig.IngressRules.CatchAll().Service.String(), tunnelConfig.HAConnections, ) logLevels, err := logger.ParseLevelString(c.String("loglevel")) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 7fb6c24f..bb033835 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -202,7 +202,7 @@ func prepareTunnelConfig( Version: version, Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch), } - ingressRules, err = ingress.ParseIngress(config.GetConfiguration(), logger) + ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) if err != nil && err != ingress.ErrNoIngressRules { return nil, err } diff --git a/cmd/cloudflared/tunnel/ingress_subcommands.go b/cmd/cloudflared/tunnel/ingress_subcommands.go index ec03c2d0..de3b1350 100644 --- a/cmd/cloudflared/tunnel/ingress_subcommands.go +++ b/cmd/cloudflared/tunnel/ingress_subcommands.go @@ -71,7 +71,7 @@ func buildTestURLCommand() *cli.Command { func validateIngressCommand(c *cli.Context) error { conf := config.GetConfiguration() fmt.Println("Validating rules from", conf.Source()) - if _, err := ingress.ParseIngressDryRun(conf); err != nil { + if _, err := ingress.ParseIngress(conf); err != nil { return errors.Wrap(err, "Validation failed") } if c.IsSet("url") { @@ -98,7 +98,7 @@ func testURLCommand(c *cli.Context) error { conf := config.GetConfiguration() fmt.Println("Using rules from", conf.Source()) - ing, err := ingress.ParseIngressDryRun(conf) + ing, err := ingress.ParseIngress(conf) if err != nil { return errors.Wrap(err, "Validation failed") } diff --git a/ingress/ingress.go b/ingress/ingress.go index a1c16ec4..614a6b5a 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -1,24 +1,17 @@ package ingress import ( - "context" - "crypto/tls" "fmt" - "net" - "net/http" "net/url" "regexp" "strings" "sync" - "time" "github.com/pkg/errors" "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/tlsconfig" - "github.com/cloudflare/cloudflared/validation" ) var ( @@ -28,82 +21,6 @@ var ( ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules") ) -// Finalize the rules by adding missing struct fields and validating each origin. -func (ing *Ingress) setHTTPTransport(logger logger.Service) error { - for ruleNumber, rule := range ing.Rules { - cfg := rule.Config - originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil) - if err != nil { - return errors.Wrap(err, "Error loading cert pool") - } - - httpTransport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - MaxIdleConns: cfg.KeepAliveConnections, - MaxIdleConnsPerHost: cfg.KeepAliveConnections, - IdleConnTimeout: cfg.KeepAliveTimeout, - TLSHandshakeTimeout: cfg.TLSTimeout, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify}, - } - if _, isHelloWorld := rule.Service.(*HelloWorld); !isHelloWorld && cfg.OriginServerName != "" { - httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName - } - - dialer := &net.Dialer{ - Timeout: cfg.ConnectTimeout, - KeepAlive: cfg.TCPKeepAlive, - } - if cfg.NoHappyEyeballs { - dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs" - } - - // DialContext depends on which kind of origin is being used. - dialContext := dialer.DialContext - switch service := rule.Service.(type) { - - // If this origin is a unix socket, enforce network type "unix". - case UnixSocketPath: - httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialContext(ctx, "unix", service.Address()) - } - // Otherwise, use the regular network config. - default: - httpTransport.DialContext = dialContext - } - - ing.Rules[ruleNumber].HTTPTransport = httpTransport - ing.Rules[ruleNumber].ClientTLSConfig = httpTransport.TLSClientConfig - } - - // Validate each origin - for _, rule := range ing.Rules { - // If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client. - if rule.Config.BastionMode { - continue - } - - // Unix sockets don't have validation - if _, ok := rule.Service.(UnixSocketPath); ok { - continue - } - switch service := rule.Service.(type) { - - case UnixSocketPath: - continue - - case *HelloWorld: - continue - - default: - if err := validation.ValidateHTTPService(service.Address(), rule.Hostname, rule.HTTPTransport); err != nil { - logger.Errorf("unable to connect to the origin: %s", err) - } - } - } - return nil -} - // FindMatchingRule returns the index of the Ingress Rule which matches the given // hostname and path. This function assumes the last rule matches everything, // which is the case if the rules were instantiated via the ingress#Validate method @@ -154,14 +71,13 @@ func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Servi }, defaults: originRequestFromSingeRule(c), } - err = ing.setHTTPTransport(logger) return ing, err } // Get a single origin service from the CLI/config. func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) { if c.IsSet("hello-world") { - return new(HelloWorld), nil + return new(helloWorld), nil } if c.IsSet("url") { originURLStr, err := config.ValidateUrl(c, compatibilityMode) @@ -172,14 +88,14 @@ func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginSer if err != nil { return nil, errors.Wrap(err, "couldn't parse origin URL") } - return &URL{URL: originURL, RootURL: originURL}, nil + return &localService{URL: originURL, RootURL: originURL}, nil } if c.IsSet("unix-socket") { - unixSocket, err := config.ValidateUnixSocket(c) + path, err := config.ValidateUnixSocket(c) if err != nil { return nil, errors.Wrap(err, "Error validating --unix-socket") } - return UnixSocketPath(unixSocket), nil + return &unixSocketPath{path: path}, nil } return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket") } @@ -192,7 +108,7 @@ func (ing Ingress) IsEmpty() bool { // StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World. func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error { for _, rule := range ing.Rules { - if err := rule.Service.Start(wg, log, shutdownC, errC, rule.Config); err != nil { + if err := rule.Service.start(wg, log, shutdownC, errC, rule.Config); err != nil { return err } } @@ -209,11 +125,12 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon for i, r := range ingress { var service OriginService - if strings.HasPrefix(r.Service, "unix:") { + if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) { // No validation necessary for unix socket filepath services - service = UnixSocketPath(strings.TrimPrefix(r.Service, "unix:")) + path := strings.TrimPrefix(r.Service, prefix) + service = &unixSocketPath{path: path} } else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" { - service = new(HelloWorld) + service = new(helloWorld) } else { // Validate URL services u, err := url.Parse(r.Service) @@ -228,7 +145,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon if u.Path != "" { return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service) } - serviceURL := URL{URL: u} + serviceURL := localService{URL: u} service = &serviceURL } @@ -262,7 +179,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon Hostname: r.Hostname, Service: service, Path: pathRegex, - Config: SetConfig(defaults, r.OriginRequest), + Config: setConfig(defaults, r.OriginRequest), } } return Ingress{Rules: rules, defaults: defaults}, nil @@ -279,20 +196,10 @@ func (e errRuleShouldNotBeCatchAll) Error() string { "will never be triggered.", e.i+1, e.hostname) } -// ParseIngress parses, validates and initializes HTTP transports to each origin. -func ParseIngress(conf *config.Configuration, logger logger.Service) (Ingress, error) { - ing, err := ParseIngressDryRun(conf) - if err != nil { - return Ingress{}, err - } - err = ing.setHTTPTransport(logger) - return ing, err -} - -// ParseIngressDryRun parses ingress rules, but does not send HTTP requests to the origins. -func ParseIngressDryRun(conf *config.Configuration) (Ingress, error) { +// ParseIngress parses ingress rules, but does not send HTTP requests to the origins. +func ParseIngress(conf *config.Configuration) (Ingress, error) { if len(conf.Ingress) == 0 { return Ingress{}, ErrNoIngressRules } - return validate(conf.Ingress, OriginRequestFromYAML(conf.OriginRequest)) + return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest)) } diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 5a426bf5..7c98bbb3 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -4,7 +4,6 @@ import ( "net/url" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" @@ -16,16 +15,16 @@ func TestParseUnixSocket(t *testing.T) { ingress: - service: unix:/tmp/echo.sock ` - ing, err := ParseIngressDryRun(MustReadIngress(rawYAML)) + ing, err := ParseIngress(MustReadIngress(rawYAML)) require.NoError(t, err) - _, ok := ing.Rules[0].Service.(UnixSocketPath) + _, ok := ing.Rules[0].Service.(*unixSocketPath) require.True(t, ok) } func Test_parseIngress(t *testing.T) { localhost8000 := MustParseURL(t, "https://localhost:8000") localhost8001 := MustParseURL(t, "https://localhost:8001") - defaultConfig := SetConfig(OriginRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{}) + defaultConfig := setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{}) require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections) type args struct { rawYAML string @@ -53,12 +52,12 @@ ingress: want: []Rule{ { Hostname: "tunnel1.example.com", - Service: &URL{URL: localhost8000}, + Service: &localService{URL: localhost8000}, Config: defaultConfig, }, { Hostname: "*", - Service: &URL{URL: localhost8001}, + Service: &localService{URL: localhost8001}, Config: defaultConfig, }, }, @@ -74,7 +73,7 @@ extraKey: extraValue want: []Rule{ { Hostname: "*", - Service: &URL{URL: localhost8000}, + Service: &localService{URL: localhost8000}, Config: defaultConfig, }, }, @@ -87,7 +86,7 @@ ingress: `}, want: []Rule{ { - Service: &URL{URL: localhost8000}, + Service: &localService{URL: localhost8000}, Config: defaultConfig, }, }, @@ -165,15 +164,37 @@ ingress: `}, wantErr: true, }, + { + name: "Invalid HTTP status", + args: args{rawYAML: ` +ingress: + - service: http_status:asdf +`}, + wantErr: true, + }, + { + name: "Valid hello world service", + args: args{rawYAML: ` +ingress: + - service: hello_world +`}, + want: []Rule{ + { + Hostname: "", + Service: new(helloWorld), + Config: defaultConfig, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ParseIngressDryRun(MustReadIngress(tt.args.rawYAML)) + got, err := ParseIngress(MustReadIngress(tt.args.rawYAML)) if (err != nil) != tt.wantErr { - t.Errorf("ParseIngressDryRun() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr) return } - assert.Equal(t, tt.want, got.Rules) + require.Equal(t, tt.want, got.Rules) }) } } @@ -195,7 +216,7 @@ ingress: service: https://localhost:8002 ` - ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) + ing, err := ParseIngress(MustReadIngress(rulesYAML)) if err != nil { b.Error(err) } diff --git a/ingress/origin_request_config.go b/ingress/origin_request_config.go index 03e0fcab..3180fd80 100644 --- a/ingress/origin_request_config.go +++ b/ingress/origin_request_config.go @@ -116,7 +116,7 @@ func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { } } -func OriginRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig { +func originRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig { out := OriginRequestConfig{ ConnectTimeout: defaultConnectTimeout, TLSTimeout: defaultTLSTimeout, @@ -310,7 +310,7 @@ func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequest // 3. Defaults chosen by the cloudflared team // 4. Golang zero values for that type // If an earlier option isn't set, it will try the next option down. -func SetConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig { +func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig { cfg := defaults cfg.setConnectTimeout(overrides) cfg.setTLSTimeout(overrides) diff --git a/ingress/origin_request_config_test.go b/ingress/origin_request_config_test.go index 4b874bff..ab33896f 100644 --- a/ingress/origin_request_config_test.go +++ b/ingress/origin_request_config_test.go @@ -71,7 +71,7 @@ ingress: proxyPort: 200 proxyType: "" ` - ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) + ing, err := ParseIngress(MustReadIngress(rulesYAML)) if err != nil { t.Error(err) } @@ -144,7 +144,7 @@ ingress: proxyPort: 200 proxyType: "" ` - ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML)) + ing, err := ParseIngress(MustReadIngress(rulesYAML)) if err != nil { t.Error(err) } diff --git a/ingress/origin_service.go b/ingress/origin_service.go index 89b4a57e..7ac32bf4 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -1,72 +1,103 @@ package ingress import ( + "context" + "crypto/tls" "fmt" "net" "net/http" "net/url" "strconv" "sync" + "time" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/socks" + "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/websocket" + gws "github.com/gorilla/websocket" "github.com/pkg/errors" ) // OriginService is something a tunnel can proxy traffic to. type OriginService interface { - Address() string + // RoundTrip is how cloudflared proxies eyeball requests to the actual origin services + http.RoundTripper + String() string // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // If it's not managed by cloudflared, this is a no-op because the user is responsible for // starting the origin service. - Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error - String() string - // RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply - // this particular type of origin service's specific routing logic. - RewriteOriginURL(*url.URL) + start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error } -// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP) -type UnixSocketPath string - -func (o UnixSocketPath) Address() string { - return string(o) +// unixSocketPath is an OriginService representing a unix socket (which accepts HTTP) +type unixSocketPath struct { + path string + transport *http.Transport } -func (o UnixSocketPath) String() string { - return "unix socket: " + string(o) +func (o *unixSocketPath) String() string { + return "unix socket: " + o.path } -func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *unixSocketPath) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + transport, err := newHTTPTransport(o, cfg) + if err != nil { + return err + } + o.transport = transport return nil } -func (o UnixSocketPath) RewriteOriginURL(u *url.URL) { - // No changes necessary because the origin request URL isn't used. - // Instead, HTTPTransport's dial is already configured to address the unix socket. +func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { + return o.transport.RoundTrip(req) } -// URL is an OriginService listening on a TCP address -type URL struct { +func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { + d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} + return d.Dial(url, headers) +} + +// localService is an OriginService listening on a TCP/IP address the user's origin can route to. +type localService struct { // The URL for the user's origin service RootURL *url.URL // The URL that cloudflared should send requests to. // If this origin requires starting a proxy, this is the proxy's address, // and that proxy points to RootURL. Otherwise, this is equal to RootURL. - URL *url.URL + URL *url.URL + transport *http.Transport } -func (o *URL) Address() string { +func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { + d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} + return d.Dial(url, headers) +} + +func (o *localService) address() string { return o.URL.String() } -func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { - staticHost := o.staticHost() - if !originRequiresProxy(staticHost, cfg) { - return nil +func (o *localService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + transport, err := newHTTPTransport(o, cfg) + if err != nil { + return err } + o.transport = transport + + // Start a proxy if one is needed + staticHost := o.staticHost() + if originRequiresProxy(staticHost, cfg) { + if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil { + return err + } + } + + return nil +} + +func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { // Start a listener for the proxy proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort))) @@ -111,16 +142,18 @@ func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan str return nil } -func (o *URL) String() string { - return o.Address() +func (o *localService) String() string { + return o.address() } -func (o *URL) RewriteOriginURL(u *url.URL) { - u.Host = o.URL.Host - u.Scheme = o.URL.Scheme +func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the request URL so that it goes to the origin service. + req.URL.Host = o.URL.Host + req.URL.Scheme = o.URL.Scheme + return o.transport.RoundTrip(req) } -func (o *URL) staticHost() string { +func (o *localService) staticHost() string { addPortIfMissing := func(uri *url.URL, port int) string { if uri.Port() != "" { @@ -143,21 +176,24 @@ func (o *URL) staticHost() string { } -// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared. -type HelloWorld struct { - server net.Listener +// HelloWorld is an OriginService for the built-in Hello World server. +// Users only use this for testing and experimenting with cloudflared. +type helloWorld struct { + server net.Listener + transport *http.Transport } -func (o *HelloWorld) Address() string { - return o.server.Addr().String() -} - -func (o *HelloWorld) String() string { - return "Hello World static HTML service" +func (o *helloWorld) String() string { + return "Hello World test origin" } // Start starts a HelloWorld server and stores its address in the Service receiver. -func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *helloWorld) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + transport, err := newHTTPTransport(o, cfg) + if err != nil { + return err + } + o.transport = transport helloListener, err := hello.CreateTLSListener("127.0.0.1:") if err != nil { return errors.Wrap(err, "Cannot start Hello World Server") @@ -171,11 +207,63 @@ func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-c return nil } -func (o *HelloWorld) RewriteOriginURL(u *url.URL) { - u.Host = o.Address() - u.Scheme = "https" +func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the request URL so that it goes to the Hello World server. + req.URL.Host = o.server.Addr().String() + req.URL.Scheme = "https" + return o.transport.RoundTrip(req) +} + +func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { + d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} + return d.Dial(url, headers) } func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool { return staticHost != "" || cfg.BastionMode } + +func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Transport, error) { + originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil) + if err != nil { + return nil, errors.Wrap(err, "Error loading cert pool") + } + + httpTransport := http.Transport{ + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: cfg.KeepAliveConnections, + MaxIdleConnsPerHost: cfg.KeepAliveConnections, + IdleConnTimeout: cfg.KeepAliveTimeout, + TLSHandshakeTimeout: cfg.TLSTimeout, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify}, + } + if _, isHelloWorld := service.(*helloWorld); !isHelloWorld && cfg.OriginServerName != "" { + httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName + } + + dialer := &net.Dialer{ + Timeout: cfg.ConnectTimeout, + KeepAlive: cfg.TCPKeepAlive, + } + if cfg.NoHappyEyeballs { + dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs" + } + + // DialContext depends on which kind of origin is being used. + dialContext := dialer.DialContext + switch service := service.(type) { + + // If this origin is a unix socket, enforce network type "unix". + case *unixSocketPath: + httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialContext(ctx, "unix", service.path) + } + + // Otherwise, use the regular network config. + default: + httpTransport.DialContext = dialContext + } + + return &httpTransport, nil +} diff --git a/ingress/rule.go b/ingress/rule.go index c47b8bb9..e91b4139 100644 --- a/ingress/rule.go +++ b/ingress/rule.go @@ -1,8 +1,6 @@ package ingress import ( - "crypto/tls" - "net/http" "regexp" "strings" ) @@ -23,11 +21,6 @@ type Rule struct { // Configure the request cloudflared sends to this specific origin. Config OriginRequestConfig - - // Configures TLS for the cloudflared -> origin request - ClientTLSConfig *tls.Config - // Configures HTTP for the cloudflared -> origin request - HTTPTransport http.RoundTripper } // MultiLineString is for outputting rules in a human-friendly way when Cloudflared diff --git a/origin/tunnel.go b/origin/tunnel.go index 7655ff10..4558f3c4 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -698,7 +698,6 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, } h.AppendTagHeaders(req) rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) - rule.Service.RewriteOriginURL(req.URL) return req, rule, nil } @@ -708,7 +707,11 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ req.Host = hostHeader } - conn, response, err := websocket.ClientConnect(req, rule.ClientTLSConfig) + dialler, ok := rule.Service.(websocket.Dialler) + if !ok { + return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service) + } + conn, response, err := websocket.ClientConnect(req, dialler) if err != nil { return nil, err } @@ -742,7 +745,7 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, req.Host = hostHeader } - response, err := rule.HTTPTransport.RoundTrip(req) + response, err := rule.Service.RoundTrip(req) if err != nil { return nil, errors.Wrap(err, "Error proxying request to origin") } diff --git a/websocket/websocket.go b/websocket/websocket.go index 961a10f9..26367113 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -69,15 +69,31 @@ func IsWebSocketUpgrade(req *http.Request) bool { return websocket.IsWebSocketUpgrade(req) } +// Dialler is something that can proxy websocket requests. +type Dialler interface { + Dial(url string, headers http.Header) (*websocket.Conn, *http.Response, error) +} + +type defaultDialler struct { + tlsConfig *tls.Config +} + +func (dd *defaultDialler) Dial(url string, header http.Header) (*websocket.Conn, *http.Response, error) { + d := &websocket.Dialer{TLSClientConfig: dd.tlsConfig} + return d.Dial(url, header) +} + // ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing // the connection. The response body may not contain the entire response and does // not need to be closed by the application. -func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) { +func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) { req.URL.Scheme = changeRequestScheme(req) wsHeaders := websocketHeaders(req) - d := &websocket.Dialer{TLSClientConfig: tlsClientConfig} - conn, response, err := d.Dial(req.URL.String(), wsHeaders) + if dialler == nil { + dialler = new(defaultDialler) + } + conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) if err != nil { return nil, response, err } diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 615753a0..3a94ed51 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -77,7 +77,8 @@ func TestServe(t *testing.T) { tlsConfig := websocketClientTLSConfig(t) assert.NotNil(t, tlsConfig) - conn, resp, err := ClientConnect(req, tlsConfig) + d := defaultDialler{tlsConfig: tlsConfig} + conn, resp, err := ClientConnect(req, &d) assert.NoError(t, err) assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))