TUN-3557: Detect SSE if content-type starts with text/event-stream
This commit is contained in:
parent
293b9af4a7
commit
fdb1f961b3
|
@ -53,10 +53,13 @@ type ConnectedFuse interface {
|
||||||
IsConnected() bool
|
IsConnected() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsServerSentEvent(headers http.Header) bool {
|
||||||
|
if contentType := headers.Get("content-type"); contentType != "" {
|
||||||
|
return strings.HasPrefix(strings.ToLower(contentType), "text/event-stream")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func uint8ToString(input uint8) string {
|
func uint8ToString(input uint8) string {
|
||||||
return strconv.FormatUint(uint64(input), 10)
|
return strconv.FormatUint(uint64(input), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isServerSentEvent(headers http.Header) bool {
|
|
||||||
return strings.ToLower(headers.Get("content-type")) == "text/event-stream"
|
|
||||||
}
|
|
||||||
|
|
|
@ -5,11 +5,13 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -111,3 +113,40 @@ func (mcf mockConnectedFuse) Connected() {}
|
||||||
func (mcf mockConnectedFuse) IsConnected() bool {
|
func (mcf mockConnectedFuse) IsConnected() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsEventStream(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
headers http.Header
|
||||||
|
isEventStream bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
headers: newHeader("Content-Type", "text/event-stream"),
|
||||||
|
isEventStream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headers: newHeader("content-type", "text/event-stream"),
|
||||||
|
isEventStream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headers: newHeader("Content-Type", "text/event-stream; charset=utf-8"),
|
||||||
|
isEventStream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headers: newHeader("Content-Type", "application/json"),
|
||||||
|
isEventStream: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headers: http.Header{},
|
||||||
|
isEventStream: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
assert.Equal(t, test.isEventStream, IsServerSentEvent(test.headers))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHeader(key, value string) http.Header {
|
||||||
|
header := http.Header{}
|
||||||
|
header.Add(key, value)
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
|
@ -167,7 +167,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
status = http.StatusOK
|
status = http.StatusOK
|
||||||
}
|
}
|
||||||
rp.w.WriteHeader(status)
|
rp.w.WriteHeader(status)
|
||||||
if isServerSentEvent(resp.Header) {
|
if IsServerSentEvent(resp.Header) {
|
||||||
rp.shouldFlush = true
|
rp.shouldFlush = true
|
||||||
}
|
}
|
||||||
if rp.shouldFlush {
|
if rp.shouldFlush {
|
||||||
|
|
|
@ -189,7 +189,7 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H
|
||||||
|
|
||||||
func sseHandler(logger logger.Service) http.HandlerFunc {
|
func sseHandler(logger logger.Service) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
|
|
@ -96,7 +96,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error writing response header")
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
}
|
}
|
||||||
if isEventStream(resp) {
|
if connection.IsServerSentEvent(resp.Header) {
|
||||||
c.logger.Debug("Detected Server-Side Events from Origin")
|
c.logger.Debug("Detected Server-Side Events from Origin")
|
||||||
c.writeEventStream(w, resp.Body)
|
c.writeEventStream(w, resp.Body)
|
||||||
} else {
|
} else {
|
||||||
|
@ -222,14 +222,3 @@ func findCfRayHeader(req *http.Request) string {
|
||||||
func isLBProbeRequest(req *http.Request) bool {
|
func isLBProbeRequest(req *http.Request) bool {
|
||||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
func uint8ToString(input uint8) string {
|
|
||||||
return strconv.FormatUint(uint64(input), 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isEventStream(response *http.Response) bool {
|
|
||||||
if response.Header.Get("content-type") == "text/event-stream" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue