diff --git a/streamhandler/request.go b/streamhandler/request.go index 40e36d2b..f82c7c7d 100644 --- a/streamhandler/request.go +++ b/streamhandler/request.go @@ -35,26 +35,49 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro return req, nil } +// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request +// object. This includes conversion of the pseudo-headers into their closest +// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { for _, header := range h2 { switch header.Name { case ":method": h1.Method = header.Value case ":scheme": + // noop - use the preexisting scheme from h1.URL case ":authority": // Otherwise the host header will be based on the origin URL h1.Host = header.Value case ":path": - u, err := url.Parse(header.Value) + // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. + // However, this HTTP/1 Request object belongs to the Go standard library, + // whose URL package makes some opinionated decisions about the encoding of + // URL characters: see the docs of https://godoc.org/net/url#URL, + // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, + // which is always used when computing url.URL.String(), whether we'd like it or not. + // + // Well, not *always*. We could circumvent this by using url.URL.Opaque. But + // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque + // is treated differently when HTTP_PROXY is set! + // See https://github.com/golang/go/issues/5684#issuecomment-66080888 + // + // This means we are subject to the behavior of net/url's function `shouldEscape` + // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 + + if header.Value == "*" { + h1.URL.Path = "*" + continue + } + // Due to the behavior of validation.ValidateUrl, h1.URL may + // already have a partial value, with or without a trailing slash. + base := h1.URL.String() + base = strings.TrimRight(base, "/") + // But we know :path begins with '/', because we handled '*' above - see RFC7540 + url, err := url.Parse(base + header.Value) if err != nil { - return fmt.Errorf("unparseable path") + return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) } - resolved := h1.URL.ResolveReference(u) - // prevent escaping base URL - if !strings.HasPrefix(resolved.String(), h1.URL.String()) { - return fmt.Errorf("invalid path %s", header.Value) - } - h1.URL = resolved + h1.URL = url case "content-length": contentLength, err := strconv.ParseInt(header.Value, 10, 64) if err != nil { diff --git a/streamhandler/request_test.go b/streamhandler/request_test.go new file mode 100644 index 00000000..d3a85f76 --- /dev/null +++ b/streamhandler/request_test.go @@ -0,0 +1,441 @@ +package streamhandler + +import ( + "fmt" + "math/rand" + "net/http" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "testing/quick" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{ + h2mux.Header{ + Name: "Mock header 1", + Value: "Mock value 1", + }, + h2mux.Header{ + Name: "Mock header 2", + Value: "Mock value 2", + }, + }, + request, + ) + + assert.Equal(t, http.Header{ + "Mock header 1": []string{"Mock value 1"}, + "Mock header 2": []string{"Mock value 2"}, + }, request.Header) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{}, + request, + ) + + assert.Equal(t, http.Header{}, request.Header) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{ + h2mux.Header{ + Name: ":path", + Value: "//bad_path/", + }, + h2mux.Header{ + Name: "Mock header", + Value: "Mock value", + }, + }, + request, + ) + + assert.Equal(t, http.Header{ + "Mock header": []string{"Mock value"}, + }, request.Header) + + assert.Equal(t, "http://example.com//bad_path/", request.URL.String()) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + assert.NoError(t, err) + + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{ + h2mux.Header{ + Name: ":path", + Value: "/?query=mock%20value", + }, + h2mux.Header{ + Name: "Mock header", + Value: "Mock value", + }, + }, + request, + ) + + assert.Equal(t, http.Header{ + "Mock header": []string{"Mock value"}, + }, request.Header) + + assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String()) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + assert.NoError(t, err) + + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{ + h2mux.Header{ + Name: ":path", + Value: "/mock%20path", + }, + h2mux.Header{ + Name: "Mock header", + Value: "Mock value", + }, + }, + request, + ) + + assert.Equal(t, http.Header{ + "Mock header": []string{"Mock value"}, + }, request.Header) + + assert.Equal(t, "http://example.com/mock%20path", request.URL.String()) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) { + type testCase struct { + path string + want string + } + testCases := []testCase{ + { + path: "", + want: "", + }, + { + path: "/", + want: "/", + }, + { + path: "//", + want: "//", + }, + { + path: "/test", + want: "/test", + }, + { + path: "//test", + want: "//test", + }, + { + // https://github.com/cloudflare/cloudflared/issues/81 + path: "//test/", + want: "//test/", + }, + { + path: "/%2Ftest", + want: "/%2Ftest", + }, + { + path: "//%20test", + want: "//%20test", + }, + { + // https://github.com/cloudflare/cloudflared/issues/124 + path: "/test?get=somthing%20a", + want: "/test?get=somthing%20a", + }, + { + path: "/%20", + want: "/%20", + }, + { + // stdlib's EscapedPath() will always percent-encode ' ' + path: "/ ", + want: "/%20", + }, + { + path: "/ a ", + want: "/%20a%20", + }, + { + path: "/a%20b", + want: "/a%20b", + }, + { + path: "/foo/bar;param?query#frag", + want: "/foo/bar;param?query#frag", + }, + { + // stdlib's EscapedPath() will always percent-encode non-ASCII chars + path: "/a␠b", + want: "/a%E2%90%A0b", + }, + { + path: "/a-umlaut-ä", + want: "/a-umlaut-%C3%A4", + }, + { + path: "/a-umlaut-%C3%A4", + want: "/a-umlaut-%C3%A4", + }, + { + path: "/a-umlaut-%c3%a4", + want: "/a-umlaut-%c3%a4", + }, + { + // here the second '#' is treated as part of the fragment + path: "/a#b#c", + want: "/a#b%23c", + }, + { + path: "/a#b␠c", + want: "/a#b%E2%90%A0c", + }, + { + path: "/a#b%20c", + want: "/a#b%20c", + }, + { + path: "/a#b c", + want: "/a#b%20c", + }, + { + // stdlib's EscapedPath() will always percent-encode '\' + path: "/\\", + want: "/%5C", + }, + { + path: "/a\\", + want: "/a%5C", + }, + { + path: "/a,b.c.", + want: "/a,b.c.", + }, + { + path: "/.", + want: "/.", + }, + { + // stdlib's EscapedPath() will always percent-encode '`' + path: "/a`", + want: "/a%60", + }, + { + path: "/a[0]", + want: "/a[0]", + }, + { + path: "/?a[0]=5 &b[]=", + want: "/?a[0]=5 &b[]=", + }, + { + path: "/?a=%22b%20%22", + want: "/?a=%22b%20%22", + }, + } + + for index, testCase := range testCases { + requestURL := "https://example.com" + + request, err := http.NewRequest(http.MethodGet, requestURL, nil) + assert.NoError(t, err) + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{ + h2mux.Header{ + Name: ":path", + Value: testCase.path, + }, + h2mux.Header{ + Name: "Mock header", + Value: "Mock value", + }, + }, + request, + ) + assert.NoError(t, headersConversionErr) + + assert.Equal(t, + http.Header{ + "Mock header": []string{"Mock value"}, + }, + request.Header) + + assert.Equal(t, + "https://example.com"+testCase.want, + request.URL.String(), + "Failed URL index: %v %#v", index, testCase) + } +} + +func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) { + config := &quick.Config{ + Values: func(args []reflect.Value, rand *rand.Rand) { + args[0] = reflect.ValueOf(randomHTTP2Path(t, rand)) + }, + } + + type testOrigin struct { + url string + + expectedScheme string + expectedBasePath string + } + testOrigins := []testOrigin{ + { + url: "http://origin.hostname.example.com:8080", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080", + }, + { + url: "http://origin.hostname.example.com:8080/", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080", + }, + { + url: "http://origin.hostname.example.com:8080/api", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080/api", + }, + { + url: "http://origin.hostname.example.com:8080/api/", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080/api", + }, + { + url: "https://origin.hostname.example.com:8080/api", + expectedScheme: "https", + expectedBasePath: "https://origin.hostname.example.com:8080/api", + }, + } + + // use multiple schemes to demonstrate that the URL is based on the + // origin's scheme, not the :scheme header + for _, testScheme := range []string{"http", "https"} { + for _, testOrigin := range testOrigins { + assertion := func(testPath string) bool { + const expectedMethod = "POST" + const expectedHostname = "request.hostname.example.com" + + h2 := []h2mux.Header{ + h2mux.Header{Name: ":method", Value: expectedMethod}, + h2mux.Header{Name: ":scheme", Value: testScheme}, + h2mux.Header{Name: ":authority", Value: expectedHostname}, + h2mux.Header{Name: ":path", Value: testPath}, + } + h1, err := http.NewRequest("GET", testOrigin.url, nil) + require.NoError(t, err) + + err = H2RequestHeadersToH1Request(h2, h1) + return assert.NoError(t, err) && + assert.Equal(t, expectedMethod, h1.Method) && + assert.Equal(t, expectedHostname, h1.Host) && + assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) && + assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String()) + } + err := quick.Check(assertion, config) + assert.NoError(t, err) + } + } +} + +func randomASCIIPrintableChar(rand *rand.Rand) int { + // smallest printable ASCII char is 32, largest is 126 + const startPrintable = 32 + const endPrintable = 127 + return startPrintable + rand.Intn(endPrintable-startPrintable) +} + +// randomASCIIText generates an ASCII string, some of whose characters may be +// percent-encoded. Its "logical length" (ignoring percent-encoding) is +// between 1 and `maxLength`. +func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string { + length := minLength + rand.Intn(maxLength) + result := "" + for i := 0; i < length; i++ { + c := randomASCIIPrintableChar(rand) + + // 1/4 chance of using percent encoding when not necessary + if c == '%' || rand.Intn(4) == 0 { + result += fmt.Sprintf("%%%02X", c) + } else { + result += string(c) + } + } + return result +} + +// Calls `randomASCIIText` and ensures the result is a valid URL path, +// i.e. one that can pass unchanged through url.URL.String() +func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { + text := randomASCIIText(rand, minLength, maxLength) + regexp, err := regexp.Compile("[^/;,]*") + require.NoError(t, err) + return "/" + regexp.ReplaceAllStringFunc(text, url.PathEscape) +} + +// Calls `randomASCIIText` and ensures the result is a valid URL query, +// i.e. one that can pass unchanged through url.URL.String() +func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { + text := randomASCIIText(rand, minLength, maxLength) + return "?" + strings.ReplaceAll(text, "#", "%23") +} + +// Calls `randomASCIIText` and ensures the result is a valid URL fragment, +// i.e. one that can pass unchanged through url.URL.String() +func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { + text := randomASCIIText(rand, minLength, maxLength) + url, err := url.Parse("#" + text) + require.NoError(t, err) + return url.String() +} + +// Assemble a random :path pseudoheader that is legal by Go stdlib standards +// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations) +func randomHTTP2Path(t *testing.T, rand *rand.Rand) string { + result := randomHTTP1Path(t, rand, 1, 64) + if rand.Intn(2) == 1 { + result += randomHTTP1Query(t, rand, 1, 32) + } + if rand.Intn(2) == 1 { + result += randomHTTP1Fragment(t, rand, 1, 16) + } + return result +} diff --git a/validation/validation.go b/validation/validation.go index 7d7589d1..83857230 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -60,6 +60,12 @@ func ValidateHostname(hostname string) (string, error) { } +// ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://). +// Note: when originUrl contains a scheme, the path is removed: +// ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080" +// but when it does not, the path is preserved: +// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/" +// This is arguably a bug, but changing it might break some cloudflared users. func ValidateUrl(originUrl string) (string, error) { if originUrl == "" { return "", fmt.Errorf("URL should not be empty") @@ -121,6 +127,8 @@ func ValidateUrl(originUrl string) (string, error) { if err != nil { return "", fmt.Errorf("URL %s has invalid format", originUrl) } + // This is why the path is preserved when `originUrl` doesn't have a schema. + // Using `parsedUrl.Port()` here, instead of `port`, would remove the path return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil } } diff --git a/validation/validation_test.go b/validation/validation_test.go index 6a7ec48b..866414ad 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -53,98 +53,65 @@ func TestValidateHostname(t *testing.T) { } func TestValidateUrl(t *testing.T) { + type testCase struct { + input string + expectedOutput string + } + testCases := []testCase{ + {"http://localhost", "http://localhost"}, + {"http://localhost/", "http://localhost"}, + {"http://localhost/api", "http://localhost"}, + {"http://localhost/api/", "http://localhost"}, + {"https://localhost", "https://localhost"}, + {"https://localhost/", "https://localhost"}, + {"https://localhost/api", "https://localhost"}, + {"https://localhost/api/", "https://localhost"}, + {"https://localhost:8080", "https://localhost:8080"}, + {"https://localhost:8080/", "https://localhost:8080"}, + {"https://localhost:8080/api", "https://localhost:8080"}, + {"https://localhost:8080/api/", "https://localhost:8080"}, + {"localhost", "http://localhost"}, + {"localhost/", "http://localhost/"}, + {"localhost/api", "http://localhost/api"}, + {"localhost/api/", "http://localhost/api/"}, + {"localhost:8080", "http://localhost:8080"}, + {"localhost:8080/", "http://localhost:8080/"}, + {"localhost:8080/api", "http://localhost:8080/api"}, + {"localhost:8080/api/", "http://localhost:8080/api/"}, + {"localhost:8080/api/?asdf", "http://localhost:8080/api/?asdf"}, + {"http://127.0.0.1:8080", "http://127.0.0.1:8080"}, + {"127.0.0.1:8080", "http://127.0.0.1:8080"}, + {"127.0.0.1", "http://127.0.0.1"}, + {"https://127.0.0.1:8080", "https://127.0.0.1:8080"}, + {"[::1]:8080", "http://[::1]:8080"}, + {"http://[::1]", "http://[::1]"}, + {"http://[::1]:8080", "http://[::1]:8080"}, + {"[::1]", "http://[::1]"}, + {"https://example.com", "https://example.com"}, + {"example.com", "http://example.com"}, + {"http://hello.example.com", "http://hello.example.com"}, + {"hello.example.com", "http://hello.example.com"}, + {"hello.example.com:8080", "http://hello.example.com:8080"}, + {"https://hello.example.com:8080", "https://hello.example.com:8080"}, + {"https://bücher.example.com", "https://xn--bcher-kva.example.com"}, + {"bücher.example.com", "http://xn--bcher-kva.example.com"}, + {"https%3A%2F%2Fhello.example.com", "https://hello.example.com"}, + {"https://alex:12345@hello.example.com:8080", "https://hello.example.com:8080"}, + } + for i, testCase := range testCases { + validUrl, err := ValidateUrl(testCase.input) + assert.NoError(t, err, "test case %v", i) + assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i) + } + validUrl, err := ValidateUrl("") assert.Equal(t, fmt.Errorf("URL should not be empty"), err) assert.Empty(t, validUrl) - validUrl, err = ValidateUrl("https://localhost:8080") - assert.Nil(t, err) - assert.Equal(t, "https://localhost:8080", validUrl) - - validUrl, err = ValidateUrl("localhost:8080") - assert.Nil(t, err) - assert.Equal(t, "http://localhost:8080", validUrl) - - validUrl, err = ValidateUrl("http://localhost") - assert.Nil(t, err) - assert.Equal(t, "http://localhost", validUrl) - - validUrl, err = ValidateUrl("http://127.0.0.1:8080") - assert.Nil(t, err) - assert.Equal(t, "http://127.0.0.1:8080", validUrl) - - validUrl, err = ValidateUrl("127.0.0.1:8080") - assert.Nil(t, err) - assert.Equal(t, "http://127.0.0.1:8080", validUrl) - - validUrl, err = ValidateUrl("127.0.0.1") - assert.Nil(t, err) - assert.Equal(t, "http://127.0.0.1", validUrl) - - validUrl, err = ValidateUrl("https://127.0.0.1:8080") - assert.Nil(t, err) - assert.Equal(t, "https://127.0.0.1:8080", validUrl) - - validUrl, err = ValidateUrl("[::1]:8080") - assert.Nil(t, err) - assert.Equal(t, "http://[::1]:8080", validUrl) - - validUrl, err = ValidateUrl("http://[::1]") - assert.Nil(t, err) - assert.Equal(t, "http://[::1]", validUrl) - - validUrl, err = ValidateUrl("http://[::1]:8080") - assert.Nil(t, err) - assert.Equal(t, "http://[::1]:8080", validUrl) - - validUrl, err = ValidateUrl("[::1]") - assert.Nil(t, err) - assert.Equal(t, "http://[::1]", validUrl) - - validUrl, err = ValidateUrl("https://example.com") - assert.Nil(t, err) - assert.Equal(t, "https://example.com", validUrl) - - validUrl, err = ValidateUrl("example.com") - assert.Nil(t, err) - assert.Equal(t, "http://example.com", validUrl) - - validUrl, err = ValidateUrl("http://hello.example.com") - assert.Nil(t, err) - assert.Equal(t, "http://hello.example.com", validUrl) - - validUrl, err = ValidateUrl("hello.example.com") - assert.Nil(t, err) - assert.Equal(t, "http://hello.example.com", validUrl) - - validUrl, err = ValidateUrl("hello.example.com:8080") - assert.Nil(t, err) - assert.Equal(t, "http://hello.example.com:8080", validUrl) - - validUrl, err = ValidateUrl("https://hello.example.com:8080") - assert.Nil(t, err) - assert.Equal(t, "https://hello.example.com:8080", validUrl) - - validUrl, err = ValidateUrl("https://bücher.example.com") - assert.Nil(t, err) - assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl) - - validUrl, err = ValidateUrl("bücher.example.com") - assert.Nil(t, err) - assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl) - - validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com") - assert.Nil(t, err) - assert.Equal(t, "https://hello.example.com", validUrl) - validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt") assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error()) assert.Empty(t, validUrl) - validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080") - assert.Nil(t, err) - assert.Equal(t, "https://hello.example.com:8080", validUrl) - } func TestToggleProtocol(t *testing.T) {