TUN-7558: Flush on Writes for StreamBasedOriginProxy

In the streambased origin proxy flow (example ssh over access), there is
a chance when we do not flush on http.ResponseWriter writes. This PR
guarantees that the response writer passed to proxy stream has a flusher
embedded after writes. This means we write much more often back to the
ResponseWriter and are not waiting. Note, this is only something we do
when proxyHTTP-ing to a StreamBasedOriginProxy because we do not want to
have situations where we are not sending information that is needed by
the other side (eyeball).
This commit is contained in:
Sudarsan Reddy 2023-07-06 14:42:44 +01:00
parent d1e338ee48
commit 39847a70f2
7 changed files with 30 additions and 9 deletions

View File

@ -157,14 +157,16 @@ type ReadWriteAcker interface {
type HTTPResponseReadWriteAcker struct { type HTTPResponseReadWriteAcker struct {
r io.Reader r io.Reader
w ResponseWriter w ResponseWriter
f http.Flusher
req *http.Request req *http.Request
} }
// NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker. // NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker.
func NewHTTPResponseReadWriterAcker(w ResponseWriter, req *http.Request) *HTTPResponseReadWriteAcker { func NewHTTPResponseReadWriterAcker(w ResponseWriter, flusher http.Flusher, req *http.Request) *HTTPResponseReadWriteAcker {
return &HTTPResponseReadWriteAcker{ return &HTTPResponseReadWriteAcker{
r: req.Body, r: req.Body,
w: w, w: w,
f: flusher,
req: req, req: req,
} }
} }
@ -174,7 +176,11 @@ func (h *HTTPResponseReadWriteAcker) Read(p []byte) (int, error) {
} }
func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) { func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) {
return h.w.Write(p) n, err := h.w.Write(p)
if n > 0 {
h.f.Flush()
}
return n, err
} }
// AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to // AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to

View File

@ -130,7 +130,8 @@ func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
} }
wsCtx, cancel := context.WithCancel(r.Context()) wsCtx, cancel := context.WithCancel(r.Context())
readPipe, writePipe := io.Pipe() readPipe, writePipe := io.Pipe()
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
go func() { go func() {
select { select {
case <-wsCtx.Done(): case <-wsCtx.Done():
@ -175,7 +176,7 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
} }
wsCtx, cancel := context.WithCancel(r.Context()) wsCtx, cancel := context.WithCancel(r.Context())
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50)) closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}

View File

@ -142,7 +142,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
break break
} }
rws := NewHTTPResponseReadWriterAcker(respWriter, r) rws := NewHTTPResponseReadWriterAcker(respWriter, respWriter, r)
requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
Dest: host, Dest: host,
CFRay: FindCfRayHeader(r), CFRay: FindCfRayHeader(r),
@ -289,6 +289,10 @@ func (rp *http2RespWriter) Header() http.Header {
return rp.respHeaders return rp.respHeaders
} }
func (rp *http2RespWriter) Flush() {
rp.flusher.Flush()
}
func (rp *http2RespWriter) WriteHeader(status int) { func (rp *http2RespWriter) WriteHeader(status int) {
if rp.hijacked() { if rp.hijacked() {
rp.log.Warn().Msg("WriteHeader after hijack") rp.log.Warn().Msg("WriteHeader after hijack")

View File

@ -461,6 +461,9 @@ func (hrw *httpResponseAdapter) Header() http.Header {
return hrw.headers return hrw.headers
} }
// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
func (hrw *httpResponseAdapter) Flush() {}
func (hrw *httpResponseAdapter) WriteHeader(status int) { func (hrw *httpResponseAdapter) WriteHeader(status int) {
hrw.WriteRespHeaders(status, hrw.headers) hrw.WriteRespHeaders(status, hrw.headers)
} }

View File

@ -450,7 +450,7 @@ func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAdd
CFRay: "123", CFRay: "123",
LBProbe: false, LBProbe: false,
} }
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req) rws := connection.NewHTTPResponseReadWriterAcker(respWriter, w.(http.Flusher), req)
return originProxy.ProxyTCP(ctx, rws, tcpReq) return originProxy.ProxyTCP(ctx, rws, tcpReq)
} }

View File

@ -136,8 +136,11 @@ func (p *Proxy) ProxyHTTP(
if err != nil { if err != nil {
return err return err
} }
flusher, ok := w.(http.Flusher)
rws := connection.NewHTTPResponseReadWriterAcker(w, req) if !ok {
return fmt.Errorf("response writer is not a flusher")
}
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil { if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum) rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, "", rule, srv) p.logRequestError(err, cfRay, "", rule, srv)

View File

@ -698,7 +698,7 @@ func TestConnections(t *testing.T) {
}() }()
} }
if test.args.connectionType == connection.TypeTCP { if test.args.connectionType == connection.TypeTCP {
rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, req) rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, respWriter.(http.Flusher), req)
err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest}) err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest})
} else { } else {
log := zerolog.Nop() log := zerolog.Nop()
@ -834,6 +834,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
return nil return nil
} }
func (w *wsRespWriter) Flush() {}
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) { func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing // do nothing
} }
@ -873,6 +875,8 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
return m.w.Write(p) return m.w.Write(p)
} }
func (m *mockTCPRespWriter) Flush() {}
func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) { func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing // do nothing
} }