diff --git a/connection/connection.go b/connection/connection.go index 6140e3db..e089483b 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -200,6 +200,7 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro type ResponseWriter interface { WriteRespHeaders(status int, header http.Header) error AddTrailer(trailerName, trailerValue string) + http.ResponseWriter io.Writer } diff --git a/connection/http2.go b/connection/http2.go index b895b7db..98bcbfbc 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -197,6 +197,7 @@ type http2RespWriter struct { flusher http.Flusher shouldFlush bool statusWritten bool + respHeaders http.Header log *zerolog.Logger } @@ -276,6 +277,14 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro return nil } +func (rp *http2RespWriter) Header() http.Header { + return rp.respHeaders +} + +func (rp *http2RespWriter) WriteHeader(status int) { + rp.WriteRespHeaders(status, rp.respHeaders) +} + func (rp *http2RespWriter) WriteErrorResponse() bool { if rp.statusWritten { return false diff --git a/connection/quic.go b/connection/quic.go index 49f77e49..328a42c0 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -402,6 +402,7 @@ func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. type httpResponseAdapter struct { *quicpogs.RequestServerStream + headers http.Header connectResponseSent bool } @@ -426,6 +427,14 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) return hrw.WriteConnectResponseData(nil, metadata...) } +func (hrw *httpResponseAdapter) Header() http.Header { + return hrw.headers +} + +func (hrw *httpResponseAdapter) WriteHeader(status int) { + hrw.WriteRespHeaders(status, hrw.headers) +} + func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 8277eeaa..81e351e5 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -700,8 +700,7 @@ func TestConnections(t *testing.T) { cancel() assert.Equal(t, test.want.err, err != nil) assert.Equal(t, test.want.message, replayer.Bytes()) - respPrinter := respWriter.(responsePrinter) - assert.Equal(t, test.want.headers, respPrinter.headers()) + assert.Equal(t, test.want.headers, respWriter.Header()) replayer.rw.Reset() }) } @@ -794,10 +793,6 @@ func (p *pipedRequestBody) Close() error { return nil } -type responsePrinter interface { - headers() http.Header -} - type wsRespWriter struct { w io.Writer responseHeaders http.Header @@ -836,12 +831,16 @@ func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) { } // respHeaders is a test function to read respHeaders -func (w *wsRespWriter) headers() http.Header { +func (w *wsRespWriter) Header() http.Header { // Removing indeterminstic header because it cannot be asserted. w.responseHeaders.Del("Date") return w.responseHeaders } +func (w *wsRespWriter) WriteHeader(status int) { + // unused +} + type mockTCPRespWriter struct { w io.Writer responseHeaders http.Header @@ -862,7 +861,7 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { return m.w.Write(p) } -func (w *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) { +func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) { // do nothing } @@ -873,10 +872,14 @@ func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) err } // respHeaders is a test function to read respHeaders -func (m *mockTCPRespWriter) headers() http.Header { +func (m *mockTCPRespWriter) Header() http.Header { return m.responseHeaders } +func (m *mockTCPRespWriter) WriteHeader(status int) { + // do nothing +} + func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress { ingressConfig := &config.Configuration{ Ingress: []config.UnvalidatedIngressRule{