diff --git a/connection/connection_test.go b/connection/connection_test.go index e8e477ea..325df5d0 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -4,16 +4,17 @@ import ( "context" "fmt" "io" + "math/rand" "net/http" "net/url" "testing" "time" - "github.com/gobwas/ws/wsutil" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/websocket" ) const ( @@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP( isWebsocket bool, ) error { if isWebsocket { - return wsEndpoint(w, req) + switch req.URL.Path { + case "/ws/echo": + return wsEchoEndpoint(w, req) + case "/ws/flaky": + return wsFlakyEndpoint(w, req) + default: + originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found")) + return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path) + } } switch req.URL.Path { case "/ok": @@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP( return nil } -type nowriter struct { - io.Reader +type echoPipe struct { + reader *io.PipeReader + writer *io.PipeWriter } -func (nowriter) Write(p []byte) (int, error) { - return 0, fmt.Errorf("Writer not implemented") +func (ep *echoPipe) Read(p []byte) (int, error) { + return ep.reader.Read(p) } -func wsEndpoint(w ResponseWriter, r *http.Request) error { +func (ep *echoPipe) Write(p []byte) (int, error) { + return ep.writer.Write(p) +} + +// A mock origin that echos data by streaming like a tcpOverWSConnection +// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go +func wsEchoEndpoint(w ResponseWriter, r *http.Request) error { resp := &http.Response{ StatusCode: http.StatusSwitchingProtocols, } - _ = w.WriteRespHeaders(resp.StatusCode, resp.Header) - clientReader := nowriter{r.Body} + if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { + return err + } + wsCtx, cancel := context.WithCancel(r.Context()) + readPipe, writePipe := io.Pipe() + wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) go func() { - for { - data, err := wsutil.ReadClientText(clientReader) - if err != nil { - return - } - if err := wsutil.WriteServerText(w, data); err != nil { - return - } + select { + case <-wsCtx.Done(): + case <-r.Context().Done(): } + readPipe.Close() + writePipe.Close() }() - <-r.Context().Done() + + originConn := &echoPipe{reader: readPipe, writer: writePipe} + websocket.Stream(wsConn, originConn, &log) + cancel() + wsConn.Close() + return nil +} + +type flakyConn struct { + closeAt time.Time +} + +func (fc *flakyConn) Read(p []byte) (int, error) { + if time.Now().After(fc.closeAt) { + return 0, io.EOF + } + n := copy(p, []byte("Read from flaky connection")) + return n, nil +} + +func (fc *flakyConn) Write(p []byte) (int, error) { + if time.Now().After(fc.closeAt) { + return 0, fmt.Errorf("Flaky connection closed") + } + return len(p), nil +} + +func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error { + resp := &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + } + if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { + return err + } + wsCtx, cancel := context.WithCancel(r.Context()) + + wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log) + + closedAfter := time.Millisecond * time.Duration(rand.Intn(50)) + originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} + websocket.Stream(wsConn, originConn, &log) + cancel() + wsConn.Close() return nil } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index e6eab072..83a39589 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) { headers := []h2mux.Header{ { Name: ":path", - Value: "/ws", + Value: "/ws/echo", }, { Name: "connection", @@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) { assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) data := []byte("test websocket") - err = wsutil.WriteClientText(writePipe, data) + err = wsutil.WriteClientBinary(writePipe, data) require.NoError(t, err) - respBody, err := wsutil.ReadServerText(stream) + respBody, err := wsutil.ReadServerBinary(stream) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) diff --git a/connection/http2_test.go b/connection/http2_test.go index 4b7435bd..23555837 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -27,7 +27,7 @@ var ( ) func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { - edgeConn, originConn := net.Pipe() + edgeConn, cfdConn := net.Pipe() var connIndex = uint8(0) log := zerolog.Nop() obs := NewObserver(&log, &log, false) @@ -41,7 +41,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { 1*time.Second, ) return NewHTTP2Connection( - originConn, + cfdConn, + // OriginProxy is set in testConfig testConfig, &pogs.ConnectionOptions{}, obs, @@ -166,6 +167,7 @@ type wsRespWriter struct { *httptest.ResponseRecorder readPipe *io.PipeReader writePipe *io.PipeWriter + closed bool } func newWSRespWriter() *wsRespWriter { @@ -174,46 +176,58 @@ func newWSRespWriter() *wsRespWriter { httptest.NewRecorder(), readPipe, writePipe, + false, } } +type nowriter struct { + io.Reader +} + +func (nowriter) Write(p []byte) (int, error) { + return 0, fmt.Errorf("Writer not implemented") +} + func (w *wsRespWriter) RespBody() io.ReadWriter { return nowriter{w.readPipe} } func (w *wsRespWriter) Write(data []byte) (n int, err error) { + if w.closed { + // Simulate writing to http2 ResponseWriter after ServeHTTP has returned + panic("Write to closed ResponseWriter") + } return w.writePipe.Write(data) } +func (w *wsRespWriter) close() { + w.closed = true +} + func TestServeWS(t *testing.T) { http2Conn, _ := newTestHTTP2Connection() ctx, cancel := context.WithCancel(context.Background()) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - http2Conn.Serve(ctx) - }() respWriter := newWSRespWriter() readPipe, writePipe := io.Pipe() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe) require.NoError(t, err) req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) - wg.Add(1) + serveDone := make(chan struct{}) go func() { - defer wg.Done() + defer close(serveDone) http2Conn.ServeHTTP(respWriter, req) + respWriter.close() }() data := []byte("test websocket") - err = wsutil.WriteClientText(writePipe, data) + err = wsutil.WriteClientBinary(writePipe, data) require.NoError(t, err) - respBody, err := wsutil.ReadServerText(respWriter.RespBody()) + respBody, err := wsutil.ReadServerBinary(respWriter.RespBody()) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) @@ -223,7 +237,64 @@ func TestServeWS(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) + <-serveDone +} + +// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184 +// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns +func TestNoWriteAfterServeHTTPReturns(t *testing.T) { + cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection() + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + cfdHTTP2Conn.Serve(ctx) + }() + + edgeTransport := http2.Transport{} + edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn) + require.NoError(t, err) + message := []byte(t.Name()) + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + readPipe, writePipe := io.Pipe() + reqCtx, reqCancel := context.WithCancel(ctx) + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe) + require.NoError(t, err) + req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) + + resp, err := edgeHTTP2Conn.RoundTrip(req) + require.NoError(t, err) + // http2RespWriter should rewrite status 101 to 200 + require.Equal(t, http.StatusOK, resp.StatusCode) + + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-reqCtx.Done(): + return + default: + } + _ = wsutil.WriteClientBinary(writePipe, message) + } + }() + + time.Sleep(time.Millisecond * 100) + reqCancel() + }() + } + wg.Wait() + cancel() + <-serverDone } func TestServeControlStream(t *testing.T) { diff --git a/connection/quic_test.go b/connection/quic_test.go index fc551a4b..d5a11ec0 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -61,7 +61,7 @@ func TestQUICServer(t *testing.T) { // This is simply a sample websocket frame message. wsBuf := &bytes.Buffer{} - wsutil.WriteClientText(wsBuf, []byte("Hello")) + wsutil.WriteClientBinary(wsBuf, []byte("Hello")) var tests = []struct { desc string @@ -118,7 +118,7 @@ func TestQUICServer(t *testing.T) { }, { desc: "test ws proxy", - dest: "/ok", + dest: "/ws/echo", connectionType: quicpogs.ConnectionTypeWebsocket, metadata: []quicpogs.Metadata{ quicpogs.Metadata{ @@ -139,7 +139,7 @@ func TestQUICServer(t *testing.T) { }, }, message: wsBuf.Bytes(), - expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, + expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, }, { desc: "test tcp proxy", @@ -278,7 +278,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque } if isWebsocket { - return wsEndpoint(w, r) + return wsEchoEndpoint(w, r) } switch r.URL.Path { case "/ok": diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 9588ce36..bdfb19b6 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri wc.streamHandler(wsConn, wc.conn, log) cancel() // Makes sure wsConn stops sending ping before terminating the stream - wsConn.WaitForShutdown() + wsConn.Close() } func (wc *tcpOverWSConnection) Close() { @@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. socks.StreamNetHandler(wsConn, sp.accessPolicy, log) cancel() // Makes sure wsConn stops sending ping before terminating the stream - wsConn.WaitForShutdown() + wsConn.Close() } func (sp *socksProxyOverWSConnection) Close() { diff --git a/websocket/connection.go b/websocket/connection.go index 79665902..9d64fc64 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -3,8 +3,10 @@ package websocket import ( "bytes" "context" + "errors" "fmt" "io" + "sync" "time" gobwas "github.com/gobwas/ws" @@ -98,15 +100,17 @@ func (c *GorillaConn) pinger(ctx context.Context) { type Conn struct { rw io.ReadWriter log *zerolog.Logger - // closed is a channel to indicate if Conn has been fully terminated - shutdownC chan struct{} + // writeLock makes sure + // 1. Only one write at a time. The pinger and Stream function can both call write. + // 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close. + writeLock sync.Mutex + done bool } func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { c := &Conn{ - rw: rw, - log: log, - shutdownC: make(chan struct{}), + rw: rw, + log: log, } go c.pinger(ctx) return c @@ -121,16 +125,22 @@ func (c *Conn) Read(reader []byte) (int, error) { return copy(reader, data), nil } -// Write will write messages to the websocket connection +// Write will write messages to the websocket connection. +// It will not write to the connection after Close is called to fix TUN-5184 func (c *Conn) Write(p []byte) (int, error) { + c.writeLock.Lock() + defer c.writeLock.Unlock() + if c.done { + return 0, errors.New("Write to closed websocket connection") + } if err := wsutil.WriteServerBinary(c.rw, p); err != nil { return 0, err } + return len(p), nil } func (c *Conn) pinger(ctx context.Context) { - defer close(c.shutdownC) pongMessge := wsutil.Message{ OpCode: gobwas.OpPong, Payload: []byte{}, @@ -140,11 +150,12 @@ func (c *Conn) pinger(ctx context.Context) { defer ticker.Stop() for { select { + // Ping/Pong messages will not be written after the connection is done case <-ticker.C: - if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { + if err := wsutil.WriteServerMessage(c, gobwas.OpPing, []byte{}); err != nil { c.log.Debug().Err(err).Msgf("failed to write ping message") } - if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { + if err := wsutil.HandleClientControlMessage(c, pongMessge); err != nil { c.log.Debug().Err(err).Msgf("failed to write pong message") } case <-ctx.Done(): @@ -162,7 +173,9 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration { return defaultPingPeriod } -// Close waits for pinger to terminate -func (c *Conn) WaitForShutdown() { - <-c.shutdownC +// Close waits for the current write to finish. Further writes will return error +func (c *Conn) Close() { + c.writeLock.Lock() + defer c.writeLock.Unlock() + c.done = true }