From 63a29f421aa9a8812b5c19205604773c69167dc7 Mon Sep 17 00:00:00 2001 From: cthuang Date: Wed, 10 Feb 2021 16:19:55 +0000 Subject: [PATCH] TUN-3895: Tests for socks stream handler --- carrier/carrier_test.go | 4 +- carrier/websocket.go | 15 +--- cmd/cloudflared/access/carrier.go | 7 +- ingress/origin_connection_test.go | 123 +++++++++++++++++++++++++++++- socks/request_handler.go | 4 +- 5 files changed, 131 insertions(+), 22 deletions(-) diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index 0300e3aa..364b4296 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -44,7 +44,7 @@ func (s *testStreamer) Write(p []byte) (int, error) { func TestStartClient(t *testing.T) { message := "Good morning Austin! Time for another sunny day in the great state of Texas." log := zerolog.Nop() - wsConn := NewWSConnection(&log, false) + wsConn := NewWSConnection(&log) ts := newTestWebSocketServer() defer ts.Close() @@ -70,7 +70,7 @@ func TestStartServer(t *testing.T) { message := "Good morning Austin! Time for another sunny day in the great state of Texas." log := zerolog.Nop() shutdownC := make(chan struct{}) - wsConn := NewWSConnection(&log, false) + wsConn := NewWSConnection(&log) ts := newTestWebSocketServer() defer ts.Close() options := &StartOptions{ diff --git a/carrier/websocket.go b/carrier/websocket.go index 7a5976fa..d13ad0a8 100644 --- a/carrier/websocket.go +++ b/carrier/websocket.go @@ -38,10 +38,9 @@ func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, er } // NewWSConnection returns a new connection object -func NewWSConnection(log *zerolog.Logger, isSocks bool) Connection { +func NewWSConnection(log *zerolog.Logger) Connection { return &Websocket{ - log: log, - isSocks: isSocks, + log: log, } } @@ -55,15 +54,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro } defer wsConn.Close() - if ws.isSocks { - dialer := &wsdialer{conn: wsConn} - requestHandler := socks.NewRequestHandler(dialer) - socksServer := socks.NewConnectionHandler(requestHandler) - - _ = socksServer.Serve(conn) - } else { - ingress.Stream(wsConn, conn, ws.log) - } + ingress.Stream(wsConn, conn, ws.log) return nil } diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index c0fdaf71..f8ff4d80 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -48,7 +48,7 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z } // we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side - wsConn := carrier.NewWSConnection(log, false) + wsConn := carrier.NewWSConnection(log) log.Info().Str(LogFieldHost, validURL.Host).Msg("Start Websocket listener") return carrier.StartForwarder(wsConn, validURL.Host, shutdown, options) @@ -100,7 +100,7 @@ func ssh(c *cli.Context) error { options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1]) options.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, - ServerName: parts[0], + ServerName: parts[0], } log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0]) default: @@ -109,7 +109,7 @@ func ssh(c *cli.Context) error { } // we could add a cmd line variable for this bool if we want the SOCK5 server to be on the client side - wsConn := carrier.NewWSConnection(log, false) + wsConn := carrier.NewWSConnection(log) if c.NArg() > 0 || c.IsSet(sshURLFlag) { forwarder, err := config.ValidateUrl(c, true) @@ -117,7 +117,6 @@ func ssh(c *cli.Context) error { log.Err(err).Msg("Error validating origin URL") return errors.Wrap(err, "error validating origin URL") } - log.Info().Str(LogFieldHost, forwarder.Host).Msg("Start Websocket listener") err = carrier.StartForwarder(wsConn, forwarder.Host, shutdownC, options) if err != nil { diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go index 80bcf562..4e7b7ee8 100644 --- a/ingress/origin_connection_test.go +++ b/ingress/origin_connection_test.go @@ -1,25 +1,31 @@ package ingress import ( + "bytes" "context" "crypto/tls" "fmt" + "io/ioutil" "net" "net/http" "net/http/httptest" + "net/url" "testing" "time" "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/socks" "github.com/gobwas/ws/wsutil" - "github.com/gorilla/websocket" + gorillaWS "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/proxy" "golang.org/x/sync/errgroup" ) const ( testStreamTimeout = time.Second * 3 + echoHeaderName = "Test-Cloudflared-Echo" ) var ( @@ -61,7 +67,7 @@ func TestStreamTCPConnection(t *testing.T) { require.NoError(t, errGroup.Wait()) } -func TestStreamWSOverTCPConnection(t *testing.T) { +func TestDefaultStreamWSOverTCPConnection(t *testing.T) { cfdConn, originConn := net.Pipe() tcpOverWSConn := tcpOverWSConnection{ conn: cfdConn, @@ -88,6 +94,100 @@ func TestStreamWSOverTCPConnection(t *testing.T) { require.NoError(t, errGroup.Wait()) } +// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode. +// Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which +// wraps SOCKS5 traffic in websocket +// Origin side runs a tcpOverWSConnection with socks.StreamHandler +func TestSocksStreamWSOverTCPConnection(t *testing.T) { + var ( + sendMessage = t.Name() + echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage) + echoMessage = fmt.Sprintf("echo-%s", sendMessage) + echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue) + ) + + statusCodes := []int{ + http.StatusOK, + http.StatusTemporaryRedirect, + http.StatusBadRequest, + http.StatusInternalServerError, + } + for _, status := range statusCodes { + handler := func(w http.ResponseWriter, r *http.Request) { + body, err := ioutil.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, []byte(sendMessage), body) + + require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName)) + w.Header().Set(echoHeaderName, echoHeaderReturnValue) + + w.WriteHeader(status) + w.Write([]byte(echoMessage)) + } + origin := httptest.NewServer(http.HandlerFunc(handler)) + defer origin.Close() + + originURL, err := url.Parse(origin.URL) + require.NoError(t, err) + + originConn, err := net.Dial("tcp", originURL.Host) + require.NoError(t, err) + + tcpOverWSConn := tcpOverWSConnection{ + conn: originConn, + streamHandler: socks.StreamHandler, + } + + wsForwarderOutConn, edgeConn := net.Pipe() + ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) + defer cancel() + + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + tcpOverWSConn.Stream(ctx, edgeConn, testLogger) + return nil + }) + + wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + errGroup.Go(func() error { + wsForwarderInConn, err := wsForwarderListener.Accept() + require.NoError(t, err) + defer wsForwarderInConn.Close() + + Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger) + return nil + }) + + eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct) + require.NoError(t, err) + + transport := &http.Transport{ + Dial: eyeballDialer.Dial, + } + + // Request URL doesn't matter because the transport is using eyeballDialer to connectq + req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage))) + assert.NoError(t, err) + req.Header.Set(echoHeaderName, echoHeaderIncomingValue) + + resp, err := transport.RoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, status, resp.StatusCode) + require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName)) + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, []byte(echoMessage), body) + + wsForwarderOutConn.Close() + edgeConn.Close() + tcpOverWSConn.Close() + + require.NoError(t, errGroup.Wait()) + } +} + func TestStreamWSConnection(t *testing.T) { eyeballConn, edgeConn := net.Pipe() @@ -121,6 +221,23 @@ func TestStreamWSConnection(t *testing.T) { require.NoError(t, errGroup.Wait()) } +type wsEyeball struct { + conn net.Conn +} + +func (wse *wsEyeball) Read(p []byte) (int, error) { + data, err := wsutil.ReadServerBinary(wse.conn) + if err != nil { + return 0, err + } + return copy(p, data), nil +} + +func (wse *wsEyeball) Write(p []byte) (int, error) { + err := wsutil.WriteClientBinary(wse.conn, p) + return len(p), err +} + func echoWSEyeball(t *testing.T, conn net.Conn) { require.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) @@ -133,7 +250,7 @@ func echoWSEyeball(t *testing.T, conn net.Conn) { } func echoWSOrigin(t *testing.T) *httptest.Server { - var upgrader = websocket.Upgrader{ + var upgrader = gorillaWS.Upgrader{ ReadBufferSize: 10, WriteBufferSize: 10, } diff --git a/socks/request_handler.go b/socks/request_handler.go index 9a1e0bea..904751c9 100644 --- a/socks/request_handler.go +++ b/socks/request_handler.go @@ -113,5 +113,7 @@ func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.L requestHandler := NewRequestHandler(dialer) socksServer := NewConnectionHandler(requestHandler) - socksServer.Serve(tunnelConn) + if err := socksServer.Serve(tunnelConn); err != nil { + log.Debug().Err(err).Msg("Socks stream handler error") + } }