TUN-6676: Add suport for trailers in http2 connections

This commit is contained in:
João Oliveirinha 2022-08-16 12:21:58 +01:00
parent d2bc15e224
commit f6bd4aa039
7 changed files with 89 additions and 89 deletions

View File

@ -24,9 +24,16 @@ const (
LogFieldConnIndex = "connIndex" LogFieldConnIndex = "connIndex"
MaxGracePeriod = time.Minute * 3 MaxGracePeriod = time.Minute * 3
MaxConcurrentStreams = math.MaxUint32 MaxConcurrentStreams = math.MaxUint32
contentTypeHeader = "content-type"
sseContentType = "text/event-stream"
grpcContentType = "application/grpc"
) )
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) var (
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
flushableContentTypes = []string{sseContentType, grpcContentType}
)
type Orchestrator interface { type Orchestrator interface {
UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse
@ -190,6 +197,7 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro
type ResponseWriter interface { type ResponseWriter interface {
WriteRespHeaders(status int, header http.Header) error WriteRespHeaders(status int, header http.Header) error
AddTrailer(trailerName, trailerValue string)
io.Writer io.Writer
} }
@ -198,10 +206,18 @@ type ConnectedFuse interface {
IsConnected() bool IsConnected() bool
} }
func IsServerSentEvent(headers http.Header) bool { // Helper method to let the caller know what content-types should require a flush on every
if contentType := headers.Get("content-type"); contentType != "" { // write to a ResponseWriter.
return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") func shouldFlush(headers http.Header) bool {
if contentType := headers.Get(contentTypeHeader); contentType != "" {
contentType = strings.ToLower(contentType)
for _, c := range flushableContentTypes {
if strings.HasPrefix(contentType, c) {
return true
}
}
} }
return false return false
} }

View File

@ -6,11 +6,9 @@ import (
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"testing"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -197,40 +195,3 @@ func (mcf mockConnectedFuse) Connected() {}
func (mcf mockConnectedFuse) IsConnected() bool { func (mcf mockConnectedFuse) IsConnected() bool {
return true return true
} }
func TestIsEventStream(t *testing.T) {
tests := []struct {
headers http.Header
isEventStream bool
}{
{
headers: newHeader("Content-Type", "text/event-stream"),
isEventStream: true,
},
{
headers: newHeader("content-type", "text/event-stream"),
isEventStream: true,
},
{
headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
isEventStream: true,
},
{
headers: newHeader("Content-Type", "application/json"),
isEventStream: false,
},
{
headers: http.Header{},
isEventStream: false,
},
}
for _, test := range tests {
assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
}
}
func newHeader(key, value string) http.Header {
header := http.Header{}
header.Add(key, value)
return header
}

View File

@ -259,6 +259,10 @@ type h2muxRespWriter struct {
*h2mux.MuxedStream *h2mux.MuxedStream
} }
func (rp *h2muxRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing. we don't support trailers over h2mux
}
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error { func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
headers := H1ResponseToH2ResponseHeaders(status, header) headers := H1ResponseToH2ResponseHeaders(status, header)
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin}) headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})

View File

@ -191,11 +191,12 @@ func (c *HTTP2Connection) close() {
} }
type http2RespWriter struct { type http2RespWriter struct {
r io.Reader r io.Reader
w http.ResponseWriter w http.ResponseWriter
flusher http.Flusher flusher http.Flusher
shouldFlush bool shouldFlush bool
log *zerolog.Logger statusWritten bool
log *zerolog.Logger
} }
func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, log *zerolog.Logger) (*http2RespWriter, error) { func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, log *zerolog.Logger) (*http2RespWriter, error) {
@ -219,11 +220,20 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
}, nil }, nil
} }
func (rp *http2RespWriter) AddTrailer(trailerName, trailerValue string) {
if !rp.statusWritten {
rp.log.Warn().Msg("Tried to add Trailer to response before status written. Ignoring...")
return
}
rp.w.Header().Add(http2.TrailerPrefix+trailerName, trailerValue)
}
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
dest := rp.w.Header() dest := rp.w.Header()
userHeaders := make(http.Header, len(header)) userHeaders := make(http.Header, len(header))
for name, values := range header { for name, values := range header {
// Since these are http2 headers, they're required to be lowercase // lowercase headers for simplicity check
h2name := strings.ToLower(name) h2name := strings.ToLower(name)
if h2name == "content-length" { if h2name == "content-length" {
@ -234,7 +244,7 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
if h2name == tracing.IntCloudflaredTracingHeader { if h2name == tracing.IntCloudflaredTracingHeader {
// Add cf-int-cloudflared-tracing header outside of serialized userHeaders // Add cf-int-cloudflared-tracing header outside of serialized userHeaders
rp.w.Header()[tracing.CanonicalCloudflaredTracingHeader] = values dest[tracing.CanonicalCloudflaredTracingHeader] = values
continue continue
} }
@ -247,18 +257,21 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
// Perform user header serialization and set them in the single header // Perform user header serialization and set them in the single header
dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders)) dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders))
rp.setResponseMetaHeader(responseMetaHeaderOrigin) rp.setResponseMetaHeader(responseMetaHeaderOrigin)
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
if status == http.StatusSwitchingProtocols { if status == http.StatusSwitchingProtocols {
status = http.StatusOK status = http.StatusOK
} }
rp.w.WriteHeader(status) rp.w.WriteHeader(status)
if IsServerSentEvent(header) { if shouldFlush(header) {
rp.shouldFlush = true rp.shouldFlush = true
} }
if rp.shouldFlush { if rp.shouldFlush {
rp.flusher.Flush() rp.flusher.Flush()
} }
rp.statusWritten = true
return nil return nil
} }

View File

@ -329,6 +329,10 @@ func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter
return httpResponseAdapter{s} return httpResponseAdapter{s}
} }
func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
// we do not support trailers over QUIC
}
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
metadata := make([]quicpogs.Metadata, 0) metadata := make([]quicpogs.Metadata, 0)
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})

View File

@ -1,7 +1,6 @@
package proxy package proxy
import ( import (
"bufio"
"context" "context"
"fmt" "fmt"
"io" "io"
@ -29,6 +28,8 @@ const (
LogFieldRule = "ingressRule" LogFieldRule = "ingressRule"
LogFieldOriginService = "originService" LogFieldOriginService = "originService"
LogFieldFlowID = "flowID" LogFieldFlowID = "flowID"
trailerHeaderName = "Trailer"
) )
// Proxy represents a means to Proxy between cloudflared and the origin services. // Proxy represents a means to Proxy between cloudflared and the origin services.
@ -207,15 +208,16 @@ func (p *Proxy) proxyHTTPRequest(
tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode) tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode)
defer resp.Body.Close() defer resp.Body.Close()
// resp headers can be nil headers := make(http.Header, len(resp.Header))
if resp.Header == nil { // copy headers
resp.Header = make(http.Header) for k, v := range resp.Header {
headers[k] = v
} }
// Add spans to response header (if available) // Add spans to response header (if available)
tr.AddSpans(resp.Header) tr.AddSpans(headers)
err = w.WriteRespHeaders(resp.StatusCode, resp.Header) err = w.WriteRespHeaders(resp.StatusCode, headers)
if err != nil { if err != nil {
return errors.Wrap(err, "Error writing response header") return errors.Wrap(err, "Error writing response header")
} }
@ -236,12 +238,10 @@ func (p *Proxy) proxyHTTPRequest(
return nil return nil
} }
if connection.IsServerSentEvent(resp.Header) { _, _ = cfio.Copy(w, resp.Body)
p.log.Debug().Msg("Detected Server-Side Events from Origin")
p.writeEventStream(w, resp.Body) // copy trailers
} else { copyTrailers(w, resp)
_, _ = cfio.Copy(w, resp.Body)
}
p.logOriginResponse(resp, fields) p.logOriginResponse(resp, fields)
return nil return nil
@ -296,26 +296,6 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
return wr.writer.Write(p) return wr.writer.Write(p)
} }
func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {
line, readErr := reader.ReadBytes('\n')
// We first try to write whatever we read even if an error occurred
// The reason for doing it is to guarantee we really push everything to the eyeball side
// before returning
if len(line) > 0 {
if _, writeErr := w.Write(line); writeErr != nil {
return
}
}
if readErr != nil {
return
}
}
}
func (p *Proxy) appendTagHeaders(r *http.Request) { func (p *Proxy) appendTagHeaders(r *http.Request) {
for _, tag := range p.tags { for _, tag := range p.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
@ -329,6 +309,14 @@ type logFields struct {
flowID string flowID string
} }
func copyTrailers(w connection.ResponseWriter, response *http.Response) {
for trailerHeader, trailerValues := range response.Trailer {
for _, trailerValue := range trailerValues {
w.AddTrailer(trailerHeader, trailerValue)
}
}
}
func (p *Proxy) logRequest(r *http.Request, fields logFields) { func (p *Proxy) logRequest(r *http.Request, fields logFields) {
if fields.cfRay != "" { if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)

View File

@ -22,6 +22,8 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/hello"
@ -62,6 +64,10 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er
return nil return nil
} }
func (w *mockHTTPRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing
}
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader") return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
} }
@ -117,7 +123,10 @@ func newMockSSERespWriter() *mockSSERespWriter {
} }
func (w *mockSSERespWriter) Write(data []byte) (int, error) { func (w *mockSSERespWriter) Write(data []byte) (int, error) {
w.writeNotification <- data newData := make([]byte, len(data))
copy(newData, data)
w.writeNotification <- newData
return len(data), nil return len(data), nil
} }
@ -256,11 +265,8 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
for i := 0; i < pushCount; i++ { for i := 0; i < pushCount; i++ {
line := responseWriter.ReadBytes() line := responseWriter.ReadBytes()
expect := fmt.Sprintf("%d\n", i) expect := fmt.Sprintf("%d\n\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line)) require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
line = responseWriter.ReadBytes()
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
} }
cancel() cancel()
@ -276,7 +282,7 @@ func testProxySSEAllData(proxy *Proxy) func(t *testing.T) {
responseWriter := newMockSSERespWriter() responseWriter := newMockSSERespWriter()
// responseWriter uses an unbuffered channel, so we call in a different go-routine // responseWriter uses an unbuffered channel, so we call in a different go-routine
go proxy.writeEventStream(responseWriter, eyeballReader) go cfio.Copy(responseWriter, eyeballReader)
result := string(<-responseWriter.writeNotification) result := string(<-responseWriter.writeNotification)
require.Equal(t, "data\r\r", result) require.Equal(t, "data\r\r", result)
@ -825,6 +831,10 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
return nil return nil
} }
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing
}
// respHeaders is a test function to read respHeaders // respHeaders is a test function to read respHeaders
func (w *wsRespWriter) headers() http.Header { func (w *wsRespWriter) headers() http.Header {
// Removing indeterminstic header because it cannot be asserted. // Removing indeterminstic header because it cannot be asserted.
@ -852,6 +862,10 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
return m.w.Write(p) return m.w.Write(p)
} }
func (w *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing
}
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
m.responseHeaders = header m.responseHeaders = header
m.code = status m.code = status