From 594380874669efacdff334960621d949d8758db8 Mon Sep 17 00:00:00 2001 From: cthuang Date: Mon, 8 Feb 2021 19:25:08 +0000 Subject: [PATCH] TUN-3889: Move host header override logic to httpService --- ingress/origin_proxy.go | 8 ++++++ ingress/origin_proxy_test.go | 51 ++++++++++++++++++++++++++++++++++++ ingress/origin_service.go | 6 +++-- origin/proxy.go | 10 ------- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index affdc2c3..23d89cf5 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -37,12 +37,20 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { // Rewrite the request URL so that it goes to the origin service. req.URL.Host = o.url.Host req.URL.Scheme = o.url.Scheme + if o.hostHeader != "" { + // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. + req.Host = o.hostHeader + } return o.transport.RoundTrip(req) } func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { req.URL.Host = o.url.Host req.URL.Scheme = websocket.ChangeRequestScheme(o.url) + if o.hostHeader != "" { + // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. + req.Host = o.hostHeader + } return newWSConnection(o.transport, req) } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index e2238a40..bf664838 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -1,11 +1,18 @@ package ingress import ( + "context" "net/http" + "net/http/httptest" + "net/url" + "sync" "testing" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/websocket" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBridgeServiceDestination(t *testing.T) { @@ -105,3 +112,47 @@ func TestBridgeServiceDestination(t *testing.T) { } } } + +func TestHTTPServiceHostHeaderOverride(t *testing.T) { + cfg := OriginRequestConfig{ + HTTPHostHeader: t.Name(), + } + handler := func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, r.Host, t.Name()) + if websocket.IsWebSocketUpgrade(r) { + respHeaders := websocket.NewResponseHeader(r) + for k, v := range respHeaders { + w.Header().Set(k, v[0]) + } + w.WriteHeader(http.StatusSwitchingProtocols) + return + } + w.Write([]byte("ok")) + } + origin := httptest.NewServer(http.HandlerFunc(handler)) + defer origin.Close() + + originURL, err := url.Parse(origin.URL) + require.NoError(t, err) + + httpService := &httpService{ + url: originURL, + } + var wg sync.WaitGroup + log := zerolog.Nop() + shutdownC := make(chan struct{}) + errC := make(chan error) + require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg)) + + req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) + require.NoError(t, err) + + resp, err := httpService.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + req = req.Clone(context.Background()) + _, resp, err = httpService.EstablishConnection(req) + require.NoError(t, err) + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) +} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index ccf1e008..cc89abfd 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -59,8 +59,9 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, } type httpService struct { - url *url.URL - transport *http.Transport + url *url.URL + hostHeader string + transport *http.Transport } func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { @@ -68,6 +69,7 @@ func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC < if err != nil { return err } + o.hostHeader = cfg.HTTPHostHeader o.transport = transport return nil } diff --git a/origin/proxy.go b/origin/proxy.go index 2edaaf07..1acbff4d 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -86,11 +86,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return nil } - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - req.Header.Set("Host", hostHeader) - req.Host = hostHeader - } - connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy) if !ok { p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) @@ -125,11 +120,6 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * // Request origin to keep connection alive to improve performance req.Header.Set("Connection", "keep-alive") - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - req.Header.Set("Host", hostHeader) - req.Host = hostHeader - } - httpService, ok := rule.Service.(ingress.HTTPOriginProxy) if !ok { p.log.Error().Msgf("%s is not a http service", rule.Service)