338 lines
8.3 KiB
Go
338 lines
8.3 KiB
Go
package ingress
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gobwas/ws/wsutil"
|
|
gorillaWS "github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/proxy"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/cloudflare/cloudflared/logger"
|
|
"github.com/cloudflare/cloudflared/socks"
|
|
"github.com/cloudflare/cloudflared/stream"
|
|
"github.com/cloudflare/cloudflared/websocket"
|
|
)
|
|
|
|
const (
|
|
testStreamTimeout = time.Second * 3
|
|
echoHeaderName = "Test-Cloudflared-Echo"
|
|
)
|
|
|
|
var (
|
|
testLogger = logger.Create(nil)
|
|
testMessage = []byte("TestStreamOriginConnection")
|
|
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
|
|
)
|
|
|
|
func TestStreamTCPConnection(t *testing.T) {
|
|
cfdConn, originConn := net.Pipe()
|
|
tcpConn := tcpConnection{
|
|
conn: cfdConn,
|
|
}
|
|
|
|
eyeballConn, edgeConn := net.Pipe()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
_, err := eyeballConn.Write(testMessage)
|
|
require.NoError(t, err)
|
|
|
|
readBuffer := make([]byte, len(testResponse))
|
|
_, err = eyeballConn.Read(readBuffer)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, testResponse, readBuffer)
|
|
|
|
return nil
|
|
})
|
|
errGroup.Go(func() error {
|
|
echoTCPOrigin(t, originConn)
|
|
originConn.Close()
|
|
return nil
|
|
})
|
|
|
|
tcpConn.Stream(ctx, edgeConn, testLogger)
|
|
require.NoError(t, errGroup.Wait())
|
|
}
|
|
|
|
func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
|
|
cfdConn, originConn := net.Pipe()
|
|
tcpOverWSConn := tcpOverWSConnection{
|
|
conn: cfdConn,
|
|
streamHandler: DefaultStreamHandler,
|
|
}
|
|
|
|
eyeballConn, edgeConn := net.Pipe()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
echoWSEyeball(t, eyeballConn)
|
|
return nil
|
|
})
|
|
errGroup.Go(func() error {
|
|
echoTCPOrigin(t, originConn)
|
|
originConn.Close()
|
|
return nil
|
|
})
|
|
|
|
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
|
require.NoError(t, errGroup.Wait())
|
|
}
|
|
|
|
// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
|
|
// Eyeball side runs cloudflared access tcp with --url flag to start a websocket forwarder which
|
|
// wraps SOCKS5 traffic in websocket
|
|
// Origin side runs a tcpOverWSConnection with socks.StreamHandler
|
|
func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
|
var (
|
|
sendMessage = t.Name()
|
|
echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
|
|
echoMessage = fmt.Sprintf("echo-%s", sendMessage)
|
|
echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
|
|
)
|
|
|
|
statusCodes := []int{
|
|
http.StatusOK,
|
|
http.StatusTemporaryRedirect,
|
|
http.StatusBadRequest,
|
|
http.StatusInternalServerError,
|
|
}
|
|
for _, status := range statusCodes {
|
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte(sendMessage), body)
|
|
|
|
require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
|
|
w.Header().Set(echoHeaderName, echoHeaderReturnValue)
|
|
|
|
w.WriteHeader(status)
|
|
w.Write([]byte(echoMessage))
|
|
}
|
|
origin := httptest.NewServer(http.HandlerFunc(handler))
|
|
defer origin.Close()
|
|
|
|
originURL, err := url.Parse(origin.URL)
|
|
require.NoError(t, err)
|
|
|
|
originConn, err := net.Dial("tcp", originURL.Host)
|
|
require.NoError(t, err)
|
|
|
|
tcpOverWSConn := tcpOverWSConnection{
|
|
conn: originConn,
|
|
streamHandler: socks.StreamHandler,
|
|
}
|
|
|
|
wsForwarderOutConn, edgeConn := net.Pipe()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
|
return nil
|
|
})
|
|
|
|
wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
errGroup.Go(func() error {
|
|
wsForwarderInConn, err := wsForwarderListener.Accept()
|
|
require.NoError(t, err)
|
|
defer wsForwarderInConn.Close()
|
|
|
|
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
|
return nil
|
|
})
|
|
|
|
eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
|
|
require.NoError(t, err)
|
|
|
|
transport := &http.Transport{
|
|
Dial: eyeballDialer.Dial,
|
|
}
|
|
|
|
// Request URL doesn't matter because the transport is using eyeballDialer to connectq
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
|
|
assert.NoError(t, err)
|
|
req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
|
|
|
|
resp, err := transport.RoundTrip(req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, status, resp.StatusCode)
|
|
require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte(echoMessage), body)
|
|
|
|
wsForwarderOutConn.Close()
|
|
edgeConn.Close()
|
|
tcpOverWSConn.Close()
|
|
|
|
require.NoError(t, errGroup.Wait())
|
|
}
|
|
}
|
|
|
|
func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
eyeballConn := &readWriter{
|
|
w: w,
|
|
r: r.Body,
|
|
}
|
|
|
|
cfdConn, originConn := net.Pipe()
|
|
tcpOverWSConn := tcpOverWSConnection{
|
|
conn: cfdConn,
|
|
streamHandler: DefaultStreamHandler,
|
|
}
|
|
go func() {
|
|
time.Sleep(time.Millisecond * 10)
|
|
// Simulate losing connection to origin
|
|
originConn.Close()
|
|
}()
|
|
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
|
|
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
|
|
})
|
|
server := httptest.NewServer(handler)
|
|
defer server.Close()
|
|
client := server.Client()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
for i := 0; i < 50; i++ {
|
|
eyeballConn, edgeConn := net.Pipe()
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn)
|
|
assert.NoError(t, err)
|
|
|
|
resp, err := client.Transport.RoundTrip(req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
|
|
|
errGroup.Go(func() error {
|
|
for {
|
|
if err := wsutil.WriteClientBinary(eyeballConn, testMessage); err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
assert.NoError(t, errGroup.Wait())
|
|
}
|
|
|
|
type wsEyeball struct {
|
|
conn net.Conn
|
|
}
|
|
|
|
func (wse *wsEyeball) Read(p []byte) (int, error) {
|
|
data, err := wsutil.ReadServerBinary(wse.conn)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return copy(p, data), nil
|
|
}
|
|
|
|
func (wse *wsEyeball) Write(p []byte) (int, error) {
|
|
err := wsutil.WriteClientBinary(wse.conn, p)
|
|
return len(p), err
|
|
}
|
|
|
|
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
|
defer func() {
|
|
assert.NoError(t, conn.Close())
|
|
}()
|
|
|
|
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
|
|
return
|
|
}
|
|
|
|
readMsg, err := wsutil.ReadServerBinary(conn)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
|
|
assert.Equal(t, testResponse, readMsg)
|
|
}
|
|
|
|
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
|
|
var upgrader = gorillaWS.Upgrader{
|
|
ReadBufferSize: 10,
|
|
WriteBufferSize: 10,
|
|
}
|
|
|
|
ws := func(w http.ResponseWriter, r *http.Request) {
|
|
header := make(http.Header)
|
|
for k, vs := range r.Header {
|
|
if k == "Test-Cloudflared-Echo" {
|
|
header[k] = vs
|
|
}
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, header)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
sawMessage := false
|
|
for {
|
|
messageType, p, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if expectMessages && !sawMessage {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
return
|
|
}
|
|
assert.Equal(t, testMessage, p)
|
|
sawMessage = true
|
|
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// NewTLSServer starts the server in another thread
|
|
return httptest.NewTLSServer(http.HandlerFunc(ws))
|
|
}
|
|
|
|
func echoTCPOrigin(t *testing.T, conn net.Conn) {
|
|
readBuffer := make([]byte, len(testMessage))
|
|
_, err := conn.Read(readBuffer)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, testMessage, readBuffer)
|
|
|
|
_, err = conn.Write(testResponse)
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
type readWriter struct {
|
|
w io.Writer
|
|
r io.Reader
|
|
}
|
|
|
|
func (r *readWriter) Read(p []byte) (n int, err error) {
|
|
return r.r.Read(p)
|
|
}
|
|
|
|
func (r *readWriter) Write(p []byte) (n int, err error) {
|
|
return r.w.Write(p)
|
|
}
|