From 6ca642e572eb83726df996318302c8396a84fb6f Mon Sep 17 00:00:00 2001 From: Nick Vollmar Date: Wed, 27 Feb 2019 16:47:00 -0600 Subject: [PATCH] TUN-1550: Add validation timeout for non-responsive origins --- validation/validation.go | 7 ++++++- validation/validation_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/validation/validation.go b/validation/validation.go index f0d80964..75fd3052 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -5,6 +5,7 @@ import ( "net" "net/url" "strings" + "time" "net/http" @@ -14,7 +15,10 @@ import ( const defaultScheme = "http" -var supportedProtocol = [2]string{"http", "https"} +var ( + supportedProtocol = [2]string{"http", "https"} + validationTimeout = time.Duration(30 * time.Second) +) func ValidateHostname(hostname string) (string, error) { if hostname == "" { @@ -149,6 +153,7 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, + Timeout: validationTimeout, } initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) diff --git a/validation/validation_test.go b/validation/validation_test.go index e4713a68..c0369ae9 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "testing" + "time" "context" "crypto/tls" @@ -383,6 +384,36 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { assert.Error(t, ValidateHTTPService(originURL, hostname, redirectClient.Transport)) } +// error path 3: origin URL is nonresponsive +func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) { + originURL := "https://127.0.0.1/" + hostname := "example.com" + oldValidationTimeout := validationTimeout + defer func() { + validationTimeout = oldValidationTimeout + }() + validationTimeout = 500 * time.Millisecond + + server, client, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + assert.Equal(t, "127.0.0.1:443", r.Host) + } else { + assert.Equal(t, hostname, r.Host) + } + time.Sleep(1 * time.Second) + w.WriteHeader(200) + })) + if !assert.NoError(t, err) { + t.FailNow() + } + defer server.Close() + + err = ValidateHTTPService(originURL, hostname, client.Transport) + if err, ok := err.(net.Error); assert.True(t, ok) { + assert.True(t, err.Timeout()) + } +} + type testRoundTripper func(req *http.Request) (*http.Response, error) func (f testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {