TUN-6676: Add suport for trailers in http2 connections
This commit is contained in:
parent
d2bc15e224
commit
f6bd4aa039
|
@ -24,9 +24,16 @@ const (
|
||||||
LogFieldConnIndex = "connIndex"
|
LogFieldConnIndex = "connIndex"
|
||||||
MaxGracePeriod = time.Minute * 3
|
MaxGracePeriod = time.Minute * 3
|
||||||
MaxConcurrentStreams = math.MaxUint32
|
MaxConcurrentStreams = math.MaxUint32
|
||||||
|
|
||||||
|
contentTypeHeader = "content-type"
|
||||||
|
sseContentType = "text/event-stream"
|
||||||
|
grpcContentType = "application/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
var (
|
||||||
|
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||||
|
flushableContentTypes = []string{sseContentType, grpcContentType}
|
||||||
|
)
|
||||||
|
|
||||||
type Orchestrator interface {
|
type Orchestrator interface {
|
||||||
UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse
|
UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse
|
||||||
|
@ -190,6 +197,7 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro
|
||||||
|
|
||||||
type ResponseWriter interface {
|
type ResponseWriter interface {
|
||||||
WriteRespHeaders(status int, header http.Header) error
|
WriteRespHeaders(status int, header http.Header) error
|
||||||
|
AddTrailer(trailerName, trailerValue string)
|
||||||
io.Writer
|
io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,10 +206,18 @@ type ConnectedFuse interface {
|
||||||
IsConnected() bool
|
IsConnected() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsServerSentEvent(headers http.Header) bool {
|
// Helper method to let the caller know what content-types should require a flush on every
|
||||||
if contentType := headers.Get("content-type"); contentType != "" {
|
// write to a ResponseWriter.
|
||||||
return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream")
|
func shouldFlush(headers http.Header) bool {
|
||||||
|
if contentType := headers.Get(contentTypeHeader); contentType != "" {
|
||||||
|
contentType = strings.ToLower(contentType)
|
||||||
|
for _, c := range flushableContentTypes {
|
||||||
|
if strings.HasPrefix(contentType, c) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,9 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tracing"
|
"github.com/cloudflare/cloudflared/tracing"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -197,40 +195,3 @@ func (mcf mockConnectedFuse) Connected() {}
|
||||||
func (mcf mockConnectedFuse) IsConnected() bool {
|
func (mcf mockConnectedFuse) IsConnected() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsEventStream(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
headers http.Header
|
|
||||||
isEventStream bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
headers: newHeader("Content-Type", "text/event-stream"),
|
|
||||||
isEventStream: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
headers: newHeader("content-type", "text/event-stream"),
|
|
||||||
isEventStream: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
|
|
||||||
isEventStream: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
headers: newHeader("Content-Type", "application/json"),
|
|
||||||
isEventStream: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
headers: http.Header{},
|
|
||||||
isEventStream: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, test := range tests {
|
|
||||||
assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHeader(key, value string) http.Header {
|
|
||||||
header := http.Header{}
|
|
||||||
header.Add(key, value)
|
|
||||||
return header
|
|
||||||
}
|
|
||||||
|
|
|
@ -259,6 +259,10 @@ type h2muxRespWriter struct {
|
||||||
*h2mux.MuxedStream
|
*h2mux.MuxedStream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *h2muxRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||||
|
// do nothing. we don't support trailers over h2mux
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
headers := H1ResponseToH2ResponseHeaders(status, header)
|
headers := H1ResponseToH2ResponseHeaders(status, header)
|
||||||
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
|
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
|
||||||
|
|
|
@ -191,11 +191,12 @@ func (c *HTTP2Connection) close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
type http2RespWriter struct {
|
type http2RespWriter struct {
|
||||||
r io.Reader
|
r io.Reader
|
||||||
w http.ResponseWriter
|
w http.ResponseWriter
|
||||||
flusher http.Flusher
|
flusher http.Flusher
|
||||||
shouldFlush bool
|
shouldFlush bool
|
||||||
log *zerolog.Logger
|
statusWritten bool
|
||||||
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, log *zerolog.Logger) (*http2RespWriter, error) {
|
func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, log *zerolog.Logger) (*http2RespWriter, error) {
|
||||||
|
@ -219,11 +220,20 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *http2RespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||||
|
if !rp.statusWritten {
|
||||||
|
rp.log.Warn().Msg("Tried to add Trailer to response before status written. Ignoring...")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rp.w.Header().Add(http2.TrailerPrefix+trailerName, trailerValue)
|
||||||
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
dest := rp.w.Header()
|
dest := rp.w.Header()
|
||||||
userHeaders := make(http.Header, len(header))
|
userHeaders := make(http.Header, len(header))
|
||||||
for name, values := range header {
|
for name, values := range header {
|
||||||
// Since these are http2 headers, they're required to be lowercase
|
// lowercase headers for simplicity check
|
||||||
h2name := strings.ToLower(name)
|
h2name := strings.ToLower(name)
|
||||||
|
|
||||||
if h2name == "content-length" {
|
if h2name == "content-length" {
|
||||||
|
@ -234,7 +244,7 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
|
||||||
|
|
||||||
if h2name == tracing.IntCloudflaredTracingHeader {
|
if h2name == tracing.IntCloudflaredTracingHeader {
|
||||||
// Add cf-int-cloudflared-tracing header outside of serialized userHeaders
|
// Add cf-int-cloudflared-tracing header outside of serialized userHeaders
|
||||||
rp.w.Header()[tracing.CanonicalCloudflaredTracingHeader] = values
|
dest[tracing.CanonicalCloudflaredTracingHeader] = values
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,18 +257,21 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
|
||||||
|
|
||||||
// Perform user header serialization and set them in the single header
|
// Perform user header serialization and set them in the single header
|
||||||
dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders))
|
dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders))
|
||||||
|
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
||||||
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
||||||
if status == http.StatusSwitchingProtocols {
|
if status == http.StatusSwitchingProtocols {
|
||||||
status = http.StatusOK
|
status = http.StatusOK
|
||||||
}
|
}
|
||||||
rp.w.WriteHeader(status)
|
rp.w.WriteHeader(status)
|
||||||
if IsServerSentEvent(header) {
|
if shouldFlush(header) {
|
||||||
rp.shouldFlush = true
|
rp.shouldFlush = true
|
||||||
}
|
}
|
||||||
if rp.shouldFlush {
|
if rp.shouldFlush {
|
||||||
rp.flusher.Flush()
|
rp.flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rp.statusWritten = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -329,6 +329,10 @@ func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter
|
||||||
return httpResponseAdapter{s}
|
return httpResponseAdapter{s}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
|
||||||
|
// we do not support trailers over QUIC
|
||||||
|
}
|
||||||
|
|
||||||
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
|
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
metadata := make([]quicpogs.Metadata, 0)
|
metadata := make([]quicpogs.Metadata, 0)
|
||||||
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
|
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -29,6 +28,8 @@ const (
|
||||||
LogFieldRule = "ingressRule"
|
LogFieldRule = "ingressRule"
|
||||||
LogFieldOriginService = "originService"
|
LogFieldOriginService = "originService"
|
||||||
LogFieldFlowID = "flowID"
|
LogFieldFlowID = "flowID"
|
||||||
|
|
||||||
|
trailerHeaderName = "Trailer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Proxy represents a means to Proxy between cloudflared and the origin services.
|
// Proxy represents a means to Proxy between cloudflared and the origin services.
|
||||||
|
@ -207,15 +208,16 @@ func (p *Proxy) proxyHTTPRequest(
|
||||||
tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode)
|
tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode)
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// resp headers can be nil
|
headers := make(http.Header, len(resp.Header))
|
||||||
if resp.Header == nil {
|
// copy headers
|
||||||
resp.Header = make(http.Header)
|
for k, v := range resp.Header {
|
||||||
|
headers[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add spans to response header (if available)
|
// Add spans to response header (if available)
|
||||||
tr.AddSpans(resp.Header)
|
tr.AddSpans(headers)
|
||||||
|
|
||||||
err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
err = w.WriteRespHeaders(resp.StatusCode, headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Error writing response header")
|
return errors.Wrap(err, "Error writing response header")
|
||||||
}
|
}
|
||||||
|
@ -236,12 +238,10 @@ func (p *Proxy) proxyHTTPRequest(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if connection.IsServerSentEvent(resp.Header) {
|
_, _ = cfio.Copy(w, resp.Body)
|
||||||
p.log.Debug().Msg("Detected Server-Side Events from Origin")
|
|
||||||
p.writeEventStream(w, resp.Body)
|
// copy trailers
|
||||||
} else {
|
copyTrailers(w, resp)
|
||||||
_, _ = cfio.Copy(w, resp.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.logOriginResponse(resp, fields)
|
p.logOriginResponse(resp, fields)
|
||||||
return nil
|
return nil
|
||||||
|
@ -296,26 +296,6 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
|
||||||
return wr.writer.Write(p)
|
return wr.writer.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
|
||||||
reader := bufio.NewReader(respBody)
|
|
||||||
for {
|
|
||||||
line, readErr := reader.ReadBytes('\n')
|
|
||||||
|
|
||||||
// We first try to write whatever we read even if an error occurred
|
|
||||||
// The reason for doing it is to guarantee we really push everything to the eyeball side
|
|
||||||
// before returning
|
|
||||||
if len(line) > 0 {
|
|
||||||
if _, writeErr := w.Write(line); writeErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if readErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Proxy) appendTagHeaders(r *http.Request) {
|
func (p *Proxy) appendTagHeaders(r *http.Request) {
|
||||||
for _, tag := range p.tags {
|
for _, tag := range p.tags {
|
||||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||||
|
@ -329,6 +309,14 @@ type logFields struct {
|
||||||
flowID string
|
flowID string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyTrailers(w connection.ResponseWriter, response *http.Response) {
|
||||||
|
for trailerHeader, trailerValues := range response.Trailer {
|
||||||
|
for _, trailerValue := range trailerValues {
|
||||||
|
w.AddTrailer(trailerHeader, trailerValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
|
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
|
||||||
if fields.cfRay != "" {
|
if fields.cfRay != "" {
|
||||||
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
|
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
|
||||||
|
|
|
@ -22,6 +22,8 @@ import (
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cfio"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
|
@ -62,6 +64,10 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *mockHTTPRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||||
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||||
}
|
}
|
||||||
|
@ -117,7 +123,10 @@ func newMockSSERespWriter() *mockSSERespWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
|
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
|
||||||
w.writeNotification <- data
|
newData := make([]byte, len(data))
|
||||||
|
copy(newData, data)
|
||||||
|
|
||||||
|
w.writeNotification <- newData
|
||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,11 +265,8 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
|
|
||||||
for i := 0; i < pushCount; i++ {
|
for i := 0; i < pushCount; i++ {
|
||||||
line := responseWriter.ReadBytes()
|
line := responseWriter.ReadBytes()
|
||||||
expect := fmt.Sprintf("%d\n", i)
|
expect := fmt.Sprintf("%d\n\n", i)
|
||||||
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
||||||
|
|
||||||
line = responseWriter.ReadBytes()
|
|
||||||
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
|
@ -276,7 +282,7 @@ func testProxySSEAllData(proxy *Proxy) func(t *testing.T) {
|
||||||
responseWriter := newMockSSERespWriter()
|
responseWriter := newMockSSERespWriter()
|
||||||
|
|
||||||
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
||||||
go proxy.writeEventStream(responseWriter, eyeballReader)
|
go cfio.Copy(responseWriter, eyeballReader)
|
||||||
|
|
||||||
result := string(<-responseWriter.writeNotification)
|
result := string(<-responseWriter.writeNotification)
|
||||||
require.Equal(t, "data\r\r", result)
|
require.Equal(t, "data\r\r", result)
|
||||||
|
@ -825,6 +831,10 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
// respHeaders is a test function to read respHeaders
|
// respHeaders is a test function to read respHeaders
|
||||||
func (w *wsRespWriter) headers() http.Header {
|
func (w *wsRespWriter) headers() http.Header {
|
||||||
// Removing indeterminstic header because it cannot be asserted.
|
// Removing indeterminstic header because it cannot be asserted.
|
||||||
|
@ -852,6 +862,10 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
||||||
return m.w.Write(p)
|
return m.w.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
m.responseHeaders = header
|
m.responseHeaders = header
|
||||||
m.code = status
|
m.code = status
|
||||||
|
|
Loading…
Reference in New Issue