diff --git a/proxy/proxy.go b/proxy/proxy.go index d3c2afd8..c16555ea 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -213,16 +213,6 @@ func (p *Proxy) proxyHTTPRequest( roundTripReq.Header.Set("Connection", "keep-alive") } - // Handle GOAWAY frame to correctly retry a request - if roundTripReq.Body != nil { - roundTripReq.GetBody = func() (io.ReadCloser, err error) { - if err.Error() == "http2: Transport received Server's graceful shutdown GOAWAY" { - return roundTripReq.Body, nil - } - return nil, err - } - } - // Set the User-Agent as an empty string if not provided to avoid inserting golang default UA if roundTripReq.Header.Get("User-Agent") == "" { roundTripReq.Header.Set("User-Agent", "") @@ -231,11 +221,23 @@ func (p *Proxy) proxyHTTPRequest( _, ttfbSpan := tr.Tracer().Start(tr.Context(), "ttfb_origin") resp, err := httpService.RoundTrip(roundTripReq) if err != nil { - tracing.EndWithErrorStatus(ttfbSpan, err) - if err := roundTripReq.Context().Err(); err != nil { - return errors.Wrap(err, "Incoming request ended abruptly") + // Check for GOAWAY error and retry once if applicable + const goawayMsg = "http2: Transport received Server's graceful shutdown GOAWAY" + if err.Error() == goawayMsg && roundTripReq.GetBody != nil { + // Reset the body for retry + newBody, getBodyErr := roundTripReq.GetBody() + if getBodyErr == nil { + roundTripReq.Body = newBody + resp, err = httpService.RoundTrip(roundTripReq) + } + } + if err != nil { + tracing.EndWithErrorStatus(ttfbSpan, err) + if err := roundTripReq.Context().Err(); err != nil { + return errors.Wrap(err, "Incoming request ended abruptly") + } + return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared") } - return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared") } tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 56c9cab9..27a7283e 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "errors" "flag" "fmt" "io" @@ -1018,3 +1019,61 @@ func runEchoWSService(t *testing.T, l net.Listener) { } }() } + +func TestHandleGOAWAYRetry(t *testing.T) { + // Simulate a request body + bodyContent := "test body content" + body := io.NopCloser(strings.NewReader(bodyContent)) + + // Create a mock request with a body + roundTripReq := &http.Request{ + Body: body, + } + + // Simulate the GOAWAY error + goawayError := errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + + // Assign the GetBody function + roundTripReq.GetBody = func() (io.ReadCloser, error) { + if goawayError.Error() == "http2: Transport received Server's graceful shutdown GOAWAY" { + return roundTripReq.Body, nil + } + return nil, goawayError + } + + // Test the GetBody function + retriedBody, err := roundTripReq.GetBody() + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Verify the retried body content + retriedContent, _ := io.ReadAll(retriedBody) + if string(retriedContent) != bodyContent { + t.Fatalf("Expected body content '%s', got '%s'", bodyContent, string(retriedContent)) + } +} +func TestHandleGOAWAYRetryError(t *testing.T) { + // Simulate a request body + bodyContent := "test body content" + body := io.NopCloser(strings.NewReader(bodyContent)) + + // Create a mock request with a body + roundTripReq := &http.Request{ + Body: body, + } + + // Simulate the GOAWAY error + goawayError := errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + + // Assign the GetBody function to return an error + roundTripReq.GetBody = func() (io.ReadCloser, error) { + return nil, goawayError + } + + // Test the GetBody function + retriedBody, err := roundTripReq.GetBody() + if retriedBody != nil || err == nil { + t.Fatalf("Expected error, got body: %v, error: %v", retriedBody, err) + } +}