TUN-7073: Fix propagating of bad stream request from origin to downstream

This changes fixes a bug where cloudflared was not propagating errors
when proxying the body of an HTTP request.

In a situation where we already sent HTTP status code, the eyeball would
see the request as sucessfully when in fact it wasn't.

To solve this, we need to guarantee that we produce HTTP RST_STREAM
frames.
This change was applied to both http2 and quic transports.
This commit is contained in:
João Oliveirinha 2023-01-16 12:42:59 +00:00
parent bd917d294c
commit 513855df5c
4 changed files with 78 additions and 39 deletions

View File

@ -115,52 +115,52 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
var requestErr error
switch connType { switch connType {
case TypeControlStream: case TypeControlStream:
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator); err != nil { requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator)
c.controlStreamErr = err if requestErr != nil {
c.log.Error().Err(err) c.controlStreamErr = requestErr
respWriter.WriteErrorResponse()
} }
case TypeConfiguration: case TypeConfiguration:
if err := c.handleConfigurationUpdate(respWriter, r); err != nil { requestErr = c.handleConfigurationUpdate(respWriter, r)
c.log.Error().Err(err)
respWriter.WriteErrorResponse()
}
case TypeWebsocket, TypeHTTP: case TypeWebsocket, TypeHTTP:
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
// Check for tracing on request // Check for tracing on request
tr := tracing.NewTracedHTTPRequest(r, c.log) tr := tracing.NewTracedHTTPRequest(r, c.log)
if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil { if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil {
err := fmt.Errorf("Failed to proxy HTTP: %w", err) requestErr = fmt.Errorf("Failed to proxy HTTP: %w", err)
c.log.Error().Err(err)
respWriter.WriteErrorResponse()
} }
case TypeTCP: case TypeTCP:
host, err := getRequestHost(r) host, err := getRequestHost(r)
if err != nil { if err != nil {
err := fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err) requestErr = fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err)
c.log.Error().Err(err) break
respWriter.WriteErrorResponse()
} }
rws := NewHTTPResponseReadWriterAcker(respWriter, r) rws := NewHTTPResponseReadWriterAcker(respWriter, r)
if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
Dest: host, Dest: host,
CFRay: FindCfRayHeader(r), CFRay: FindCfRayHeader(r),
LBProbe: IsLBProbeRequest(r), LBProbe: IsLBProbeRequest(r),
CfTraceID: r.Header.Get(tracing.TracerContextName), CfTraceID: r.Header.Get(tracing.TracerContextName),
}); err != nil { })
respWriter.WriteErrorResponse()
}
default: default:
err := fmt.Errorf("Received unknown connection type: %s", connType) requestErr = fmt.Errorf("Received unknown connection type: %s", connType)
c.log.Error().Err(err) }
respWriter.WriteErrorResponse()
if requestErr != nil {
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
// WriteErrorResponse will return false if status was already written. we need to abort handler.
if !respWriter.WriteErrorResponse() {
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
panic(http.ErrAbortHandler)
}
} }
} }
@ -275,9 +275,16 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
return nil return nil
} }
func (rp *http2RespWriter) WriteErrorResponse() { func (rp *http2RespWriter) WriteErrorResponse() bool {
if rp.statusWritten {
return false
}
rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.setResponseMetaHeader(responseMetaHeaderCfd)
rp.w.WriteHeader(http.StatusBadGateway) rp.w.WriteHeader(http.StatusBadGateway)
rp.statusWritten = true
return true
} }
func (rp *http2RespWriter) setResponseMetaHeader(value string) { func (rp *http2RespWriter) setResponseMetaHeader(value string) {

View File

@ -193,7 +193,11 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) {
// A call to close will simulate a close to the read-side, which will fail subsequent reads. // A call to close will simulate a close to the read-side, which will fail subsequent reads.
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
if err := q.handleStream(ctx, noCloseStream); err != nil { if err := q.handleStream(ctx, noCloseStream); err != nil {
q.logger.Err(err).Msg("Failed to handle QUIC stream") q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")
// if we received an error at this level, then close write side of stream with an error, which will result in
// RST_STREAM frame.
quicStream.CancelWrite(0)
} }
} }
@ -226,43 +230,60 @@ func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.
return err return err
} }
if err := q.dispatchRequest(ctx, stream, err, request); err != nil { if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil {
_ = stream.WriteConnectResponseData(err)
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed") q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
// if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is
// closed with an RST_STREAM frame
if connectResponseSent {
return err
}
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
return writeRespErr
}
} }
return nil return nil
} }
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error { // dispatchRequest will dispatch the request depending on the type and returns an error if it occurs.
// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream.
// This is important since it informs
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) (error, bool) {
originProxy, err := q.orchestrator.GetOriginProxy() originProxy, err := q.orchestrator.GetOriginProxy()
if err != nil { if err != nil {
return err return err, false
} }
switch request.Type { switch request.Type {
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger) tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger)
if err != nil { if err != nil {
return err return err, false
} }
w := newHTTPResponseAdapter(stream) w := newHTTPResponseAdapter(stream)
return originProxy.ProxyHTTP(w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket) return originProxy.ProxyHTTP(&w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket), w.connectResponseSent
case quicpogs.ConnectionTypeTCP: case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{stream} rwa := &streamReadWriteAcker{RequestServerStream: stream}
metadata := request.MetadataMap() metadata := request.MetadataMap()
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{ return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
Dest: request.Dest, Dest: request.Dest,
FlowID: metadata[QUICMetadataFlowID], FlowID: metadata[QUICMetadataFlowID],
CfTraceID: metadata[tracing.TracerContextName], CfTraceID: metadata[tracing.TracerContextName],
}) }), rwa.connectResponseSent
default:
return errors.Errorf("unsupported error type: %s", request.Type), false
} }
return nil
} }
func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error { func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
return rpcStream.Serve(q, q, q.logger) if err := rpcStream.Serve(q, q, q.logger); err != nil {
q.logger.Err(err).Msg("failed handling RPC stream")
}
return nil
} }
// RegisterUdpSession is the RPC method invoked by edge to register and run a session // RegisterUdpSession is the RPC method invoked by edge to register and run a session
@ -357,6 +378,7 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32,
// the client. // the client.
type streamReadWriteAcker struct { type streamReadWriteAcker struct {
*quicpogs.RequestServerStream *quicpogs.RequestServerStream
connectResponseSent bool
} }
// AckConnection acks response back to the proxy. // AckConnection acks response back to the proxy.
@ -365,23 +387,25 @@ func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
Key: tracing.CanonicalCloudflaredTracingHeader, Key: tracing.CanonicalCloudflaredTracingHeader,
Val: tracePropagation, Val: tracePropagation,
} }
s.connectResponseSent = true
return s.WriteConnectResponseData(nil, metadata) return s.WriteConnectResponseData(nil, metadata)
} }
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
type httpResponseAdapter struct { type httpResponseAdapter struct {
*quicpogs.RequestServerStream *quicpogs.RequestServerStream
connectResponseSent bool
} }
func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter { func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
return httpResponseAdapter{s} return httpResponseAdapter{RequestServerStream: s}
} }
func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
// we do not support trailers over QUIC // we do not support trailers over QUIC
} }
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
metadata := make([]quicpogs.Metadata, 0) metadata := make([]quicpogs.Metadata, 0)
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)}) metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
for k, vv := range header { for k, vv := range header {
@ -390,13 +414,19 @@ func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v}) metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
} }
} }
return hrw.WriteConnectResponseData(nil, metadata...) return hrw.WriteConnectResponseData(nil, metadata...)
} }
func (hrw httpResponseAdapter) WriteErrorResponse(err error) { func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
} }
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...quicpogs.Metadata) error {
hrw.connectResponseSent = true
return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
}
func buildHTTPRequest( func buildHTTPRequest(
ctx context.Context, ctx context.Context,
connectRequest *quicpogs.ConnectRequest, connectRequest *quicpogs.ConnectRequest,

View File

@ -261,7 +261,9 @@ func (p *Proxy) proxyHTTPRequest(
return nil return nil
} }
_, _ = cfio.Copy(w, resp.Body) if _, err = cfio.Copy(w, resp.Body); err != nil {
return err
}
// copy trailers // copy trailers
copyTrailers(w, resp) copyTrailers(w, resp)

View File

@ -261,7 +261,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
defer wg.Done() defer wg.Done()
log := zerolog.Nop() log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false) err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false)
require.NoError(t, err) require.Equal(t, err.Error(), "context canceled")
require.Equal(t, http.StatusOK, responseWriter.Code) require.Equal(t, http.StatusOK, responseWriter.Code)
}() }()