diff --git a/connection/connection.go b/connection/connection.go index bfc779df..41fd03ab 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "strconv" + "strings" "time" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -55,3 +56,7 @@ type ConnectedFuse interface { func uint8ToString(input uint8) string { return strconv.FormatUint(uint64(input), 10) } + +func isServerSentEvent(headers http.Header) bool { + return strings.ToLower(headers.Get("content-type")) == "text/event-stream" +} diff --git a/connection/h2mux.go b/connection/h2mux.go index 5323fb0e..35932fea 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -205,14 +205,14 @@ type h2muxRespWriter struct { func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error { headers := h2mux.H1ResponseToH2ResponseHeaders(resp) - headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseSourceOrigin}) + headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseMetaHeaderOrigin}) return rp.WriteHeaders(headers) } func (rp *h2muxRespWriter) WriteErrorResponse(err error) { rp.WriteHeaders([]h2mux.Header{ {Name: ":status", Value: "502"}, - {Name: responseMetaHeaderField, Value: responseSourceCloudflared}, + {Name: responseMetaHeaderField, Value: responseMetaHeaderCfd}, }) rp.Write([]byte("502 Bad Gateway")) } diff --git a/connection/header.go b/connection/header.go index 77b69ab8..5c80c953 100644 --- a/connection/header.go +++ b/connection/header.go @@ -25,7 +25,7 @@ type responseMetaHeader struct { func mustInitRespMetaHeader(src string) string { header, err := json.Marshal(responseMetaHeader{Source: src}) if err != nil { - panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", responseSourceCloudflared, err)) + panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err)) } return string(header) } diff --git a/connection/http2.go b/connection/http2.go index a9417d23..b5724d10 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -2,7 +2,7 @@ package connection import ( "context" - "fmt" + "errors" "io" "math" "net" @@ -23,6 +23,10 @@ const ( controlStreamUpgrade = "control-stream" ) +var ( + errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher") +) + type HTTP2Connection struct { conn net.Conn server *http2.Server @@ -37,7 +41,16 @@ type HTTP2Connection struct { connectedFuse ConnectedFuse } -func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection { +func NewHTTP2Connection( + conn net.Conn, + config *Config, + originURL *url.URL, + namedTunnelConfig *NamedTunnelConfig, + connOptions *tunnelpogs.ConnectionOptions, + observer *Observer, + connIndex uint8, + connectedFuse ConnectedFuse, +) *HTTP2Connection { return &HTTP2Connection{ conn: conn, server: &http2.Server{ @@ -77,34 +90,33 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { r: r.Body, w: w, } + flusher, isFlusher := w.(http.Flusher) + if !isFlusher { + c.observer.Errorf("%T doesn't implement http.Flusher", w) + respWriter.WriteErrorResponse(errNotFlusher) + return + } + respWriter.flusher = flusher if isControlStreamUpgrade(r) { + respWriter.shouldFlush = true err := c.serveControlStream(r.Context(), respWriter) if err != nil { respWriter.WriteErrorResponse(err) } } else if isWebsocketUpgrade(r) { - wsRespWriter, err := newWSRespWriter(respWriter) - if err != nil { - respWriter.WriteErrorResponse(err) - return - } + respWriter.shouldFlush = true stripWebsocketUpgradeHeader(r) - c.config.OriginClient.Proxy(wsRespWriter, r, true) + c.config.OriginClient.Proxy(respWriter, r, true) } else { c.config.OriginClient.Proxy(respWriter, r, false) } } -func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error { - stream, err := newWSRespWriter(h2RespWriter) - if err != nil { - return err - } - - rpcClient := newRegistrationRPCClient(ctx, stream, c.observer) +func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { + rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer) defer rpcClient.close() - if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { + if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { return err } c.connectedFuse.Connected() @@ -146,8 +158,10 @@ func (c *HTTP2Connection) close() { } type http2RespWriter struct { - r io.Reader - w http.ResponseWriter + r io.Reader + w http.ResponseWriter + flusher http.Flusher + shouldFlush bool } func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { @@ -172,13 +186,19 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { // Perform user header serialization and set them in the single header dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) - rp.setResponseMetaHeader(responseMetaHeaderCfd) + rp.setResponseMetaHeader(responseMetaHeaderOrigin) status := resp.StatusCode // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 if status == http.StatusSwitchingProtocols { status = http.StatusOK } rp.w.WriteHeader(status) + if isServerSentEvent(resp.Header) { + rp.shouldFlush = true + } + if rp.shouldFlush { + rp.flusher.Flush() + } return nil } @@ -195,43 +215,15 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) { return rp.r.Read(p) } -func (wr *http2RespWriter) Write(p []byte) (n int, err error) { - return wr.w.Write(p) -} - -type wsRespWriter struct { - *http2RespWriter - flusher http.Flusher -} - -func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) { - flusher, ok := h2.w.(http.Flusher) - if !ok { - return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher") +func (rp *http2RespWriter) Write(p []byte) (n int, err error) { + n, err = rp.w.Write(p) + if err == nil && rp.shouldFlush { + rp.flusher.Flush() } - return &wsRespWriter{ - h2, - flusher, - }, nil + return n, err } -func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) (err error) { - err = rw.http2RespWriter.WriteRespHeaders(resp) - if err == nil { - rw.flusher.Flush() - } - return -} - -func (rw *wsRespWriter) Write(p []byte) (n int, err error) { - n, err = rw.http2RespWriter.Write(p) - if err == nil { - rw.flusher.Flush() - } - return -} - -func (rw *wsRespWriter) Close() error { +func (rp *http2RespWriter) Close() error { return nil } diff --git a/hello/hello.go b/hello/hello.go index 63d1830e..b78c8f7e 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -19,8 +19,10 @@ import ( ) const ( - UptimeRoute = "/uptime" - WSRoute = "/ws" + UptimeRoute = "/uptime" + WSRoute = "/ws" + SSERoute = "/sse" + defaultSSEFreq = time.Second * 10 ) type templateData struct { @@ -111,6 +113,7 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow muxer := http.NewServeMux() muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now())) muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader)) + muxer.HandleFunc(SSERoute, sseHandler(logger)) muxer.HandleFunc("/", rootHandler(serverName)) httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer} go func() { @@ -182,6 +185,42 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H } } +func sseHandler(logger logger.Service) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + logger.Errorf("Can't support SSE. ResponseWriter %T doesn't implement http.Flusher interface", w) + return + } + + freq := defaultSSEFreq + if requestedFreq := r.URL.Query()["freq"]; len(requestedFreq) > 0 { + parsedFreq, err := time.ParseDuration(requestedFreq[0]) + if err == nil { + freq = parsedFreq + } + } + logger.Infof("Server Sent Events every %s", freq) + ticker := time.NewTicker(freq) + counter := 0 + for { + select { + case <-r.Context().Done(): + return + case <-ticker.C: + } + _, err := fmt.Fprintf(w, "%d\n\n", counter) + if err != nil { + return + } + flusher.Flush() + counter++ + } + } +} + func rootHandler(serverName string) http.HandlerFunc { responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) return func(w http.ResponseWriter, r *http.Request) { diff --git a/origin/proxy.go b/origin/proxy.go index 6042779e..9c52ab24 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -99,7 +99,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt return nil, errors.Wrap(err, "Error writing response header") } if isEventStream(resp) { - //h.observer.Debug("Detected Server-Side Events from Origin") + c.logger.Debug("Detected Server-Side Events from Origin") c.writeEventStream(w, resp.Body) } else { // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 3d43e997..9a286a3d 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -12,6 +12,7 @@ import ( "net/url" "sync" "testing" + "time" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/hello" @@ -55,9 +56,9 @@ type mockWSRespWriter struct { reader io.Reader } -func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter { +func newMockWSRespWriter(reader io.Reader) *mockWSRespWriter { return &mockWSRespWriter{ - httpRespWriter, + newMockHTTPRespWriter(), make(chan []byte), reader, } @@ -77,6 +78,27 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) { return w.reader.Read(data) } +type mockSSERespWriter struct { + *mockHTTPRespWriter + writeNotification chan []byte +} + +func newMockSSERespWriter() *mockSSERespWriter { + return &mockSSERespWriter{ + newMockHTTPRespWriter(), + make(chan []byte), + } +} + +func (w *mockSSERespWriter) Write(data []byte) (int, error) { + w.writeNotification <- data + return len(data), nil +} + +func (w *mockSSERespWriter) ReadBytes() []byte { + return <-w.writeNotification +} + func TestProxy(t *testing.T) { logger, err := logger.New() require.NoError(t, err) @@ -112,6 +134,7 @@ func TestProxy(t *testing.T) { client := NewClient(proxyConfig, logger) t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL)) t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS)) + t.Run("testProxySSE", testProxySSE(t, client, originURL)) cancel() } @@ -135,7 +158,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil) readPipe, writePipe := io.Pipe() - respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe) + respWriter := newMockWSRespWriter(readPipe) var wg sync.WaitGroup wg.Add(1) @@ -167,3 +190,38 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL wg.Wait() } } + +func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) { + return func(t *testing.T) { + var ( + pushCount = 50 + pushFreq = time.Duration(time.Millisecond * 10) + ) + respWriter := newMockSSERespWriter() + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err = client.Proxy(respWriter, req, false) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, respWriter.Code) + }() + + for i := 0; i < pushCount; i++ { + line := respWriter.ReadBytes() + expect := fmt.Sprintf("%d\n", i) + require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line)) + + line = respWriter.ReadBytes() + require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line)) + } + + cancel() + wg.Wait() + } +}