From 3b93914612ce97e3b4e36e25cfcf60216ec117df Mon Sep 17 00:00:00 2001 From: cthuang Date: Fri, 15 Jan 2021 17:25:56 +0000 Subject: [PATCH] TUN-3764: Actively flush data for TCP streams --- connection/connection.go | 27 ++++++++++++++ connection/http2.go | 79 +++++++++++++++++++++++----------------- 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/connection/connection.go b/connection/connection.go index bc59b1e2..145beae2 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -1,6 +1,7 @@ package connection import ( + "fmt" "io" "net/http" "strconv" @@ -56,9 +57,35 @@ type Type int const ( TypeWebsocket Type = iota TypeTCP + TypeControlStream TypeHTTP ) +// ShouldFlush returns whether this kind of connection should actively flush data +func (t Type) shouldFlush() bool { + switch t { + case TypeWebsocket, TypeTCP, TypeControlStream: + return true + default: + return false + } +} + +func (t Type) String() string { + switch t { + case TypeWebsocket: + return "websocket" + case TypeTCP: + return "tcp" + case TypeControlStream: + return "control stream" + case TypeHTTP: + return "http" + default: + return fmt.Sprintf("Unknown Type %d", t) + } +} + type OriginProxy interface { Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error } diff --git a/connection/http2.go b/connection/http2.go index c46a8ead..086975a0 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -97,44 +97,25 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.activeRequestsWG.Add(1) defer c.activeRequestsWG.Done() - respWriter := &http2RespWriter{ - r: r.Body, - w: w, - } - flusher, isFlusher := w.(http.Flusher) - if !isFlusher { - c.observer.log.Error().Msgf("%T doesn't implement http.Flusher", w) - respWriter.WriteErrorResponse() + connType := determineHTTP2Type(r) + respWriter, err := newHTTP2RespWriter(r, w, connType) + if err != nil { + c.observer.log.Error().Msg(err.Error()) return } - respWriter.flusher = flusher - switch { - case isControlStreamUpgrade(r): - respWriter.shouldFlush = true - if err := c.serveControlStream(r.Context(), respWriter); err != nil { - respWriter.WriteErrorResponse() - } - return - - case isWebsocketUpgrade(r): - respWriter.shouldFlush = true + var proxyErr error + switch connType { + case TypeControlStream: + proxyErr = c.serveControlStream(r.Context(), respWriter) + case TypeWebsocket: stripWebsocketUpgradeHeader(r) - if err := c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket); err != nil { - respWriter.WriteErrorResponse() - } - return - - case IsTCPStream(r): - if err := c.config.OriginProxy.Proxy(respWriter, r, TypeTCP); err != nil { - respWriter.WriteErrorResponse() - } - return - + proxyErr = c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket) default: - if err := c.config.OriginProxy.Proxy(respWriter, r, TypeHTTP); err != nil { - respWriter.WriteErrorResponse() - } + proxyErr = c.config.OriginProxy.Proxy(respWriter, r, connType) + } + if proxyErr != nil { + respWriter.WriteErrorResponse() } } @@ -174,6 +155,25 @@ type http2RespWriter struct { shouldFlush bool } +func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) { + flusher, isFlusher := w.(http.Flusher) + if !isFlusher { + respWriter := &http2RespWriter{ + r: r.Body, + w: w, + } + respWriter.WriteErrorResponse() + return nil, fmt.Errorf("%T doesn't implement http.Flusher", w) + } + + return &http2RespWriter{ + r: r.Body, + w: w, + flusher: flusher, + shouldFlush: connType.shouldFlush(), + }, nil +} + func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { dest := rp.w.Header() userHeaders := make(http.Header, len(header)) @@ -243,6 +243,19 @@ func (rp *http2RespWriter) Close() error { return nil } +func determineHTTP2Type(r *http.Request) Type { + switch { + case isWebsocketUpgrade(r): + return TypeWebsocket + case IsTCPStream(r): + return TypeTCP + case isControlStreamUpgrade(r): + return TypeControlStream + default: + return TypeHTTP + } +} + func isControlStreamUpgrade(r *http.Request) bool { return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade }