TUN-1781: ServeStream should return early on error

This commit is contained in:
Chung-Ting Huang 2019-04-25 18:13:06 -05:00
parent 137928ecaf
commit 809d2f3f28
1 changed files with 78 additions and 49 deletions

View File

@ -596,32 +596,63 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID)
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
defer h.metrics.decrementConcurrentRequests(h.connectionID)
req, reqErr := h.createRequest(stream)
if reqErr != nil {
h.logError(stream, reqErr)
return reqErr
}
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
h.logger.WithError(err).Error("invalid request received")
}
h.AppendTagHeaders(req)
cfRay := FindCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
h.logRequest(req, cfRay, lbProbe)
var resp *http.Response
var respErr error
if websocket.IsWebSocketUpgrade(req) {
resp, respErr = h.serveWebsocket(stream, req)
} else {
resp, respErr = h.serveHTTP(stream, req)
}
if respErr != nil {
h.logError(stream, respErr)
return respErr
}
h.logResponseOk(resp, cfRay, lbProbe)
return nil
}
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
}
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
h.AppendTagHeaders(req)
return req, nil
}
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil {
h.logError(stream, err)
} else {
stream.WriteHeaders(H1ResponseToH2Response(response))
return nil, err
}
defer conn.Close()
err = stream.WriteHeaders(H1ResponseToH2Response(response))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream)
h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay, lbProbe)
}
} else {
return response, nil
}
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if h.noChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
@ -635,12 +666,15 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
req.Header.Set("Connection", "keep-alive")
response, err := h.httpClient.RoundTrip(req)
if err != nil {
h.logError(stream, err)
} else {
return nil, errors.Wrap(err, "Error proxying request to origin")
}
defer response.Body.Close()
stream.WriteHeaders(H1ResponseToH2Response(response))
err = stream.WriteHeaders(H1ResponseToH2Response(response))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
}
if h.isEventStream(response) {
h.writeEventStream(stream, response.Body)
} else {
@ -648,13 +682,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
// compression generates dictionary on first write
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
}
h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay, lbProbe)
}
}
h.metrics.decrementConcurrentRequests(h.connectionID)
return nil
return response, nil
}
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
@ -702,7 +730,8 @@ func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool
}
}
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) {
func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
h.metrics.incrementResponses(h.connectionID, "200")
logger := log.NewEntry(h.logger)
if cfRay != "" {
logger = logger.WithField("CF-RAY", cfRay)