diff --git a/connection/http2.go b/connection/http2.go index a6756b1d..5fa26bcb 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -115,52 +115,52 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + var requestErr error switch connType { case TypeControlStream: - if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator); err != nil { - c.controlStreamErr = err - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator) + if requestErr != nil { + c.controlStreamErr = requestErr } case TypeConfiguration: - if err := c.handleConfigurationUpdate(respWriter, r); err != nil { - c.log.Error().Err(err) - respWriter.WriteErrorResponse() - } + requestErr = c.handleConfigurationUpdate(respWriter, r) case TypeWebsocket, TypeHTTP: stripWebsocketUpgradeHeader(r) // Check for tracing on request tr := tracing.NewTracedHTTPRequest(r, c.log) if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil { - err := fmt.Errorf("Failed to proxy HTTP: %w", err) - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = fmt.Errorf("Failed to proxy HTTP: %w", err) } case TypeTCP: host, err := getRequestHost(r) if err != nil { - err := fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err) - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err) + break } rws := NewHTTPResponseReadWriterAcker(respWriter, r) - if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ + requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ Dest: host, CFRay: FindCfRayHeader(r), LBProbe: IsLBProbeRequest(r), CfTraceID: r.Header.Get(tracing.TracerContextName), - }); err != nil { - respWriter.WriteErrorResponse() - } + }) default: - err := fmt.Errorf("Received unknown connection type: %s", connType) - c.log.Error().Err(err) - respWriter.WriteErrorResponse() + requestErr = fmt.Errorf("Received unknown connection type: %s", connType) + } + + if requestErr != nil { + c.log.Error().Err(requestErr).Msg("failed to serve incoming request") + + // WriteErrorResponse will return false if status was already written. we need to abort handler. + if !respWriter.WriteErrorResponse() { + c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent") + panic(http.ErrAbortHandler) + } } } @@ -275,9 +275,16 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro return nil } -func (rp *http2RespWriter) WriteErrorResponse() { +func (rp *http2RespWriter) WriteErrorResponse() bool { + if rp.statusWritten { + return false + } + rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.w.WriteHeader(http.StatusBadGateway) + rp.statusWritten = true + + return true } func (rp *http2RespWriter) setResponseMetaHeader(value string) { diff --git a/connection/quic.go b/connection/quic.go index 182cb881..4c4a4207 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -193,7 +193,11 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) { // A call to close will simulate a close to the read-side, which will fail subsequent reads. noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} if err := q.handleStream(ctx, noCloseStream); err != nil { - q.logger.Err(err).Msg("Failed to handle QUIC stream") + q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream") + + // if we received an error at this level, then close write side of stream with an error, which will result in + // RST_STREAM frame. + quicStream.CancelWrite(0) } } @@ -226,43 +230,60 @@ func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs. return err } - if err := q.dispatchRequest(ctx, stream, err, request); err != nil { - _ = stream.WriteConnectResponseData(err) + if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil { q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") + + // if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is + // closed with an RST_STREAM frame + if connectResponseSent { + return err + } + + if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil { + return writeRespErr + } } return nil } -func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error { +// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs. +// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream. +// This is important since it informs +func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) (error, bool) { originProxy, err := q.orchestrator.GetOriginProxy() if err != nil { - return err + return err, false } switch request.Type { case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger) if err != nil { - return err + return err, false } w := newHTTPResponseAdapter(stream) - return originProxy.ProxyHTTP(w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket) + return originProxy.ProxyHTTP(&w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket), w.connectResponseSent case quicpogs.ConnectionTypeTCP: - rwa := &streamReadWriteAcker{stream} + rwa := &streamReadWriteAcker{RequestServerStream: stream} metadata := request.MetadataMap() return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ Dest: request.Dest, FlowID: metadata[QUICMetadataFlowID], CfTraceID: metadata[tracing.TracerContextName], - }) + }), rwa.connectResponseSent + default: + return errors.Errorf("unsupported error type: %s", request.Type), false } - return nil } func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error { - return rpcStream.Serve(q, q, q.logger) + if err := rpcStream.Serve(q, q, q.logger); err != nil { + q.logger.Err(err).Msg("failed handling RPC stream") + } + + return nil } // RegisterUdpSession is the RPC method invoked by edge to register and run a session @@ -357,6 +378,7 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, // the client. type streamReadWriteAcker struct { *quicpogs.RequestServerStream + connectResponseSent bool } // AckConnection acks response back to the proxy. @@ -365,23 +387,25 @@ func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error { Key: tracing.CanonicalCloudflaredTracingHeader, Val: tracePropagation, } + s.connectResponseSent = true return s.WriteConnectResponseData(nil, metadata) } // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. type httpResponseAdapter struct { *quicpogs.RequestServerStream + connectResponseSent bool } func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter { - return httpResponseAdapter{s} + return httpResponseAdapter{RequestServerStream: s} } -func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { +func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { // we do not support trailers over QUIC } -func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { +func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { metadata := make([]quicpogs.Metadata, 0) metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) for k, vv := range header { @@ -390,13 +414,19 @@ func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v}) } } + return hrw.WriteConnectResponseData(nil, metadata...) } -func (hrw httpResponseAdapter) WriteErrorResponse(err error) { +func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } +func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...quicpogs.Metadata) error { + hrw.connectResponseSent = true + return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...) +} + func buildHTTPRequest( ctx context.Context, connectRequest *quicpogs.ConnectRequest, diff --git a/proxy/proxy.go b/proxy/proxy.go index fa994b94..55c45dc0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -261,7 +261,9 @@ func (p *Proxy) proxyHTTPRequest( return nil } - _, _ = cfio.Copy(w, resp.Body) + if _, err = cfio.Copy(w, resp.Body); err != nil { + return err + } // copy trailers copyTrailers(w, resp) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 384298c3..58e541d5 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -261,7 +261,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { defer wg.Done() log := zerolog.Nop() err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false) - require.NoError(t, err) + require.Equal(t, err.Error(), "context canceled") require.Equal(t, http.StatusOK, responseWriter.Code) }()