TUN-1160: pass Host header during origin url validation
This commit is contained in:
parent
2b820d790c
commit
83c6c8713b
|
@ -192,7 +192,7 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, version st
|
||||||
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
|
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validation.ValidateHTTPService(originURL, httpTransport)
|
err = validation.ValidateHTTPService(originURL, hostname, httpTransport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("unable to connect to the origin")
|
logger.WithError(err).Error("unable to connect to the origin")
|
||||||
return nil, errors.Wrap(err, "unable to connect to the origin")
|
return nil, errors.Wrap(err, "unable to connect to the origin")
|
||||||
|
|
|
@ -6,9 +6,10 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
"net/http"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultScheme = "http"
|
const defaultScheme = "http"
|
||||||
|
@ -137,7 +138,7 @@ func validateIP(scheme, host, port string) (string, error) {
|
||||||
return fmt.Sprintf("%s://%s", scheme, host), nil
|
return fmt.Sprintf("%s://%s", scheme, host), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateHTTPService(originURL string, transport http.RoundTripper) error {
|
func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
|
||||||
parsedURL, err := url.Parse(originURL)
|
parsedURL, err := url.Parse(originURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -145,18 +146,27 @@ func ValidateHTTPService(originURL string, transport http.RoundTripper) error {
|
||||||
|
|
||||||
client := &http.Client{Transport: transport}
|
client := &http.Client{Transport: transport}
|
||||||
|
|
||||||
initialResponse, initialErr := client.Get(parsedURL.String())
|
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
initialRequest.Host = hostname
|
||||||
|
initialResponse, initialErr := client.Do(initialRequest)
|
||||||
if initialErr != nil || initialResponse.StatusCode != http.StatusOK {
|
if initialErr != nil || initialResponse.StatusCode != http.StatusOK {
|
||||||
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
||||||
oldScheme := parsedURL.Scheme
|
oldScheme := parsedURL.Scheme
|
||||||
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
|
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
|
||||||
|
|
||||||
secondResponse, _ := client.Get(parsedURL.String())
|
secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
secondRequest.Host = hostname
|
||||||
|
secondResponse, _ := client.Do(secondRequest)
|
||||||
if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols
|
if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols
|
||||||
return errors.Errorf(
|
return errors.Errorf(
|
||||||
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
||||||
parsedURL.Hostname(),
|
parsedURL.Host,
|
||||||
oldScheme,
|
oldScheme,
|
||||||
parsedURL.Scheme,
|
parsedURL.Scheme,
|
||||||
parsedURL,
|
parsedURL,
|
||||||
|
|
|
@ -7,12 +7,13 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateHostname(t *testing.T) {
|
func TestValidateHostname(t *testing.T) {
|
||||||
|
@ -151,57 +152,88 @@ func TestToggleProtocol(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
||||||
|
originURL := "http://127.0.0.1/"
|
||||||
|
hostname := "example.com"
|
||||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport))
|
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) {
|
func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) {
|
||||||
|
originURL := "http://127.0.0.1/"
|
||||||
|
hostname := "example.com"
|
||||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
w.WriteHeader(400)
|
w.WriteHeader(400)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
assert.Equal(t, nil, ValidateHTTPService("http://example.com/", client.Transport))
|
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
||||||
|
originURL := "https://127.0.0.1:1234/"
|
||||||
|
hostname := "example.com"
|
||||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
assert.Equal(t,
|
assert.Equal(t,
|
||||||
"example.com doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://example.com:1234/",
|
"127.0.0.1:1234 doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://127.0.0.1:1234/",
|
||||||
ValidateHTTPService("https://example.com:1234/", client.Transport).Error())
|
ValidateHTTPService(originURL, hostname, client.Transport).Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||||
|
originURL := "https://127.0.0.1/"
|
||||||
|
hostname := "example.com"
|
||||||
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
assert.Equal(t, nil, ValidateHTTPService("https://example.com/", client.Transport))
|
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||||
|
originURL := "http://127.0.0.1:1234/"
|
||||||
|
hostname := "example.com"
|
||||||
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
assert.Equal(t,
|
assert.Equal(t,
|
||||||
"example.com doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://example.com:1234/",
|
"127.0.0.1:1234 doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://127.0.0.1:1234/",
|
||||||
ValidateHTTPService("http://example.com:1234/", client.Transport).Error())
|
ValidateHTTPService(originURL, hostname, client.Transport).Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
|
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
|
||||||
|
|
Loading…
Reference in New Issue