TUN-3480: Support SSE with http2 connection, and add SSE handler to hello-world server

This commit is contained in:
cthuang 2020-10-23 15:49:24 +01:00
parent 6b86f81c4a
commit eef5b78eac
7 changed files with 156 additions and 62 deletions

View File

@ -4,6 +4,7 @@ import (
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -55,3 +56,7 @@ type ConnectedFuse interface {
func uint8ToString(input uint8) string { func uint8ToString(input uint8) string {
return strconv.FormatUint(uint64(input), 10) return strconv.FormatUint(uint64(input), 10)
} }
func isServerSentEvent(headers http.Header) bool {
return strings.ToLower(headers.Get("content-type")) == "text/event-stream"
}

View File

@ -205,14 +205,14 @@ type h2muxRespWriter struct {
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error { func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
headers := h2mux.H1ResponseToH2ResponseHeaders(resp) headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseSourceOrigin}) headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseMetaHeaderOrigin})
return rp.WriteHeaders(headers) return rp.WriteHeaders(headers)
} }
func (rp *h2muxRespWriter) WriteErrorResponse(err error) { func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
rp.WriteHeaders([]h2mux.Header{ rp.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"}, {Name: ":status", Value: "502"},
{Name: responseMetaHeaderField, Value: responseSourceCloudflared}, {Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},
}) })
rp.Write([]byte("502 Bad Gateway")) rp.Write([]byte("502 Bad Gateway"))
} }

View File

@ -25,7 +25,7 @@ type responseMetaHeader struct {
func mustInitRespMetaHeader(src string) string { func mustInitRespMetaHeader(src string) string {
header, err := json.Marshal(responseMetaHeader{Source: src}) header, err := json.Marshal(responseMetaHeader{Source: src})
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", responseSourceCloudflared, err)) panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
} }
return string(header) return string(header)
} }

View File

@ -2,7 +2,7 @@ package connection
import ( import (
"context" "context"
"fmt" "errors"
"io" "io"
"math" "math"
"net" "net"
@ -23,6 +23,10 @@ const (
controlStreamUpgrade = "control-stream" controlStreamUpgrade = "control-stream"
) )
var (
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
)
type HTTP2Connection struct { type HTTP2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
@ -37,7 +41,16 @@ type HTTP2Connection struct {
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
} }
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection { func NewHTTP2Connection(
conn net.Conn,
config *Config,
originURL *url.URL,
namedTunnelConfig *NamedTunnelConfig,
connOptions *tunnelpogs.ConnectionOptions,
observer *Observer,
connIndex uint8,
connectedFuse ConnectedFuse,
) *HTTP2Connection {
return &HTTP2Connection{ return &HTTP2Connection{
conn: conn, conn: conn,
server: &http2.Server{ server: &http2.Server{
@ -77,34 +90,33 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r: r.Body, r: r.Body,
w: w, w: w,
} }
flusher, isFlusher := w.(http.Flusher)
if !isFlusher {
c.observer.Errorf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse(errNotFlusher)
return
}
respWriter.flusher = flusher
if isControlStreamUpgrade(r) { if isControlStreamUpgrade(r) {
respWriter.shouldFlush = true
err := c.serveControlStream(r.Context(), respWriter) err := c.serveControlStream(r.Context(), respWriter)
if err != nil { if err != nil {
respWriter.WriteErrorResponse(err) respWriter.WriteErrorResponse(err)
} }
} else if isWebsocketUpgrade(r) { } else if isWebsocketUpgrade(r) {
wsRespWriter, err := newWSRespWriter(respWriter) respWriter.shouldFlush = true
if err != nil {
respWriter.WriteErrorResponse(err)
return
}
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
c.config.OriginClient.Proxy(wsRespWriter, r, true) c.config.OriginClient.Proxy(respWriter, r, true)
} else { } else {
c.config.OriginClient.Proxy(respWriter, r, false) c.config.OriginClient.Proxy(respWriter, r, false)
} }
} }
func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error { func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
stream, err := newWSRespWriter(h2RespWriter) rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer)
if err != nil {
return err
}
rpcClient := newRegistrationRPCClient(ctx, stream, c.observer)
defer rpcClient.close() defer rpcClient.close()
if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err return err
} }
c.connectedFuse.Connected() c.connectedFuse.Connected()
@ -146,8 +158,10 @@ func (c *HTTP2Connection) close() {
} }
type http2RespWriter struct { type http2RespWriter struct {
r io.Reader r io.Reader
w http.ResponseWriter w http.ResponseWriter
flusher http.Flusher
shouldFlush bool
} }
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
@ -172,13 +186,19 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
// Perform user header serialization and set them in the single header // Perform user header serialization and set them in the single header
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.setResponseMetaHeader(responseMetaHeaderOrigin)
status := resp.StatusCode status := resp.StatusCode
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
if status == http.StatusSwitchingProtocols { if status == http.StatusSwitchingProtocols {
status = http.StatusOK status = http.StatusOK
} }
rp.w.WriteHeader(status) rp.w.WriteHeader(status)
if isServerSentEvent(resp.Header) {
rp.shouldFlush = true
}
if rp.shouldFlush {
rp.flusher.Flush()
}
return nil return nil
} }
@ -195,43 +215,15 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
return rp.r.Read(p) return rp.r.Read(p)
} }
func (wr *http2RespWriter) Write(p []byte) (n int, err error) { func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
return wr.w.Write(p) n, err = rp.w.Write(p)
} if err == nil && rp.shouldFlush {
rp.flusher.Flush()
type wsRespWriter struct {
*http2RespWriter
flusher http.Flusher
}
func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) {
flusher, ok := h2.w.(http.Flusher)
if !ok {
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
} }
return &wsRespWriter{ return n, err
h2,
flusher,
}, nil
} }
func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) (err error) { func (rp *http2RespWriter) Close() error {
err = rw.http2RespWriter.WriteRespHeaders(resp)
if err == nil {
rw.flusher.Flush()
}
return
}
func (rw *wsRespWriter) Write(p []byte) (n int, err error) {
n, err = rw.http2RespWriter.Write(p)
if err == nil {
rw.flusher.Flush()
}
return
}
func (rw *wsRespWriter) Close() error {
return nil return nil
} }

