From 1ff5fd3fdcb76a36b9ade1b1bf367daa4b95f24e Mon Sep 17 00:00:00 2001 From: cthuang Date: Fri, 4 Feb 2022 16:51:37 +0000 Subject: [PATCH] TUN-5744: Add a test to make sure cloudflared uses scheme defined in ingress rule, not X-Forwarded-Proto header --- ingress/origin_proxy_test.go | 41 ++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index b14408b8..5716bba1 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -147,7 +147,48 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { respBody, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, respBody, []byte(originURL.Host)) +} +// TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request +func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + require.NotNil(t, r.TLS) + // Echo the X-Forwarded-Proto header for assertions + w.Write([]byte(r.Header.Get("X-Forwarded-Proto"))) + } + origin := httptest.NewTLSServer(http.HandlerFunc(handler)) + defer origin.Close() + + originURL, err := url.Parse(origin.URL) + require.NoError(t, err) + require.Equal(t, "https", originURL.Scheme) + + cfg := OriginRequestConfig{ + NoTLSVerify: true, + } + httpService := &httpService{ + url: originURL, + } + var wg sync.WaitGroup + shutdownC := make(chan struct{}) + errC := make(chan error) + require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg)) + + // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header + protos := []string{"https", "http", "dne"} + for _, p := range protos { + req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) + require.NoError(t, err) + req.Header.Add("X-Forwarded-Proto", p) + + resp, err := httpService.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, respBody, []byte(p)) + } } func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {