TUN-3895: Tests for socks stream handler
This commit is contained in:
parent
e20c4f8752
commit
63a29f421a
|
@ -44,7 +44,7 @@ func (s *testStreamer) Write(p []byte) (int, error) {
|
||||||
func TestStartClient(t *testing.T) {
|
func TestStartClient(t *testing.T) {
|
||||||
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
wsConn := NewWSConnection(&log, false)
|
wsConn := NewWSConnection(&log)
|
||||||
ts := newTestWebSocketServer()
|
ts := newTestWebSocketServer()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ func TestStartServer(t *testing.T) {
|
||||||
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
shutdownC := make(chan struct{})
|
shutdownC := make(chan struct{})
|
||||||
wsConn := NewWSConnection(&log, false)
|
wsConn := NewWSConnection(&log)
|
||||||
ts := newTestWebSocketServer()
|
ts := newTestWebSocketServer()
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
options := &StartOptions{
|
options := &StartOptions{
|
||||||
|
|
|
@ -38,10 +38,9 @@ func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, er
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWSConnection returns a new connection object
|
// NewWSConnection returns a new connection object
|
||||||
func NewWSConnection(log *zerolog.Logger, isSocks bool) Connection {
|
func NewWSConnection(log *zerolog.Logger) Connection {
|
||||||
return &Websocket{
|
return &Websocket{
|
||||||
log: log,
|
log: log,
|
||||||
isSocks: isSocks,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,15 +54,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
|
||||||
}
|
}
|
||||||
defer wsConn.Close()
|
defer wsConn.Close()
|
||||||
|
|
||||||
if ws.isSocks {
|
ingress.Stream(wsConn, conn, ws.log)
|
||||||
dialer := &wsdialer{conn: wsConn}
|
|
||||||
requestHandler := socks.NewRequestHandler(dialer)
|
|
||||||
socksServer := socks.NewConnectionHandler(requestHandler)
|
|
||||||
|
|
||||||
_ = socksServer.Serve(conn)
|
|
||||||
} else {
|
|
||||||
ingress.Stream(wsConn, conn, ws.log)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z
|
||||||
}
|
}
|
||||||
|
|
||||||
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
|
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
|
||||||
wsConn := carrier.NewWSConnection(log, false)
|
wsConn := carrier.NewWSConnection(log)
|
||||||
|
|
||||||
log.Info().Str(LogFieldHost, validURL.Host).Msg("Start Websocket listener")
|
log.Info().Str(LogFieldHost, validURL.Host).Msg("Start Websocket listener")
|
||||||
return carrier.StartForwarder(wsConn, validURL.Host, shutdown, options)
|
return carrier.StartForwarder(wsConn, validURL.Host, shutdown, options)
|
||||||
|
@ -100,7 +100,7 @@ func ssh(c *cli.Context) error {
|
||||||
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
|
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
|
||||||
options.TLSClientConfig = &tls.Config{
|
options.TLSClientConfig = &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
ServerName: parts[0],
|
ServerName: parts[0],
|
||||||
}
|
}
|
||||||
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
|
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
|
||||||
default:
|
default:
|
||||||
|
@ -109,7 +109,7 @@ func ssh(c *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
|
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
|
||||||
wsConn := carrier.NewWSConnection(log, false)
|
wsConn := carrier.NewWSConnection(log)
|
||||||
|
|
||||||
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
|
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
|
||||||
forwarder, err := config.ValidateUrl(c, true)
|
forwarder, err := config.ValidateUrl(c, true)
|
||||||
|
@ -117,7 +117,6 @@ func ssh(c *cli.Context) error {
|
||||||
log.Err(err).Msg("Error validating origin URL")
|
log.Err(err).Msg("Error validating origin URL")
|
||||||
return errors.Wrap(err, "error validating origin URL")
|
return errors.Wrap(err, "error validating origin URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Str(LogFieldHost, forwarder.Host).Msg("Start Websocket listener")
|
log.Info().Str(LogFieldHost, forwarder.Host).Msg("Start Websocket listener")
|
||||||
err = carrier.StartForwarder(wsConn, forwarder.Host, shutdownC, options)
|
err = carrier.StartForwarder(wsConn, forwarder.Host, shutdownC, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,25 +1,31 @@
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
"github.com/gorilla/websocket"
|
gorillaWS "github.com/gorilla/websocket"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testStreamTimeout = time.Second * 3
|
testStreamTimeout = time.Second * 3
|
||||||
|
echoHeaderName = "Test-Cloudflared-Echo"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -61,7 +67,7 @@ func TestStreamTCPConnection(t *testing.T) {
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamWSOverTCPConnection(t *testing.T) {
|
func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
|
||||||
cfdConn, originConn := net.Pipe()
|
cfdConn, originConn := net.Pipe()
|
||||||
tcpOverWSConn := tcpOverWSConnection{
|
tcpOverWSConn := tcpOverWSConnection{
|
||||||
conn: cfdConn,
|
conn: cfdConn,
|
||||||
|
@ -88,6 +94,100 @@ func TestStreamWSOverTCPConnection(t *testing.T) {
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
|
||||||
|
// Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
|
||||||
|
// wraps SOCKS5 traffic in websocket
|
||||||
|
// Origin side runs a tcpOverWSConnection with socks.StreamHandler
|
||||||
|
func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||||
|
var (
|
||||||
|
sendMessage = t.Name()
|
||||||
|
echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
|
||||||
|
echoMessage = fmt.Sprintf("echo-%s", sendMessage)
|
||||||
|
echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
|
||||||
|
)
|
||||||
|
|
||||||
|
statusCodes := []int{
|
||||||
|
http.StatusOK,
|
||||||
|
http.StatusTemporaryRedirect,
|
||||||
|
http.StatusBadRequest,
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
for _, status := range statusCodes {
|
||||||
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte(sendMessage), body)
|
||||||
|
|
||||||
|
require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
|
||||||
|
w.Header().Set(echoHeaderName, echoHeaderReturnValue)
|
||||||
|
|
||||||
|
w.WriteHeader(status)
|
||||||
|
w.Write([]byte(echoMessage))
|
||||||
|
}
|
||||||
|
origin := httptest.NewServer(http.HandlerFunc(handler))
|
||||||
|
defer origin.Close()
|
||||||
|
|
||||||
|
originURL, err := url.Parse(origin.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
originConn, err := net.Dial("tcp", originURL.Host)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tcpOverWSConn := tcpOverWSConnection{
|
||||||
|
conn: originConn,
|
||||||
|
streamHandler: socks.StreamHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
wsForwarderOutConn, edgeConn := net.Pipe()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
wsForwarderInConn, err := wsForwarderListener.Accept()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer wsForwarderInConn.Close()
|
||||||
|
|
||||||
|
Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
Dial: eyeballDialer.Dial,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request URL doesn't matter because the transport is using eyeballDialer to connectq
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, status, resp.StatusCode)
|
||||||
|
require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []byte(echoMessage), body)
|
||||||
|
|
||||||
|
wsForwarderOutConn.Close()
|
||||||
|
edgeConn.Close()
|
||||||
|
tcpOverWSConn.Close()
|
||||||
|
|
||||||
|
require.NoError(t, errGroup.Wait())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamWSConnection(t *testing.T) {
|
func TestStreamWSConnection(t *testing.T) {
|
||||||
eyeballConn, edgeConn := net.Pipe()
|
eyeballConn, edgeConn := net.Pipe()
|
||||||
|
|
||||||
|
@ -121,6 +221,23 @@ func TestStreamWSConnection(t *testing.T) {
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type wsEyeball struct {
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wse *wsEyeball) Read(p []byte) (int, error) {
|
||||||
|
data, err := wsutil.ReadServerBinary(wse.conn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return copy(p, data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wse *wsEyeball) Write(p []byte) (int, error) {
|
||||||
|
err := wsutil.WriteClientBinary(wse.conn, p)
|
||||||
|
return len(p), err
|
||||||
|
}
|
||||||
|
|
||||||
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||||
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
||||||
|
|
||||||
|
@ -133,7 +250,7 @@ func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func echoWSOrigin(t *testing.T) *httptest.Server {
|
func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||||
var upgrader = websocket.Upgrader{
|
var upgrader = gorillaWS.Upgrader{
|
||||||
ReadBufferSize: 10,
|
ReadBufferSize: 10,
|
||||||
WriteBufferSize: 10,
|
WriteBufferSize: 10,
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,5 +113,7 @@ func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.L
|
||||||
requestHandler := NewRequestHandler(dialer)
|
requestHandler := NewRequestHandler(dialer)
|
||||||
socksServer := NewConnectionHandler(requestHandler)
|
socksServer := NewConnectionHandler(requestHandler)
|
||||||
|
|
||||||
socksServer.Serve(tunnelConn)
|
if err := socksServer.Serve(tunnelConn); err != nil {
|
||||||
|
log.Debug().Err(err).Msg("Socks stream handler error")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue