TUN-1250: ValidateHTTPService shouldn't follow 302s
This commit is contained in:
parent
446c5cf60c
commit
3e8d886c25
|
@ -144,15 +144,20 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{Transport: transport}
|
client := &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
initialRequest.Host = hostname
|
initialRequest.Host = hostname
|
||||||
initialResponse, initialErr := client.Do(initialRequest)
|
_, initialErr := client.Do(initialRequest)
|
||||||
if initialErr != nil || initialResponse.StatusCode != http.StatusOK {
|
if initialErr != nil {
|
||||||
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
||||||
oldScheme := parsedURL.Scheme
|
oldScheme := parsedURL.Scheme
|
||||||
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
|
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
|
||||||
|
@ -162,8 +167,8 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
secondRequest.Host = hostname
|
secondRequest.Host = hostname
|
||||||
secondResponse, _ := client.Do(secondRequest)
|
_, secondErr := client.Do(secondRequest)
|
||||||
if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols
|
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
||||||
return errors.Errorf(
|
return errors.Errorf(
|
||||||
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
|
||||||
parsedURL.Host,
|
parsedURL.Host,
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package validation
|
package validation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
|
@ -151,58 +153,89 @@ func TestToggleProtocol(t *testing.T) {
|
||||||
assert.Equal(t, "", toggleProtocol(""))
|
assert.Equal(t, "", toggleProtocol(""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Happy path 1: originURL is HTTP, and HTTP connections work
|
||||||
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
||||||
originURL := "http://127.0.0.1/"
|
originURL := "http://127.0.0.1/"
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
return emptyResponse(200), nil
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
t.Fatal("http works, shouldn't have tried with https")
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
return emptyResponse(503), nil
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
t.Fatal("http works, shouldn't have tried with https")
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
// Integration-style test with a mock server
|
||||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, hostname, r.Host)
|
assert.Equal(t, hostname, r.Host)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
|
|
||||||
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
|
// this will fail if the client follows the 302
|
||||||
}
|
redirectServer, redirectClient, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/followedRedirect" {
|
||||||
func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) {
|
t.Fatal("shouldn't have followed the 302")
|
||||||
originURL := "http://127.0.0.1/"
|
}
|
||||||
hostname := "example.com"
|
|
||||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method == "CONNECT" {
|
if r.Method == "CONNECT" {
|
||||||
assert.Equal(t, "127.0.0.1:443", r.Host)
|
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, hostname, r.Host)
|
assert.Equal(t, hostname, r.Host)
|
||||||
}
|
}
|
||||||
w.WriteHeader(400)
|
w.Header().Set("Location", "/followedRedirect")
|
||||||
|
w.WriteHeader(302)
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer redirectServer.Close()
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
|
||||||
|
|
||||||
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
|
||||||
originURL := "https://127.0.0.1:1234/"
|
|
||||||
hostname := "example.com"
|
|
||||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method == "CONNECT" {
|
|
||||||
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, hostname, r.Host)
|
|
||||||
}
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
assert.Equal(t,
|
|
||||||
"127.0.0.1:1234 doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://127.0.0.1:1234/",
|
|
||||||
ValidateHTTPService(originURL, hostname, client.Transport).Error())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Happy path 2: originURL is HTTPS, and HTTPS connections work
|
||||||
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||||
originURL := "https://127.0.0.1/"
|
originURL := "https://127.0.0.1/"
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
t.Fatal("https works, shouldn't have tried with http")
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
return emptyResponse(200), nil
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
t.Fatal("https works, shouldn't have tried with http")
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
return emptyResponse(503), nil
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
// Integration-style test with a mock server
|
||||||
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "CONNECT" {
|
if r.Method == "CONNECT" {
|
||||||
assert.Equal(t, "127.0.0.1:443", r.Host)
|
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||||
|
@ -213,13 +246,113 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
|
|
||||||
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
|
// this will fail if the client follows the 302
|
||||||
|
redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/followedRedirect" {
|
||||||
|
t.Fatal("shouldn't have followed the 302")
|
||||||
|
}
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
|
w.Header().Set("Location", "/followedRedirect")
|
||||||
|
w.WriteHeader(302)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer redirectServer.Close()
|
||||||
|
assert.Nil(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Error path 1: originURL is HTTPS, but HTTP connections work
|
||||||
|
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
||||||
|
originURL := "https://127.0.0.1:1234/"
|
||||||
|
hostname := "example.com"
|
||||||
|
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
return emptyResponse(200), nil
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
return nil, assert.AnError
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
return emptyResponse(503), nil
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
return nil, assert.AnError
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
// Integration-style test with a mock server
|
||||||
|
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer server.Close()
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
|
|
||||||
|
// this will fail if the client follows the 302
|
||||||
|
redirectServer, redirectClient, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/followedRedirect" {
|
||||||
|
t.Fatal("shouldn't have followed the 302")
|
||||||
|
}
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
|
w.Header().Set("Location", "/followedRedirect")
|
||||||
|
w.WriteHeader(302)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer redirectServer.Close()
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error path 2: originURL is HTTP, but HTTPS connections work
|
||||||
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||||
originURL := "http://127.0.0.1:1234/"
|
originURL := "http://127.0.0.1:1234/"
|
||||||
hostname := "example.com"
|
hostname := "example.com"
|
||||||
|
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
return nil, assert.AnError
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
return emptyResponse(200), nil
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||||
|
assert.Equal(t, req.Host, hostname)
|
||||||
|
if req.URL.Scheme == "http" {
|
||||||
|
return nil, assert.AnError
|
||||||
|
}
|
||||||
|
if req.URL.Scheme == "https" {
|
||||||
|
return emptyResponse(503), nil
|
||||||
|
}
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
})))
|
||||||
|
|
||||||
|
// Integration-style test with a mock server
|
||||||
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "CONNECT" {
|
if r.Method == "CONNECT" {
|
||||||
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
assert.Equal(t, "127.0.0.1:1234", r.Host)
|
||||||
|
@ -230,10 +363,38 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, client.Transport))
|
||||||
|
|
||||||
assert.Equal(t,
|
// this will fail if the client follows the 302
|
||||||
"127.0.0.1:1234 doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://127.0.0.1:1234/",
|
redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ValidateHTTPService(originURL, hostname, client.Transport).Error())
|
if r.URL.Path == "/followedRedirect" {
|
||||||
|
t.Fatal("shouldn't have followed the 302")
|
||||||
|
}
|
||||||
|
if r.Method == "CONNECT" {
|
||||||
|
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, hostname, r.Host)
|
||||||
|
}
|
||||||
|
w.Header().Set("Location", "/followedRedirect")
|
||||||
|
w.WriteHeader(302)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer redirectServer.Close()
|
||||||
|
assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
|
||||||
|
}
|
||||||
|
|
||||||
|
type testRoundTripper func(req *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func emptyResponse(statusCode int) *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Body: ioutil.NopCloser(bytes.NewReader(nil)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
|
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {
|
||||||
|
|
Loading…
Reference in New Issue