TUN-1781: ServeStream should return early on error
This commit is contained in:
parent
137928ecaf
commit
809d2f3f28
127
origin/tunnel.go
127
origin/tunnel.go
|
@ -596,65 +596,93 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
|||
|
||||
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
h.metrics.incrementRequests(h.connectionID)
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||
defer h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||
|
||||
req, reqErr := h.createRequest(stream)
|
||||
if reqErr != nil {
|
||||
h.logError(stream, reqErr)
|
||||
return reqErr
|
||||
}
|
||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("invalid request received")
|
||||
}
|
||||
h.AppendTagHeaders(req)
|
||||
|
||||
cfRay := FindCfRayHeader(req)
|
||||
lbProbe := isLBProbeRequest(req)
|
||||
h.logRequest(req, cfRay, lbProbe)
|
||||
|
||||
var resp *http.Response
|
||||
var respErr error
|
||||
if websocket.IsWebSocketUpgrade(req) {
|
||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||
if err != nil {
|
||||
h.logError(stream, err)
|
||||
} else {
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
defer conn.Close()
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
h.metrics.incrementResponses(h.connectionID, "200")
|
||||
h.logResponse(response, cfRay, lbProbe)
|
||||
}
|
||||
resp, respErr = h.serveWebsocket(stream, req)
|
||||
} else {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if h.noChunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
resp, respErr = h.serveHTTP(stream, req)
|
||||
}
|
||||
if respErr != nil {
|
||||
h.logError(stream, respErr)
|
||||
return respErr
|
||||
}
|
||||
h.logResponseOk(resp, cfRay, lbProbe)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||
}
|
||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid request received")
|
||||
}
|
||||
h.AppendTagHeaders(req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
response, err := h.httpClient.RoundTrip(req)
|
||||
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logError(stream, err)
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
if h.isEventStream(response) {
|
||||
h.writeEventStream(stream, response.Body)
|
||||
} else {
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
|
||||
}
|
||||
|
||||
h.metrics.incrementResponses(h.connectionID, "200")
|
||||
h.logResponse(response, cfRay, lbProbe)
|
||||
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if h.noChunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||
return nil
|
||||
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
response, err := h.httpClient.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
err = stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
if h.isEventStream(response) {
|
||||
h.writeEventStream(stream, response.Body)
|
||||
} else {
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
|
||||
|
@ -702,7 +730,8 @@ func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool
|
|||
}
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) {
|
||||
func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
|
||||
h.metrics.incrementResponses(h.connectionID, "200")
|
||||
logger := log.NewEntry(h.logger)
|
||||
if cfRay != "" {
|
||||
logger = logger.WithField("CF-RAY", cfRay)
|
||||
|
|
Loading…
Reference in New Issue