TUN-1160: pass Host header during origin url validation

This commit is contained in:
Nick Vollmar 2018-10-29 11:57:58 -05:00
parent 2b820d790c
commit 83c6c8713b
3 changed files with 57 additions and 15 deletions

View File

@ -192,7 +192,7 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
} }
err = validation.ValidateHTTPService(originURL, httpTransport) err = validation.ValidateHTTPService(originURL, hostname, httpTransport)
if err != nil { if err != nil {
logger.WithError(err).Error("unable to connect to the origin") logger.WithError(err).Error("unable to connect to the origin")
return nil, errors.Wrap(err, "unable to connect to the origin") return nil, errors.Wrap(err, "unable to connect to the origin")

View File

@ -6,9 +6,10 @@ import (
"net/url" "net/url"
"strings" "strings"
"net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/net/idna" "golang.org/x/net/idna"
"net/http"
) )
const defaultScheme = "http" const defaultScheme = "http"
@ -137,7 +138,7 @@ func validateIP(scheme, host, port string) (string, error) {
return fmt.Sprintf("%s://%s", scheme, host), nil return fmt.Sprintf("%s://%s", scheme, host), nil
} }
func ValidateHTTPService(originURL string, transport http.RoundTripper) error { func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
parsedURL, err := url.Parse(originURL) parsedURL, err := url.Parse(originURL)
if err != nil { if err != nil {
return err return err
@ -145,18 +146,27 @@ func ValidateHTTPService(originURL string, transport http.RoundTripper) error {
client := &http.Client{Transport: transport} client := &http.Client{Transport: transport}
initialResponse, initialErr := client.Get(parsedURL.String()) initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil {
return err
}
initialRequest.Host = hostname
initialResponse, initialErr := client.Do(initialRequest)
if initialErr != nil || initialResponse.StatusCode != http.StatusOK { if initialErr != nil || initialResponse.StatusCode != http.StatusOK {
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := parsedURL.Scheme oldScheme := parsedURL.Scheme
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
secondResponse, _ := client.Get(parsedURL.String()) secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil {
return err
}
secondRequest.Host = hostname
secondResponse, _ := client.Do(secondRequest)
if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols
return errors.Errorf( return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s", "%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
parsedURL.Hostname(), parsedURL.Host,
oldScheme, oldScheme,
parsedURL.Scheme, parsedURL.Scheme,
parsedURL, parsedURL,

View File

@ -7,12 +7,13 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"github.com/stretchr/testify/assert"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"github.com/stretchr/testify/assert"
) )
func TestValidateHostname(t *testing.T) { func TestValidateHostname(t *testing.T) {
@ -151,57 +152,88 @@ func TestToggleProtocol(t *testing.T) {
} }
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
originURL := "http://127.0.0.1/"
hostname := "example.com"
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, hostname, r.Host)
w.WriteHeader(200) w.WriteHeader(200)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport)) assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
} }
func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) { func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) {
originURL := "http://127.0.0.1/"
hostname := "example.com"
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.WriteHeader(400) w.WriteHeader(400)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport)) assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
} }
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := "https://127.0.0.1:1234/"
hostname := "example.com"
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:1234", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.WriteHeader(200) w.WriteHeader(200)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Equal(t, assert.Equal(t,
"example.com doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://example.com:1234/", "127.0.0.1:1234 doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://127.0.0.1:1234/",
ValidateHTTPService("https://example.com:1234/", client.Transport).Error()) ValidateHTTPService(originURL, hostname, client.Transport).Error())
} }
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
originURL := "https://127.0.0.1/"
hostname := "example.com"
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.WriteHeader(200) w.WriteHeader(200)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Equal(t, nil, ValidateHTTPService("https://example.com/", client.Transport)) assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
} }
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
originURL := "http://127.0.0.1:1234/"
hostname := "example.com"
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:1234", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.WriteHeader(200) w.WriteHeader(200)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Equal(t, assert.Equal(t,
"example.com doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://example.com:1234/", "127.0.0.1:1234 doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://127.0.0.1:1234/",
ValidateHTTPService("http://example.com:1234/", client.Transport).Error()) ValidateHTTPService(originURL, hostname, client.Transport).Error())
} }
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) { func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {