TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown

This commit is contained in:
cthuang 2021-10-19 20:01:17 +01:00 committed by Nuno Diegues
parent 1ff5fd3fdc
commit db01127191
6 changed files with 212 additions and 70 deletions

View File

@ -4,16 +4,17 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/websocket"
)
const (
@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
isWebsocket bool,
) error {
if isWebsocket {
return wsEndpoint(w, req)
switch req.URL.Path {
case "/ws/echo":
return wsEchoEndpoint(w, req)
case "/ws/flaky":
return wsFlakyEndpoint(w, req)
default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
}
}
switch req.URL.Path {
case "/ok":
@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
return nil
}
type nowriter struct {
io.Reader
type echoPipe struct {
reader *io.PipeReader
writer *io.PipeWriter
}
func (nowriter) Write(p []byte) (int, error) {
return 0, fmt.Errorf("Writer not implemented")
func (ep *echoPipe) Read(p []byte) (int, error) {
return ep.reader.Read(p)
}
func wsEndpoint(w ResponseWriter, r *http.Request) error {
func (ep *echoPipe) Write(p []byte) (int, error) {
return ep.writer.Write(p)
}
// A mock origin that echos data by streaming like a tcpOverWSConnection
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
clientReader := nowriter{r.Body}
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
return err
}
wsCtx, cancel := context.WithCancel(r.Context())
readPipe, writePipe := io.Pipe()
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
go func() {
for {
data, err := wsutil.ReadClientText(clientReader)
if err != nil {
return
}
if err := wsutil.WriteServerText(w, data); err != nil {
return
}
select {
case <-wsCtx.Done():
case <-r.Context().Done():
}
readPipe.Close()
writePipe.Close()
}()
<-r.Context().Done()
originConn := &echoPipe{reader: readPipe, writer: writePipe}
websocket.Stream(wsConn, originConn, &log)
cancel()
wsConn.Close()
return nil
}
type flakyConn struct {
closeAt time.Time
}
func (fc *flakyConn) Read(p []byte) (int, error) {
if time.Now().After(fc.closeAt) {
return 0, io.EOF
}
n := copy(p, "Read from flaky connection")
return n, nil
}
func (fc *flakyConn) Write(p []byte) (int, error) {
if time.Now().After(fc.closeAt) {
return 0, fmt.Errorf("flaky connection closed")
}
return len(p), nil
}
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
return err
}
wsCtx, cancel := context.WithCancel(r.Context())
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
websocket.Stream(wsConn, originConn, &log)
cancel()
wsConn.Close()
return nil
}

View File

@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
headers := []h2mux.Header{
{
Name: ":path",
Value: "/ws",
Value: "/ws/echo",
},
{
Name: "connection",
@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data)
err = wsutil.WriteClientBinary(writePipe, data)
require.NoError(t, err)
respBody, err := wsutil.ReadServerText(stream)
respBody, err := wsutil.ReadServerBinary(stream)
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))

View File