View File

@ -19,8 +19,10 @@ import (
) )
const ( const (
UptimeRoute = "/uptime" UptimeRoute = "/uptime"
WSRoute = "/ws" WSRoute = "/ws"
SSERoute = "/sse"
defaultSSEFreq = time.Second * 10
) )
type templateData struct { type templateData struct {
@ -111,6 +113,7 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
muxer := http.NewServeMux() muxer := http.NewServeMux()
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now())) muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader)) muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
muxer.HandleFunc(SSERoute, sseHandler(logger))
muxer.HandleFunc("/", rootHandler(serverName)) muxer.HandleFunc("/", rootHandler(serverName))
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer} httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
go func() { go func() {
@ -182,6 +185,42 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H
} }
} }
func sseHandler(logger logger.Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, ok := w.(http.Flusher)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
logger.Errorf("Can't support SSE. ResponseWriter %T doesn't implement http.Flusher interface", w)
return
}
freq := defaultSSEFreq
if requestedFreq := r.URL.Query()["freq"]; len(requestedFreq) > 0 {
parsedFreq, err := time.ParseDuration(requestedFreq[0])
if err == nil {
freq = parsedFreq
}
}
logger.Infof("Server Sent Events every %s", freq)
ticker := time.NewTicker(freq)
counter := 0
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
}
_, err := fmt.Fprintf(w, "%d\n\n", counter)
if err != nil {
return
}
flusher.Flush()
counter++
}
}
}
func rootHandler(serverName string) http.HandlerFunc { func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {

View File

@ -99,7 +99,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
return nil, errors.Wrap(err, "Error writing response header") return nil, errors.Wrap(err, "Error writing response header")
} }
if isEventStream(resp) { if isEventStream(resp) {
//h.observer.Debug("Detected Server-Side Events from Origin") c.logger.Debug("Detected Server-Side Events from Origin")
c.writeEventStream(w, resp.Body) c.writeEventStream(w, resp.Body)
} else { } else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream

View File

@ -12,6 +12,7 @@ import (
"net/url" "net/url"
"sync" "sync"
"testing" "testing"
"time"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/hello"
@ -55,9 +56,9 @@ type mockWSRespWriter struct {
reader io.Reader reader io.Reader
} }
func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter { func newMockWSRespWriter(reader io.Reader) *mockWSRespWriter {
return &mockWSRespWriter{ return &mockWSRespWriter{
httpRespWriter, newMockHTTPRespWriter(),
make(chan []byte), make(chan []byte),
reader, reader,
} }
@ -77,6 +78,27 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data) return w.reader.Read(data)
} }
type mockSSERespWriter struct {
*mockHTTPRespWriter
writeNotification chan []byte
}
func newMockSSERespWriter() *mockSSERespWriter {
return &mockSSERespWriter{
newMockHTTPRespWriter(),
make(chan []byte),
}
}
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
w.writeNotification <- data
return len(data), nil
}
func (w *mockSSERespWriter) ReadBytes() []byte {
return <-w.writeNotification
}
func TestProxy(t *testing.T) { func TestProxy(t *testing.T) {
logger, err := logger.New() logger, err := logger.New()
require.NoError(t, err) require.NoError(t, err)
@ -112,6 +134,7 @@ func TestProxy(t *testing.T) {
client := NewClient(proxyConfig, logger) client := NewClient(proxyConfig, logger)
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL)) t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS)) t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
t.Run("testProxySSE", testProxySSE(t, client, originURL))
cancel() cancel()
} }
@ -135,7 +158,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
readPipe, writePipe := io.Pipe() readPipe, writePipe := io.Pipe()
respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe) respWriter := newMockWSRespWriter(readPipe)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -167,3 +190,38 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
wg.Wait() wg.Wait()
} }
} }
func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
return func(t *testing.T) {
var (
pushCount = 50
pushFreq = time.Duration(time.Millisecond * 10)
)
respWriter := newMockSSERespWriter()
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = client.Proxy(respWriter, req, false)
require.NoError(t, err)
require.Equal(t, http.StatusOK, respWriter.Code)
}()
for i := 0; i < pushCount; i++ {
line := respWriter.ReadBytes()
expect := fmt.Sprintf("%d\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
line = respWriter.ReadBytes()
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
}
cancel()
wg.Wait()
}
}