TUN-3480: Support SSE with http2 connection, and add SSE handler to hello-world server
This commit is contained in:
parent
6b86f81c4a
commit
eef5b78eac
|
@ -4,6 +4,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
@ -55,3 +56,7 @@ type ConnectedFuse interface {
|
||||||
func uint8ToString(input uint8) string {
|
func uint8ToString(input uint8) string {
|
||||||
return strconv.FormatUint(uint64(input), 10)
|
return strconv.FormatUint(uint64(input), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isServerSentEvent(headers http.Header) bool {
|
||||||
|
return strings.ToLower(headers.Get("content-type")) == "text/event-stream"
|
||||||
|
}
|
||||||
|
|
|
@ -205,14 +205,14 @@ type h2muxRespWriter struct {
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
|
func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
|
headers := h2mux.H1ResponseToH2ResponseHeaders(resp)
|
||||||
headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseSourceOrigin})
|
headers = append(headers, h2mux.Header{Name: responseMetaHeaderField, Value: responseMetaHeaderOrigin})
|
||||||
return rp.WriteHeaders(headers)
|
return rp.WriteHeaders(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
|
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
|
||||||
rp.WriteHeaders([]h2mux.Header{
|
rp.WriteHeaders([]h2mux.Header{
|
||||||
{Name: ":status", Value: "502"},
|
{Name: ":status", Value: "502"},
|
||||||
{Name: responseMetaHeaderField, Value: responseSourceCloudflared},
|
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},
|
||||||
})
|
})
|
||||||
rp.Write([]byte("502 Bad Gateway"))
|
rp.Write([]byte("502 Bad Gateway"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ type responseMetaHeader struct {
|
||||||
func mustInitRespMetaHeader(src string) string {
|
func mustInitRespMetaHeader(src string) string {
|
||||||
header, err := json.Marshal(responseMetaHeader{Source: src})
|
header, err := json.Marshal(responseMetaHeader{Source: src})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", responseSourceCloudflared, err))
|
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
|
||||||
}
|
}
|
||||||
return string(header)
|
return string(header)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
@ -23,6 +23,10 @@ const (
|
||||||
controlStreamUpgrade = "control-stream"
|
controlStreamUpgrade = "control-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
|
||||||
|
)
|
||||||
|
|
||||||
type HTTP2Connection struct {
|
type HTTP2Connection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
server *http2.Server
|
server *http2.Server
|
||||||
|
@ -37,7 +41,16 @@ type HTTP2Connection struct {
|
||||||
connectedFuse ConnectedFuse
|
connectedFuse ConnectedFuse
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection {
|
func NewHTTP2Connection(
|
||||||
|
conn net.Conn,
|
||||||
|
config *Config,
|
||||||
|
originURL *url.URL,
|
||||||
|
namedTunnelConfig *NamedTunnelConfig,
|
||||||
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
|
observer *Observer,
|
||||||
|
connIndex uint8,
|
||||||
|
connectedFuse ConnectedFuse,
|
||||||
|
) *HTTP2Connection {
|
||||||
return &HTTP2Connection{
|
return &HTTP2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
server: &http2.Server{
|
server: &http2.Server{
|
||||||
|
@ -77,34 +90,33 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
r: r.Body,
|
r: r.Body,
|
||||||
w: w,
|
w: w,
|
||||||
}
|
}
|
||||||
|
flusher, isFlusher := w.(http.Flusher)
|
||||||
|
if !isFlusher {
|
||||||
|
c.observer.Errorf("%T doesn't implement http.Flusher", w)
|
||||||
|
respWriter.WriteErrorResponse(errNotFlusher)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
respWriter.flusher = flusher
|
||||||
if isControlStreamUpgrade(r) {
|
if isControlStreamUpgrade(r) {
|
||||||
|
respWriter.shouldFlush = true
|
||||||
err := c.serveControlStream(r.Context(), respWriter)
|
err := c.serveControlStream(r.Context(), respWriter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respWriter.WriteErrorResponse(err)
|
respWriter.WriteErrorResponse(err)
|
||||||
}
|
}
|
||||||
} else if isWebsocketUpgrade(r) {
|
} else if isWebsocketUpgrade(r) {
|
||||||
wsRespWriter, err := newWSRespWriter(respWriter)
|
respWriter.shouldFlush = true
|
||||||
if err != nil {
|
|
||||||
respWriter.WriteErrorResponse(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
stripWebsocketUpgradeHeader(r)
|
stripWebsocketUpgradeHeader(r)
|
||||||
c.config.OriginClient.Proxy(wsRespWriter, r, true)
|
c.config.OriginClient.Proxy(respWriter, r, true)
|
||||||
} else {
|
} else {
|
||||||
c.config.OriginClient.Proxy(respWriter, r, false)
|
c.config.OriginClient.Proxy(respWriter, r, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) serveControlStream(ctx context.Context, h2RespWriter *http2RespWriter) error {
|
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
||||||
stream, err := newWSRespWriter(h2RespWriter)
|
rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rpcClient := newRegistrationRPCClient(ctx, stream, c.observer)
|
|
||||||
defer rpcClient.close()
|
defer rpcClient.close()
|
||||||
|
|
||||||
if err = registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.connectedFuse.Connected()
|
c.connectedFuse.Connected()
|
||||||
|
@ -148,6 +160,8 @@ func (c *HTTP2Connection) close() {
|
||||||
type http2RespWriter struct {
|
type http2RespWriter struct {
|
||||||
r io.Reader
|
r io.Reader
|
||||||
w http.ResponseWriter
|
w http.ResponseWriter
|
||||||
|
flusher http.Flusher
|
||||||
|
shouldFlush bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
@ -172,13 +186,19 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
|
||||||
// Perform user header serialization and set them in the single header
|
// Perform user header serialization and set them in the single header
|
||||||
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
||||||
status := resp.StatusCode
|
status := resp.StatusCode
|
||||||
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
||||||
if status == http.StatusSwitchingProtocols {
|
if status == http.StatusSwitchingProtocols {
|
||||||
status = http.StatusOK
|
status = http.StatusOK
|
||||||
}
|
}
|
||||||
rp.w.WriteHeader(status)
|
rp.w.WriteHeader(status)
|
||||||
|
if isServerSentEvent(resp.Header) {
|
||||||
|
rp.shouldFlush = true
|
||||||
|
}
|
||||||
|
if rp.shouldFlush {
|
||||||
|
rp.flusher.Flush()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,43 +215,15 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
||||||
return rp.r.Read(p)
|
return rp.r.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wr *http2RespWriter) Write(p []byte) (n int, err error) {
|
func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
|
||||||
return wr.w.Write(p)
|
n, err = rp.w.Write(p)
|
||||||
}
|
if err == nil && rp.shouldFlush {
|
||||||
|
rp.flusher.Flush()
|
||||||
type wsRespWriter struct {
|
|
||||||
*http2RespWriter
|
|
||||||
flusher http.Flusher
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWSRespWriter(h2 *http2RespWriter) (*wsRespWriter, error) {
|
|
||||||
flusher, ok := h2.w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher")
|
|
||||||
}
|
}
|
||||||
return &wsRespWriter{
|
return n, err
|
||||||
h2,
|
|
||||||
flusher,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *wsRespWriter) WriteRespHeaders(resp *http.Response) (err error) {
|
func (rp *http2RespWriter) Close() error {
|
||||||
err = rw.http2RespWriter.WriteRespHeaders(resp)
|
|
||||||
if err == nil {
|
|
||||||
rw.flusher.Flush()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *wsRespWriter) Write(p []byte) (n int, err error) {
|
|
||||||
n, err = rw.http2RespWriter.Write(p)
|
|
||||||
if err == nil {
|
|
||||||
rw.flusher.Flush()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *wsRespWriter) Close() error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@ import (
|
||||||
const (
|
const (
|
||||||
UptimeRoute = "/uptime"
|
UptimeRoute = "/uptime"
|
||||||
WSRoute = "/ws"
|
WSRoute = "/ws"
|
||||||
|
SSERoute = "/sse"
|
||||||
|
defaultSSEFreq = time.Second * 10
|
||||||
)
|
)
|
||||||
|
|
||||||
type templateData struct {
|
type templateData struct {
|
||||||
|
@ -111,6 +113,7 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
|
||||||
muxer := http.NewServeMux()
|
muxer := http.NewServeMux()
|
||||||
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
|
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
|
||||||
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
|
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
|
||||||
|
muxer.HandleFunc(SSERoute, sseHandler(logger))
|
||||||
muxer.HandleFunc("/", rootHandler(serverName))
|
muxer.HandleFunc("/", rootHandler(serverName))
|
||||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
|
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -182,6 +185,42 @@ func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.H
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sseHandler(logger logger.Service) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
logger.Errorf("Can't support SSE. ResponseWriter %T doesn't implement http.Flusher interface", w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
freq := defaultSSEFreq
|
||||||
|
if requestedFreq := r.URL.Query()["freq"]; len(requestedFreq) > 0 {
|
||||||
|
parsedFreq, err := time.ParseDuration(requestedFreq[0])
|
||||||
|
if err == nil {
|
||||||
|
freq = parsedFreq
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Infof("Server Sent Events every %s", freq)
|
||||||
|
ticker := time.NewTicker(freq)
|
||||||
|
counter := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
_, err := fmt.Fprintf(w, "%d\n\n", counter)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
counter++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func rootHandler(serverName string) http.HandlerFunc {
|
func rootHandler(serverName string) http.HandlerFunc {
|
||||||
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
|
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
|
||||||
return nil, errors.Wrap(err, "Error writing response header")
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
}
|
}
|
||||||
if isEventStream(resp) {
|
if isEventStream(resp) {
|
||||||
//h.observer.Debug("Detected Server-Side Events from Origin")
|
c.logger.Debug("Detected Server-Side Events from Origin")
|
||||||
c.writeEventStream(w, resp.Body)
|
c.writeEventStream(w, resp.Body)
|
||||||
} else {
|
} else {
|
||||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
|
@ -55,9 +56,9 @@ type mockWSRespWriter struct {
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter {
|
func newMockWSRespWriter(reader io.Reader) *mockWSRespWriter {
|
||||||
return &mockWSRespWriter{
|
return &mockWSRespWriter{
|
||||||
httpRespWriter,
|
newMockHTTPRespWriter(),
|
||||||
make(chan []byte),
|
make(chan []byte),
|
||||||
reader,
|
reader,
|
||||||
}
|
}
|
||||||
|
@ -77,6 +78,27 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||||
return w.reader.Read(data)
|
return w.reader.Read(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockSSERespWriter struct {
|
||||||
|
*mockHTTPRespWriter
|
||||||
|
writeNotification chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSSERespWriter() *mockSSERespWriter {
|
||||||
|
return &mockSSERespWriter{
|
||||||
|
newMockHTTPRespWriter(),
|
||||||
|
make(chan []byte),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
|
||||||
|
w.writeNotification <- data
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockSSERespWriter) ReadBytes() []byte {
|
||||||
|
return <-w.writeNotification
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxy(t *testing.T) {
|
func TestProxy(t *testing.T) {
|
||||||
logger, err := logger.New()
|
logger, err := logger.New()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -112,6 +134,7 @@ func TestProxy(t *testing.T) {
|
||||||
client := NewClient(proxyConfig, logger)
|
client := NewClient(proxyConfig, logger)
|
||||||
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
|
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
|
||||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
|
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
|
||||||
|
t.Run("testProxySSE", testProxySSE(t, client, originURL))
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +158,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
|
||||||
|
|
||||||
readPipe, writePipe := io.Pipe()
|
readPipe, writePipe := io.Pipe()
|
||||||
respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe)
|
respWriter := newMockWSRespWriter(readPipe)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -167,3 +190,38 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
pushCount = 50
|
||||||
|
pushFreq = time.Duration(time.Millisecond * 10)
|
||||||
|
)
|
||||||
|
respWriter := newMockSSERespWriter()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err = client.Proxy(respWriter, req, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, respWriter.Code)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < pushCount; i++ {
|
||||||
|
line := respWriter.ReadBytes()
|
||||||
|
expect := fmt.Sprintf("%d\n", i)
|
||||||
|
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
||||||
|
|
||||||
|
line = respWriter.ReadBytes()
|
||||||
|
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue