TUN-3895: Tests for socks stream handler
This commit is contained in:
parent
e20c4f8752
commit
63a29f421a
|
@ -44,7 +44,7 @@ func (s *testStreamer) Write(p []byte) (int, error) {
|
|||
func TestStartClient(t *testing.T) {
|
||||
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||
log := zerolog.Nop()
|
||||
wsConn := NewWSConnection(&log, false)
|
||||
wsConn := NewWSConnection(&log)
|
||||
ts := newTestWebSocketServer()
|
||||
defer ts.Close()
|
||||
|
||||
|
@ -70,7 +70,7 @@ func TestStartServer(t *testing.T) {
|
|||
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||
log := zerolog.Nop()
|
||||
shutdownC := make(chan struct{})
|
||||
wsConn := NewWSConnection(&log, false)
|
||||
wsConn := NewWSConnection(&log)
|
||||
ts := newTestWebSocketServer()
|
||||
defer ts.Close()
|
||||
options := &StartOptions{
|
||||
|
|
|
@ -38,10 +38,9 @@ func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, er
|
|||
}
|
||||
|
||||
// NewWSConnection returns a new connection object
|
||||
func NewWSConnection(log *zerolog.Logger, isSocks bool) Connection {
|
||||
func NewWSConnection(log *zerolog.Logger) Connection {
|
||||
return &Websocket{
|
||||
log: log,
|
||||
isSocks: isSocks,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,15 +54,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
|
|||
}
|
||||
defer wsConn.Close()
|
||||
|
||||
if ws.isSocks {
|
||||
dialer := &wsdialer{conn: wsConn}
|
||||
requestHandler := socks.NewRequestHandler(dialer)
|
||||
socksServer := socks.NewConnectionHandler(requestHandler)
|
||||
|
||||
_ = socksServer.Serve(conn)
|
||||
} else {
|
||||
ingress.Stream(wsConn, conn, ws.log)
|
||||
}
|
||||
ingress.Stream(wsConn, conn, ws.log)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z
|
|||
}
|
||||
|
||||
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
|
||||
wsConn := carrier.NewWSConnection(log, false)
|
||||
wsConn := carrier.NewWSConnection(log)
|
||||
|
||||
log.Info().Str(LogFieldHost, validURL.Host).Msg("Start Websocket listener")
|
||||
return carrier.StartForwarder(wsConn, validURL.Host, shutdown, options)
|
||||
|
@ -100,7 +100,7 @@ func ssh(c *cli.Context) error {
|
|||
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
|
||||
options.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: parts[0],
|
||||
ServerName: parts[0],
|
||||
}
|
||||
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
|
||||
default:
|
||||
|
@ -109,7 +109,7 @@ func ssh(c *cli.Context) error {
|
|||
}
|
||||
|
||||
// we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side
|
||||
wsConn := carrier.NewWSConnection(log, false)
|
||||
wsConn := carrier.NewWSConnection(log)
|
||||
|
||||
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
|
||||
forwarder, err := config.ValidateUrl(c, true)
|
||||
|
@ -117,7 +117,6 @@ func ssh(c *cli.Context) error {
|
|||
log.Err(err).Msg("Error validating origin URL")
|
||||
return errors.Wrap(err, "error validating origin URL")
|
||||
}
|
||||
|
||||
log.Info().Str(LogFieldHost, forwarder.Host).Msg("Start Websocket listener")
|
||||
err = carrier.StartForwarder(wsConn, forwarder.Host, shutdownC, options)
|
||||
if err != nil {
|
||||
|
|
|
@ -1,25 +1,31 @@
|
|||
package ingress
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/gorilla/websocket"
|
||||
gorillaWS "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
testStreamTimeout = time.Second * 3
|
||||
echoHeaderName = "Test-Cloudflared-Echo"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -61,7 +67,7 @@ func TestStreamTCPConnection(t *testing.T) {
|
|||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
func TestStreamWSOverTCPConnection(t *testing.T) {
|
||||
func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
|
||||
cfdConn, originConn := net.Pipe()
|
||||
tcpOverWSConn := tcpOverWSConnection{
|
||||
conn: cfdConn,
|
||||
|
@ -88,6 +94,100 @@ func TestStreamWSOverTCPConnection(t *testing.T) {
|
|||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
|
||||
// Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
|
||||
// wraps SOCKS5 traffic in websocket
|
||||
// Origin side runs a tcpOverWSConnection with socks.StreamHandler
|
||||
func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||
var (
|
||||
sendMessage = t.Name()
|
||||
echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
|
||||
echoMessage = fmt.Sprintf("echo-%s", sendMessage)
|
||||
echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
|
||||
)
|
||||
|
||||
statusCodes := []int{
|
||||
http.StatusOK,
|
||||
http.StatusTemporaryRedirect,
|
||||
http.StatusBadRequest,
|
||||
http.StatusInternalServerError,
|
||||
}
|
||||
for _, status := range statusCodes {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(sendMessage), body)
|
||||
|
||||
require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
|
||||
w.Header().Set(echoHeaderName, echoHeaderReturnValue)
|
||||
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(echoMessage))
|
||||
}
|
||||
origin := httptest.NewServer(http.HandlerFunc(handler))
|
||||
defer origin.Close()
|
||||
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
originConn, err := net.Dial("tcp", originURL.Host)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpOverWSConn := tcpOverWSConnection{
|
||||
conn: originConn,
|
||||
streamHandler: socks.StreamHandler,
|
||||
}
|
||||
|
||||
wsForwarderOutConn, edgeConn := net.Pipe()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
||||
return nil
|
||||
})
|
||||
|
||||
wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
errGroup.Go(func() error {
|
||||
wsForwarderInConn, err := wsForwarderListener.Accept()
|
||||
require.NoError(t, err)
|
||||
defer wsForwarderInConn.Close()
|
||||
|
||||
Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
||||
return nil
|
||||
})
|
||||
|
||||
eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport := &http.Transport{
|
||||
Dial: eyeballDialer.Dial,
|
||||
}
|
||||
|
||||
// Request URL doesn't matter because the transport is using eyeballDialer to connectq
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
|
||||
assert.NoError(t, err)
|
||||
req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, status, resp.StatusCode)
|
||||
require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte(echoMessage), body)
|
||||
|
||||
wsForwarderOutConn.Close()
|
||||
edgeConn.Close()
|
||||
tcpOverWSConn.Close()
|
||||
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamWSConnection(t *testing.T) {
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
|
@ -121,6 +221,23 @@ func TestStreamWSConnection(t *testing.T) {
|
|||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
type wsEyeball struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (wse *wsEyeball) Read(p []byte) (int, error) {
|
||||
data, err := wsutil.ReadServerBinary(wse.conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return copy(p, data), nil
|
||||
}
|
||||
|
||||
func (wse *wsEyeball) Write(p []byte) (int, error) {
|
||||
err := wsutil.WriteClientBinary(wse.conn, p)
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
||||
|
||||
|
@ -133,7 +250,7 @@ func echoWSEyeball(t *testing.T, conn net.Conn) {
|
|||
}
|
||||
|
||||
func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||
var upgrader = websocket.Upgrader{
|
||||
var upgrader = gorillaWS.Upgrader{
|
||||
ReadBufferSize: 10,
|
||||
WriteBufferSize: 10,
|
||||
}
|
||||
|
|
|
@ -113,5 +113,7 @@ func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.L
|
|||
requestHandler := NewRequestHandler(dialer)
|
||||
socksServer := NewConnectionHandler(requestHandler)
|
||||
|
||||
socksServer.Serve(tunnelConn)
|
||||
if err := socksServer.Serve(tunnelConn); err != nil {
|
||||
log.Debug().Err(err).Msg("Socks stream handler error")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue