TUN-7073: Fix propagating of bad stream request from origin to downstream
This changes fixes a bug where cloudflared was not propagating errors when proxying the body of an HTTP request. In a situation where we already sent HTTP status code, the eyeball would see the request as sucessfully when in fact it wasn't. To solve this, we need to guarantee that we produce HTTP RST_STREAM frames. This change was applied to both http2 and quic transports.
This commit is contained in:
parent
bd917d294c
commit
513855df5c
|
@ -115,52 +115,52 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var requestErr error
|
||||||
switch connType {
|
switch connType {
|
||||||
case TypeControlStream:
|
case TypeControlStream:
|
||||||
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator); err != nil {
|
requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator)
|
||||||
c.controlStreamErr = err
|
if requestErr != nil {
|
||||||
c.log.Error().Err(err)
|
c.controlStreamErr = requestErr
|
||||||
respWriter.WriteErrorResponse()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TypeConfiguration:
|
case TypeConfiguration:
|
||||||
if err := c.handleConfigurationUpdate(respWriter, r); err != nil {
|
requestErr = c.handleConfigurationUpdate(respWriter, r)
|
||||||
c.log.Error().Err(err)
|
|
||||||
respWriter.WriteErrorResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
case TypeWebsocket, TypeHTTP:
|
case TypeWebsocket, TypeHTTP:
|
||||||
stripWebsocketUpgradeHeader(r)
|
stripWebsocketUpgradeHeader(r)
|
||||||
// Check for tracing on request
|
// Check for tracing on request
|
||||||
tr := tracing.NewTracedHTTPRequest(r, c.log)
|
tr := tracing.NewTracedHTTPRequest(r, c.log)
|
||||||
if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil {
|
if err := originProxy.ProxyHTTP(respWriter, tr, connType == TypeWebsocket); err != nil {
|
||||||
err := fmt.Errorf("Failed to proxy HTTP: %w", err)
|
requestErr = fmt.Errorf("Failed to proxy HTTP: %w", err)
|
||||||
c.log.Error().Err(err)
|
|
||||||
respWriter.WriteErrorResponse()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TypeTCP:
|
case TypeTCP:
|
||||||
host, err := getRequestHost(r)
|
host, err := getRequestHost(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err)
|
requestErr = fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err)
|
||||||
c.log.Error().Err(err)
|
break
|
||||||
respWriter.WriteErrorResponse()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rws := NewHTTPResponseReadWriterAcker(respWriter, r)
|
rws := NewHTTPResponseReadWriterAcker(respWriter, r)
|
||||||
if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
|
requestErr = originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
|
||||||
Dest: host,
|
Dest: host,
|
||||||
CFRay: FindCfRayHeader(r),
|
CFRay: FindCfRayHeader(r),
|
||||||
LBProbe: IsLBProbeRequest(r),
|
LBProbe: IsLBProbeRequest(r),
|
||||||
CfTraceID: r.Header.Get(tracing.TracerContextName),
|
CfTraceID: r.Header.Get(tracing.TracerContextName),
|
||||||
}); err != nil {
|
})
|
||||||
respWriter.WriteErrorResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err := fmt.Errorf("Received unknown connection type: %s", connType)
|
requestErr = fmt.Errorf("Received unknown connection type: %s", connType)
|
||||||
c.log.Error().Err(err)
|
}
|
||||||
respWriter.WriteErrorResponse()
|
|
||||||
|
if requestErr != nil {
|
||||||
|
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
|
||||||
|
|
||||||
|
// WriteErrorResponse will return false if status was already written. we need to abort handler.
|
||||||
|
if !respWriter.WriteErrorResponse() {
|
||||||
|
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
|
||||||
|
panic(http.ErrAbortHandler)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -275,9 +275,16 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) WriteErrorResponse() {
|
func (rp *http2RespWriter) WriteErrorResponse() bool {
|
||||||
|
if rp.statusWritten {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
||||||
rp.w.WriteHeader(http.StatusBadGateway)
|
rp.w.WriteHeader(http.StatusBadGateway)
|
||||||
|
rp.statusWritten = true
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) setResponseMetaHeader(value string) {
|
func (rp *http2RespWriter) setResponseMetaHeader(value string) {
|
||||||
|
|
|
@ -193,7 +193,11 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
||||||
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
|
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
|
||||||
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
|
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
|
||||||
if err := q.handleStream(ctx, noCloseStream); err != nil {
|
if err := q.handleStream(ctx, noCloseStream); err != nil {
|
||||||
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")
|
||||||
|
|
||||||
|
// if we received an error at this level, then close write side of stream with an error, which will result in
|
||||||
|
// RST_STREAM frame.
|
||||||
|
quicStream.CancelWrite(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,43 +230,60 @@ func (q *QUICConnection) handleDataStream(ctx context.Context, stream *quicpogs.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := q.dispatchRequest(ctx, stream, err, request); err != nil {
|
if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil {
|
||||||
_ = stream.WriteConnectResponseData(err)
|
|
||||||
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
||||||
|
|
||||||
|
// if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is
|
||||||
|
// closed with an RST_STREAM frame
|
||||||
|
if connectResponseSent {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
|
||||||
|
return writeRespErr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) error {
|
// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs.
|
||||||
|
// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream.
|
||||||
|
// This is important since it informs
|
||||||
|
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *quicpogs.RequestServerStream, err error, request *quicpogs.ConnectRequest) (error, bool) {
|
||||||
originProxy, err := q.orchestrator.GetOriginProxy()
|
originProxy, err := q.orchestrator.GetOriginProxy()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch request.Type {
|
switch request.Type {
|
||||||
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
|
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
|
||||||
tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger)
|
tracedReq, err := buildHTTPRequest(ctx, request, stream, q.logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, false
|
||||||
}
|
}
|
||||||
w := newHTTPResponseAdapter(stream)
|
w := newHTTPResponseAdapter(stream)
|
||||||
return originProxy.ProxyHTTP(w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket)
|
return originProxy.ProxyHTTP(&w, tracedReq, request.Type == quicpogs.ConnectionTypeWebsocket), w.connectResponseSent
|
||||||
|
|
||||||
case quicpogs.ConnectionTypeTCP:
|
case quicpogs.ConnectionTypeTCP:
|
||||||
rwa := &streamReadWriteAcker{stream}
|
rwa := &streamReadWriteAcker{RequestServerStream: stream}
|
||||||
metadata := request.MetadataMap()
|
metadata := request.MetadataMap()
|
||||||
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
|
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
|
||||||
Dest: request.Dest,
|
Dest: request.Dest,
|
||||||
FlowID: metadata[QUICMetadataFlowID],
|
FlowID: metadata[QUICMetadataFlowID],
|
||||||
CfTraceID: metadata[tracing.TracerContextName],
|
CfTraceID: metadata[tracing.TracerContextName],
|
||||||
})
|
}), rwa.connectResponseSent
|
||||||
|
default:
|
||||||
|
return errors.Errorf("unsupported error type: %s", request.Type), false
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
|
func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
|
||||||
return rpcStream.Serve(q, q, q.logger)
|
if err := rpcStream.Serve(q, q, q.logger); err != nil {
|
||||||
|
q.logger.Err(err).Msg("failed handling RPC stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
|
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
|
||||||
|
@ -357,6 +378,7 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32,
|
||||||
// the client.
|
// the client.
|
||||||
type streamReadWriteAcker struct {
|
type streamReadWriteAcker struct {
|
||||||
*quicpogs.RequestServerStream
|
*quicpogs.RequestServerStream
|
||||||
|
connectResponseSent bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// AckConnection acks response back to the proxy.
|
// AckConnection acks response back to the proxy.
|
||||||
|
@ -365,23 +387,25 @@ func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
|
||||||
Key: tracing.CanonicalCloudflaredTracingHeader,
|
Key: tracing.CanonicalCloudflaredTracingHeader,
|
||||||
Val: tracePropagation,
|
Val: tracePropagation,
|
||||||
}
|
}
|
||||||
|
s.connectResponseSent = true
|
||||||
return s.WriteConnectResponseData(nil, metadata)
|
return s.WriteConnectResponseData(nil, metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
|
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
|
||||||
type httpResponseAdapter struct {
|
type httpResponseAdapter struct {
|
||||||
*quicpogs.RequestServerStream
|
*quicpogs.RequestServerStream
|
||||||
|
connectResponseSent bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
|
func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
|
||||||
return httpResponseAdapter{s}
|
return httpResponseAdapter{RequestServerStream: s}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hrw httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
|
func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
|
||||||
// we do not support trailers over QUIC
|
// we do not support trailers over QUIC
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
|
func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
metadata := make([]quicpogs.Metadata, 0)
|
metadata := make([]quicpogs.Metadata, 0)
|
||||||
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
|
metadata = append(metadata, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
|
||||||
for k, vv := range header {
|
for k, vv := range header {
|
||||||
|
@ -390,13 +414,19 @@ func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
|
||||||
metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
|
metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return hrw.WriteConnectResponseData(nil, metadata...)
|
return hrw.WriteConnectResponseData(nil, metadata...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
|
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
|
||||||
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...quicpogs.Metadata) error {
|
||||||
|
hrw.connectResponseSent = true
|
||||||
|
return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
|
||||||
|
}
|
||||||
|
|
||||||
func buildHTTPRequest(
|
func buildHTTPRequest(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connectRequest *quicpogs.ConnectRequest,
|
connectRequest *quicpogs.ConnectRequest,
|
||||||
|
|
|
@ -261,7 +261,9 @@ func (p *Proxy) proxyHTTPRequest(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = cfio.Copy(w, resp.Body)
|
if _, err = cfio.Copy(w, resp.Body); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// copy trailers
|
// copy trailers
|
||||||
copyTrailers(w, resp)
|
copyTrailers(w, resp)
|
||||||
|
|
|
@ -261,7 +261,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false)
|
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, &log), false)
|
||||||
require.NoError(t, err)
|
require.Equal(t, err.Error(), "context canceled")
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, responseWriter.Code)
|
require.Equal(t, http.StatusOK, responseWriter.Code)
|
||||||
}()
|
}()
|
||||||
|
|
Loading…
Reference in New Issue