diff --git a/connection/connection.go b/connection/connection.go index 25d796dd..eade8ded 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -157,14 +157,16 @@ type ReadWriteAcker interface { type HTTPResponseReadWriteAcker struct { r io.Reader w ResponseWriter + f http.Flusher req *http.Request } // NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker. -func NewHTTPResponseReadWriterAcker(w ResponseWriter, req *http.Request) *HTTPResponseReadWriteAcker { +func NewHTTPResponseReadWriterAcker(w ResponseWriter, flusher http.Flusher, req *http.Request) *HTTPResponseReadWriteAcker { return &HTTPResponseReadWriteAcker{ r: req.Body, w: w, + f: flusher, req: req, } } @@ -174,7 +176,11 @@ func (h *HTTPResponseReadWriteAcker) Read(p []byte) (int, error) { } func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) { - return h.w.Write(p) + n, err := h.w.Write(p) + if n > 0 { + h.f.Flush() + } + return n, err } // AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to diff --git a/connection/connection_test.go b/connection/connection_test.go index 3d6801e0..ffd483d2 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -130,7 +130,8 @@ func wsEchoEndpoint(w ResponseWriter, r *http.Request) error { } wsCtx, cancel := context.WithCancel(r.Context()) readPipe, writePipe := io.Pipe() - wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) + + wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log) go func() { select { case <-wsCtx.Done(): @@ -175,7 +176,7 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error { } wsCtx, cancel := context.WithCancel(r.Context()) - wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) + wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log) closedAfter := time.Millisecond * time.Duration(rand.Intn(50)) originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} diff --git a/connection/http2.go b/connection/http2.go index 4945855b..124746cb 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -142,7 +142,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { break } - rws := NewHTTPResponseReadWriterAcker(respWriter, r) + rws := NewHTTPResponseReadWriterAcker(respWriter, respWriter, r) requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ Dest: host, CFRay: FindCfRayHeader(r), @@ -289,6 +289,10 @@ func (rp *http2RespWriter) Header() http.Header { return rp.respHeaders } +func (rp *http2RespWriter) Flush() { + rp.flusher.Flush() +} + func (rp *http2RespWriter) WriteHeader(status int) { if rp.hijacked() { rp.log.Warn().Msg("WriteHeader after hijack") diff --git a/connection/quic.go b/connection/quic.go index df7188e3..ab968594 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -461,6 +461,9 @@ func (hrw *httpResponseAdapter) Header() http.Header { return hrw.headers } +// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here. +func (hrw *httpResponseAdapter) Flush() {} + func (hrw *httpResponseAdapter) WriteHeader(status int) { hrw.WriteRespHeaders(status, hrw.headers) } diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index b07b67bd..ff39b7f4 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -450,7 +450,7 @@ func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAdd CFRay: "123", LBProbe: false, } - rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req) + rws := connection.NewHTTPResponseReadWriterAcker(respWriter, w.(http.Flusher), req) return originProxy.ProxyTCP(ctx, rws, tcpReq) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 18e53dc2..2acab46a 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -136,8 +136,11 @@ func (p *Proxy) ProxyHTTP( if err != nil { return err } - - rws := connection.NewHTTPResponseReadWriterAcker(w, req) + flusher, ok := w.(http.Flusher) + if !ok { + return fmt.Errorf("response writer is not a flusher") + } + rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil { rule, srv := ruleField(p.ingressRules, ruleNum) p.logRequestError(err, cfRay, "", rule, srv) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 942bbe30..86132b7c 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -698,7 +698,7 @@ func TestConnections(t *testing.T) { }() } if test.args.connectionType == connection.TypeTCP { - rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, req) + rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, respWriter.(http.Flusher), req) err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest}) } else { log := zerolog.Nop() @@ -834,6 +834,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { return nil } +func (w *wsRespWriter) Flush() {} + func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) { // do nothing } @@ -873,6 +875,8 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { return m.w.Write(p) } +func (m *mockTCPRespWriter) Flush() {} + func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) { // do nothing }