From 83c6c8713b722d5670cca3dc74e72dd17ea7c80f Mon Sep 17 00:00:00 2001 From: Nick Vollmar Date: Mon, 29 Oct 2018 11:57:58 -0500 Subject: [PATCH] TUN-1160: pass Host header during origin url validation --- cmd/cloudflared/tunnel/configuration.go | 2 +- validation/validation.go | 22 ++++++++---- validation/validation_test.go | 48 ++++++++++++++++++++----- 3 files changed, 57 insertions(+), 15 deletions(-) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 776b9f92..c08a9e5d 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -192,7 +192,7 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") } - err = validation.ValidateHTTPService(originURL, httpTransport) + err = validation.ValidateHTTPService(originURL, hostname, httpTransport) if err != nil { logger.WithError(err).Error("unable to connect to the origin") return nil, errors.Wrap(err, "unable to connect to the origin") diff --git a/validation/validation.go b/validation/validation.go index 71251b94..0fdeb143 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -6,9 +6,10 @@ import ( "net/url" "strings" + "net/http" + "github.com/pkg/errors" "golang.org/x/net/idna" - "net/http" ) const defaultScheme = "http" @@ -137,7 +138,7 @@ func validateIP(scheme, host, port string) (string, error) { return fmt.Sprintf("%s://%s", scheme, host), nil } -func ValidateHTTPService(originURL string, transport http.RoundTripper) error { +func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error { parsedURL, err := url.Parse(originURL) if err != nil { return err @@ -145,18 +146,27 @@ func ValidateHTTPService(originURL string, transport http.RoundTripper) error { client := &http.Client{Transport: transport} - initialResponse, initialErr := client.Get(parsedURL.String()) + initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) + if err != nil { + return err + } + initialRequest.Host = hostname + initialResponse, initialErr := client.Do(initialRequest) if initialErr != nil || initialResponse.StatusCode != http.StatusOK { // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? oldScheme := parsedURL.Scheme parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) - secondResponse, _ := client.Get(parsedURL.String()) - + secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil) + if err != nil { + return err + } + secondRequest.Host = hostname + secondResponse, _ := client.Do(secondRequest) if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols return errors.Errorf( "%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s", - parsedURL.Hostname(), + parsedURL.Host, oldScheme, parsedURL.Scheme, parsedURL, diff --git a/validation/validation_test.go b/validation/validation_test.go index e4f381f7..01a68cac 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -7,12 +7,13 @@ import ( "context" "crypto/tls" "crypto/x509" - "github.com/stretchr/testify/assert" "net" "net/http" "net/http/httptest" "net/url" "strings" + + "github.com/stretchr/testify/assert" ) func TestValidateHostname(t *testing.T) { @@ -151,57 +152,88 @@ func TestToggleProtocol(t *testing.T) { } func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { + originURL := "http://127.0.0.1/" + hostname := "example.com" server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, hostname, r.Host) w.WriteHeader(200) })) assert.NoError(t, err) defer server.Close() - assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport)) + assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) } func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) { + originURL := "http://127.0.0.1/" + hostname := "example.com" server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:443", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } w.WriteHeader(400) })) assert.NoError(t, err) defer server.Close() - assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport)) + assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) } func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { + originURL := "https://127.0.0.1:1234/" + hostname := "example.com" server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:1234", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } w.WriteHeader(200) })) assert.NoError(t, err) defer server.Close() assert.Equal(t, - "example.com doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://example.com:1234/", - ValidateHTTPService("https://example.com:1234/", client.Transport).Error()) + "127.0.0.1:1234 doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://127.0.0.1:1234/", + ValidateHTTPService(originURL, hostname, client.Transport).Error()) } func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { + originURL := "https://127.0.0.1/" + hostname := "example.com" server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:443", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } w.WriteHeader(200) })) assert.NoError(t, err) defer server.Close() - assert.Equal(t, nil, ValidateHTTPService("https://example.com/", client.Transport)) + assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) } func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { + originURL := "http://127.0.0.1:1234/" + hostname := "example.com" server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:1234", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } w.WriteHeader(200) })) assert.NoError(t, err) defer server.Close() assert.Equal(t, - "example.com doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://example.com:1234/", - ValidateHTTPService("http://example.com:1234/", client.Transport).Error()) + "127.0.0.1:1234 doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://127.0.0.1:1234/", + ValidateHTTPService(originURL, hostname, client.Transport).Error()) } func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {