From f6bd4aa03910c3a3995956df244e6286bb31592c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveirinha?= Date: Tue, 16 Aug 2022 12:21:58 +0100 Subject: [PATCH] TUN-6676: Add suport for trailers in http2 connections --- connection/connection.go | 24 +++++++++++++--- connection/connection_test.go | 39 -------------------------- connection/h2mux.go | 4 +++ connection/http2.go | 29 +++++++++++++------ connection/quic.go | 4 +++ proxy/proxy.go | 52 ++++++++++++++--------------------- proxy/proxy_test.go | 26 ++++++++++++++---- 7 files changed, 89 insertions(+), 89 deletions(-) diff --git a/connection/connection.go b/connection/connection.go index be72b5d2..5d2db19c 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -24,9 +24,16 @@ const ( LogFieldConnIndex = "connIndex" MaxGracePeriod = time.Minute * 3 MaxConcurrentStreams = math.MaxUint32 + + contentTypeHeader = "content-type" + sseContentType = "text/event-stream" + grpcContentType = "application/grpc" ) -var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) +var ( + switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) + flushableContentTypes = []string{sseContentType, grpcContentType} +) type Orchestrator interface { UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse @@ -190,6 +197,7 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro type ResponseWriter interface { WriteRespHeaders(status int, header http.Header) error + AddTrailer(trailerName, trailerValue string) io.Writer } @@ -198,10 +206,18 @@ type ConnectedFuse interface { IsConnected() bool } -func IsServerSentEvent(headers http.Header) bool { - if contentType := headers.Get("content-type"); contentType != "" { - return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") +// Helper method to let the caller know what content-types should require a flush on every +// write to a ResponseWriter. +func shouldFlush(headers http.Header) bool { + if contentType := headers.Get(contentTypeHeader); contentType != "" { + contentType = strings.ToLower(contentType) + for _, c := range flushableContentTypes { + if strings.HasPrefix(contentType, c) { + return true + } + } } + return false } diff --git a/connection/connection_test.go b/connection/connection_test.go index ae37db75..3708e16a 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -6,11 +6,9 @@ import ( "io" "math/rand" "net/http" - "testing" "time" "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/tracing" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -197,40 +195,3 @@ func (mcf mockConnectedFuse) Connected() {} func (mcf mockConnectedFuse) IsConnected() bool { return true } - -func TestIsEventStream(t *testing.T) { - tests := []struct { - headers http.Header - isEventStream bool - }{ - { - headers: newHeader("Content-Type", "text/event-stream"), - isEventStream: true, - }, - { - headers: newHeader("content-type", "text/event-stream"), - isEventStream: true, - }, - { - headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"), - isEventStream: true, - }, - { - headers: newHeader("Content-Type", "application/json"), - isEventStream: false, - }, - { - headers: http.Header{}, - isEventStream: false, - }, - } - for _, test := range tests { - assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers)) - } -} - -func newHeader(key, value string) http.Header { - header := http.Header{} - header.Add(key, value) - return header -} diff --git a/connection/h2mux.go b/connection/h2mux.go index b78b433f..d8291b62 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -259,6 +259,10 @@ type h2muxRespWriter struct { *h2mux.MuxedStream } +func (rp *h2muxRespWriter) AddTrailer(trailerName, trailerValue string) { + // do nothing. we don't support trailers over h2mux +} + func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error { headers := H1ResponseToH2ResponseHeaders(status, header) headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin}) diff --git a/connection/http2.go b/connection/http2.go index 8b488010..a6756b1d 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -191,11 +191,12 @@ func (c *HTTP2Connection) close() { } type http2RespWriter struct { - r io.Reader - w http.ResponseWriter - flusher http.Flusher - shouldFlush bool - log *zerolog.Logger + r io.Reader + w http.ResponseWriter + flusher http.Flusher + shouldFlush bool + statusWritten bool + log *zerolog.Logger } func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, log *zerolog.Logger) (*http2RespWriter, error) { @@ -219,11 +220,20 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l }, nil } +func (rp *http2RespWriter) AddTrailer(trailerName, trailerValue string) { + if !rp.statusWritten { + rp.log.Warn().Msg("Tried to add Trailer to response before status written. Ignoring...") + return + } + + rp.w.Header().Add(http2.TrailerPrefix+trailerName, trailerValue) +} + func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { dest := rp.w.Header() userHeaders := make(http.Header, len(header)) for name, values := range header { - // Since these are http2 headers, they're required to be lowercase + // lowercase headers for simplicity check h2name := strings.ToLower(name) if h2name == "content-length" { @@ -234,7 +244,7 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro if h2name == tracing.IntCloudflaredTracingHeader { // Add cf-int-cloudflared-tracing header outside of serialized userHeaders - rp.w.Header()[tracing.CanonicalCloudflaredTracingHeader] = values + dest[tracing.CanonicalCloudflaredTracingHeader] = values continue } @@ -247,18 +257,21 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro // Perform user header serialization and set them in the single header dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders)) + rp.setResponseMetaHeader(responseMetaHeaderOrigin) // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 if status == http.StatusSwitchingProtocols { status = http.StatusOK } rp.w.WriteHeader(status) - if IsServerSentEvent(header) { + if shouldFlush(header) { rp.shouldFlush = true } if rp.shouldFlush { rp.flusher.Flush() } + + rp.statusWritten = true return nil } diff --git a/connection/quic.go b/connection/quic.go index edc17af2..7a61bdbb 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -329,6 +329,10 @@ func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter return httpResponseAdapter{s} } +func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { + // we do not support trailers over QUIC +} + func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { metadata := make([]quicpogs.Metadata, 0) metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) diff --git a/proxy/proxy.go b/proxy/proxy.go index 9bc37615..6fe4efa4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,7 +1,6 @@ package proxy import ( - "bufio" "context" "fmt" "io" @@ -29,6 +28,8 @@ const ( LogFieldRule = "ingressRule" LogFieldOriginService = "originService" LogFieldFlowID = "flowID" + + trailerHeaderName = "Trailer" ) // Proxy represents a means to Proxy between cloudflared and the origin services. @@ -207,15 +208,16 @@ func (p *Proxy) proxyHTTPRequest( tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode) defer resp.Body.Close() - // resp headers can be nil - if resp.Header == nil { - resp.Header = make(http.Header) + headers := make(http.Header, len(resp.Header)) + // copy headers + for k, v := range resp.Header { + headers[k] = v } // Add spans to response header (if available) - tr.AddSpans(resp.Header) + tr.AddSpans(headers) - err = w.WriteRespHeaders(resp.StatusCode, resp.Header) + err = w.WriteRespHeaders(resp.StatusCode, headers) if err != nil { return errors.Wrap(err, "Error writing response header") } @@ -236,12 +238,10 @@ func (p *Proxy) proxyHTTPRequest( return nil } - if connection.IsServerSentEvent(resp.Header) { - p.log.Debug().Msg("Detected Server-Side Events from Origin") - p.writeEventStream(w, resp.Body) - } else { - _, _ = cfio.Copy(w, resp.Body) - } + _, _ = cfio.Copy(w, resp.Body) + + // copy trailers + copyTrailers(w, resp) p.logOriginResponse(resp, fields) return nil @@ -296,26 +296,6 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) { return wr.writer.Write(p) } -func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { - reader := bufio.NewReader(respBody) - for { - line, readErr := reader.ReadBytes('\n') - - // We first try to write whatever we read even if an error occurred - // The reason for doing it is to guarantee we really push everything to the eyeball side - // before returning - if len(line) > 0 { - if _, writeErr := w.Write(line); writeErr != nil { - return - } - } - - if readErr != nil { - return - } - } -} - func (p *Proxy) appendTagHeaders(r *http.Request) { for _, tag := range p.tags { r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) @@ -329,6 +309,14 @@ type logFields struct { flowID string } +func copyTrailers(w connection.ResponseWriter, response *http.Response) { + for trailerHeader, trailerValues := range response.Trailer { + for _, trailerValue := range trailerValues { + w.AddTrailer(trailerHeader, trailerValue) + } + } +} + func (p *Proxy) logRequest(r *http.Request, fields logFields) { if fields.cfRay != "" { p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index da90288e..5407acaa 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -22,6 +22,8 @@ import ( "github.com/urfave/cli/v2" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/cfio" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/hello" @@ -62,6 +64,10 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er return nil } +func (w *mockHTTPRespWriter) AddTrailer(trailerName, trailerValue string) { + // do nothing +} + func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader") } @@ -117,7 +123,10 @@ func newMockSSERespWriter() *mockSSERespWriter { } func (w *mockSSERespWriter) Write(data []byte) (int, error) { - w.writeNotification <- data + newData := make([]byte, len(data)) + copy(newData, data) + + w.writeNotification <- newData return len(data), nil } @@ -256,11 +265,8 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { for i := 0; i < pushCount; i++ { line := responseWriter.ReadBytes() - expect := fmt.Sprintf("%d\n", i) + expect := fmt.Sprintf("%d\n\n", i) require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line)) - - line = responseWriter.ReadBytes() - require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line)) } cancel() @@ -276,7 +282,7 @@ func testProxySSEAllData(proxy *Proxy) func(t *testing.T) { responseWriter := newMockSSERespWriter() // responseWriter uses an unbuffered channel, so we call in a different go-routine - go proxy.writeEventStream(responseWriter, eyeballReader) + go cfio.Copy(responseWriter, eyeballReader) result := string(<-responseWriter.writeNotification) require.Equal(t, "data\r\r", result) @@ -825,6 +831,10 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { return nil } +func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) { + // do nothing +} + // respHeaders is a test function to read respHeaders func (w *wsRespWriter) headers() http.Header { // Removing indeterminstic header because it cannot be asserted. @@ -852,6 +862,10 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { return m.w.Write(p) } +func (w *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) { + // do nothing +} + func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { m.responseHeaders = header m.code = status