TUN-1250: ValidateHTTPService shouldn't follow 302s

This commit is contained in:
Nick Vollmar 2018-12-07 11:35:05 -06:00
parent 446c5cf60c
commit 3e8d886c25
2 changed files with 204 additions and 38 deletions

View File

@ -144,15 +144,20 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
return err return err
} }
client := &http.Client{Transport: transport} client := &http.Client{
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
if err != nil { if err != nil {
return err return err
} }
initialRequest.Host = hostname initialRequest.Host = hostname
initialResponse, initialErr := client.Do(initialRequest) _, initialErr := client.Do(initialRequest)
if initialErr != nil || initialResponse.StatusCode != http.StatusOK { if initialErr != nil {
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := parsedURL.Scheme oldScheme := parsedURL.Scheme
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
@ -162,8 +167,8 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
return err return err
} }
secondRequest.Host = hostname secondRequest.Host = hostname
secondResponse, _ := client.Do(secondRequest) _, secondErr := client.Do(secondRequest)
if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols if secondErr == nil { // Worked this time--advise the user to switch protocols
return errors.Errorf( return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s", "%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s",
parsedURL.Host, parsedURL.Host,

View File

@ -1,7 +1,9 @@
package validation package validation
import ( import (
"bytes"
"fmt" "fmt"
"io/ioutil"
"testing" "testing"
"context" "context"
@ -151,58 +153,89 @@ func TestToggleProtocol(t *testing.T) {
assert.Equal(t, "", toggleProtocol("")) assert.Equal(t, "", toggleProtocol(""))
} }
// Happy path 1: originURL is HTTP, and HTTP connections work
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
originURL := "http://127.0.0.1/" originURL := "http://127.0.0.1/"
hostname := "example.com" hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(200), nil
}
if req.URL.Scheme == "https" {
t.Fatal("http works, shouldn't have tried with https")
}
panic("Shouldn't reach here")
})))
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(503), nil
}
if req.URL.Scheme == "https" {
t.Fatal("http works, shouldn't have tried with https")
}
panic("Shouldn't reach here")
})))
// Integration-style test with a mock server
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, hostname, r.Host) assert.Equal(t, hostname, r.Host)
w.WriteHeader(200) w.WriteHeader(200)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Nil(t, ValidateHTTPService(originURL, hostname, client.Transport))
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) // this will fail if the client follows the 302
} redirectServer, redirectClient, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/followedRedirect" {
func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) { t.Fatal("shouldn't have followed the 302")
originURL := "http://127.0.0.1/" }
hostname := "example.com"
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" { if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host) assert.Equal(t, "127.0.0.1:443", r.Host)
} else { } else {
assert.Equal(t, hostname, r.Host) assert.Equal(t, hostname, r.Host)
} }
w.WriteHeader(400) w.Header().Set("Location", "/followedRedirect")
w.WriteHeader(302)
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer redirectServer.Close()
assert.Nil(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport))
}
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := "https://127.0.0.1:1234/"
hostname := "example.com"
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:1234", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.WriteHeader(200)
}))
assert.NoError(t, err)
defer server.Close()
assert.Equal(t,
"127.0.0.1:1234 doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://127.0.0.1:1234/",
ValidateHTTPService(originURL, hostname, client.Transport).Error())
} }
// Happy path 2: originURL is HTTPS, and HTTPS connections work
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
originURL := "https://127.0.0.1/" originURL := "https://127.0.0.1/"
hostname := "example.com" hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
t.Fatal("https works, shouldn't have tried with http")
}
if req.URL.Scheme == "https" {
return emptyResponse(200), nil
}
panic("Shouldn't reach here")
})))
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
t.Fatal("https works, shouldn't have tried with http")
}
if req.URL.Scheme == "https" {
return emptyResponse(503), nil
}
panic("Shouldn't reach here")
})))
// Integration-style test with a mock server
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" { if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host) assert.Equal(t, "127.0.0.1:443", r.Host)
@ -213,13 +246,113 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Nil(t, ValidateHTTPService(originURL, hostname, client.Transport))
assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) // this will fail if the client follows the 302
redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/followedRedirect" {
t.Fatal("shouldn't have followed the 302")
}
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.Header().Set("Location", "/followedRedirect")
w.WriteHeader(302)
}))
assert.NoError(t, err)
defer redirectServer.Close()
assert.Nil(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
} }
// Error path 1: originURL is HTTPS, but HTTP connections work
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := "https://127.0.0.1:1234/"
hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(200), nil
}
if req.URL.Scheme == "https" {
return nil, assert.AnError
}
panic("Shouldn't reach here")
})))
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return emptyResponse(503), nil
}
if req.URL.Scheme == "https" {
return nil, assert.AnError
}
panic("Shouldn't reach here")
})))
// Integration-style test with a mock server
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:1234", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.WriteHeader(200)
}))
assert.NoError(t, err)
defer server.Close()
assert.Error(t, ValidateHTTPService(originURL, hostname, client.Transport))
// this will fail if the client follows the 302
redirectServer, redirectClient, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/followedRedirect" {
t.Fatal("shouldn't have followed the 302")
}
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:1234", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.Header().Set("Location", "/followedRedirect")
w.WriteHeader(302)
}))
assert.NoError(t, err)
defer redirectServer.Close()
assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
}
// Error path 2: originURL is HTTP, but HTTPS connections work
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
originURL := "http://127.0.0.1:1234/" originURL := "http://127.0.0.1:1234/"
hostname := "example.com" hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return nil, assert.AnError
}
if req.URL.Scheme == "https" {
return emptyResponse(200), nil
}
panic("Shouldn't reach here")
})))
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return nil, assert.AnError
}
if req.URL.Scheme == "https" {
return emptyResponse(503), nil
}
panic("Shouldn't reach here")
})))
// Integration-style test with a mock server
server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" { if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:1234", r.Host) assert.Equal(t, "127.0.0.1:1234", r.Host)
@ -230,10 +363,38 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
})) }))
assert.NoError(t, err) assert.NoError(t, err)
defer server.Close() defer server.Close()
assert.Error(t, ValidateHTTPService(originURL, hostname, client.Transport))
assert.Equal(t, // this will fail if the client follows the 302
"127.0.0.1:1234 doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://127.0.0.1:1234/", redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ValidateHTTPService(originURL, hostname, client.Transport).Error()) if r.URL.Path == "/followedRedirect" {
t.Fatal("shouldn't have followed the 302")
}
if r.Method == "CONNECT" {
assert.Equal(t, "127.0.0.1:443", r.Host)
} else {
assert.Equal(t, hostname, r.Host)
}
w.Header().Set("Location", "/followedRedirect")
w.WriteHeader(302)
}))
assert.NoError(t, err)
defer redirectServer.Close()
assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport))
}
type testRoundTripper func(req *http.Request) (*http.Response, error)
func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func emptyResponse(statusCode int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Header: make(http.Header),
}
} }
func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) { func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {