TUN-1781: ServeStream should return early on error
This commit is contained in:
		
							parent
							
								
									137928ecaf
								
							
						
					
					
						commit
						809d2f3f28
					
				
							
								
								
									
										127
									
								
								origin/tunnel.go
								
								
								
								
							
							
						
						
									
										127
									
								
								origin/tunnel.go
								
								
								
								
							|  | @ -596,65 +596,93 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { | |||
| 
 | ||||
| func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { | ||||
| 	h.metrics.incrementRequests(h.connectionID) | ||||
| 	req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) | ||||
| 	if err != nil { | ||||
| 		h.logger.WithError(err).Panic("Unexpected error from http.NewRequest") | ||||
| 	defer h.metrics.decrementConcurrentRequests(h.connectionID) | ||||
| 
 | ||||
| 	req, reqErr := h.createRequest(stream) | ||||
| 	if reqErr != nil { | ||||
| 		h.logError(stream, reqErr) | ||||
| 		return reqErr | ||||
| 	} | ||||
| 	err = H2RequestHeadersToH1Request(stream.Headers, req) | ||||
| 	if err != nil { | ||||
| 		h.logger.WithError(err).Error("invalid request received") | ||||
| 	} | ||||
| 	h.AppendTagHeaders(req) | ||||
| 
 | ||||
| 	cfRay := FindCfRayHeader(req) | ||||
| 	lbProbe := isLBProbeRequest(req) | ||||
| 	h.logRequest(req, cfRay, lbProbe) | ||||
| 
 | ||||
| 	var resp *http.Response | ||||
| 	var respErr error | ||||
| 	if websocket.IsWebSocketUpgrade(req) { | ||||
| 		conn, response, err := websocket.ClientConnect(req, h.tlsConfig) | ||||
| 		if err != nil { | ||||
| 			h.logError(stream, err) | ||||
| 		} else { | ||||
| 			stream.WriteHeaders(H1ResponseToH2Response(response)) | ||||
| 			defer conn.Close() | ||||
| 			// Copy to/from stream to the undelying connection. Use the underlying
 | ||||
| 			// connection because cloudflared doesn't operate on the message themselves
 | ||||
| 			websocket.Stream(conn.UnderlyingConn(), stream) | ||||
| 			h.metrics.incrementResponses(h.connectionID, "200") | ||||
| 			h.logResponse(response, cfRay, lbProbe) | ||||
| 		} | ||||
| 		resp, respErr = h.serveWebsocket(stream, req) | ||||
| 	} else { | ||||
| 		// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
 | ||||
| 		if h.noChunkedEncoding { | ||||
| 			req.TransferEncoding = []string{"gzip", "deflate"} | ||||
| 			cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) | ||||
| 			if err == nil { | ||||
| 				req.ContentLength = int64(cLength) | ||||
| 			} | ||||
| 		} | ||||
| 		resp, respErr = h.serveHTTP(stream, req) | ||||
| 	} | ||||
| 	if respErr != nil { | ||||
| 		h.logError(stream, respErr) | ||||
| 		return respErr | ||||
| 	} | ||||
| 	h.logResponseOk(resp, cfRay, lbProbe) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| 		// Request origin to keep connection alive to improve performance
 | ||||
| 		req.Header.Set("Connection", "keep-alive") | ||||
| func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) { | ||||
| 	req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") | ||||
| 	} | ||||
| 	err = H2RequestHeadersToH1Request(stream.Headers, req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "invalid request received") | ||||
| 	} | ||||
| 	h.AppendTagHeaders(req) | ||||
| 	return req, nil | ||||
| } | ||||
| 
 | ||||
| 		response, err := h.httpClient.RoundTrip(req) | ||||
| func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { | ||||
| 	conn, response, err := websocket.ClientConnect(req, h.tlsConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 	err = stream.WriteHeaders(H1ResponseToH2Response(response)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header") | ||||
| 	} | ||||
| 	// Copy to/from stream to the undelying connection. Use the underlying
 | ||||
| 	// connection because cloudflared doesn't operate on the message themselves
 | ||||
| 	websocket.Stream(conn.UnderlyingConn(), stream) | ||||
| 	return response, nil | ||||
| } | ||||
| 
 | ||||
| 		if err != nil { | ||||
| 			h.logError(stream, err) | ||||
| 		} else { | ||||
| 			defer response.Body.Close() | ||||
| 			stream.WriteHeaders(H1ResponseToH2Response(response)) | ||||
| 			if h.isEventStream(response) { | ||||
| 				h.writeEventStream(stream, response.Body) | ||||
| 			} else { | ||||
| 				// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
 | ||||
| 				// compression generates dictionary on first write
 | ||||
| 				io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) | ||||
| 			} | ||||
| 
 | ||||
| 			h.metrics.incrementResponses(h.connectionID, "200") | ||||
| 			h.logResponse(response, cfRay, lbProbe) | ||||
| func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { | ||||
| 	// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
 | ||||
| 	if h.noChunkedEncoding { | ||||
| 		req.TransferEncoding = []string{"gzip", "deflate"} | ||||
| 		cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) | ||||
| 		if err == nil { | ||||
| 			req.ContentLength = int64(cLength) | ||||
| 		} | ||||
| 	} | ||||
| 	h.metrics.decrementConcurrentRequests(h.connectionID) | ||||
| 	return nil | ||||
| 
 | ||||
| 	// Request origin to keep connection alive to improve performance
 | ||||
| 	req.Header.Set("Connection", "keep-alive") | ||||
| 
 | ||||
| 	response, err := h.httpClient.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error proxying request to origin") | ||||
| 	} | ||||
| 	defer response.Body.Close() | ||||
| 
 | ||||
| 	err = stream.WriteHeaders(H1ResponseToH2Response(response)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header") | ||||
| 	} | ||||
| 	if h.isEventStream(response) { | ||||
| 		h.writeEventStream(stream, response.Body) | ||||
| 	} else { | ||||
| 		// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
 | ||||
| 		// compression generates dictionary on first write
 | ||||
| 		io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) | ||||
| 	} | ||||
| 	return response, nil | ||||
| } | ||||
| 
 | ||||
| func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) { | ||||
|  | @ -702,7 +730,8 @@ func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) { | ||||
| func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { | ||||
| 	h.metrics.incrementResponses(h.connectionID, "200") | ||||
| 	logger := log.NewEntry(h.logger) | ||||
| 	if cfRay != "" { | ||||
| 		logger = logger.WithField("CF-RAY", cfRay) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue