TUN-1781: ServeStream should return early on error

This commit is contained in:
Chung-Ting Huang 2019-04-25 18:13:06 -05:00
parent 137928ecaf
commit 809d2f3f28
1 changed files with 78 additions and 49 deletions

View File

@ -596,65 +596,93 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID) h.metrics.incrementRequests(h.connectionID)
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) defer h.metrics.decrementConcurrentRequests(h.connectionID)
if err != nil {
h.logger.WithError(err).Panic("Unexpected error from http.NewRequest") req, reqErr := h.createRequest(stream)
if reqErr != nil {
h.logError(stream, reqErr)
return reqErr
} }
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
h.logger.WithError(err).Error("invalid request received")
}
h.AppendTagHeaders(req)
cfRay := FindCfRayHeader(req) cfRay := FindCfRayHeader(req)
lbProbe := isLBProbeRequest(req) lbProbe := isLBProbeRequest(req)
h.logRequest(req, cfRay, lbProbe) h.logRequest(req, cfRay, lbProbe)
var resp *http.Response
var respErr error
if websocket.IsWebSocketUpgrade(req) { if websocket.IsWebSocketUpgrade(req) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig) resp, respErr = h.serveWebsocket(stream, req)
if err != nil {
h.logError(stream, err)
} else {
stream.WriteHeaders(H1ResponseToH2Response(response))
defer conn.Close()
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream)
h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay, lbProbe)
}
} else { } else {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate resp, respErr = h.serveHTTP(stream, req)
if h.noChunkedEncoding { }
req.TransferEncoding = []string{"gzip", "deflate"} if respErr != nil {
cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) h.logError(stream, respErr)
if err == nil { return respErr
req.ContentLength = int64(cLength) }
} h.logResponseOk(resp, cfRay, lbProbe)
} return nil
}
// Request origin to keep connection alive to improve performance func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req.Header.Set("Connection", "keep-alive") req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
h.AppendTagHeaders(req)
return req, nil
}
response, err := h.httpClient.RoundTrip(req) func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil {
return nil, err
}
defer conn.Close()
err = stream.WriteHeaders(H1ResponseToH2Response(response))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream)
return response, nil
}
if err != nil { func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
h.logError(stream, err) // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
} else { if h.noChunkedEncoding {
defer response.Body.Close() req.TransferEncoding = []string{"gzip", "deflate"}
stream.WriteHeaders(H1ResponseToH2Response(response)) cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if h.isEventStream(response) { if err == nil {
h.writeEventStream(stream, response.Body) req.ContentLength = int64(cLength)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
}
h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay, lbProbe)
} }
} }
h.metrics.decrementConcurrentRequests(h.connectionID)
return nil // Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
response, err := h.httpClient.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
defer response.Body.Close()
err = stream.WriteHeaders(H1ResponseToH2Response(response))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
if h.isEventStream(response) {
h.writeEventStream(stream, response.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
}
return response, nil
} }
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) { func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
@ -702,7 +730,8 @@ func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool
} }
} }
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) { func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
h.metrics.incrementResponses(h.connectionID, "200")
logger := log.NewEntry(h.logger) logger := log.NewEntry(h.logger)
if cfRay != "" { if cfRay != "" {
logger = logger.WithField("CF-RAY", cfRay) logger = logger.WithField("CF-RAY", cfRay)