From 809d2f3f28fa90e732f5063e704e7ae371df4201 Mon Sep 17 00:00:00 2001 From: Chung-Ting Huang Date: Thu, 25 Apr 2019 18:13:06 -0500 Subject: [PATCH] TUN-1781: ServeStream should return early on error --- origin/tunnel.go | 127 +++++++++++++++++++++++++++++------------------ 1 file changed, 78 insertions(+), 49 deletions(-) diff --git a/origin/tunnel.go b/origin/tunnel.go index 960ec42f..f8b66a28 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -596,65 +596,93 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { h.metrics.incrementRequests(h.connectionID) - req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) - if err != nil { - h.logger.WithError(err).Panic("Unexpected error from http.NewRequest") + defer h.metrics.decrementConcurrentRequests(h.connectionID) + + req, reqErr := h.createRequest(stream) + if reqErr != nil { + h.logError(stream, reqErr) + return reqErr } - err = H2RequestHeadersToH1Request(stream.Headers, req) - if err != nil { - h.logger.WithError(err).Error("invalid request received") - } - h.AppendTagHeaders(req) + cfRay := FindCfRayHeader(req) lbProbe := isLBProbeRequest(req) h.logRequest(req, cfRay, lbProbe) + + var resp *http.Response + var respErr error if websocket.IsWebSocketUpgrade(req) { - conn, response, err := websocket.ClientConnect(req, h.tlsConfig) - if err != nil { - h.logError(stream, err) - } else { - stream.WriteHeaders(H1ResponseToH2Response(response)) - defer conn.Close() - // Copy to/from stream to the undelying connection. Use the underlying - // connection because cloudflared doesn't operate on the message themselves - websocket.Stream(conn.UnderlyingConn(), stream) - h.metrics.incrementResponses(h.connectionID, "200") - h.logResponse(response, cfRay, lbProbe) - } + resp, respErr = h.serveWebsocket(stream, req) } else { - // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if h.noChunkedEncoding { - req.TransferEncoding = []string{"gzip", "deflate"} - cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) - if err == nil { - req.ContentLength = int64(cLength) - } - } + resp, respErr = h.serveHTTP(stream, req) + } + if respErr != nil { + h.logError(stream, respErr) + return respErr + } + h.logResponseOk(resp, cfRay, lbProbe) + return nil +} - // Request origin to keep connection alive to improve performance - req.Header.Set("Connection", "keep-alive") +func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) { + req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) + if err != nil { + return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") + } + err = H2RequestHeadersToH1Request(stream.Headers, req) + if err != nil { + return nil, errors.Wrap(err, "invalid request received") + } + h.AppendTagHeaders(req) + return req, nil +} - response, err := h.httpClient.RoundTrip(req) +func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { + conn, response, err := websocket.ClientConnect(req, h.tlsConfig) + if err != nil { + return nil, err + } + defer conn.Close() + err = stream.WriteHeaders(H1ResponseToH2Response(response)) + if err != nil { + return nil, errors.Wrap(err, "Error writing response header") + } + // Copy to/from stream to the undelying connection. Use the underlying + // connection because cloudflared doesn't operate on the message themselves + websocket.Stream(conn.UnderlyingConn(), stream) + return response, nil +} - if err != nil { - h.logError(stream, err) - } else { - defer response.Body.Close() - stream.WriteHeaders(H1ResponseToH2Response(response)) - if h.isEventStream(response) { - h.writeEventStream(stream, response.Body) - } else { - // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream - // compression generates dictionary on first write - io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) - } - - h.metrics.incrementResponses(h.connectionID, "200") - h.logResponse(response, cfRay, lbProbe) +func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { + // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate + if h.noChunkedEncoding { + req.TransferEncoding = []string{"gzip", "deflate"} + cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) + if err == nil { + req.ContentLength = int64(cLength) } } - h.metrics.decrementConcurrentRequests(h.connectionID) - return nil + + // Request origin to keep connection alive to improve performance + req.Header.Set("Connection", "keep-alive") + + response, err := h.httpClient.RoundTrip(req) + if err != nil { + return nil, errors.Wrap(err, "Error proxying request to origin") + } + defer response.Body.Close() + + err = stream.WriteHeaders(H1ResponseToH2Response(response)) + if err != nil { + return nil, errors.Wrap(err, "Error writing response header") + } + if h.isEventStream(response) { + h.writeEventStream(stream, response.Body) + } else { + // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream + // compression generates dictionary on first write + io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) + } + return response, nil } func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) { @@ -702,7 +730,8 @@ func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool } } -func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) { +func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { + h.metrics.incrementResponses(h.connectionID, "200") logger := log.NewEntry(h.logger) if cfRay != "" { logger = logger.WithField("CF-RAY", cfRay)