@ -2,6 +2,7 @@ package connection
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
@ -27,7 +28,7 @@ var (
)
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
edgeConn, originConn := net.Pipe()
edgeConn, cfdConn := net.Pipe()
var connIndex = uint8(0)
log := zerolog.Nop()
obs := NewObserver(&log, &log, false)
@ -41,7 +42,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
1*time.Second,
)
return NewHTTP2Connection(
originConn,
cfdConn,
// OriginProxy is set in testConfig
testConfig,
&pogs.ConnectionOptions{},
obs,
@ -166,6 +168,8 @@ type wsRespWriter struct {
*httptest.ResponseRecorder
readPipe *io.PipeReader
writePipe *io.PipeWriter
closed bool
panicked bool
}
func newWSRespWriter() *wsRespWriter {
@ -174,46 +178,59 @@ func newWSRespWriter() *wsRespWriter {
httptest.NewRecorder(),
readPipe,
writePipe,
false,
false,
}
}
type nowriter struct {
io.Reader
}
func (nowriter) Write(_ []byte) (int, error) {
return 0, fmt.Errorf("writer not implemented")
}
func (w *wsRespWriter) RespBody() io.ReadWriter {
return nowriter{w.readPipe}
}
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
if w.closed {
w.panicked = true
return 0, errors.New("wsRespWriter panicked")
}
return w.writePipe.Write(data)
}
func (w *wsRespWriter) close() {
w.closed = true
}
func TestServeWS(t *testing.T) {
http2Conn, _ := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
respWriter := newWSRespWriter()
readPipe, writePipe := io.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
wg.Add(1)
serveDone := make(chan struct{})
go func() {
defer wg.Done()
defer close(serveDone)
http2Conn.ServeHTTP(respWriter, req)
respWriter.close()
}()
data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data)
err = wsutil.WriteClientBinary(writePipe, data)
require.NoError(t, err)
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
@ -223,7 +240,65 @@ func TestServeWS(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
<-serveDone
require.False(t, respWriter.panicked)
}
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
cfdHTTP2Conn.Serve(ctx)
}()
edgeTransport := http2.Transport{}
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
require.NoError(t, err)
message := []byte(t.Name())
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-reqCtx.Done():
return
default:
}
_ = wsutil.WriteClientBinary(writePipe, message)
}
}()
time.Sleep(time.Millisecond * 100)
reqCancel()
}()
}
wg.Wait()
cancel()
<-serverDone
}
func TestServeControlStream(t *testing.T) {

View File

@ -47,7 +47,7 @@ func TestQUICServer(t *testing.T) {
// This is simply a sample websocket frame message.
wsBuf := &bytes.Buffer{}
wsutil.WriteClientText(wsBuf, []byte("Hello"))
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
var tests = []struct {
desc string
@ -104,7 +104,7 @@ func TestQUICServer(t *testing.T) {
},
{
desc: "test ws proxy",
dest: "/ok",
dest: "/ws/echo",
connectionType: quicpogs.ConnectionTypeWebsocket,
metadata: []quicpogs.Metadata{
{
@ -125,7 +125,7 @@ func TestQUICServer(t *testing.T) {
},
},
message: wsBuf.Bytes(),
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
},
{
desc: "test tcp proxy",
@ -233,7 +233,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
}
if isWebsocket {
return wsEndpoint(w, r)
return wsEchoEndpoint(w, r)
}
switch r.URL.Path {
case "/ok":

View File

@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
wc.streamHandler(wsConn, wc.conn, log)
cancel()
// Makes sure wsConn stops sending ping before terminating the stream
wsConn.WaitForShutdown()
wsConn.Close()
}
func (wc *tcpOverWSConnection) Close() {
@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
cancel()
// Makes sure wsConn stops sending ping before terminating the stream
wsConn.WaitForShutdown()
wsConn.Close()
}
func (sp *socksProxyOverWSConnection) Close() {

View File

@ -3,8 +3,10 @@ package websocket
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sync"
"time"
gobwas "github.com/gobwas/ws"
@ -14,9 +16,6 @@ import (
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
defaultPongWait = 60 * time.Second
@ -79,34 +78,20 @@ func (c *GorillaConn) SetDeadline(t time.Time) error {
return nil
}
// pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) {
ticker := time.NewTicker(defaultPingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
c.log.Debug().Msgf("failed to send ping message: %s", err)
}
case <-ctx.Done():
return
}
}
}
type Conn struct {
rw io.ReadWriter
log *zerolog.Logger
// closed is a channel to indicate if Conn has been fully terminated
shutdownC chan struct{}
// writeLock makes sure
// 1. Only one write at a time. The pinger and Stream function can both call write.
// 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close.
writeLock sync.Mutex
done bool
}
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
c := &Conn{
rw: rw,
log: log,
shutdownC: make(chan struct{}),
rw: rw,
log: log,
}
go c.pinger(ctx)
return c
@ -121,16 +106,22 @@ func (c *Conn) Read(reader []byte) (int, error) {
return copy(reader, data), nil
}
// Write will write messages to the websocket connection
// Write will write messages to the websocket connection.
// It will not write to the connection after Close is called to fix TUN-5184
func (c *Conn) Write(p []byte) (int, error) {
c.writeLock.Lock()
defer c.writeLock.Unlock()
if c.done {
return 0, errors.New("write to closed websocket connection")
}
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
return 0, err
}
return len(p), nil
}
func (c *Conn) pinger(ctx context.Context) {
defer close(c.shutdownC)
pongMessge := wsutil.Message{
OpCode: gobwas.OpPong,
Payload: []byte{},
@ -141,7 +132,11 @@ func (c *Conn) pinger(ctx context.Context) {
for {
select {
case <-ticker.C:
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
done, err := c.ping()
if done {
return
}
if err != nil {
c.log.Debug().Err(err).Msgf("failed to write ping message")
}
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
@ -153,6 +148,17 @@ func (c *Conn) pinger(ctx context.Context) {
}
}
func (c *Conn) ping() (bool, error) {
c.writeLock.Lock()
defer c.writeLock.Unlock()
if c.done {
return true, nil
}
return false, wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{})
}
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
if val := ctx.Value(PingPeriodContextKey); val != nil {
if period, ok := val.(time.Duration); ok {
@ -162,7 +168,9 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
return defaultPingPeriod
}
// Close waits for pinger to terminate
func (c *Conn) WaitForShutdown() {
<-c.shutdownC
// Close waits for the current write to finish. Further writes will return error
func (c *Conn) Close() {
c.writeLock.Lock()
defer c.writeLock.Unlock()
c.done = true
}