TUN-1781: ServeStream should return early on error
This commit is contained in:
parent
137928ecaf
commit
809d2f3f28
127
origin/tunnel.go
127
origin/tunnel.go
|
@ -596,65 +596,93 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
||||||
|
|
||||||
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
h.metrics.incrementRequests(h.connectionID)
|
h.metrics.incrementRequests(h.connectionID)
|
||||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
defer h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||||
if err != nil {
|
|
||||||
h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
|
req, reqErr := h.createRequest(stream)
|
||||||
|
if reqErr != nil {
|
||||||
|
h.logError(stream, reqErr)
|
||||||
|
return reqErr
|
||||||
}
|
}
|
||||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.WithError(err).Error("invalid request received")
|
|
||||||
}
|
|
||||||
h.AppendTagHeaders(req)
|
|
||||||
cfRay := FindCfRayHeader(req)
|
cfRay := FindCfRayHeader(req)
|
||||||
lbProbe := isLBProbeRequest(req)
|
lbProbe := isLBProbeRequest(req)
|
||||||
h.logRequest(req, cfRay, lbProbe)
|
h.logRequest(req, cfRay, lbProbe)
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
var respErr error
|
||||||
if websocket.IsWebSocketUpgrade(req) {
|
if websocket.IsWebSocketUpgrade(req) {
|
||||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
resp, respErr = h.serveWebsocket(stream, req)
|
||||||
if err != nil {
|
|
||||||
h.logError(stream, err)
|
|
||||||
} else {
|
|
||||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
|
||||||
defer conn.Close()
|
|
||||||
// Copy to/from stream to the undelying connection. Use the underlying
|
|
||||||
// connection because cloudflared doesn't operate on the message themselves
|
|
||||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
|
||||||
h.metrics.incrementResponses(h.connectionID, "200")
|
|
||||||
h.logResponse(response, cfRay, lbProbe)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
resp, respErr = h.serveHTTP(stream, req)
|
||||||
if h.noChunkedEncoding {
|
}
|
||||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
if respErr != nil {
|
||||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
h.logError(stream, respErr)
|
||||||
if err == nil {
|
return respErr
|
||||||
req.ContentLength = int64(cLength)
|
}
|
||||||
}
|
h.logResponseOk(resp, cfRay, lbProbe)
|
||||||
}
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Request origin to keep connection alive to improve performance
|
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||||
req.Header.Set("Connection", "keep-alive")
|
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||||
|
}
|
||||||
|
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "invalid request received")
|
||||||
|
}
|
||||||
|
h.AppendTagHeaders(req)
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
response, err := h.httpClient.RoundTrip(req)
|
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||||
|
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
err = stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
|
}
|
||||||
|
// Copy to/from stream to the undelying connection. Use the underlying
|
||||||
|
// connection because cloudflared doesn't operate on the message themselves
|
||||||
|
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||||
h.logError(stream, err)
|
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||||
} else {
|
if h.noChunkedEncoding {
|
||||||
defer response.Body.Close()
|
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||||
if h.isEventStream(response) {
|
if err == nil {
|
||||||
h.writeEventStream(stream, response.Body)
|
req.ContentLength = int64(cLength)
|
||||||
} else {
|
|
||||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
|
||||||
// compression generates dictionary on first write
|
|
||||||
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
|
|
||||||
}
|
|
||||||
|
|
||||||
h.metrics.incrementResponses(h.connectionID, "200")
|
|
||||||
h.logResponse(response, cfRay, lbProbe)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
h.metrics.decrementConcurrentRequests(h.connectionID)
|
|
||||||
return nil
|
// Request origin to keep connection alive to improve performance
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
response, err := h.httpClient.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
err = stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
|
}
|
||||||
|
if h.isEventStream(response) {
|
||||||
|
h.writeEventStream(stream, response.Body)
|
||||||
|
} else {
|
||||||
|
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||||
|
// compression generates dictionary on first write
|
||||||
|
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
|
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
|
||||||
|
@ -702,7 +730,8 @@ func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) {
|
func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
|
||||||
|
h.metrics.incrementResponses(h.connectionID, "200")
|
||||||
logger := log.NewEntry(h.logger)
|
logger := log.NewEntry(h.logger)
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
logger = logger.WithField("CF-RAY", cfRay)
|
logger = logger.WithField("CF-RAY", cfRay)
|
||||||
|
|
Loading…
Reference in New Issue