diff --git a/validation/validation.go b/validation/validation.go index 0fdeb143..f0d80964 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -144,15 +144,20 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round return err } - client := &http.Client{Transport: transport} + client := &http.Client{ + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) if err != nil { return err } initialRequest.Host = hostname - initialResponse, initialErr := client.Do(initialRequest) - if initialErr != nil || initialResponse.StatusCode != http.StatusOK { + _, initialErr := client.Do(initialRequest) + if initialErr != nil { // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? oldScheme := parsedURL.Scheme parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) @@ -162,8 +167,8 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round return err } secondRequest.Host = hostname - secondResponse, _ := client.Do(secondRequest) - if secondResponse != nil && secondResponse.StatusCode == http.StatusOK { // Worked this time--advise the user to switch protocols + _, secondErr := client.Do(secondRequest) + if secondErr == nil { // Worked this time--advise the user to switch protocols return errors.Errorf( "%s doesn't seem to work over %s, but does seem to work over %s. Consider changing the origin URL to %s", parsedURL.Host, diff --git a/validation/validation_test.go b/validation/validation_test.go index 01a68cac..e4713a68 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -1,7 +1,9 @@ package validation import ( + "bytes" "fmt" + "io/ioutil" "testing" "context" @@ -151,58 +153,89 @@ func TestToggleProtocol(t *testing.T) { assert.Equal(t, "", toggleProtocol("")) } +// Happy path 1: originURL is HTTP, and HTTP connections work func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { originURL := "http://127.0.0.1/" hostname := "example.com" + + assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + return emptyResponse(200), nil + } + if req.URL.Scheme == "https" { + t.Fatal("http works, shouldn't have tried with https") + } + panic("Shouldn't reach here") + }))) + + assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + return emptyResponse(503), nil + } + if req.URL.Scheme == "https" { + t.Fatal("http works, shouldn't have tried with https") + } + panic("Shouldn't reach here") + }))) + + // Integration-style test with a mock server server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, hostname, r.Host) w.WriteHeader(200) })) assert.NoError(t, err) defer server.Close() + assert.Nil(t, ValidateHTTPService(originURL, hostname, client.Transport)) - assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) -} - -func TestValidateHTTPService_ServerNonOKResponse(t *testing.T) { - originURL := "http://127.0.0.1/" - hostname := "example.com" - server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // this will fail if the client follows the 302 + redirectServer, redirectClient, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/followedRedirect" { + t.Fatal("shouldn't have followed the 302") + } if r.Method == "CONNECT" { assert.Equal(t, "127.0.0.1:443", r.Host) } else { assert.Equal(t, hostname, r.Host) } - w.WriteHeader(400) + w.Header().Set("Location", "/followedRedirect") + w.WriteHeader(302) })) assert.NoError(t, err) - defer server.Close() + defer redirectServer.Close() + assert.Nil(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport)) - assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) -} - -func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { - originURL := "https://127.0.0.1:1234/" - hostname := "example.com" - server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "CONNECT" { - assert.Equal(t, "127.0.0.1:1234", r.Host) - } else { - assert.Equal(t, hostname, r.Host) - } - w.WriteHeader(200) - })) - assert.NoError(t, err) - defer server.Close() - - assert.Equal(t, - "127.0.0.1:1234 doesn't seem to work over https, but does seem to work over http. Consider changing the origin URL to http://127.0.0.1:1234/", - ValidateHTTPService(originURL, hostname, client.Transport).Error()) } +// Happy path 2: originURL is HTTPS, and HTTPS connections work func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { originURL := "https://127.0.0.1/" hostname := "example.com" + + assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + t.Fatal("https works, shouldn't have tried with http") + } + if req.URL.Scheme == "https" { + return emptyResponse(200), nil + } + panic("Shouldn't reach here") + }))) + + assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + t.Fatal("https works, shouldn't have tried with http") + } + if req.URL.Scheme == "https" { + return emptyResponse(503), nil + } + panic("Shouldn't reach here") + }))) + + // Integration-style test with a mock server server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { assert.Equal(t, "127.0.0.1:443", r.Host) @@ -213,13 +246,113 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { })) assert.NoError(t, err) defer server.Close() + assert.Nil(t, ValidateHTTPService(originURL, hostname, client.Transport)) - assert.Equal(t, nil, ValidateHTTPService(originURL, hostname, client.Transport)) + // this will fail if the client follows the 302 + redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/followedRedirect" { + t.Fatal("shouldn't have followed the 302") + } + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:443", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } + w.Header().Set("Location", "/followedRedirect") + w.WriteHeader(302) + })) + assert.NoError(t, err) + defer redirectServer.Close() + assert.Nil(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport)) } +// Error path 1: originURL is HTTPS, but HTTP connections work +func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { + originURL := "https://127.0.0.1:1234/" + hostname := "example.com" + + assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + return emptyResponse(200), nil + } + if req.URL.Scheme == "https" { + return nil, assert.AnError + } + panic("Shouldn't reach here") + }))) + + assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + return emptyResponse(503), nil + } + if req.URL.Scheme == "https" { + return nil, assert.AnError + } + panic("Shouldn't reach here") + }))) + + // Integration-style test with a mock server + server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:1234", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } + w.WriteHeader(200) + })) + assert.NoError(t, err) + defer server.Close() + assert.Error(t, ValidateHTTPService(originURL, hostname, client.Transport)) + + // this will fail if the client follows the 302 + redirectServer, redirectClient, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/followedRedirect" { + t.Fatal("shouldn't have followed the 302") + } + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:1234", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } + w.Header().Set("Location", "/followedRedirect") + w.WriteHeader(302) + })) + assert.NoError(t, err) + defer redirectServer.Close() + assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport)) + +} + +// Error path 2: originURL is HTTP, but HTTPS connections work func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { originURL := "http://127.0.0.1:1234/" hostname := "example.com" + + assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + return nil, assert.AnError + } + if req.URL.Scheme == "https" { + return emptyResponse(200), nil + } + panic("Shouldn't reach here") + }))) + + assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, req.Host, hostname) + if req.URL.Scheme == "http" { + return nil, assert.AnError + } + if req.URL.Scheme == "https" { + return emptyResponse(503), nil + } + panic("Shouldn't reach here") + }))) + + // Integration-style test with a mock server server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == "CONNECT" { assert.Equal(t, "127.0.0.1:1234", r.Host) @@ -230,10 +363,38 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { })) assert.NoError(t, err) defer server.Close() + assert.Error(t, ValidateHTTPService(originURL, hostname, client.Transport)) - assert.Equal(t, - "127.0.0.1:1234 doesn't seem to work over http, but does seem to work over https. Consider changing the origin URL to https://127.0.0.1:1234/", - ValidateHTTPService(originURL, hostname, client.Transport).Error()) + // this will fail if the client follows the 302 + redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/followedRedirect" { + t.Fatal("shouldn't have followed the 302") + } + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:443", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } + w.Header().Set("Location", "/followedRedirect") + w.WriteHeader(302) + })) + assert.NoError(t, err) + defer redirectServer.Close() + assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport)) +} + +type testRoundTripper func(req *http.Request) (*http.Response, error) + +func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func emptyResponse(statusCode int) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: ioutil.NopCloser(bytes.NewReader(nil)), + Header: make(http.Header), + } } func createMockServerAndClient(handler http.Handler) (*httptest.Server, *http.Client, error) {