TUN-3889: Move host header override logic to httpService
This commit is contained in:
parent
ed57ee64e8
commit
5943808746
|
@ -37,12 +37,20 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return newWSConnection(o.transport, req)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBridgeServiceDestination(t *testing.T) {
|
||||
|
@ -105,3 +112,47 @@ func TestBridgeServiceDestination(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
cfg := OriginRequestConfig{
|
||||
HTTPHostHeader: t.Name(),
|
||||
}
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, r.Host, t.Name())
|
||||
if websocket.IsWebSocketUpgrade(r) {
|
||||
respHeaders := websocket.NewResponseHeader(r)
|
||||
for k, v := range respHeaders {
|
||||
w.Header().Set(k, v[0])
|
||||
}
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
return
|
||||
}
|
||||
w.Write([]byte("ok"))
|
||||
}
|
||||
origin := httptest.NewServer(http.HandlerFunc(handler))
|
||||
defer origin.Close()
|
||||
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpService := &httpService{
|
||||
url: originURL,
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
log := zerolog.Nop()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg))
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := httpService.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
req = req.Clone(context.Background())
|
||||
_, resp, err = httpService.EstablishConnection(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
|
|
|
@ -60,6 +60,7 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
|
|||
|
||||
type httpService struct {
|
||||
url *url.URL
|
||||
hostHeader string
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
|
@ -68,6 +69,7 @@ func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.hostHeader = cfg.HTTPHostHeader
|
||||
o.transport = transport
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -86,11 +86,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
return nil
|
||||
}
|
||||
|
||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||
req.Header.Set("Host", hostHeader)
|
||||
req.Host = hostHeader
|
||||
}
|
||||
|
||||
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
|
||||
if !ok {
|
||||
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
||||
|
@ -125,11 +120,6 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
|
|||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||
req.Header.Set("Host", hostHeader)
|
||||
req.Host = hostHeader
|
||||
}
|
||||
|
||||
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
|
||||
if !ok {
|
||||
p.log.Error().Msgf("%s is not a http service", rule.Service)
|
||||
|
|
Loading…
Reference in New Issue