TUN-3764: Actively flush data for TCP streams

This commit is contained in:
cthuang 2021-01-15 17:25:56 +00:00 committed by Nuno Diegues
parent b4700a52e3
commit 3b93914612
2 changed files with 73 additions and 33 deletions

View File

@ -1,6 +1,7 @@
package connection package connection
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@ -56,9 +57,35 @@ type Type int
const ( const (
TypeWebsocket Type = iota TypeWebsocket Type = iota
TypeTCP TypeTCP
TypeControlStream
TypeHTTP TypeHTTP
) )
// ShouldFlush returns whether this kind of connection should actively flush data
func (t Type) shouldFlush() bool {
switch t {
case TypeWebsocket, TypeTCP, TypeControlStream:
return true
default:
return false
}
}
func (t Type) String() string {
switch t {
case TypeWebsocket:
return "websocket"
case TypeTCP:
return "tcp"
case TypeControlStream:
return "control stream"
case TypeHTTP:
return "http"
default:
return fmt.Sprintf("Unknown Type %d", t)
}
}
type OriginProxy interface { type OriginProxy interface {
Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error
} }

View File

@ -97,44 +97,25 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.activeRequestsWG.Add(1) c.activeRequestsWG.Add(1)
defer c.activeRequestsWG.Done() defer c.activeRequestsWG.Done()
respWriter := &http2RespWriter{ connType := determineHTTP2Type(r)
r: r.Body, respWriter, err := newHTTP2RespWriter(r, w, connType)
w: w, if err != nil {
} c.observer.log.Error().Msg(err.Error())
flusher, isFlusher := w.(http.Flusher)
if !isFlusher {
c.observer.log.Error().Msgf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse()
return return
} }
respWriter.flusher = flusher
switch { var proxyErr error
case isControlStreamUpgrade(r): switch connType {
respWriter.shouldFlush = true case TypeControlStream:
if err := c.serveControlStream(r.Context(), respWriter); err != nil { proxyErr = c.serveControlStream(r.Context(), respWriter)
respWriter.WriteErrorResponse() case TypeWebsocket:
}
return
case isWebsocketUpgrade(r):
respWriter.shouldFlush = true
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
if err := c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket); err != nil { proxyErr = c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket)
respWriter.WriteErrorResponse()
}
return
case IsTCPStream(r):
if err := c.config.OriginProxy.Proxy(respWriter, r, TypeTCP); err != nil {
respWriter.WriteErrorResponse()
}
return
default: default:
if err := c.config.OriginProxy.Proxy(respWriter, r, TypeHTTP); err != nil { proxyErr = c.config.OriginProxy.Proxy(respWriter, r, connType)
respWriter.WriteErrorResponse() }
} if proxyErr != nil {
respWriter.WriteErrorResponse()
} }
} }
@ -174,6 +155,25 @@ type http2RespWriter struct {
shouldFlush bool shouldFlush bool
} }
func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) {
flusher, isFlusher := w.(http.Flusher)
if !isFlusher {
respWriter := &http2RespWriter{
r: r.Body,
w: w,
}
respWriter.WriteErrorResponse()
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
}
return &http2RespWriter{
r: r.Body,
w: w,
flusher: flusher,
shouldFlush: connType.shouldFlush(),
}, nil
}
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
dest := rp.w.Header() dest := rp.w.Header()
userHeaders := make(http.Header, len(header)) userHeaders := make(http.Header, len(header))
@ -243,6 +243,19 @@ func (rp *http2RespWriter) Close() error {
return nil return nil
} }
func determineHTTP2Type(r *http.Request) Type {
switch {
case isWebsocketUpgrade(r):
return TypeWebsocket
case IsTCPStream(r):
return TypeTCP
case isControlStreamUpgrade(r):
return TypeControlStream
default:
return TypeHTTP
}
}
func isControlStreamUpgrade(r *http.Request) bool { func isControlStreamUpgrade(r *http.Request) bool {
return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade
} }