TUN-3557: Detect SSE if content-type starts with text/event-stream

This commit is contained in:
cthuang 2020-11-18 11:53:59 +00:00 committed by Chung Ting Huang
parent 293b9af4a7
commit fdb1f961b3
5 changed files with 49 additions and 18 deletions

View File

@ -53,10 +53,13 @@ type ConnectedFuse interface {
IsConnected() bool IsConnected() bool
} }
func IsServerSentEvent(headers http.Header) bool {
if contentType := headers.Get("content-type"); contentType != "" {
return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream")
}
return false
}
func uint8ToString(input uint8) string { func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10) return strconv.FormatUint(uint64(input), 10)
} }
func isServerSentEvent(headers http.Header) bool {
return strings.ToLower(headers.Get("content-type")) == "text/event-stream"
}

View File

@ -5,11 +5,13 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"testing"
"time" "time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -111,3 +113,40 @@ func (mcf mockConnectedFuse) Connected() {}
func (mcf mockConnectedFuse) IsConnected() bool { func (mcf mockConnectedFuse) IsConnected() bool {
return true return true
} }
func TestIsEventStream(t *testing.T) {
tests := []struct {
headers http.Header
isEventStream bool
}{
{
headers: newHeader("Content-Type", "text/event-stream"),
isEventStream: true,
},
{
headers: newHeader("content-type", "text/event-stream"),
isEventStream: true,
},
{
headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
isEventStream: true,
},
{
headers: newHeader("Content-Type", "application/json"),
isEventStream: false,
},
{
headers: http.Header{},
isEventStream: false,
},
}
for _, test := range tests {
assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
}
}
func newHeader(key, value string) http.Header {
header := http.Header{}
header.Add(key, value)
return header
}

View File

@ -167,7 +167,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
status = http.StatusOK status = http.StatusOK
} }
rp.w.WriteHeader(status) rp.w.WriteHeader(status)
if isServerSentEvent(resp.Header) { if IsServerSentEvent(resp.Header) {
rp.shouldFlush = true rp.shouldFlush = true
} }
if rp.shouldFlush { if rp.shouldFlush {

View File

@ -189,7 +189,7 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H
func sseHandler(logger logger.Service) http.HandlerFunc { func sseHandler(logger logger.Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)

View File

@ -96,7 +96,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error writing response header") return nil, errors.Wrap(err, "Error writing response header")
} }
if isEventStream(resp) { if connection.IsServerSentEvent(resp.Header) {
c.logger.Debug("Detected Server-Side Events from Origin") c.logger.Debug("Detected Server-Side Events from Origin")
c.writeEventStream(w, resp.Body) c.writeEventStream(w, resp.Body)
} else { } else {
@ -222,14 +222,3 @@ func findCfRayHeader(req *http.Request) string {
func isLBProbeRequest(req *http.Request) bool { func isLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
} }
func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10)
}
func isEventStream(response *http.Response) bool {
if response.Header.Get("content-type") == "text/event-stream" {
return true
}
return false
}