240 lines
6.5 KiB
Go
240 lines
6.5 KiB
Go
package ingress
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/cloudflare/cloudflared/carrier"
|
|
"github.com/cloudflare/cloudflared/websocket"
|
|
)
|
|
|
|
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
|
|
// the expected response
|
|
func assertEstablishConnectionResponse(t *testing.T,
|
|
originProxy StreamBasedOriginProxy,
|
|
req *http.Request,
|
|
expectHeader http.Header,
|
|
) {
|
|
_, resp, err := originProxy.EstablishConnection(req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, switchingProtocolText, resp.Status)
|
|
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
assert.Equal(t, expectHeader, resp.Header)
|
|
}
|
|
|
|
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
|
origin := echoWSOrigin(t)
|
|
defer origin.Close()
|
|
originURL, err := url.Parse(origin.URL)
|
|
require.NoError(t, err)
|
|
|
|
httpService := &httpService{
|
|
url: originURL,
|
|
hostHeader: origin.URL,
|
|
transport: &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
},
|
|
},
|
|
}
|
|
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
|
req.Header.Set("Test-Cloudflared-Echo", t.Name())
|
|
|
|
expectHeader := http.Header{
|
|
"Connection": {"Upgrade"},
|
|
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
|
"Upgrade": {"websocket"},
|
|
"Test-Cloudflared-Echo": {t.Name()},
|
|
}
|
|
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
|
|
}
|
|
|
|
func TestHelloWorldEstablishConnection(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
shutdownC := make(chan struct{})
|
|
errC := make(chan error)
|
|
helloWorldSerivce := &helloWorld{}
|
|
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
|
|
|
|
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
|
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
|
require.NoError(t, err)
|
|
|
|
expectHeader := http.Header{
|
|
"Connection": {"Upgrade"},
|
|
// Accept key when Sec-Websocket-Key is not specified
|
|
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
|
"Upgrade": {"websocket"},
|
|
}
|
|
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
|
|
|
close(shutdownC)
|
|
}
|
|
|
|
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
|
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
listenerClosed := make(chan struct{})
|
|
tcpListenRoutine(originListener, listenerClosed)
|
|
|
|
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
|
|
|
|
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
|
require.NoError(t, err)
|
|
|
|
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
|
|
|
|
originListener.Close()
|
|
<-listenerClosed
|
|
|
|
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
|
require.NoError(t, err)
|
|
|
|
// Origin not listening for new connection, should return an error
|
|
_, resp, err := rawTCPService.EstablishConnection(req)
|
|
require.Error(t, err)
|
|
require.Nil(t, resp)
|
|
}
|
|
|
|
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
listenerClosed := make(chan struct{})
|
|
tcpListenRoutine(originListener, listenerClosed)
|
|
|
|
originURL := &url.URL{
|
|
Scheme: "tcp",
|
|
Host: originListener.Addr().String(),
|
|
}
|
|
|
|
baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
|
|
require.NoError(t, err)
|
|
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
|
|
|
bastionReq := baseReq.Clone(context.Background())
|
|
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
|
|
|
expectHeader := http.Header{
|
|
"Connection": {"Upgrade"},
|
|
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
|
"Upgrade": {"websocket"},
|
|
}
|
|
|
|
tests := []struct {
|
|
testCase string
|
|
service *tcpOverWSService
|
|
req *http.Request
|
|
expectErr bool
|
|
}{
|
|
{
|
|
testCase: "specific TCP service",
|
|
service: newTCPOverWSService(originURL),
|
|
req: baseReq,
|
|
},
|
|
{
|
|
testCase: "bastion service",
|
|
service: newBastionService(),
|
|
req: bastionReq,
|
|
},
|
|
{
|
|
testCase: "invalid bastion request",
|
|
service: newBastionService(),
|
|
req: baseReq,
|
|
expectErr: true,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.testCase, func(t *testing.T) {
|
|
if test.expectErr {
|
|
_, resp, err := test.service.EstablishConnection(test.req)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, resp)
|
|
} else {
|
|
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
|
}
|
|
})
|
|
}
|
|
|
|
originListener.Close()
|
|
<-listenerClosed
|
|
|
|
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
|
// Origin not listening for new connection, should return an error
|
|
_, resp, err := service.EstablishConnection(bastionReq)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, resp)
|
|
}
|
|
}
|
|
|
|
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
|
cfg := OriginRequestConfig{
|
|
HTTPHostHeader: t.Name(),
|
|
}
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
require.Equal(t, r.Host, t.Name())
|
|
if websocket.IsWebSocketUpgrade(r) {
|
|
respHeaders := websocket.NewResponseHeader(r)
|
|
for k, v := range respHeaders {
|
|
w.Header().Set(k, v[0])
|
|
}
|
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
|
return
|
|
}
|
|
w.Write([]byte("ok"))
|
|
}
|
|
origin := httptest.NewServer(http.HandlerFunc(handler))
|
|
defer origin.Close()
|
|
|
|
originURL, err := url.Parse(origin.URL)
|
|
require.NoError(t, err)
|
|
|
|
httpService := &httpService{
|
|
url: originURL,
|
|
}
|
|
var wg sync.WaitGroup
|
|
shutdownC := make(chan struct{})
|
|
errC := make(chan error)
|
|
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
|
|
|
|
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := httpService.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
req = req.Clone(context.Background())
|
|
_, resp, err = httpService.EstablishConnection(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
}
|
|
|
|
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
|
|
go func() {
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
close(closeChan)
|
|
return
|
|
}
|
|
// Close immediately, this test is not about testing read/write on connection
|
|
conn.Close()
|
|
}
|
|
}()
|
|
}
|