TUN-868: HTTP/HTTPS mismatch should have a better error message
This commit is contained in:
parent
0e6492342b
commit
dd2b5e4f3e
|
@ -87,7 +87,7 @@ func Flags() []cli.Flag {
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "url",
|
Name: "url",
|
||||||
Value: "https://localhost:8080",
|
Value: "http://localhost:8080",
|
||||||
Usage: "Connect to the local webserver at `URL`.",
|
Usage: "Connect to the local webserver at `URL`.",
|
||||||
EnvVars: []string{"TUNNEL_URL"},
|
EnvVars: []string{"TUNNEL_URL"},
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -171,12 +171,12 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st
|
||||||
|
|
||||||
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
||||||
|
|
||||||
url, err := validateUrl(c)
|
originURL, err := validateUrl(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Error validating url")
|
logger.WithError(err).Error("Error validating origin URL")
|
||||||
return nil, errors.Wrap(err, "Error validating url")
|
return nil, errors.Wrap(err, "Error validating origin URL")
|
||||||
}
|
}
|
||||||
logger.Infof("Proxying tunnel requests to %s", url)
|
logger.Infof("Proxying tunnel requests to %s", originURL)
|
||||||
|
|
||||||
originCert, err := getOriginCert(c)
|
originCert, err := getOriginCert(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -208,9 +208,15 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st
|
||||||
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
|
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = validation.ValidateHTTPService(originURL, httpTransport)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("unable to connect to the origin")
|
||||||
|
return nil, errors.Wrap(err, "unable to connect to the origin")
|
||||||
|
}
|
||||||
|
|
||||||
return &origin.TunnelConfig{
|
return &origin.TunnelConfig{
|
||||||
EdgeAddrs: c.StringSlice("edge"),
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
OriginUrl: url,
|
OriginUrl: originURL,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
OriginCert: originCert,
|
OriginCert: originCert,
|
||||||
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
||||||
|
|
|
@ -464,7 +464,7 @@ func NewTunnelHandler(ctx context.Context,
|
||||||
) (*TunnelHandler, string, error) {
|
) (*TunnelHandler, string, error) {
|
||||||
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL)
|
return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL)
|
||||||
}
|
}
|
||||||
h := &TunnelHandler{
|
h := &TunnelHandler{
|
||||||
originUrl: originURL,
|
originUrl: originURL,
|
||||||
|
|
|
@ -6,7 +6,9 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultScheme = "http"
|
const defaultScheme = "http"
|
||||||
|
@ -48,7 +50,7 @@ func ValidateHostname(hostname string) (string, error) {
|
||||||
|
|
||||||
func ValidateUrl(originUrl string) (string, error) {
|
func ValidateUrl(originUrl string) (string, error) {
|
||||||
if originUrl == "" {
|
if originUrl == "" {
|
||||||
return "", fmt.Errorf("Url should not be empty")
|
return "", fmt.Errorf("URL should not be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if net.ParseIP(originUrl) != nil {
|
if net.ParseIP(originUrl) != nil {
|
||||||
|
@ -134,3 +136,44 @@ func validateIP(scheme, host, port string) (string, error) {
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s://%s", scheme, host), nil
|
return fmt.Sprintf("%s://%s", scheme, host), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ValidateHTTPService(originURL string, transport http.RoundTripper) error {
|
||||||
|
parsedURL, err := url.Parse(originURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
|
initialResponse, initialErr := client.Get(parsedURL.String())
|
||||||
|
if initialErr != nil || initialResponse.StatusCode != http.StatusOK {
|
||||||
|
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
||||||
|
oldScheme := parsedURL.Scheme
|
||||||
|
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
|
||||||
|
|
||||||
|
secondResponse, _ := client.Get(parsedURL.String())
|
||||||
|
|
||||||
|
if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols
|
||||||
|
return errors.Errorf(
|
||||||
|
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
||||||
|
parsedURL.Hostname(),
|
||||||
|
oldScheme,
|
||||||
|
parsedURL.Scheme,
|
||||||
|
parsedURL,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return initialErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func toggleProtocol(httpProtocol string) string {
|
||||||
|
switch httpProtocol {
|
||||||
|
case "http":
|
||||||
|
return "https"
|
||||||
|
case "https":
|
||||||
|
return "http"
|
||||||
|
default:
|
||||||
|
return httpProtocol
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,15 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateHostname(t *testing.T) {
|
func TestValidateHostname(t *testing.T) {
|
||||||
|
@ -42,7 +50,7 @@ func TestValidateHostname(t *testing.T) {
|
||||||
|
|
||||||
func TestValidateUrl(t *testing.T) {
|
func TestValidateUrl(t *testing.T) {
|
||||||
validUrl, err := ValidateUrl("")
|
validUrl, err := ValidateUrl("")
|
||||||
assert.Equal(t, fmt.Errorf("Url should not be empty"), err)
|
assert.Equal(t, fmt.Errorf("URL should not be empty"), err)
|
||||||
assert.Empty(t, validUrl)
|
assert.Empty(t, validUrl)
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://localhost:8080")
|
validUrl, err = ValidateUrl("https://localhost:8080")
|
||||||
|
@ -134,3 +142,102 @@ func TestValidateUrl(t *testing.T) {
|
||||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToggleProtocol(t *testing.T) {
|
||||||
|
assert.Equal(t, "https", toggleProtocol("http"))
|
||||||
|
assert.Equal(t, "http", toggleProtocol("https"))
|
||||||
|
assert.Equal(t, "random", toggleProtocol("random"))
|
||||||
|
assert.Equal(t, "", toggleProtocol(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
||||||
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) {
|
||||||
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(400)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
||||||
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
"example.com doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://example.com:1234/",
|
||||||
|
ValidateHTTPService("https://example.com:1234/", client.Transport).Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||||
|
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, nil, ValidateHTTPService("https://example.com/", client.Transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||||
|
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
assert.Equal(t,
|
||||||
|
"example.com doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://example.com:1234/",
|
||||||
|
ValidateHTTPService("http://example.com:1234/", client.Transport).Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
|
||||||
|
client := http.DefaultClient
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||||
|
return url.Parse(server.URL)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
|
||||||
|
client := http.DefaultClient
|
||||||
|
server := httptest.NewTLSServer(handler)
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
server.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certpool := x509.NewCertPool()
|
||||||
|
certpool.AddCert(cert)
|
||||||
|
|
||||||
|
client.Transport = &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial("tcp", server.URL[strings.LastIndex(server.URL, "/")+1:])
|
||||||
|
},
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: certpool,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return server, client, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue