cloudflared-mirror/ingress/origin_proxy_test.go

349 lines
9.0 KiB
Go

package ingress
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/websocket"
)
func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
rawTCPService := &rawTCPService{
name: ServiceWarpRouting,
dialer: newProxyAwareDialer(30*time.Second, 30*time.Second, nil),
}
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
originListener.Close()
<-listenerClosed
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
// Origin not listening for new connection, should return an error
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
require.Error(t, err)
}
func TestProxyAwareDialer(t *testing.T) {
tests := []struct {
name string
httpProxy string
httpsProxy string
socksProxy string
expectDirect bool
expectProxy bool
}{
{
name: "no proxy configured",
expectDirect: true,
},
{
name: "HTTP proxy configured",
httpProxy: "http://proxy.example.com:8080",
expectProxy: true,
},
{
name: "HTTPS proxy configured",
httpsProxy: "http://proxy.example.com:8080",
expectProxy: true,
},
{
name: "SOCKS proxy configured",
socksProxy: "socks5://proxy.example.com:1080",
expectProxy: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origHTTP := os.Getenv("HTTP_PROXY")
origHTTPS := os.Getenv("HTTPS_PROXY")
origSOCKS := os.Getenv("ALL_PROXY")
defer func() {
os.Setenv("HTTP_PROXY", origHTTP)
os.Setenv("HTTPS_PROXY", origHTTPS)
os.Setenv("ALL_PROXY", origSOCKS)
}()
os.Setenv("HTTP_PROXY", tt.httpProxy)
os.Setenv("HTTPS_PROXY", tt.httpsProxy)
os.Setenv("ALL_PROXY", tt.socksProxy)
dialer := newProxyAwareDialer(30*time.Second, 30*time.Second, TestLogger)
assert.NotNil(t, dialer)
if tt.expectDirect {
_, ok := dialer.(*net.Dialer)
assert.True(t, ok, "Expected net.Dialer when no proxy configured")
} else if tt.expectProxy {
assert.NotNil(t, dialer, "Expected proxy dialer when proxy configured")
}
})
}
}
func TestProxyAwareDialerHTTPConnect(t *testing.T) {
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "CONNECT" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.WriteHeader(http.StatusOK)
}))
defer proxyServer.Close()
origHTTP := os.Getenv("HTTP_PROXY")
defer os.Setenv("HTTP_PROXY", origHTTP)
os.Setenv("HTTP_PROXY", proxyServer.URL)
dialer := newProxyAwareDialer(5*time.Second, 5*time.Second, TestLogger)
assert.NotNil(t, dialer)
// Test actual dial (this will fail because our mock proxy doesn't handle the full protocol)
// but we can verify the proxy detection logic works
proxyAwareDialer, ok := dialer.(*proxyAwareDialer)
assert.True(t, ok, "Expected proxyAwareDialer when HTTP proxy configured")
assert.NotNil(t, proxyAwareDialer.baseDialer)
}
func TestGetEnvProxy(t *testing.T) {
tests := []struct {
name string
upper string
lower string
upperVal string
lowerVal string
expected string
}{
{
name: "upper case takes priority",
upper: "TEST_PROXY",
lower: "test_proxy",
upperVal: "upper_value",
lowerVal: "lower_value",
expected: "upper_value",
},
{
name: "lower case when upper not set",
upper: "TEST_PROXY",
lower: "test_proxy",
lowerVal: "lower_value",
expected: "lower_value",
},
{
name: "empty when neither set",
upper: "TEST_PROXY",
lower: "test_proxy",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Save and restore environment
origUpper := os.Getenv(tt.upper)
origLower := os.Getenv(tt.lower)
defer func() {
os.Setenv(tt.upper, origUpper)
os.Setenv(tt.lower, origLower)
}()
os.Unsetenv(tt.upper)
os.Unsetenv(tt.lower)
if tt.upperVal != "" {
os.Setenv(tt.upper, tt.upperVal)
}
if tt.lowerVal != "" {
os.Setenv(tt.lower, tt.lowerVal)
}
result := getEnvProxy(tt.upper, tt.lower)
assert.Equal(t, tt.expected, result)
})
}
}
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
originURL := &url.URL{
Scheme: "tcp",
Host: originListener.Addr().String(),
}
baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
require.NoError(t, err)
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
tests := []struct {
testCase string
service *tcpOverWSService
req *http.Request
expectErr bool
}{
{
testCase: "specific TCP service",
service: newTCPOverWSService(originURL),
req: baseReq,
},
{
testCase: "bastion service",
service: newBastionService(),
req: bastionReq,
},
{
testCase: "invalid bastion request",
service: newBastionService(),
req: baseReq,
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
})
}
originListener.Close()
<-listenerClosed
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
}
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
cfg := OriginRequestConfig{
HTTPHostHeader: t.Name(),
}
handler := func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.Host, t.Name())
if websocket.IsWebSocketUpgrade(r) {
respHeaders := websocket.NewResponseHeader(r)
for k, v := range respHeaders {
w.Header().Set(k, v[0])
}
w.WriteHeader(http.StatusSwitchingProtocols)
return
}
// return the X-Forwarded-Host header for assertions
// as the httptest Server URL isn't available here yet
w.Write([]byte(r.Header.Get("X-Forwarded-Host")))
}
origin := httptest.NewServer(http.HandlerFunc(handler))
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
httpService := &httpService{
url: originURL,
}
shutdownC := make(chan struct{})
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
resp, err := httpService.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, respBody, []byte(originURL.Host))
}
// TestHTTPServiceUsesIngressRuleScheme makes sure httpService uses scheme defined in ingress rule and not by eyeball request
func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
handler := func(w http.ResponseWriter, r *http.Request) {
require.NotNil(t, r.TLS)
// Echo the X-Forwarded-Proto header for assertions
w.Write([]byte(r.Header.Get("X-Forwarded-Proto")))
}
origin := httptest.NewTLSServer(http.HandlerFunc(handler))
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
require.Equal(t, "https", originURL.Scheme)
cfg := OriginRequestConfig{
NoTLSVerify: true,
}
httpService := &httpService{
url: originURL,
}
shutdownC := make(chan struct{})
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
protos := []string{"https", "http", "dne"}
for _, p := range protos {
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
req.Header.Add("X-Forwarded-Proto", p)
resp, err := httpService.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, respBody, []byte(p))
}
}
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
close(closeChan)
return
}
// Close immediately, this test is not about testing read/write on connection
conn.Close()
}
}()
}