diff --git a/connection/connection_test.go b/connection/connection_test.go index 325df5d0..e8e477ea 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -4,17 +4,16 @@ import ( "context" "fmt" "io" - "math/rand" "net/http" "net/url" "testing" "time" + "github.com/gobwas/ws/wsutil" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/websocket" ) const ( @@ -51,15 +50,7 @@ func (moc *mockOriginProxy) ProxyHTTP( isWebsocket bool, ) error { if isWebsocket { - switch req.URL.Path { - case "/ws/echo": - return wsEchoEndpoint(w, req) - case "/ws/flaky": - return wsFlakyEndpoint(w, req) - default: - originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found")) - return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path) - } + return wsEndpoint(w, req) } switch req.URL.Path { case "/ok": @@ -87,82 +78,32 @@ func (moc *mockOriginProxy) ProxyTCP( return nil } -type echoPipe struct { - reader *io.PipeReader - writer *io.PipeWriter +type nowriter struct { + io.Reader } -func (ep *echoPipe) Read(p []byte) (int, error) { - return ep.reader.Read(p) +func (nowriter) Write(p []byte) (int, error) { + return 0, fmt.Errorf("Writer not implemented") } -func (ep *echoPipe) Write(p []byte) (int, error) { - return ep.writer.Write(p) -} - -// A mock origin that echos data by streaming like a tcpOverWSConnection -// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go -func wsEchoEndpoint(w ResponseWriter, r *http.Request) error { +func wsEndpoint(w ResponseWriter, r *http.Request) error { resp := &http.Response{ StatusCode: http.StatusSwitchingProtocols, } - if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { - return err - } - wsCtx, cancel := context.WithCancel(r.Context()) - readPipe, writePipe := io.Pipe() - wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) + _ = w.WriteRespHeaders(resp.StatusCode, resp.Header) + clientReader := nowriter{r.Body} go func() { - select { - case <-wsCtx.Done(): - case <-r.Context().Done(): + for { + data, err := wsutil.ReadClientText(clientReader) + if err != nil { + return + } + if err := wsutil.WriteServerText(w, data); err != nil { + return + } } - readPipe.Close() - writePipe.Close() }() - - originConn := &echoPipe{reader: readPipe, writer: writePipe} - websocket.Stream(wsConn, originConn, &log) - cancel() - wsConn.Close() - return nil -} - -type flakyConn struct { - closeAt time.Time -} - -func (fc *flakyConn) Read(p []byte) (int, error) { - if time.Now().After(fc.closeAt) { - return 0, io.EOF - } - n := copy(p, []byte("Read from flaky connection")) - return n, nil -} - -func (fc *flakyConn) Write(p []byte) (int, error) { - if time.Now().After(fc.closeAt) { - return 0, fmt.Errorf("Flaky connection closed") - } - return len(p), nil -} - -func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error { - resp := &http.Response{ - StatusCode: http.StatusSwitchingProtocols, - } - if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { - return err - } - wsCtx, cancel := context.WithCancel(r.Context()) - - wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) - - closedAfter := time.Millisecond * time.Duration(rand.Intn(50)) - originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} - websocket.Stream(wsConn, originConn, &log) - cancel() - wsConn.Close() + <-r.Context().Done() return nil } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index 83a39589..e6eab072 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) { headers := []h2mux.Header{ { Name: ":path", - Value: "/ws/echo", + Value: "/ws", }, { Name: "connection", @@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) { assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) data := []byte("test websocket") - err = wsutil.WriteClientBinary(writePipe, data) + err = wsutil.WriteClientText(writePipe, data) require.NoError(t, err) - respBody, err := wsutil.ReadServerBinary(stream) + respBody, err := wsutil.ReadServerText(stream) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) diff --git a/connection/http2_test.go b/connection/http2_test.go index 23555837..4b7435bd 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -27,7 +27,7 @@ var ( ) func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { - edgeConn, cfdConn := net.Pipe() + edgeConn, originConn := net.Pipe() var connIndex = uint8(0) log := zerolog.Nop() obs := NewObserver(&log, &log, false) @@ -41,8 +41,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { 1*time.Second, ) return NewHTTP2Connection( - cfdConn, - // OriginProxy is set in testConfig + originConn, testConfig, &pogs.ConnectionOptions{}, obs, @@ -167,7 +166,6 @@ type wsRespWriter struct { *httptest.ResponseRecorder readPipe *io.PipeReader writePipe *io.PipeWriter - closed bool } func newWSRespWriter() *wsRespWriter { @@ -176,58 +174,46 @@ func newWSRespWriter() *wsRespWriter { httptest.NewRecorder(), readPipe, writePipe, - false, } } -type nowriter struct { - io.Reader -} - -func (nowriter) Write(p []byte) (int, error) { - return 0, fmt.Errorf("Writer not implemented") -} - func (w *wsRespWriter) RespBody() io.ReadWriter { return nowriter{w.readPipe} } func (w *wsRespWriter) Write(data []byte) (n int, err error) { - if w.closed { - // Simulate writing to http2 ResponseWriter after ServeHTTP has returned - panic("Write to closed ResponseWriter") - } return w.writePipe.Write(data) } -func (w *wsRespWriter) close() { - w.closed = true -} - func TestServeWS(t *testing.T) { http2Conn, _ := newTestHTTP2Connection() ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.Serve(ctx) + }() respWriter := newWSRespWriter() readPipe, writePipe := io.Pipe() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe) require.NoError(t, err) req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) - serveDone := make(chan struct{}) + wg.Add(1) go func() { - defer close(serveDone) + defer wg.Done() http2Conn.ServeHTTP(respWriter, req) - respWriter.close() }() data := []byte("test websocket") - err = wsutil.WriteClientBinary(writePipe, data) + err = wsutil.WriteClientText(writePipe, data) require.NoError(t, err) - respBody, err := wsutil.ReadServerBinary(respWriter.RespBody()) + respBody, err := wsutil.ReadServerText(respWriter.RespBody()) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) @@ -237,64 +223,7 @@ func TestServeWS(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) - <-serveDone -} - -// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184 -// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns -func TestNoWriteAfterServeHTTPReturns(t *testing.T) { - cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection() - - ctx, cancel := context.WithCancel(context.Background()) - var wg sync.WaitGroup - - serverDone := make(chan struct{}) - go func() { - defer close(serverDone) - cfdHTTP2Conn.Serve(ctx) - }() - - edgeTransport := http2.Transport{} - edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn) - require.NoError(t, err) - message := []byte(t.Name()) - - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() - readPipe, writePipe := io.Pipe() - reqCtx, reqCancel := context.WithCancel(ctx) - req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe) - require.NoError(t, err) - req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) - - resp, err := edgeHTTP2Conn.RoundTrip(req) - require.NoError(t, err) - // http2RespWriter should rewrite status 101 to 200 - require.Equal(t, http.StatusOK, resp.StatusCode) - - wg.Add(1) - go func() { - defer wg.Done() - for { - select { - case <-reqCtx.Done(): - return - default: - } - _ = wsutil.WriteClientBinary(writePipe, message) - } - }() - - time.Sleep(time.Millisecond * 100) - reqCancel() - }() - } - wg.Wait() - cancel() - <-serverDone } func TestServeControlStream(t *testing.T) { diff --git a/connection/quic_test.go b/connection/quic_test.go index d5a11ec0..fc551a4b 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -61,7 +61,7 @@ func TestQUICServer(t *testing.T) { // This is simply a sample websocket frame message. wsBuf := &bytes.Buffer{} - wsutil.WriteClientBinary(wsBuf, []byte("Hello")) + wsutil.WriteClientText(wsBuf, []byte("Hello")) var tests = []struct { desc string @@ -118,7 +118,7 @@ func TestQUICServer(t *testing.T) { }, { desc: "test ws proxy", - dest: "/ws/echo", + dest: "/ok", connectionType: quicpogs.ConnectionTypeWebsocket, metadata: []quicpogs.Metadata{ quicpogs.Metadata{ @@ -139,7 +139,7 @@ func TestQUICServer(t *testing.T) { }, }, message: wsBuf.Bytes(), - expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, + expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, }, { desc: "test tcp proxy", @@ -278,7 +278,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque } if isWebsocket { - return wsEchoEndpoint(w, r) + return wsEndpoint(w, r) } switch r.URL.Path { case "/ok": diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index bdfb19b6..9588ce36 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri wc.streamHandler(wsConn, wc.conn, log) cancel() // Makes sure wsConn stops sending ping before terminating the stream - wsConn.Close() + wsConn.WaitForShutdown() } func (wc *tcpOverWSConnection) Close() { @@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. socks.StreamNetHandler(wsConn, sp.accessPolicy, log) cancel() // Makes sure wsConn stops sending ping before terminating the stream - wsConn.Close() + wsConn.WaitForShutdown() } func (sp *socksProxyOverWSConnection) Close() { diff --git a/websocket/connection.go b/websocket/connection.go index 9d64fc64..79665902 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -3,10 +3,8 @@ package websocket import ( "bytes" "context" - "errors" "fmt" "io" - "sync" "time" gobwas "github.com/gobwas/ws" @@ -100,17 +98,15 @@ func (c *GorillaConn) pinger(ctx context.Context) { type Conn struct { rw io.ReadWriter log *zerolog.Logger - // writeLock makes sure - // 1. Only one write at a time. The pinger and Stream function can both call write. - // 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close. - writeLock sync.Mutex - done bool + // closed is a channel to indicate if Conn has been fully terminated + shutdownC chan struct{} } func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { c := &Conn{ - rw: rw, - log: log, + rw: rw, + log: log, + shutdownC: make(chan struct{}), } go c.pinger(ctx) return c @@ -125,22 +121,16 @@ func (c *Conn) Read(reader []byte) (int, error) { return copy(reader, data), nil } -// Write will write messages to the websocket connection. -// It will not write to the connection after Close is called to fix TUN-5184 +// Write will write messages to the websocket connection func (c *Conn) Write(p []byte) (int, error) { - c.writeLock.Lock() - defer c.writeLock.Unlock() - if c.done { - return 0, errors.New("Write to closed websocket connection") - } if err := wsutil.WriteServerBinary(c.rw, p); err != nil { return 0, err } - return len(p), nil } func (c *Conn) pinger(ctx context.Context) { + defer close(c.shutdownC) pongMessge := wsutil.Message{ OpCode: gobwas.OpPong, Payload: []byte{}, @@ -150,12 +140,11 @@ func (c *Conn) pinger(ctx context.Context) { defer ticker.Stop() for { select { - // Ping/Pong messages will not be written after the connection is done case <-ticker.C: - if err := wsutil.WriteServerMessage(c, gobwas.OpPing, []byte{}); err != nil { + if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { c.log.Debug().Err(err).Msgf("failed to write ping message") } - if err := wsutil.HandleClientControlMessage(c, pongMessge); err != nil { + if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { c.log.Debug().Err(err).Msgf("failed to write pong message") } case <-ctx.Done(): @@ -173,9 +162,7 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration { return defaultPingPeriod } -// Close waits for the current write to finish. Further writes will return error -func (c *Conn) Close() { - c.writeLock.Lock() - defer c.writeLock.Unlock() - c.done = true +// Close waits for pinger to terminate +func (c *Conn) WaitForShutdown() { + <-c.shutdownC }