From dd2b5e4f3e0355065f594c052cae59c4b02f1f6b Mon Sep 17 00:00:00 2001 From: Areg Harutyunyan Date: Tue, 11 Sep 2018 11:44:16 -0500 Subject: [PATCH] TUN-868: HTTP/HTTPS mismatch should have a better error message --- cmd/cloudflared/tunnel/cmd.go | 2 +- cmd/cloudflared/tunnel/configuration.go | 16 ++-- origin/tunnel.go | 2 +- validation/validation.go | 45 +++++++++- validation/validation_test.go | 109 +++++++++++++++++++++++- 5 files changed, 165 insertions(+), 9 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 25347c7b..459df3b4 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -87,7 +87,7 @@ func Flags() []cli.Flag { }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "url", - Value: "https://localhost:8080", + Value: "http://localhost:8080", Usage: "Connect to the local webserver at `URL`.", EnvVars: []string{"TUNNEL_URL"}, }), diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index fd43146e..961e55c2 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -171,12 +171,12 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) - url, err := validateUrl(c) + originURL, err := validateUrl(c) if err != nil { - logger.WithError(err).Error("Error validating url") - return nil, errors.Wrap(err, "Error validating url") + logger.WithError(err).Error("Error validating origin URL") + return nil, errors.Wrap(err, "Error validating origin URL") } - logger.Infof("Proxying tunnel requests to %s", url) + logger.Infof("Proxying tunnel requests to %s", originURL) originCert, err := getOriginCert(c) if err != nil { @@ -208,9 +208,15 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") } + err = validation.ValidateHTTPService(originURL, httpTransport) + if err != nil { + logger.WithError(err).Error("unable to connect to the origin") + return nil, errors.Wrap(err, "unable to connect to the origin") + } + return &origin.TunnelConfig{ EdgeAddrs: c.StringSlice("edge"), - OriginUrl: url, + OriginUrl: originURL, Hostname: hostname, OriginCert: originCert, TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), diff --git a/origin/tunnel.go b/origin/tunnel.go index 543601b3..57ebd9aa 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -464,7 +464,7 @@ func NewTunnelHandler(ctx context.Context, ) (*TunnelHandler, string, error) { originURL, err := validation.ValidateUrl(config.OriginUrl) if err != nil { - return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL) + return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL) } h := &TunnelHandler{ originUrl: originURL, diff --git a/validation/validation.go b/validation/validation.go index be254419..71251b94 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -6,7 +6,9 @@ import ( "net/url" "strings" + "github.com/pkg/errors" "golang.org/x/net/idna" + "net/http" ) const defaultScheme = "http" @@ -48,7 +50,7 @@ func ValidateHostname(hostname string) (string, error) { func ValidateUrl(originUrl string) (string, error) { if originUrl == "" { - return "", fmt.Errorf("Url should not be empty") + return "", fmt.Errorf("URL should not be empty") } if net.ParseIP(originUrl) != nil { @@ -134,3 +136,44 @@ func validateIP(scheme, host, port string) (string, error) { } return fmt.Sprintf("%s://%s", scheme, host), nil } + +func ValidateHTTPService(originURL string, transport http.RoundTripper) error { + parsedURL, err := url.Parse(originURL) + if err != nil { + return err + } + + client := &http.Client{Transport: transport} + + initialResponse, initialErr := client.Get(parsedURL.String()) + if initialErr != nil || initialResponse.StatusCode != http.StatusOK { + // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? + oldScheme := parsedURL.Scheme + parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) + + secondResponse, _ := client.Get(parsedURL.String()) + + if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols + return errors.Errorf( + "%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s", + parsedURL.Hostname(), + oldScheme, + parsedURL.Scheme, + parsedURL, + ) + } + } + + return initialErr +} + +func toggleProtocol(httpProtocol string) string { + switch httpProtocol { + case "http": + return "https" + case "https": + return "http" + default: + return httpProtocol + } +} diff --git a/validation/validation_test.go b/validation/validation_test.go index 3cd4cbd2..e4f381f7 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -4,7 +4,15 @@ import ( "fmt" "testing" + "context" + "crypto/tls" + "crypto/x509" "github.com/stretchr/testify/assert" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" ) func TestValidateHostname(t *testing.T) { @@ -42,7 +50,7 @@ func TestValidateHostname(t *testing.T) { func TestValidateUrl(t *testing.T) { validUrl, err := ValidateUrl("") - assert.Equal(t, fmt.Errorf("Url should not be empty"), err) + assert.Equal(t, fmt.Errorf("URL should not be empty"), err) assert.Empty(t, validUrl) validUrl, err = ValidateUrl("https://localhost:8080") @@ -134,3 +142,102 @@ func TestValidateUrl(t *testing.T) { assert.Equal(t, "https://hello.example.com:8080", validUrl) } + +func TestToggleProtocol(t *testing.T) { + assert.Equal(t, "https", toggleProtocol("http")) + assert.Equal(t, "http", toggleProtocol("https")) + assert.Equal(t, "random", toggleProtocol("random")) + assert.Equal(t, "", toggleProtocol("")) +} + +func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { + server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + assert.NoError(t, err) + defer server.Close() + + assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport)) +} + +func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) { + server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(400) + })) + assert.NoError(t, err) + defer server.Close() + + assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport)) +} + +func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { + server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + assert.NoError(t, err) + defer server.Close() + + assert.Equal(t, + "example.com doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://example.com:1234/", + ValidateHTTPService("https://example.com:1234/", client.Transport).Error()) +} + +func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { + server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + assert.NoError(t, err) + defer server.Close() + + assert.Equal(t, nil, ValidateHTTPService("https://example.com/", client.Transport)) +} + +func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { + server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + assert.NoError(t, err) + defer server.Close() + + assert.Equal(t, + "example.com doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://example.com:1234/", + ValidateHTTPService("http://example.com:1234/", client.Transport).Error()) +} + +func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) { + client := http.DefaultClient + server := httptest.NewServer(handler) + + client.Transport = &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + return url.Parse(server.URL) + }, + } + + return server, client, nil +} + +func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) { + client := http.DefaultClient + server := httptest.NewTLSServer(handler) + + cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0]) + if err != nil { + server.Close() + return nil, nil, err + } + + certpool := x509.NewCertPool() + certpool.AddCert(cert) + + client.Transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:]) + }, + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + } + + return server, client, nil +}