From 5bd4028ea709486a8867d44b43b9ef964a34a8bc Mon Sep 17 00:00:00 2001 From: Areg Harutyunyan Date: Fri, 6 Mar 2020 13:49:09 +0000 Subject: [PATCH] TUN-2761: Use the new header management functions in cloudflared --- connection/features.go | 6 +-- h2mux/header.go | 66 ----------------------- origin/tunnel.go | 8 +-- streamhandler/request.go | 2 +- streamhandler/stream_handler_test.go | 78 ++++++++++++++++++++-------- 5 files changed, 65 insertions(+), 95 deletions(-) diff --git a/connection/features.go b/connection/features.go index c10f7cbf..7e555fa4 100644 --- a/connection/features.go +++ b/connection/features.go @@ -1,9 +1,9 @@ package connection const ( - FEATURE_SERIALIZED_HEADERS = "serialized_headers" + FeatureSerializedHeaders = "serialized_headers" ) -var SUPPORTED_FEATURES = []string{ - //FEATURE_SERIALIZED_HEADERS, +var SupportedFeatures = []string{ + FeatureSerializedHeaders, } diff --git a/h2mux/header.go b/h2mux/header.go index 20a878e8..62e71164 100644 --- a/h2mux/header.go +++ b/h2mux/header.go @@ -139,72 +139,6 @@ func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { return h2 } -// Obsolete version of H2RequestHeadersToH1Request -func OldH2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error { - for _, header := range h2 { - switch header.Name { - case ":method": - h1.Method = header.Value - case ":scheme": - // noop - use the preexisting scheme from h1.URL - case ":authority": - // Otherwise the host header will be based on the origin URL - h1.Host = header.Value - case ":path": - // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. - // However, this HTTP/1 Request object belongs to the Go standard library, - // whose URL package makes some opinionated decisions about the encoding of - // URL characters: see the docs of https://godoc.org/net/url#URL, - // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, - // which is always used when computing url.URL.String(), whether we'd like it or not. - // - // Well, not *always*. We could circumvent this by using url.URL.Opaque. But - // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque - // is treated differently when HTTP_PROXY is set! - // See https://github.com/golang/go/issues/5684#issuecomment-66080888 - // - // This means we are subject to the behavior of net/url's function `shouldEscape` - // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 - - if header.Value == "*" { - h1.URL.Path = "*" - continue - } - // Due to the behavior of validation.ValidateUrl, h1.URL may - // already have a partial value, with or without a trailing slash. - base := h1.URL.String() - base = strings.TrimRight(base, "/") - // But we know :path begins with '/', because we handled '*' above - see RFC7540 - url, err := url.Parse(base + header.Value) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) - } - h1.URL = url - case "content-length": - contentLength, err := strconv.ParseInt(header.Value, 10, 64) - if err != nil { - return fmt.Errorf("unparseable content length") - } - h1.ContentLength = contentLength - default: - h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) - } - } - return nil -} - -// Obsolete version of H1ResponseToH2ResponseHeaders -func OldH1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { - h2 = []Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} - for headerName, headerValues := range h1.Header { - for _, headerValue := range headerValues { - h2 = append(h2, Header{Name: strings.ToLower(headerName), Value: headerValue}) - } - } - - return h2 -} - // Serialize HTTP1.x headers by base64-encoding each header name and value, // and then joining them in the format of [key:value;] func SerializeHeaders(h1Headers http.Header) string { diff --git a/origin/tunnel.go b/origin/tunnel.go index 3bc5c3ba..2aaba194 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -165,7 +165,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str RunFromTerminal: c.RunFromTerminal, CompressionQuality: c.CompressionQuality, UUID: uuid.String(), - Features: connection.SUPPORTED_FEATURES, + Features: connection.SupportedFeatures, } } @@ -603,7 +603,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, if err != nil { return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") } - err = h2mux.OldH2RequestHeadersToH1Request(stream.Headers, req) + err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { return nil, errors.Wrap(err, "invalid request received") } @@ -622,7 +622,7 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ return nil, err } defer conn.Close() - err = stream.WriteHeaders(h2mux.OldH1ResponseToH2ResponseHeaders(response)) + err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response)) if err != nil { return nil, errors.Wrap(err, "Error writing response header") } @@ -656,7 +656,7 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) } defer response.Body.Close() - err = stream.WriteHeaders(h2mux.OldH1ResponseToH2ResponseHeaders(response)) + err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response)) if err != nil { return nil, errors.Wrap(err, "Error writing response header") } diff --git a/streamhandler/request.go b/streamhandler/request.go index 52d69bc2..b6bb55d3 100644 --- a/streamhandler/request.go +++ b/streamhandler/request.go @@ -26,7 +26,7 @@ func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, erro if err != nil { return nil, errors.Wrap(err, "unexpected error from http.NewRequest") } - err = h2mux.OldH2RequestHeadersToH1Request(stream.Headers, req) + err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { return nil, errors.Wrap(err, "invalid request received") } diff --git a/streamhandler/stream_handler_test.go b/streamhandler/stream_handler_test.go index af1eb45b..721f5a6c 100644 --- a/streamhandler/stream_handler_test.go +++ b/streamhandler/stream_handler_test.go @@ -31,11 +31,11 @@ var ( {Name: ":scheme", Value: "http"}, {Name: ":authority", Value: "example.com"}, {Name: ":path", Value: "/"}, + + // Regular headers must always come after the pseudoheaders + {Name: h2mux.RequestUserHeadersField, Value: ""}, } - tunnelHostnameHeader = h2mux.Header{ - Name: h2mux.CloudflaredProxyTunnelHostnameHeader, - Value: testTunnelHostname.String(), - } + tunnelHostnameHeader = h2mux.Header{Name: h2mux.CloudflaredProxyTunnelHostnameHeader, Value: testTunnelHostname.String()} ) func TestServeRequest(t *testing.T) { @@ -69,29 +69,73 @@ func TestServeRequest(t *testing.T) { assertRespBody(t, message, stream) } -func TestServeBadRequest(t *testing.T) { +func createStreamHandler() *StreamHandler { configChan := make(chan *pogs.ClientConfig) useConfigResultChan := make(chan *pogs.UseConfigurationResult) - streamHandler := NewStreamHandler(configChan, useConfigResultChan, logrus.New()) + return NewStreamHandler(configChan, useConfigResultChan, logrus.New()) +} + +func createRequestMuxPair(t *testing.T, streamHandler *StreamHandler) *DefaultMuxerPair { muxPair := NewDefaultMuxerPair(t, streamHandler) muxPair.Serve(t) + return muxPair +} + +func TestServeStatusBadRequest(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) defer cancel() // No tunnel hostname header, expect to get 400 Bad Request - stream, err := muxPair.EdgeMux.OpenStream(ctx, baseHeaders, nil) + stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, baseHeaders, nil) assert.NoError(t, err) assertStatusHeader(t, http.StatusBadRequest, stream.Headers) assertRespBody(t, statusBadRequest.text, stream) +} + +func TestServeInvalidContentLength(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) + defer cancel() + + // Invalid content-length, wouldn't be able to create a request + // Expect to get 400 Bad Request + headers := append(baseHeaders, tunnelHostnameHeader) + headers = append(headers, h2mux.Header{ + Name: "content-length", + Value: "x", + }) + streamHandler := createStreamHandler() + streamHandler.UpdateConfig([]*pogs.ReverseProxyConfig{ + { + TunnelHostname: testTunnelHostname, + OriginConfig: &pogs.HTTPOriginConfig{ + URLString: "", + }, + }, + }) + mux := createRequestMuxPair(t, streamHandler).EdgeMux + stream, err := mux.OpenStream(ctx, headers, nil) + assert.NoError(t, err) + assertStatusHeader(t, http.StatusBadRequest, stream.Headers) + assertRespBody(t, statusBadRequest.text, stream) +} + +func TestServeStatusNotFound(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) + defer cancel() // No mapping for the tunnel hostname, expect to get 404 Not Found headers := append(baseHeaders, tunnelHostnameHeader) - stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil) + stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, headers, nil) assert.NoError(t, err) assertStatusHeader(t, http.StatusNotFound, stream.Headers) assertRespBody(t, statusNotFound.text, stream) +} + +func TestServeStatusBadGateway(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) + defer cancel() // Nothing listening on empty url, so proxy would fail. Expect to get 502 Bad Gateway reverseProxyConfigs := []*pogs.ReverseProxyConfig{ @@ -102,21 +146,13 @@ func TestServeBadRequest(t *testing.T) { }, }, } + streamHandler := createStreamHandler() streamHandler.UpdateConfig(reverseProxyConfigs) - stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil) + headers := append(baseHeaders, tunnelHostnameHeader) + stream, err := createRequestMuxPair(t, streamHandler).EdgeMux.OpenStream(ctx, headers, nil) assert.NoError(t, err) assertStatusHeader(t, http.StatusBadGateway, stream.Headers) assertRespBody(t, statusBadGateway.text, stream) - - // Invalid content-length, wouldn't not be able to create a request. Expect to get 400 Bad Request - headers = append(headers, h2mux.Header{ - Name: "content-length", - Value: "x", - }) - stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil) - assert.NoError(t, err) - assertStatusHeader(t, http.StatusBadRequest, stream.Headers) - assertRespBody(t, statusBadRequest.text, stream) } func assertStatusHeader(t *testing.T, expectedStatus int, headers []h2mux.Header) { @@ -218,6 +254,6 @@ type mockHTTPHandler struct { message []byte } -func (mth *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Write(mth.message) +func (mth *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write(mth.message) }