178 lines
4.1 KiB
Go
178 lines
4.1 KiB
Go
package ingress
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cloudflare/cloudflared/logger"
|
|
"github.com/gobwas/ws/wsutil"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
const (
|
|
testStreamTimeout = time.Second * 3
|
|
)
|
|
|
|
var (
|
|
testLogger = logger.Create(nil)
|
|
testMessage = []byte("TestStreamOriginConnection")
|
|
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
|
|
)
|
|
|
|
func TestStreamTCPConnection(t *testing.T) {
|
|
cfdConn, originConn := net.Pipe()
|
|
tcpConn := tcpConnection{
|
|
conn: cfdConn,
|
|
}
|
|
|
|
eyeballConn, edgeConn := net.Pipe()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
_, err := eyeballConn.Write(testMessage)
|
|
|
|
readBuffer := make([]byte, len(testResponse))
|
|
_, err = eyeballConn.Read(readBuffer)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, testResponse, readBuffer)
|
|
|
|
return nil
|
|
})
|
|
errGroup.Go(func() error {
|
|
echoTCPOrigin(t, originConn)
|
|
originConn.Close()
|
|
return nil
|
|
})
|
|
|
|
tcpConn.Stream(ctx, edgeConn, testLogger)
|
|
require.NoError(t, errGroup.Wait())
|
|
}
|
|
|
|
func TestStreamWSOverTCPConnection(t *testing.T) {
|
|
cfdConn, originConn := net.Pipe()
|
|
tcpOverWSConn := tcpOverWSConnection{
|
|
conn: cfdConn,
|
|
streamHandler: DefaultStreamHandler,
|
|
}
|
|
|
|
eyeballConn, edgeConn := net.Pipe()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
echoWSEyeball(t, eyeballConn)
|
|
return nil
|
|
})
|
|
errGroup.Go(func() error {
|
|
echoTCPOrigin(t, originConn)
|
|
originConn.Close()
|
|
return nil
|
|
})
|
|
|
|
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
|
require.NoError(t, errGroup.Wait())
|
|
}
|
|
|
|
func TestStreamWSConnection(t *testing.T) {
|
|
eyeballConn, edgeConn := net.Pipe()
|
|
|
|
origin := echoWSOrigin(t)
|
|
defer origin.Close()
|
|
|
|
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
|
|
|
clientTLSConfig := &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
|
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
|
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
|
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
|
defer cancel()
|
|
|
|
errGroup, ctx := errgroup.WithContext(ctx)
|
|
errGroup.Go(func() error {
|
|
echoWSEyeball(t, eyeballConn)
|
|
return nil
|
|
})
|
|
|
|
wsConn.Stream(ctx, edgeConn, testLogger)
|
|
require.NoError(t, errGroup.Wait())
|
|
}
|
|
|
|
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
|
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
|
|
|
readMsg, err := wsutil.ReadServerBinary(conn)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, testResponse, readMsg)
|
|
|
|
require.NoError(t, conn.Close())
|
|
}
|
|
|
|
func echoWSOrigin(t *testing.T) *httptest.Server {
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 10,
|
|
WriteBufferSize: 10,
|
|
}
|
|
|
|
ws := func(w http.ResponseWriter, r *http.Request) {
|
|
header := make(http.Header)
|
|
for k, vs := range r.Header {
|
|
if k == "Test-Cloudflared-Echo" {
|
|
header[k] = vs
|
|
}
|
|
}
|
|
conn, err := upgrader.Upgrade(w, r, header)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
for {
|
|
messageType, p, err := conn.ReadMessage()
|
|
if err != nil {
|
|
return
|
|
}
|
|
require.Equal(t, testMessage, p)
|
|
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// NewTLSServer starts the server in another thread
|
|
return httptest.NewTLSServer(http.HandlerFunc(ws))
|
|
}
|
|
|
|
func echoTCPOrigin(t *testing.T, conn net.Conn) {
|
|
readBuffer := make([]byte, len(testMessage))
|
|
_, err := conn.Read(readBuffer)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, testMessage, readBuffer)
|
|
|
|
_, err = conn.Write(testResponse)
|
|
assert.NoError(t, err)
|
|
}
|