diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 8a3330d3..2b6defde 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -90,10 +90,6 @@ type Muxer struct { compressionQuality CompressionPreset } -type Header struct { - Name, Value string -} - func RPCHeaders() []Header { return []Header{ {Name: ":method", Value: "RPC"}, diff --git a/h2mux/header.go b/h2mux/header.go new file mode 100644 index 00000000..4da9e03c --- /dev/null +++ b/h2mux/header.go @@ -0,0 +1,143 @@ +package h2mux + +import ( + "bytes" + "encoding/base64" + "fmt" + "github.com/pkg/errors" + "net/http" + "net/url" + "strconv" + "strings" +) + +type Header struct { + Name, Value string +} + +var headerEncoding = base64.RawStdEncoding + +// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request +// object. This includes conversion of the pseudo-headers into their closest +// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 +func H2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error { + for _, header := range h2 { + switch header.Name { + case ":method": + h1.Method = header.Value + case ":scheme": + // noop - use the preexisting scheme from h1.URL + case ":authority": + // Otherwise the host header will be based on the origin URL + h1.Host = header.Value + case ":path": + // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. + // However, this HTTP/1 Request object belongs to the Go standard library, + // whose URL package makes some opinionated decisions about the encoding of + // URL characters: see the docs of https://godoc.org/net/url#URL, + // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, + // which is always used when computing url.URL.String(), whether we'd like it or not. + // + // Well, not *always*. We could circumvent this by using url.URL.Opaque. But + // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque + // is treated differently when HTTP_PROXY is set! + // See https://github.com/golang/go/issues/5684#issuecomment-66080888 + // + // This means we are subject to the behavior of net/url's function `shouldEscape` + // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 + + if header.Value == "*" { + h1.URL.Path = "*" + continue + } + // Due to the behavior of validation.ValidateUrl, h1.URL may + // already have a partial value, with or without a trailing slash. + base := h1.URL.String() + base = strings.TrimRight(base, "/") + // But we know :path begins with '/', because we handled '*' above - see RFC7540 + url, err := url.Parse(base + header.Value) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) + } + h1.URL = url + case "content-length": + contentLength, err := strconv.ParseInt(header.Value, 10, 64) + if err != nil { + return fmt.Errorf("unparseable content length") + } + h1.ContentLength = contentLength + default: + h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) + } + } + return nil +} + +func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { + h2 = []Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} + for headerName, headerValues := range h1.Header { + for _, headerValue := range headerValues { + h2 = append(h2, Header{Name: strings.ToLower(headerName), Value: headerValue}) + } + } + + return h2 +} + +// Serialize HTTP1.x headers by base64-encoding each header name and value, +// and then joining them in the format of [key:value;] +func SerializeHeaders(h1 *http.Request) []byte { + var serializedHeaders [][]byte + for headerName, headerValues := range h1.Header { + for _, headerValue := range headerValues { + encodedName := make([]byte, headerEncoding.EncodedLen(len(headerName))) + headerEncoding.Encode(encodedName, []byte(headerName)) + + encodedValue := make([]byte, headerEncoding.EncodedLen(len(headerValue))) + headerEncoding.Encode(encodedValue, []byte(headerValue)) + + serializedHeaders = append( + serializedHeaders, + bytes.Join( + [][]byte{encodedName, encodedValue}, + []byte(":"), + ), + ) + } + } + + return bytes.Join(serializedHeaders, []byte(";")) +} + +// Deserialize headers serialized by `SerializeHeader` +func DeserializeHeaders(serializedHeaders []byte) (http.Header, error) { + const unableToDeserializeErr = "Unable to deserialize headers" + + deserialized := http.Header{} + for _, serializedPair := range bytes.Split(serializedHeaders, []byte(";")) { + if len(serializedPair) == 0 { + continue + } + + serializedHeaderParts := bytes.Split(serializedPair, []byte(":")) + if len(serializedHeaderParts) != 2 { + return nil, errors.New(unableToDeserializeErr) + } + + serializedName := serializedHeaderParts[0] + serializedValue := serializedHeaderParts[1] + deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName))) + deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue))) + + if _, err := headerEncoding.Decode(deserializedName, serializedName); err != nil { + return nil, errors.Wrap(err, unableToDeserializeErr) + } + if _, err := headerEncoding.Decode(deserializedValue, serializedValue); err != nil { + return nil, errors.Wrap(err, unableToDeserializeErr) + } + + deserialized.Add(string(deserializedName), string(deserializedValue)) + } + + return deserialized, nil +} diff --git a/streamhandler/request_test.go b/h2mux/header_test.go similarity index 77% rename from streamhandler/request_test.go rename to h2mux/header_test.go index d3a85f76..83c7ac35 100644 --- a/streamhandler/request_test.go +++ b/h2mux/header_test.go @@ -1,4 +1,4 @@ -package streamhandler +package h2mux import ( "fmt" @@ -11,7 +11,6 @@ import ( "testing" "testing/quick" - "github.com/cloudflare/cloudflared/h2mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,12 +20,12 @@ func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) { assert.NoError(t, err) headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{ - h2mux.Header{ + []Header{ + { Name: "Mock header 1", Value: "Mock value 1", }, - h2mux.Header{ + { Name: "Mock header 2", Value: "Mock value 2", }, @@ -47,7 +46,7 @@ func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) { assert.NoError(t, err) headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{}, + []Header{}, request, ) @@ -61,12 +60,12 @@ func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) { assert.NoError(t, err) headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{ - h2mux.Header{ + []Header{ + { Name: ":path", Value: "//bad_path/", }, - h2mux.Header{ + { Name: "Mock header", Value: "Mock value", }, @@ -88,12 +87,12 @@ func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) { assert.NoError(t, err) headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{ - h2mux.Header{ + []Header{ + { Name: ":path", Value: "/?query=mock%20value", }, - h2mux.Header{ + { Name: "Mock header", Value: "Mock value", }, @@ -115,12 +114,12 @@ func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) { assert.NoError(t, err) headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{ - h2mux.Header{ + []Header{ + { Name: ":path", Value: "/mock%20path", }, - h2mux.Header{ + { Name: "Mock header", Value: "Mock value", }, @@ -278,12 +277,12 @@ func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) { request, err := http.NewRequest(http.MethodGet, requestURL, nil) assert.NoError(t, err) headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{ - h2mux.Header{ + []Header{ + { Name: ":path", Value: testCase.path, }, - h2mux.Header{ + { Name: "Mock header", Value: "Mock value", }, @@ -354,11 +353,11 @@ func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) { const expectedMethod = "POST" const expectedHostname = "request.hostname.example.com" - h2 := []h2mux.Header{ - h2mux.Header{Name: ":method", Value: expectedMethod}, - h2mux.Header{Name: ":scheme", Value: testScheme}, - h2mux.Header{Name: ":authority", Value: expectedHostname}, - h2mux.Header{Name: ":path", Value: testPath}, + h2 := []Header{ + {Name: ":method", Value: expectedMethod}, + {Name: ":scheme", Value: testScheme}, + {Name: ":authority", Value: expectedHostname}, + {Name: ":path", Value: testPath}, } h1, err := http.NewRequest("GET", testOrigin.url, nil) require.NoError(t, err) @@ -406,14 +405,14 @@ func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string { // i.e. one that can pass unchanged through url.URL.String() func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { text := randomASCIIText(rand, minLength, maxLength) - regexp, err := regexp.Compile("[^/;,]*") + re, err := regexp.Compile("[^/;,]*") require.NoError(t, err) - return "/" + regexp.ReplaceAllStringFunc(text, url.PathEscape) + return "/" + re.ReplaceAllStringFunc(text, url.PathEscape) } // Calls `randomASCIIText` and ensures the result is a valid URL query, // i.e. one that can pass unchanged through url.URL.String() -func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { +func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string { text := randomASCIIText(rand, minLength, maxLength) return "?" + strings.ReplaceAll(text, "#", "%23") } @@ -422,9 +421,9 @@ func randomHTTP1Query(t *testing.T, rand *rand.Rand, minLength int, maxLength in // i.e. one that can pass unchanged through url.URL.String() func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { text := randomASCIIText(rand, minLength, maxLength) - url, err := url.Parse("#" + text) + u, err := url.Parse("#" + text) require.NoError(t, err) - return url.String() + return u.String() } // Assemble a random :path pseudoheader that is legal by Go stdlib standards @@ -432,10 +431,78 @@ func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength func randomHTTP2Path(t *testing.T, rand *rand.Rand) string { result := randomHTTP1Path(t, rand, 1, 64) if rand.Intn(2) == 1 { - result += randomHTTP1Query(t, rand, 1, 32) + result += randomHTTP1Query(rand, 1, 32) } if rand.Intn(2) == 1 { result += randomHTTP1Fragment(t, rand, 1, 16) } return result } + +func TestSerializeHeaders(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + mockHeaders := map[string][]string{ + "Mock-Header-One": {"Mock header one value", "three"}, + "Mock-Header-Two-Long": {"Mock header two value\nlong"}, + ":;": {":;", ";:"}, + ":": {":"}, + ";": {";"}, + ";;": {";;"}, + "Empty values": {"", ""}, + "": {"Empty key"}, + "control\tcharacter\b\n": {"value\n\b\t"}, + ";\v:": {":\v;"}, + } + + for header, values := range mockHeaders { + for _, value := range values { + // Note that Golang's http library is opinionated; + // at this point every header name will be title-cased in order to comply with the HTTP RFC + // This means our proxy is not completely transparent when it comes to proxying headers + request.Header.Add(header, value) + } + } + + serializedHeaders := SerializeHeaders(request) + + // Sanity check: the headers serialized to something that's not an empty string + assert.NotEqual(t, "", serializedHeaders) + + // Deserialize back, and ensure we get the same set of headers + deserializedHeaders, err := DeserializeHeaders(serializedHeaders) + assert.NoError(t, err) + + assert.Equal(t, len(mockHeaders), len(deserializedHeaders)) + for header, value := range deserializedHeaders { + assert.NotEqual(t, "", value) + assert.Equal(t, mockHeaders[header], value) + } +} + +func TestSerializeNoHeaders(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + serializedHeaders := SerializeHeaders(request) + deserializedHeaders, err := DeserializeHeaders(serializedHeaders) + assert.NoError(t, err) + assert.Equal(t, 0, len(deserializedHeaders)) +} + +func TestDeserializeMalformed(t *testing.T) { + var err error + + malformedData := []string{ + "malformed data", + "bW9jawo=", // "mock" + "bW9jawo=:ZGF0YQo=:bW9jawo=", // "mock:data:mock" + "::", + } + + for _, malformedValue := range malformedData { + _, err = DeserializeHeaders([]byte(malformedValue)) + assert.Error(t, err) + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go index 63dca0da..4ee6b55e 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -488,16 +488,6 @@ func LogServerInfo( metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) } -func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { - h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} - for headerName, headerValues := range h1.Header { - for _, headerValue := range headerValues { - h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue}) - } - } - return -} - type TunnelHandler struct { originUrl string httpHostHeader string @@ -512,8 +502,6 @@ type TunnelHandler struct { noChunkedEncoding bool } -var dialer = net.Dialer{} - // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error func NewTunnelHandler(ctx context.Context, config *TunnelConfig, @@ -592,7 +580,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, if err != nil { return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") } - err = streamhandler.H2RequestHeadersToH1Request(stream.Headers, req) + err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { return nil, errors.Wrap(err, "invalid request received") } @@ -611,7 +599,7 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ return nil, err } defer conn.Close() - err = stream.WriteHeaders(H1ResponseToH2Response(response)) + err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response)) if err != nil { return nil, errors.Wrap(err, "Error writing response header") } @@ -645,7 +633,7 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) } defer response.Body.Close() - err = stream.WriteHeaders(H1ResponseToH2Response(response)) + err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response)) if err != nil { return nil, errors.Wrap(err, "Error writing response header") } diff --git a/streamhandler/request.go b/streamhandler/request.go index f82c7c7d..b6bb55d3 100644 --- a/streamhandler/request.go +++ b/streamhandler/request.go @@ -1,10 +1,8 @@ package streamhandler import ( - "fmt" "net/http" "net/url" - "strconv" "strings" "github.com/cloudflare/cloudflared/h2mux" @@ -28,65 +26,9 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro if err != nil { return nil, errors.Wrap(err, "unexpected error from http.NewRequest") } - err = H2RequestHeadersToH1Request(stream.Headers, req) + err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { return nil, errors.Wrap(err, "invalid request received") } return req, nil } - -// H2RequestHeadersToH1Request converts the HTTP/2 headers to an HTTP/1 Request -// object. This includes conversion of the pseudo-headers into their closest -// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 -func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { - for _, header := range h2 { - switch header.Name { - case ":method": - h1.Method = header.Value - case ":scheme": - // noop - use the preexisting scheme from h1.URL - case ":authority": - // Otherwise the host header will be based on the origin URL - h1.Host = header.Value - case ":path": - // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. - // However, this HTTP/1 Request object belongs to the Go standard library, - // whose URL package makes some opinionated decisions about the encoding of - // URL characters: see the docs of https://godoc.org/net/url#URL, - // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, - // which is always used when computing url.URL.String(), whether we'd like it or not. - // - // Well, not *always*. We could circumvent this by using url.URL.Opaque. But - // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque - // is treated differently when HTTP_PROXY is set! - // See https://github.com/golang/go/issues/5684#issuecomment-66080888 - // - // This means we are subject to the behavior of net/url's function `shouldEscape` - // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 - - if header.Value == "*" { - h1.URL.Path = "*" - continue - } - // Due to the behavior of validation.ValidateUrl, h1.URL may - // already have a partial value, with or without a trailing slash. - base := h1.URL.String() - base = strings.TrimRight(base, "/") - // But we know :path begins with '/', because we handled '*' above - see RFC7540 - url, err := url.Parse(base + header.Value) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) - } - h1.URL = url - case "content-length": - contentLength, err := strconv.ParseInt(header.Value, 10, 64) - if err != nil { - return fmt.Errorf("unparseable content length") - } - h1.ContentLength = contentLength - default: - h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) - } - } - return nil -}