cloudflared-mirror/origin/proxy_test.go

228 lines
5.6 KiB
Go

package origin
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/gobwas/ws/wsutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockHTTPRespWriter struct {
*httptest.ResponseRecorder
}
func newMockHTTPRespWriter() *mockHTTPRespWriter {
return &mockHTTPRespWriter{
httptest.NewRecorder(),
}
}
func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
w.WriteHeader(resp.StatusCode)
for header, val := range resp.Header {
w.Header()[header] = val
}
return nil
}
func (w *mockHTTPRespWriter) WriteErrorResponse(err error) {
w.WriteHeader(http.StatusBadGateway)
}
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
}
type mockWSRespWriter struct {
*mockHTTPRespWriter
writeNotification chan []byte
reader io.Reader
}
func newMockWSRespWriter(reader io.Reader) *mockWSRespWriter {
return &mockWSRespWriter{
newMockHTTPRespWriter(),
make(chan []byte),
reader,
}
}
func (w *mockWSRespWriter) Write(data []byte) (int, error) {
w.writeNotification <- data
return len(data), nil
}
func (w *mockWSRespWriter) respBody() io.ReadWriter {
data := <-w.writeNotification
return bytes.NewBuffer(data)
}
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data)
}
type mockSSERespWriter struct {
*mockHTTPRespWriter
writeNotification chan []byte
}
func newMockSSERespWriter() *mockSSERespWriter {
return &mockSSERespWriter{
newMockHTTPRespWriter(),
make(chan []byte),
}
}
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
w.writeNotification <- data
return len(data), nil
}
func (w *mockSSERespWriter) ReadBytes() []byte {
return <-w.writeNotification
}
func TestProxy(t *testing.T) {
logger, err := logger.New()
require.NoError(t, err)
// let runtime pick an available port
listener, err := hello.CreateTLSListener("127.0.0.1:0")
require.NoError(t, err)
originURL := &url.URL{
Scheme: "https",
Host: listener.Addr().String(),
}
originCA := x509.NewCertPool()
helloCert, err := tlsconfig.GetHelloCertificateX509()
require.NoError(t, err)
originCA.AddCert(helloCert)
clientTLS := &tls.Config{
RootCAs: originCA,
}
proxyConfig := &ProxyConfig{
Client: &http.Transport{
TLSClientConfig: clientTLS,
},
URL: originURL,
TLSConfig: clientTLS,
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
hello.StartHelloWorldServer(logger, listener, ctx.Done())
}()
client := NewClient(proxyConfig, logger)
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
t.Run("testProxySSE", testProxySSE(t, client, originURL))
cancel()
}
func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
return func(t *testing.T) {
respWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
err = client.Proxy(respWriter, req, false)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, respWriter.Code)
}
}
func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) {
return func(t *testing.T) {
// WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
readPipe, writePipe := io.Pipe()
respWriter := newMockWSRespWriter(readPipe)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = client.Proxy(respWriter, req, true)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
}()
msg := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, msg)
require.NoError(t, err)
// ReadServerText reads next data message from rw, considering that caller represents client side.
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
err = wsutil.WriteClientBinary(writePipe, msg)
require.NoError(t, err)
returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody())
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
cancel()
wg.Wait()
}
}
func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
return func(t *testing.T) {
var (
pushCount = 50
pushFreq = time.Duration(time.Millisecond * 10)
)
respWriter := newMockSSERespWriter()
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = client.Proxy(respWriter, req, false)
require.NoError(t, err)
require.Equal(t, http.StatusOK, respWriter.Code)
}()
for i := 0; i < pushCount; i++ {
line := respWriter.ReadBytes()
expect := fmt.Sprintf("%d\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
line = respWriter.ReadBytes()
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
}
cancel()
wg.Wait()
}
}