TUN-3889: Move host header override logic to httpService

This commit is contained in:
cthuang 2021-02-08 19:25:08 +00:00 committed by Nuno Diegues
parent ed57ee64e8
commit 5943808746
4 changed files with 63 additions and 12 deletions

View File

@ -37,12 +37,20 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service. // Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host req.URL.Host = o.url.Host
req.URL.Scheme = o.url.Scheme req.URL.Scheme = o.url.Scheme
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return o.transport.RoundTrip(req) return o.transport.RoundTrip(req)
} }
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.url.Host req.URL.Host = o.url.Host
req.URL.Scheme = websocket.ChangeRequestScheme(o.url) req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return newWSConnection(o.transport, req) return newWSConnection(o.transport, req)
} }

View File

@ -1,11 +1,18 @@
package ingress package ingress
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest"
"net/url"
"sync"
"testing" "testing"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestBridgeServiceDestination(t *testing.T) { func TestBridgeServiceDestination(t *testing.T) {
@ -105,3 +112,47 @@ func TestBridgeServiceDestination(t *testing.T) {
} }
} }
} }
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
cfg := OriginRequestConfig{
HTTPHostHeader: t.Name(),
}
handler := func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.Host, t.Name())
if websocket.IsWebSocketUpgrade(r) {
respHeaders := websocket.NewResponseHeader(r)
for k, v := range respHeaders {
w.Header().Set(k, v[0])
}
w.WriteHeader(http.StatusSwitchingProtocols)
return
}
w.Write([]byte("ok"))
}
origin := httptest.NewServer(http.HandlerFunc(handler))
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
httpService := &httpService{
url: originURL,
}
var wg sync.WaitGroup
log := zerolog.Nop()
shutdownC := make(chan struct{})
errC := make(chan error)
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
resp, err := httpService.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
req = req.Clone(context.Background())
_, resp, err = httpService.EstablishConnection(req)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
}

View File

@ -60,6 +60,7 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
type httpService struct { type httpService struct {
url *url.URL url *url.URL
hostHeader string
transport *http.Transport transport *http.Transport
} }
@ -68,6 +69,7 @@ func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <
if err != nil { if err != nil {
return err return err
} }
o.hostHeader = cfg.HTTPHostHeader
o.transport = transport o.transport = transport
return nil return nil
} }

View File

@ -86,11 +86,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return nil return nil
} }
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy) connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
if !ok { if !ok {
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
@ -125,11 +120,6 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
// Request origin to keep connection alive to improve performance // Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive") req.Header.Set("Connection", "keep-alive")
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
req.Header.Set("Host", hostHeader)
req.Host = hostHeader
}
httpService, ok := rule.Service.(ingress.HTTPOriginProxy) httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok { if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service) p.log.Error().Msgf("%s is not a http service", rule.Service)