TUN-3889: Move host header override logic to httpService
This commit is contained in:
parent
ed57ee64e8
commit
5943808746
|
@ -37,12 +37,20 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
// Rewrite the request URL so that it goes to the origin service.
|
// Rewrite the request URL so that it goes to the origin service.
|
||||||
req.URL.Host = o.url.Host
|
req.URL.Host = o.url.Host
|
||||||
req.URL.Scheme = o.url.Scheme
|
req.URL.Scheme = o.url.Scheme
|
||||||
|
if o.hostHeader != "" {
|
||||||
|
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||||
|
req.Host = o.hostHeader
|
||||||
|
}
|
||||||
return o.transport.RoundTrip(req)
|
return o.transport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
req.URL.Host = o.url.Host
|
req.URL.Host = o.url.Host
|
||||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||||
|
if o.hostHeader != "" {
|
||||||
|
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||||
|
req.Host = o.hostHeader
|
||||||
|
}
|
||||||
return newWSConnection(o.transport, req)
|
return newWSConnection(o.transport, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,18 @@
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBridgeServiceDestination(t *testing.T) {
|
func TestBridgeServiceDestination(t *testing.T) {
|
||||||
|
@ -105,3 +112,47 @@ func TestBridgeServiceDestination(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||||
|
cfg := OriginRequestConfig{
|
||||||
|
HTTPHostHeader: t.Name(),
|
||||||
|
}
|
||||||
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, r.Host, t.Name())
|
||||||
|
if websocket.IsWebSocketUpgrade(r) {
|
||||||
|
respHeaders := websocket.NewResponseHeader(r)
|
||||||
|
for k, v := range respHeaders {
|
||||||
|
w.Header().Set(k, v[0])
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
}
|
||||||
|
origin := httptest.NewServer(http.HandlerFunc(handler))
|
||||||
|
defer origin.Close()
|
||||||
|
|
||||||
|
originURL, err := url.Parse(origin.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
httpService := &httpService{
|
||||||
|
url: originURL,
|
||||||
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
log := zerolog.Nop()
|
||||||
|
shutdownC := make(chan struct{})
|
||||||
|
errC := make(chan error)
|
||||||
|
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg))
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, err := httpService.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
req = req.Clone(context.Background())
|
||||||
|
_, resp, err = httpService.EstablishConnection(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
|
@ -60,6 +60,7 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
|
||||||
|
|
||||||
type httpService struct {
|
type httpService struct {
|
||||||
url *url.URL
|
url *url.URL
|
||||||
|
hostHeader string
|
||||||
transport *http.Transport
|
transport *http.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +69,7 @@ func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
o.hostHeader = cfg.HTTPHostHeader
|
||||||
o.transport = transport
|
o.transport = transport
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,11 +86,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
|
||||||
req.Header.Set("Host", hostHeader)
|
|
||||||
req.Host = hostHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
|
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
|
||||||
if !ok {
|
if !ok {
|
||||||
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
||||||
|
@ -125,11 +120,6 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
|
||||||
// Request origin to keep connection alive to improve performance
|
// Request origin to keep connection alive to improve performance
|
||||||
req.Header.Set("Connection", "keep-alive")
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
|
||||||
req.Header.Set("Host", hostHeader)
|
|
||||||
req.Host = hostHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
|
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
|
||||||
if !ok {
|
if !ok {
|
||||||
p.log.Error().Msgf("%s is not a http service", rule.Service)
|
p.log.Error().Msgf("%s is not a http service", rule.Service)
|
||||||
|
|
Loading…
Reference in New Issue