From a6c2348127c0f53c5cdea64377b1e60cb57c7e24 Mon Sep 17 00:00:00 2001 From: Sudarsan Reddy Date: Tue, 2 Feb 2021 18:27:50 +0000 Subject: [PATCH] TUN-3817: Adds tests for websocket based streaming regression --- ingress/origin_connection.go | 10 + ingress/origin_proxy.go | 7 + origin/proxy.go | 89 ++++---- origin/proxy_test.go | 425 +++++++++++++++++++++++++---------- websocket/connection.go | 10 +- 5 files changed, 374 insertions(+), 167 deletions(-) diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index d49ca945..8dab2c22 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -5,6 +5,7 @@ import ( "net" "net/http" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/websocket" gws "github.com/gorilla/websocket" ) @@ -15,6 +16,7 @@ type OriginConnection interface { // Stream should generally be implemented as a bidirectional io.Copy. Stream(tunnelConn io.ReadWriter) Close() + Type() connection.Type } type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn) @@ -57,6 +59,10 @@ func (tc *tcpConnection) Close() { tc.conn.Close() } +func (*tcpConnection) Type() connection.Type { + return connection.TypeTCP +} + // wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets. // TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does. type wsConnection struct { @@ -73,6 +79,10 @@ func (wsc *wsConnection) Close() { wsc.wsConn.Close() } +func (wsc *wsConnection) Type() connection.Type { + return connection.TypeWebsocket +} + func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) { d := &gws.Dialer{ TLSClientConfig: transport.TLSClientConfig, diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 0c617d6a..c5a9dff3 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -9,6 +9,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/websocket" "github.com/pkg/errors" ) @@ -39,6 +40,12 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } +func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) { + req.URL.Host = o.url.Host + req.URL.Scheme = websocket.ChangeRequestScheme(o.url) + return newWSConnection(o.transport, req) +} + func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { // Rewrite the request URL so that it goes to the Hello World server. req.URL.Host = o.server.Addr().String() diff --git a/origin/proxy.go b/origin/proxy.go index 42e771d0..09b39273 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -52,6 +52,9 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn cfRay := findCfRayHeader(req) lbProbe := isLBProbeRequest(req) + serveCtx, cancel := context.WithCancel(req.Context()) + defer cancel() + p.appendTagHeaders(req) if sourceConnectionType == connection.TypeTCP { if p.warpRouting == nil { @@ -59,7 +62,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn p.log.Error().Msg(err.Error()) return err } - resp, err := p.handleProxyConn(w, req, nil, p.warpRouting.Proxy) + resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy) if err != nil { p.logRequestError(err, cfRay, ingress.ServiceWarpRouting) w.WriteErrorResponse() @@ -83,12 +86,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return nil } - respHeader := http.Header{} - if sourceConnectionType == connection.TypeWebsocket { - go websocket.NewConn(w, p.log).Pinger(req.Context()) - respHeader = websocket.NewResponseHeader(req) - } - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { req.Header.Set("Host", hostHeader) req.Host = hostHeader @@ -99,7 +96,8 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) return fmt.Errorf("Not a connection-oriented service") } - resp, err := p.handleProxyConn(w, req, respHeader, connectionProxy) + + resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, connectionProxy) if err != nil { p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) return err @@ -109,31 +107,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return nil } -func (p *proxy) handleProxyConn( - w connection.ResponseWriter, - req *http.Request, - respHeader http.Header, - connectionProxy ingress.StreamBasedOriginProxy) (*http.Response, error) { - connClosedChan := make(chan struct{}) - err := p.proxyConnection(connClosedChan, w, req, connectionProxy) - if err != nil { - return nil, err - } - - status := http.StatusSwitchingProtocols - resp := &http.Response{ - Status: http.StatusText(status), - StatusCode: status, - Header: respHeader, - ContentLength: -1, - } - w.WriteRespHeaders(http.StatusSwitchingProtocols, nil) - - <-connClosedChan - return resp, nil - -} - func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) { p.logRequestError(err, cfRay, ruleNum) w.WriteErrorResponse() @@ -186,27 +159,51 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * return resp, nil } -func (p *proxy) proxyConnection(connClosedChan chan struct{}, - conn io.ReadWriter, req *http.Request, connectionProxy ingress.StreamBasedOriginProxy) error { +func (p *proxy) proxyConnection( + serveCtx context.Context, + w connection.ResponseWriter, + req *http.Request, + sourceConnectionType connection.Type, + connectionProxy ingress.StreamBasedOriginProxy, +) (*http.Response, error) { originConn, err := connectionProxy.EstablishConnection(req) if err != nil { - return err + return nil, err } - serveCtx, cancel := context.WithCancel(req.Context()) + var eyeballConn io.ReadWriter = w + respHeader := http.Header{} + if sourceConnectionType == connection.TypeWebsocket { + wsReadWriter := websocket.NewConn(serveCtx, w, p.log) + // If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames + if originConn.Type() != sourceConnectionType { + eyeballConn = wsReadWriter + } + respHeader = websocket.NewResponseHeader(req) + } + status := http.StatusSwitchingProtocols + resp := &http.Response{ + Status: http.StatusText(status), + StatusCode: status, + Header: respHeader, + ContentLength: -1, + } + w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader) + if err != nil { + return nil, errors.Wrap(err, "Error writing response header") + } + + streamCtx, cancel := context.WithCancel(serveCtx) + defer cancel() + go func() { - // serveCtx is done if req is cancelled, or streamWebsocket returns - <-serveCtx.Done() + // streamCtx is done if req is cancelled or if Stream returns + <-streamCtx.Done() originConn.Close() - close(connClosedChan) }() - go func() { - originConn.Stream(conn) - cancel() - }() - - return nil + originConn.Stream(eyeballConn) + return resp, nil } func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 5078217d..44b09d5f 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -17,11 +17,10 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/websocket" + gorillaWS "github.com/gorilla/websocket" "github.com/urfave/cli/v2" "github.com/gobwas/ws/wsutil" @@ -354,112 +353,347 @@ func TestProxyError(t *testing.T) { assert.Equal(t, "http response error", respWriter.Body.String()) } -func TestProxyBastionMode(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError) - flagSet.Bool("bastion", true, "") - - cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil) - err := cliCtx.Set(config.BastionFlag, "true") - require.NoError(t, err) - - allowURLFromArgs := false - ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs) - require.NoError(t, err) - - var wg sync.WaitGroup - errC := make(chan error) - - log := logger.Create(nil) - - ingressRule.StartOrigins(&wg, log, ctx.Done(), errC) - - proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, log) - - t.Run("testBastionWebsocket", testBastionWebsocket(proxy)) - cancel() +type replayer struct { + sync.RWMutex + writeDone chan struct{} + rw *bytes.Buffer } -func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) { - return func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - readPipe, _ := io.Pipe() - respWriter := newMockWSRespWriter(readPipe) +func newReplayer(buffer *bytes.Buffer) { - var wg sync.WaitGroup - msgFromConn := []byte("data from websocket proxy") - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - wg.Add(1) - go func() { - defer wg.Done() - defer ln.Close() - conn, err := ln.Accept() +} + +func (r *replayer) Read(p []byte) (int, error) { + r.RLock() + defer r.RUnlock() + return r.rw.Read(p) +} + +func (r *replayer) Write(p []byte) (int, error) { + r.Lock() + defer r.Unlock() + n, err := r.rw.Write(p) + return n, err +} + +func (r *replayer) String() string { + r.Lock() + defer r.Unlock() + return r.rw.String() +} + +func (r *replayer) Bytes() []byte { + r.Lock() + defer r.Unlock() + return r.rw.Bytes() +} + +// TestConnections tests every possible permutation of connection protocols +// proxied by cloudflared. +// +// WS - WS : When a websocket based ingress is configured on the origin and +// the eyeball is also a websocket client streaming data. +// TCP - TCP : When teamnet is enabled and an http or tcp service is running +// on the origin. +// TCP - WS: When teamnet is enabled and a websocket based service is running +// on the origin. +// WS - TCP: When a tcp based ingress is configured on the origin and the +// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access). +func TestConnections(t *testing.T) { + logger := logger.Create(nil) + replayer := &replayer{rw: &bytes.Buffer{}} + + var tests = []struct { + name string + skip bool + ingressServicePrefix string + + originService func(*testing.T, net.Listener) + eyeballService connection.ResponseWriter + connectionType connection.Type + wantMessage []byte + }{ + { + name: "ws-ws proxy", + ingressServicePrefix: "ws://", + originService: runEchoWSService, + eyeballService: newWSRespWriter([]byte("test1"), replayer), + connectionType: connection.TypeWebsocket, + wantMessage: []byte("test1"), + }, + { + name: "tcp-tcp proxy", + ingressServicePrefix: "tcp://", + originService: runEchoTCPService, + eyeballService: newTCPRespWriter( + []byte(`test2`), + replayer, + ), + connectionType: connection.TypeTCP, + wantMessage: []byte("echo-test2"), + }, + { + name: "tcp-ws proxy", + ingressServicePrefix: "ws://", + originService: runEchoWSService, + eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), + connectionType: connection.TypeTCP, + wantMessage: []byte("test3"), + }, + { + name: "ws-tcp proxy", + ingressServicePrefix: "tcp://", + originService: runEchoTCPService, + eyeballService: newWSRespWriter([]byte("test4"), replayer), + connectionType: connection.TypeWebsocket, + wantMessage: []byte("echo-test4"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.skip { + t.Skip("todo: skipping a failing test. THis should be fixed before merge") + } + ctx, cancel := context.WithCancel(context.Background()) + ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - wsConn := websocket.NewConn(conn, nil) - wsConn.Write(msgFromConn) - }() + test.originService(t, ln) + ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String()) + var wg sync.WaitGroup + errC := make(chan error) + ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) + proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) + req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) + require.NoError(t, err) + req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value") - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil) - req.Header.Set(h2mux.CFJumpDestinationHeader, ln.Addr().String()) - - wg.Add(1) - go func() { - defer wg.Done() - err = proxy.Proxy(respWriter, req, connection.TypeWebsocket) + if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { + go func() { + resp := pipedWS.roundtrip(test.ingressServicePrefix + ln.Addr().String()) + replayer.Write(resp) + }() + } + err = proxy.Proxy(test.eyeballService, req, test.connectionType) require.NoError(t, err) - require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) - }() - - // ReadServerText reads next data message from rw, considering that caller represents proxy side. - returnedMsg, err := wsutil.ReadServerText(respWriter.respBody()) - if err != io.EOF { - require.NoError(t, err) - require.Equal(t, msgFromConn, returnedMsg) - } - - cancel() - wg.Wait() + cancel() + assert.Equal(t, test.wantMessage, replayer.Bytes()) + replayer.rw.Reset() + }) } } -func TestTCPStream(t *testing.T) { - logger := logger.Create(nil) +type pipedWSWriter struct { + dialer gorillaWS.Dialer + wsConn net.Conn + pipedConn net.Conn + respWriter connection.ResponseWriter + messageToWrite []byte +} - ctx, cancel := context.WithCancel(context.Background()) +func newPipedWSWriter(rw *mockTCPRespWriter, messageToWrite []byte) *pipedWSWriter { + conn1, conn2 := net.Pipe() + dialer := gorillaWS.Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + return conn2, nil + }, + } + rw.pr = conn1 + rw.w = conn1 + return &pipedWSWriter{ + dialer: dialer, + pipedConn: conn1, + wsConn: conn2, + messageToWrite: messageToWrite, + respWriter: rw, + } +} +func (p *pipedWSWriter) roundtrip(addr string) []byte { + header := http.Header{} + conn, resp, err := p.dialer.Dial(addr, header) + if err != nil { + panic(err) + } + defer conn.Close() + + if resp.StatusCode != http.StatusSwitchingProtocols { + panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode)) + } + + err = conn.WriteMessage(gorillaWS.TextMessage, p.messageToWrite) + if err != nil { + panic(err) + } + + _, data, err := conn.ReadMessage() + if err != nil { + panic(err) + } + + return data +} + +func (p *pipedWSWriter) Read(data []byte) (int, error) { + return p.pipedConn.Read(data) +} + +func (p *pipedWSWriter) Write(data []byte) (int, error) { + return p.pipedConn.Write(data) +} + +func (p *pipedWSWriter) WriteErrorResponse() { +} + +func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error { + return nil +} + +type wsRespWriter struct { + w io.Writer + pr *io.PipeReader + pw *io.PipeWriter + code int +} + +// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames. +// and wsutil.ReadClientText to translate frames from server to byte data. +// In essence, this acts as a wsClient. +func newWSRespWriter(data []byte, w io.Writer) *wsRespWriter { + pr, pw := io.Pipe() + go wsutil.WriteClientBinary(pw, data) + return &wsRespWriter{ + w: w, + pr: pr, + pw: pw, + } +} + +// Read is read by ingress.Stream and serves as the input from the client. +func (w *wsRespWriter) Read(p []byte) (int, error) { + return w.pr.Read(p) +} + +// Write is written to by ingress.Stream and serves as the output to the client. +func (w *wsRespWriter) Write(p []byte) (int, error) { + defer w.pw.Close() + returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p)) + if err != nil { + // The data was not returned by a websocket connecton. + if err != io.ErrUnexpectedEOF { + return w.w.Write(p) + } + } + return w.w.Write(returnedMsg) +} + +func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { + w.code = status + return nil +} + +func (w *wsRespWriter) WriteErrorResponse() { +} + +func runEchoTCPService(t *testing.T, l net.Listener) { + go func() { + for { + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + + for { + buf := make([]byte, 1024) + size, err := conn.Read(buf) + if err == io.EOF { + return + } + data := []byte("echo-") + data = append(data, buf[:size]...) + _, err = conn.Write(data) + if err != nil { + t.Log(err) + return + } + } + } + }() +} + +func runEchoWSService(t *testing.T, l net.Listener) { + var upgrader = gorillaWS.Upgrader{ + ReadBufferSize: 10, + WriteBufferSize: 10, + } + + var ws = func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + defer conn.Close() + + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + return + } + + if err := conn.WriteMessage(messageType, p); err != nil { + return + } + } + } + + server := http.Server{ + Handler: http.HandlerFunc(ws), + } + + go func() { + err := server.Serve(l) + require.NoError(t, err) + }() +} + +func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress { ingressConfig := &config.Configuration{ Ingress: []config.UnvalidatedIngressRule{ { Hostname: "*", - Service: "bastion", + Service: service, }, }, } ingressRule, err := ingress.ParseIngress(ingressConfig) require.NoError(t, err) + return ingressRule +} - var wg sync.WaitGroup - errC := make(chan error) - ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) - - proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) - - t.Run("testTCPStream", testTCPStreamProxy(proxy)) - cancel() +type tcpWrappedWs struct { } type mockTCPRespWriter struct { w io.Writer + pr io.Reader + pw *io.PipeWriter code int } +func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter { + pr, pw := io.Pipe() + go pw.Write(data) + return &mockTCPRespWriter{ + w: w, + pr: pr, + pw: pw, + } +} + func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) { - return len(p), nil + return m.pr.Read(p) } func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { + defer m.pw.Close() return m.w.Write(p) } @@ -470,44 +704,3 @@ func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) err m.code = status return nil } - -func testTCPStreamProxy(proxy connection.OriginProxy) func(t *testing.T) { - return func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - readPipe, writePipe := io.Pipe() - respWriter := &mockTCPRespWriter{ - w: writePipe, - } - msgFromConn := []byte("data from tcp proxy") - ln, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - go func() { - defer ln.Close() - conn, err := ln.Accept() - require.NoError(t, err) - defer conn.Close() - _, err = conn.Write(msgFromConn) - require.NoError(t, err) - }() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil) - require.NoError(t, err) - - req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value") - req.Host = ln.Addr().String() - err = proxy.Proxy(respWriter, req, connection.TypeTCP) - require.NoError(t, err) - - require.Equal(t, http.StatusSwitchingProtocols, respWriter.code) - - returnedMsg := make([]byte, len(msgFromConn)) - - _, err = readPipe.Read(returnedMsg) - - require.NoError(t, err) - require.Equal(t, msgFromConn, returnedMsg) - - cancel() - } -} diff --git a/websocket/connection.go b/websocket/connection.go index 8098fcb1..ec327804 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -36,7 +36,6 @@ func (c *GorillaConn) Read(p []byte) (int, error) { if err != nil { return 0, err } - return copy(p, message), nil } @@ -71,11 +70,13 @@ type Conn struct { log *zerolog.Logger } -func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn { - return &Conn{ +func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { + c := &Conn{ rw: rw, log: log, } + go c.pinger(ctx) + return c } // Read will read messages from the websocket connection @@ -92,11 +93,10 @@ func (c *Conn) Write(p []byte) (int, error) { if err := wsutil.WriteServerBinary(c.rw, p); err != nil { return 0, err } - return len(p), nil } -func (c *Conn) Pinger(ctx context.Context) { +func (c *Conn) pinger(ctx context.Context) { pongMessge := wsutil.Message{ OpCode: gobwas.OpPong, Payload: []byte{},