diff --git a/connection/connection.go b/connection/connection.go index d1b1081e..07348e70 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -53,10 +53,13 @@ type ConnectedFuse interface { IsConnected() bool } +func IsServerSentEvent(headers http.Header) bool { + if contentType := headers.Get("content-type"); contentType != "" { + return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") + } + return false +} + func uint8ToString(input uint8) string { return strconv.FormatUint(uint64(input), 10) } - -func isServerSentEvent(headers http.Header) bool { - return strings.ToLower(headers.Get("content-type")) == "text/event-stream" -} diff --git a/connection/connection_test.go b/connection/connection_test.go index f15dacfb..07ea8a30 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -5,11 +5,13 @@ import ( "io" "net/http" "net/url" + "testing" "time" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/logger" "github.com/gobwas/ws/wsutil" + "github.com/stretchr/testify/assert" ) const ( @@ -111,3 +113,40 @@ func (mcf mockConnectedFuse) Connected() {} func (mcf mockConnectedFuse) IsConnected() bool { return true } + +func TestIsEventStream(t *testing.T) { + tests := []struct { + headers http.Header + isEventStream bool + }{ + { + headers: newHeader("Content-Type", "text/event-stream"), + isEventStream: true, + }, + { + headers: newHeader("content-type", "text/event-stream"), + isEventStream: true, + }, + { + headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"), + isEventStream: true, + }, + { + headers: newHeader("Content-Type", "application/json"), + isEventStream: false, + }, + { + headers: http.Header{}, + isEventStream: false, + }, + } + for _, test := range tests { + assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers)) + } +} + +func newHeader(key, value string) http.Header { + header := http.Header{} + header.Add(key, value) + return header +} diff --git a/connection/http2.go b/connection/http2.go index a2f00e2c..f148c91b 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -167,7 +167,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { status = http.StatusOK } rp.w.WriteHeader(status) - if isServerSentEvent(resp.Header) { + if IsServerSentEvent(resp.Header) { rp.shouldFlush = true } if rp.shouldFlush { diff --git a/hello/hello.go b/hello/hello.go index fcb41821..8ccdcd22 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -189,7 +189,7 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H func sseHandler(logger logger.Service) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") flusher, ok := w.(http.Flusher) if !ok { w.WriteHeader(http.StatusInternalServerError) diff --git a/origin/proxy.go b/origin/proxy.go index ddefdbf4..f3b34982 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -96,7 +96,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule if err != nil { return nil, errors.Wrap(err, "Error writing response header") } - if isEventStream(resp) { + if connection.IsServerSentEvent(resp.Header) { c.logger.Debug("Detected Server-Side Events from Origin") c.writeEventStream(w, resp.Body) } else { @@ -222,14 +222,3 @@ func findCfRayHeader(req *http.Request) string { func isLBProbeRequest(req *http.Request) bool { return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) } - -func uint8ToString(input uint8) string { - return strconv.FormatUint(uint64(input), 10) -} - -func isEventStream(response *http.Response) bool { - if response.Header.Get("content-type") == "text/event-stream" { - return true - } - return false -}