diff --git a/h2mux/header.go b/h2mux/header.go index 822f99bb..20a878e8 100644 --- a/h2mux/header.go +++ b/h2mux/header.go @@ -1,7 +1,6 @@ package h2mux import ( - "bytes" "encoding/base64" "fmt" "github.com/pkg/errors" @@ -17,9 +16,130 @@ type Header struct { var headerEncoding = base64.RawStdEncoding -// OldH2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request -// object. This includes conversion of the pseudo-headers into their closest +const ( + RequestUserHeadersField = "cf-cloudflared-request-headers" + ResponseUserHeadersField = "cf-cloudflared-response-headers" +) + +// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld +// to an HTTP/1 Request object destined for the local origin web service. +// This operation includes conversion of the pseudo-headers into their closest // HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 +func H2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error { + for _, header := range h2 { + switch strings.ToLower(header.Name) { + case ":method": + h1.Method = header.Value + case ":scheme": + // noop - use the preexisting scheme from h1.URL + case ":authority": + // Otherwise the host header will be based on the origin URL + h1.Host = header.Value + case ":path": + // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. + // However, this HTTP/1 Request object belongs to the Go standard library, + // whose URL package makes some opinionated decisions about the encoding of + // URL characters: see the docs of https://godoc.org/net/url#URL, + // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, + // which is always used when computing url.URL.String(), whether we'd like it or not. + // + // Well, not *always*. We could circumvent this by using url.URL.Opaque. But + // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque + // is treated differently when HTTP_PROXY is set! + // See https://github.com/golang/go/issues/5684#issuecomment-66080888 + // + // This means we are subject to the behavior of net/url's function `shouldEscape` + // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 + + if header.Value == "*" { + h1.URL.Path = "*" + continue + } + // Due to the behavior of validation.ValidateUrl, h1.URL may + // already have a partial value, with or without a trailing slash. + base := h1.URL.String() + base = strings.TrimRight(base, "/") + // But we know :path begins with '/', because we handled '*' above - see RFC7540 + requestURL, err := url.Parse(base + header.Value) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) + } + h1.URL = requestURL + case "content-length": + contentLength, err := strconv.ParseInt(header.Value, 10, 64) + if err != nil { + return fmt.Errorf("unparseable content length") + } + h1.ContentLength = contentLength + default: + // Ignore any other header; + // User headers will be read from `RequestUserHeadersField` + continue + } + } + + // Find and parse user headers serialized into a single one + userHeaders, err := ParseUserHeaders(RequestUserHeadersField, h2) + if err != nil { + return errors.Wrap(err, "Unable to parse user headers") + } + for _, userHeader := range userHeaders { + h1.Header.Add(http.CanonicalHeaderKey(userHeader.Name), userHeader.Value) + } + + return nil +} + +func ParseUserHeaders(headerNameToParseFrom string, headers []Header) ([]Header, error) { + for _, header := range headers { + if header.Name == headerNameToParseFrom { + return DeserializeHeaders(header.Value) + } + } + + return nil, fmt.Errorf("%v header not found", RequestUserHeadersField) +} + +func IsControlHeader(headerName string) bool { + headerName = strings.ToLower(headerName) + + return strings.ToLower(headerName) == "content-length" || + strings.HasPrefix(headerName, ":") || + strings.HasPrefix(headerName, "cf-") +} + +func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { + h2 = []Header{ + {Name: ":status", Value: strconv.Itoa(h1.StatusCode)}, + } + userHeaders := http.Header{} + for header, values := range h1.Header { + for _, value := range values { + if strings.ToLower(header) == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + + // Since these are http2 headers, they're required to be lowercase + h2 = append(h2, Header{Name: strings.ToLower(header), Value: value}) + } else if !IsControlHeader(header) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + if _, ok := userHeaders[header]; ok { + userHeaders[header] = append(userHeaders[header], value) + } else { + userHeaders[header] = []string{value} + } + } + } + } + + // Perform user header serialization and set them in the single header + h2 = append(h2, CreateSerializedHeaders(ResponseUserHeadersField, userHeaders)...) + + return h2 +} + +// Obsolete version of H2RequestHeadersToH1Request func OldH2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error { for _, header := range h2 { switch header.Name { @@ -73,6 +193,7 @@ func OldH2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error { return nil } +// Obsolete version of H1ResponseToH2ResponseHeaders func OldH1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { h2 = []Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} for headerName, headerValues := range h1.Header { @@ -86,9 +207,9 @@ func OldH1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { // Serialize HTTP1.x headers by base64-encoding each header name and value, // and then joining them in the format of [key:value;] -func SerializeHeaders(h1 *http.Request) []byte { - var serializedHeaders [][]byte - for headerName, headerValues := range h1.Header { +func SerializeHeaders(h1Headers http.Header) string { + var serializedHeaders []string + for headerName, headerValues := range h1Headers { for _, headerValue := range headerValues { encodedName := make([]byte, headerEncoding.EncodedLen(len(headerName))) headerEncoding.Encode(encodedName, []byte(headerName)) @@ -98,28 +219,28 @@ func SerializeHeaders(h1 *http.Request) []byte { serializedHeaders = append( serializedHeaders, - bytes.Join( - [][]byte{encodedName, encodedValue}, - []byte(":"), + strings.Join( + []string{string(encodedName), string(encodedValue)}, + ":", ), ) } } - return bytes.Join(serializedHeaders, []byte(";")) + return strings.Join(serializedHeaders, ";") } // Deserialize headers serialized by `SerializeHeader` -func DeserializeHeaders(serializedHeaders []byte) (http.Header, error) { +func DeserializeHeaders(serializedHeaders string) ([]Header, error) { const unableToDeserializeErr = "Unable to deserialize headers" - deserialized := http.Header{} - for _, serializedPair := range bytes.Split(serializedHeaders, []byte(";")) { + var deserialized []Header + for _, serializedPair := range strings.Split(serializedHeaders, ";") { if len(serializedPair) == 0 { continue } - serializedHeaderParts := bytes.Split(serializedPair, []byte(":")) + serializedHeaderParts := strings.Split(serializedPair, ":") if len(serializedHeaderParts) != 2 { return nil, errors.New(unableToDeserializeErr) } @@ -129,15 +250,30 @@ func DeserializeHeaders(serializedHeaders []byte) (http.Header, error) { deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName))) deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue))) - if _, err := headerEncoding.Decode(deserializedName, serializedName); err != nil { + if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil { return nil, errors.Wrap(err, unableToDeserializeErr) } - if _, err := headerEncoding.Decode(deserializedValue, serializedValue); err != nil { + if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil { return nil, errors.Wrap(err, unableToDeserializeErr) } - deserialized.Add(string(deserializedName), string(deserializedValue)) + deserialized = append(deserialized, Header{ + Name: string(deserializedName), + Value: string(deserializedValue), + }) } return deserialized, nil } + +func CreateSerializedHeaders(headersField string, headers ...http.Header) []Header { + var serializedHeaderChunks []string + for _, headerChunk := range headers { + serializedHeaderChunks = append(serializedHeaderChunks, SerializeHeaders(headerChunk)) + } + + return []Header{{ + headersField, + strings.Join(serializedHeaderChunks, ";"), + }} +} diff --git a/h2mux/header_test.go b/h2mux/header_test.go index 8c5301eb..b5781d91 100644 --- a/h2mux/header_test.go +++ b/h2mux/header_test.go @@ -7,6 +7,7 @@ import ( "net/url" "reflect" "regexp" + "sort" "strings" "testing" "testing/quick" @@ -15,29 +16,30 @@ import ( "github.com/stretchr/testify/require" ) +type ByName []Header + +func (a ByName) Len() int { return len(a) } +func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByName) Less(i, j int) bool { + if a[i].Name == a[j].Name { + return a[i].Value < a[j].Value + } + + return a[i].Name < a[j].Name +} + func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) - headersConversionErr := OldH2RequestHeadersToH1Request( - []Header{ - { - Name: "Mock header 1", - Value: "Mock value 1", - }, - { - Name: "Mock header 2", - Value: "Mock value 2", - }, - }, - request, - ) + mockHeaders := http.Header{ + "Mock header 1": {"Mock value 1"}, + "Mock header 2": {"Mock value 2"}, + } - assert.Equal(t, http.Header{ - "Mock header 1": []string{"Mock value 1"}, - "Mock header 2": []string{"Mock value 2"}, - }, request.Header) + headersConversionErr := H2RequestHeadersToH1Request(CreateSerializedHeaders(RequestUserHeadersField, mockHeaders), request) + assert.True(t, reflect.DeepEqual(mockHeaders, request.Header)) assert.NoError(t, headersConversionErr) } @@ -45,13 +47,15 @@ func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) - headersConversionErr := OldH2RequestHeadersToH1Request( - []Header{}, + headersConversionErr := H2RequestHeadersToH1Request( + []Header{{ + RequestUserHeadersField, + SerializeHeaders(http.Header{}), + }}, request, ) - assert.Equal(t, http.Header{}, request.Header) - + assert.True(t, reflect.DeepEqual(http.Header{}, request.Header)) assert.NoError(t, headersConversionErr) } @@ -59,19 +63,12 @@ func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) - headersConversionErr := OldH2RequestHeadersToH1Request( - []Header{ - { - Name: ":path", - Value: "//bad_path/", - }, - { - Name: "Mock header", - Value: "Mock value", - }, - }, - request, - ) + mockRequestHeaders := []Header{ + {Name: ":path", Value: "//bad_path/"}, + {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) assert.Equal(t, http.Header{ "Mock header": []string{"Mock value"}, @@ -86,19 +83,12 @@ func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) assert.NoError(t, err) - headersConversionErr := OldH2RequestHeadersToH1Request( - []Header{ - { - Name: ":path", - Value: "/?query=mock%20value", - }, - { - Name: "Mock header", - Value: "Mock value", - }, - }, - request, - ) + mockRequestHeaders := []Header{ + {Name: ":path", Value: "/?query=mock%20value"}, + {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) assert.Equal(t, http.Header{ "Mock header": []string{"Mock value"}, @@ -113,19 +103,12 @@ func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) assert.NoError(t, err) - headersConversionErr := OldH2RequestHeadersToH1Request( - []Header{ - { - Name: ":path", - Value: "/mock%20path", - }, - { - Name: "Mock header", - Value: "Mock value", - }, - }, - request, - ) + mockRequestHeaders := []Header{ + {Name: ":path", Value: "/mock%20path"}, + {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) assert.Equal(t, http.Header{ "Mock header": []string{"Mock value"}, @@ -276,19 +259,13 @@ func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) { request, err := http.NewRequest(http.MethodGet, requestURL, nil) assert.NoError(t, err) - headersConversionErr := OldH2RequestHeadersToH1Request( - []Header{ - { - Name: ":path", - Value: testCase.path, - }, - { - Name: "Mock header", - Value: "Mock value", - }, - }, - request, - ) + + mockRequestHeaders := []Header{ + {Name: ":path", Value: testCase.path}, + {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) assert.NoError(t, headersConversionErr) assert.Equal(t, @@ -358,11 +335,12 @@ func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) { {Name: ":scheme", Value: testScheme}, {Name: ":authority", Value: expectedHostname}, {Name: ":path", Value: testPath}, + {Name: RequestUserHeadersField, Value: ""}, } h1, err := http.NewRequest("GET", testOrigin.url, nil) require.NoError(t, err) - err = OldH2RequestHeadersToH1Request(h2, h1) + err = H2RequestHeadersToH1Request(h2, h1) return assert.NoError(t, err) && assert.Equal(t, expectedMethod, h1.Method) && assert.Equal(t, expectedHostname, h1.Host) && @@ -439,11 +417,21 @@ func randomHTTP2Path(t *testing.T, rand *rand.Rand) string { return result } +func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []Header) { + for name, values := range headers { + for _, value := range values { + h2muxHeaders = append(h2muxHeaders, Header{name, value}) + } + } + + return h2muxHeaders +} + func TestSerializeHeaders(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) - mockHeaders := map[string][]string{ + mockHeaders := http.Header{ "Mock-Header-One": {"Mock header one value", "three"}, "Mock-Header-Two-Long": {"Mock header two value\nlong"}, ":;": {":;", ";:"}, @@ -465,7 +453,7 @@ func TestSerializeHeaders(t *testing.T) { } } - serializedHeaders := SerializeHeaders(request) + serializedHeaders := SerializeHeaders(request.Header) // Sanity check: the headers serialized to something that's not an empty string assert.NotEqual(t, "", serializedHeaders) @@ -474,18 +462,24 @@ func TestSerializeHeaders(t *testing.T) { deserializedHeaders, err := DeserializeHeaders(serializedHeaders) assert.NoError(t, err) - assert.Equal(t, len(mockHeaders), len(deserializedHeaders)) - for header, value := range deserializedHeaders { - assert.NotEqual(t, "", value) - assert.Equal(t, mockHeaders[header], value) - } + assert.Equal(t, 13, len(deserializedHeaders)) + h2muxExpectedHeaders := stdlibHeaderToH2muxHeader(mockHeaders) + + sort.Sort(ByName(deserializedHeaders)) + sort.Sort(ByName(h2muxExpectedHeaders)) + + assert.True( + t, + reflect.DeepEqual(h2muxExpectedHeaders, deserializedHeaders), + fmt.Sprintf("got = %#v, want = %#v\n", deserializedHeaders, h2muxExpectedHeaders), + ) } func TestSerializeNoHeaders(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) - serializedHeaders := SerializeHeaders(request) + serializedHeaders := SerializeHeaders(request.Header) deserializedHeaders, err := DeserializeHeaders(serializedHeaders) assert.NoError(t, err) assert.Equal(t, 0, len(deserializedHeaders)) @@ -502,7 +496,117 @@ func TestDeserializeMalformed(t *testing.T) { } for _, malformedValue := range malformedData { - _, err = DeserializeHeaders([]byte(malformedValue)) + _, err = DeserializeHeaders(malformedValue) assert.Error(t, err) } } + +func TestParseHeaders(t *testing.T) { + mockUserHeadersToSerialize := http.Header{ + "Mock-Header-One": {"1", "1.5"}, + "Mock-Header-Two": {"2"}, + "Mock-Header-Three": {"3"}, + } + + mockHeaders := []Header{ + {Name: "One", Value: "1"}, + {Name: "Cf-Two", Value: "cf-value-1"}, + {Name: "Cf-Two", Value: "cf-value-2"}, + {Name: RequestUserHeadersField, Value: SerializeHeaders(mockUserHeadersToSerialize)}, + } + + expectedHeaders := []Header{ + {Name: "Mock-Header-One", Value: "1"}, + {Name: "Mock-Header-One", Value: "1.5"}, + {Name: "Mock-Header-Two", Value: "2"}, + {Name: "Mock-Header-Three", Value: "3"}, + } + parsedHeaders, err := ParseUserHeaders(RequestUserHeadersField, mockHeaders) + assert.NoError(t, err) + assert.ElementsMatch(t, expectedHeaders, parsedHeaders) +} + +func TestParseHeadersNoSerializedHeader(t *testing.T) { + mockHeaders := []Header{ + {Name: "One", Value: "1"}, + {Name: "Cf-Two", Value: "cf-value-1"}, + {Name: "Cf-Two", Value: "cf-value-2"}, + } + + _, err := ParseUserHeaders(RequestUserHeadersField, mockHeaders) + assert.EqualError(t, err, fmt.Sprintf("%s header not found", RequestUserHeadersField)) +} + +func TestIsControlHeader(t *testing.T) { + controlHeaders := []string{ + // Anything that begins with cf- + "cf-sample-header", + "CF-SAMPLE-HEADER", + "Cf-Sample-Header", + + // Any http2 pseudoheader + ":sample-pseudo-header", + + // content-length is a special case, it has to be there + // for some requests to work (per the HTTP2 spec) + "content-length", + } + + for _, header := range controlHeaders { + assert.True(t, IsControlHeader(header)) + } +} + +func TestIsNotControlHeader(t *testing.T) { + notControlHeaders := []string{ + "Mock-header", + "Another-sample-header", + } + + for _, header := range notControlHeaders { + assert.False(t, IsControlHeader(header)) + } +} + +func TestH1ResponseToH2ResponseHeaders(t *testing.T) { + mockHeaders := http.Header{ + "User-header-one": {""}, + "User-header-two": {"1", "2"}, + "cf-header": {"cf-value"}, + "Content-Length": {"123"}, + } + mockResponse := http.Response{ + StatusCode: 200, + Header: mockHeaders, + } + + headers := H1ResponseToH2ResponseHeaders(&mockResponse) + + serializedHeadersIndex := -1 + for i, header := range headers { + if header.Name == ResponseUserHeadersField { + serializedHeadersIndex = i + break + } + } + assert.NotEqual(t, -1, serializedHeadersIndex) + actualControlHeaders := append( + headers[:serializedHeadersIndex], + headers[serializedHeadersIndex+1:]..., + ) + expectedControlHeaders := []Header{ + {Name: ":status", Value: "200"}, + {Name: "content-length", Value: "123"}, + } + + assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders) + + actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value) + expectedUserHeaders := []Header{ + {Name: "User-header-one", Value: ""}, + {Name: "User-header-two", Value: "1"}, + {Name: "User-header-two", Value: "2"}, + } + assert.NoError(t, err) + assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders) +}