From 513855df5cd9b2c1aa234c75f7248c969834e068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Oliveirinha?= Date: Mon, 16 Jan 2023 12:42:59 +0000 Subject: [PATCH] TUN-7073: Fix propagating of bad stream request from origin to downstream This changes fixes a bug where cloudflared was not propagating errors when proxying the body of an HTTP request. In a situation where we already sent HTTP status code, the eyeball would see the request as sucessfully when in fact it wasn't. To solve this, we need to guarantee that we produce HTTP RST_STREAM frames. This change was applied to both http2 and quic transports. --- connection/http2.go | 51 +++++++++++++++++++++----------------- connection/quic.go | 60 +++++++++++++++++++++++++++++++++------------ proxy/proxy.go | 4 ++- proxy/proxy_test.go | 2 +- 4 files changed, 78 insertions(+), 39 deletions(-) diff --git a/connection/http2.go b/connection/http2.go index a6756b1d..5fa26bcb 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -115,52 +115,52 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + var requestErr error switch connType { case TypeControlStream: - if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator); err != nil { - c.controlStreamErr = err - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator) + if requestErr != nil { + c.controlStreamErr = requestErr } case TypeConfiguration: - if err := c.handleConfigurationUpdate(respWriter, r); err != nil { - c.log.Error().Err(err) - respWriter.WriteErrorResponse() - } + requestErr = c.handleConfigurationUpdate(respWriter, r) case TypeWebsocket, TypeHTTP: stripWebsocketUpgradeHeader(r) // Check for tracing on request tr := tracing.NewTracedHTTPRequest(r, c.log) if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil { - err := fmt.Errorf("Failed to proxy HTTP: %w", err) - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = fmt.Errorf("Failed to proxy HTTP: %w", err) } case TypeTCP: host, err := getRequestHost(r) if err != nil { - err := fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err) - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err) + break } rws := NewHTTPResponseReadWriterAcker(respWriter, r) - if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ + requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ Dest: host, CFRay: FindCfRayHeader(r), LBProbe: IsLBProbeRequest(r), CfTraceID: r.Header.Get(tracing.TracerContextName), - }); err != nil { - respWriter.WriteErrorResponse() - } + }) default: - err := fmt.Errorf("Received unknown connection type: %s", connType) - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = fmt.Errorf("Received unknown connection type: %s", connType) + } + + if requestErr != nil { + c.log.Error().Err(requestErr).Msg("failed to serve incoming request") + + // WriteErrorResponse will return false if status was already written. we need to abort handler. + if !respWriter.WriteErrorResponse() { + c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent") + panic(http.ErrAbortHandler) + } } } @@ -275,9 +275,16 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro return nil } -func (rp *http2RespWriter) WriteErrorResponse() { +func (rp *http2RespWriter) WriteErrorResponse() bool { + if rp.statusWritten { + return false + } + rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.w.WriteHeader(http.StatusBadGateway) + rp.statusWritten = true + + return true } func (rp *http2RespWriter) setResponseMetaHeader(value string) { diff --git a/connection/quic.go b/connection/quic.go index 182cb881..4c4a4207 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -193,7 +193,11 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) { // A call to close will simulate a close to the read-side, which will fail subsequent reads. noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} if err := q.handleStream(ctx, noCloseStream); err != nil { - q.logger.Err(err).Msg("Failed to handle QUIC stream") + q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream") + + // if we received an error at this level, then close write side of stream with an error, which will result in + // RST_STREAM frame. + quicStream.CancelWrite(0) } } @@ -226,43 +230,60 @@ func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs. return err } - if err := q.dispatchRequest(ctx, stream, err, request); err != nil { - _ = stream.WriteConnectResponseData(err) + if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil { q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") + + // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is + // closed with an RST_STREAM frame + if connectResponseSent { + return err + } + + if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { + return writeRespErr + } } return nil } -func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error { +// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs. +// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream. +// This is important since it informs +func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) (error, bool) { originProxy, err := q.orchestrator.GetOriginProxy() if err != nil { - return err + return err, false } switch request.Type { case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger) if err != nil { - return err + return err, false } w := newHTTPResponseAdapter(stream) - return originProxy.ProxyHTTP(w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket) + return originProxy.ProxyHTTP(&w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket), w.connectResponseSent case quicpogs.ConnectionTypeTCP: - rwa := &streamReadWriteAcker{stream} + rwa := &streamReadWriteAcker{RequestServerStream: stream} metadata := request.MetadataMap() return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ Dest: request.Dest, FlowID: metadata[QUICMetadataFlowID], CfTraceID: metadata[tracing.TracerContextName], - }) + }), rwa.connectResponseSent + default: + return errors.Errorf("unsupported error type: %s", request.Type), false } - return nil } func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error { - return rpcStream.Serve(q, q, q.logger) + if err := rpcStream.Serve(q, q, q.logger); err != nil { + q.logger.Err(err).Msg("failed handling RPC stream") + } + + return nil } // RegisterUdpSession is the RPC method invoked by edge to register and run a session @@ -357,6 +378,7 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, // the client. type streamReadWriteAcker struct { *quicpogs.RequestServerStream + connectResponseSent bool } // AckConnection acks response back to the proxy. @@ -365,23 +387,25 @@ func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { Key: tracing.CanonicalCloudflaredTracingHeader, Val: tracePropagation, } + s.connectResponseSent = true return s.WriteConnectResponseData(nil, metadata) } // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. type httpResponseAdapter struct { *quicpogs.RequestServerStream + connectResponseSent bool } func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter { - return httpResponseAdapter{s} + return httpResponseAdapter{RequestServerStream: s} } -func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { +func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { // we do not support trailers over QUIC } -func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { +func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { metadata := make([]quicpogs.Metadata, 0) metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) for k, vv := range header { @@ -390,13 +414,19 @@ func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v}) } } + return hrw.WriteConnectResponseData(nil, metadata...) } -func (hrw httpResponseAdapter) WriteErrorResponse(err error) { +func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } +func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...quicpogs.Metadata) error { + hrw.connectResponseSent = true + return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...) +} + func buildHTTPRequest( ctx context.Context, connectRequest *quicpogs.ConnectRequest, diff --git a/proxy/proxy.go b/proxy/proxy.go index fa994b94..55c45dc0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -261,7 +261,9 @@ func (p *Proxy) proxyHTTPRequest( return nil } - _, _ = cfio.Copy(w, resp.Body) + if _, err = cfio.Copy(w, resp.Body); err != nil { + return err + } // copy trailers copyTrailers(w, resp) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 384298c3..58e541d5 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -261,7 +261,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { defer wg.Done() log := zerolog.Nop() err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false) - require.NoError(t, err) + require.Equal(t, err.Error(), "context canceled") require.Equal(t, http.StatusOK, responseWriter.Code) }()