TUN-7558: Flush on Writes for StreamBasedOriginProxy

In the streambased origin proxy flow (example ssh over access), there is
a chance when we do not flush on http.ResponseWriter writes. This PR
guarantees that the response writer passed to proxy stream has a flusher
embedded after writes. This means we write much more often back to the
ResponseWriter and are not waiting. Note, this is only something we do
when proxyHTTP-ing to a StreamBasedOriginProxy because we do not want to
have situations where we are not sending information that is needed by
the other side (eyeball).
This commit is contained in:
Sudarsan Reddy 2023-07-06 14:42:44 +01:00 committed by Jean Khawand
parent 286addc102
commit 4f79a2baba
7 changed files with 30 additions and 9 deletions

View File

@ -157,14 +157,16 @@ type ReadWriteAcker interface {
type HTTPResponseReadWriteAcker struct {
r io.Reader
w ResponseWriter
f http.Flusher
req *http.Request
}
// NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker.
func NewHTTPResponseReadWriterAcker(w ResponseWriter, req *http.Request) *HTTPResponseReadWriteAcker {
func NewHTTPResponseReadWriterAcker(w ResponseWriter, flusher http.Flusher, req *http.Request) *HTTPResponseReadWriteAcker {
return &HTTPResponseReadWriteAcker{
r: req.Body,
w: w,
f: flusher,
req: req,
}
}
@ -174,7 +176,11 @@ func (h *HTTPResponseReadWriteAcker) Read(p []byte) (int, error) {
}
func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) {
return h.w.Write(p)
n, err := h.w.Write(p)
if n > 0 {
h.f.Flush()
}
return n, err
}
// AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to

View File

@ -130,7 +130,8 @@ func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
}
wsCtx, cancel := context.WithCancel(r.Context())
readPipe, writePipe := io.Pipe()
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
go func() {
select {
case <-wsCtx.Done():
@ -175,7 +176,7 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
}
wsCtx, cancel := context.WithCancel(r.Context())
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}

View File

@ -142,7 +142,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
break
}
rws := NewHTTPResponseReadWriterAcker(respWriter, r)
rws := NewHTTPResponseReadWriterAcker(respWriter, respWriter, r)
requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
Dest: host,
CFRay: FindCfRayHeader(r),
@ -289,6 +289,10 @@ func (rp *http2RespWriter) Header() http.Header {
return rp.respHeaders
}
func (rp *http2RespWriter) Flush() {
rp.flusher.Flush()
}
func (rp *http2RespWriter) WriteHeader(status int) {
if rp.hijacked() {
rp.log.Warn().Msg("WriteHeader after hijack")

View File

@ -461,6 +461,9 @@ func (hrw *httpResponseAdapter) Header() http.Header {
return hrw.headers
}
// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
func (hrw *httpResponseAdapter) Flush() {}
func (hrw *httpResponseAdapter) WriteHeader(status int) {
hrw.WriteRespHeaders(status, hrw.headers)
}

View File

@ -450,7 +450,7 @@ func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAdd
CFRay: "123",
LBProbe: false,
}
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, w.(http.Flusher), req)
return originProxy.ProxyTCP(ctx, rws, tcpReq)
}

View File

@ -136,8 +136,11 @@ func (p *Proxy) ProxyHTTP(
if err != nil {
return err
}
rws := connection.NewHTTPResponseReadWriterAcker(w, req)
flusher, ok := w.(http.Flusher)
if !ok {
return fmt.Errorf("response writer is not a flusher")
}
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, "", rule, srv)

View File

@ -698,7 +698,7 @@ func TestConnections(t *testing.T) {
}()
}
if test.args.connectionType == connection.TypeTCP {
rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, respWriter.(http.Flusher), req)
err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest})
} else {
log := zerolog.Nop()
@ -834,6 +834,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
return nil
}
func (w *wsRespWriter) Flush() {}
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing
}
@ -873,6 +875,8 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
return m.w.Write(p)
}
func (m *mockTCPRespWriter) Flush() {}
func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
// do nothing
}