TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown

This commit is contained in:
cthuang 2021-10-19 20:01:17 +01:00 committed by Chung Ting Huang
parent 2ca4633f89
commit f8fbbcd806
6 changed files with 195 additions and 52 deletions

View File

@ -4,16 +4,17 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"net/url" "net/url"
"testing" "testing"
"time" "time"
"github.com/gobwas/ws/wsutil"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/websocket"
) )
const ( const (
@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
isWebsocket bool, isWebsocket bool,
) error { ) error {
if isWebsocket { if isWebsocket {
return wsEndpoint(w, req) switch req.URL.Path {
case "/ws/echo":
return wsEchoEndpoint(w, req)
case "/ws/flaky":
return wsFlakyEndpoint(w, req)
default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
}
} }
switch req.URL.Path { switch req.URL.Path {
case "/ok": case "/ok":
@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
return nil return nil
} }
type nowriter struct { type echoPipe struct {
io.Reader reader *io.PipeReader
writer *io.PipeWriter
} }
func (nowriter) Write(p []byte) (int, error) { func (ep *echoPipe) Read(p []byte) (int, error) {
return 0, fmt.Errorf("Writer not implemented") return ep.reader.Read(p)
} }
func wsEndpoint(w ResponseWriter, r *http.Request) error { func (ep *echoPipe) Write(p []byte) (int, error) {
return ep.writer.Write(p)
}
// A mock origin that echos data by streaming like a tcpOverWSConnection
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{ resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols, StatusCode: http.StatusSwitchingProtocols,
} }
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header) if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
clientReader := nowriter{r.Body} return err
}
wsCtx, cancel := context.WithCancel(r.Context())
readPipe, writePipe := io.Pipe()
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
go func() { go func() {
for { select {
data, err := wsutil.ReadClientText(clientReader) case <-wsCtx.Done():
if err != nil { case <-r.Context().Done():
return
}
if err := wsutil.WriteServerText(w, data); err != nil {
return
}
} }
readPipe.Close()
writePipe.Close()
}() }()
<-r.Context().Done()
originConn := &echoPipe{reader: readPipe, writer: writePipe}
websocket.Stream(wsConn, originConn, &log)
cancel()
wsConn.Close()
return nil
}
type flakyConn struct {
closeAt time.Time
}
func (fc *flakyConn) Read(p []byte) (int, error) {
if time.Now().After(fc.closeAt) {
return 0, io.EOF
}
n := copy(p, []byte("Read from flaky connection"))
return n, nil
}
func (fc *flakyConn) Write(p []byte) (int, error) {
if time.Now().After(fc.closeAt) {
return 0, fmt.Errorf("Flaky connection closed")
}
return len(p), nil
}
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
return err
}
wsCtx, cancel := context.WithCancel(r.Context())
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
websocket.Stream(wsConn, originConn, &log)
cancel()
wsConn.Close()
return nil return nil
} }

View File

@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
headers := []h2mux.Header{ headers := []h2mux.Header{
{ {
Name: ":path", Name: ":path",
Value: "/ws", Value: "/ws/echo",
}, },
{ {
Name: "connection", Name: "connection",
@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
data := []byte("test websocket") data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data) err = wsutil.WriteClientBinary(writePipe, data)
require.NoError(t, err) require.NoError(t, err)
respBody, err := wsutil.ReadServerText(stream) respBody, err := wsutil.ReadServerBinary(stream)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))

View File

@ -27,7 +27,7 @@ var (
) )
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
edgeConn, originConn := net.Pipe() edgeConn, cfdConn := net.Pipe()
var connIndex = uint8(0) var connIndex = uint8(0)
log := zerolog.Nop() log := zerolog.Nop()
obs := NewObserver(&log, &log, false) obs := NewObserver(&log, &log, false)
@ -41,7 +41,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
1*time.Second, 1*time.Second,
) )
return NewHTTP2Connection( return NewHTTP2Connection(
originConn, cfdConn,
// OriginProxy is set in testConfig
testConfig, testConfig,
&pogs.ConnectionOptions{}, &pogs.ConnectionOptions{},
obs, obs,
@ -166,6 +167,7 @@ type wsRespWriter struct {
*httptest.ResponseRecorder *httptest.ResponseRecorder
readPipe *io.PipeReader readPipe *io.PipeReader
writePipe *io.PipeWriter writePipe *io.PipeWriter
closed bool
} }
func newWSRespWriter() *wsRespWriter { func newWSRespWriter() *wsRespWriter {
@ -174,46 +176,58 @@ func newWSRespWriter() *wsRespWriter {
httptest.NewRecorder(), httptest.NewRecorder(),
readPipe, readPipe,
writePipe, writePipe,
false,
} }
} }
type nowriter struct {
io.Reader
}
func (nowriter) Write(p []byte) (int, error) {
return 0, fmt.Errorf("Writer not implemented")
}
func (w *wsRespWriter) RespBody() io.ReadWriter { func (w *wsRespWriter) RespBody() io.ReadWriter {
return nowriter{w.readPipe} return nowriter{w.readPipe}
} }
func (w *wsRespWriter) Write(data []byte) (n int, err error) { func (w *wsRespWriter) Write(data []byte) (n int, err error) {
if w.closed {
// Simulate writing to http2 ResponseWriter after ServeHTTP has returned
panic("Write to closed ResponseWriter")
}
return w.writePipe.Write(data) return w.writePipe.Write(data)
} }
func (w *wsRespWriter) close() {
w.closed = true
}
func TestServeWS(t *testing.T) { func TestServeWS(t *testing.T) {
http2Conn, _ := newTestHTTP2Connection() http2Conn, _ := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
respWriter := newWSRespWriter() respWriter := newWSRespWriter()
readPipe, writePipe := io.Pipe() readPipe, writePipe := io.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
require.NoError(t, err) require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
wg.Add(1) serveDone := make(chan struct{})
go func() { go func() {
defer wg.Done() defer close(serveDone)
http2Conn.ServeHTTP(respWriter, req) http2Conn.ServeHTTP(respWriter, req)
respWriter.close()
}() }()
data := []byte("test websocket") data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data) err = wsutil.WriteClientBinary(writePipe, data)
require.NoError(t, err) require.NoError(t, err)
respBody, err := wsutil.ReadServerText(respWriter.RespBody()) respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
@ -223,7 +237,64 @@ func TestServeWS(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
<-serveDone
}
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
cfdHTTP2Conn.Serve(ctx)
}()
edgeTransport := http2.Transport{}
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
require.NoError(t, err)
message := []byte(t.Name())
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-reqCtx.Done():
return
default:
}
_ = wsutil.WriteClientBinary(writePipe, message)
}
}()
time.Sleep(time.Millisecond * 100)
reqCancel()
}()
}
wg.Wait() wg.Wait()
cancel()
<-serverDone
} }
func TestServeControlStream(t *testing.T) { func TestServeControlStream(t *testing.T) {

View File

@ -61,7 +61,7 @@ func TestQUICServer(t *testing.T) {
// This is simply a sample websocket frame message. // This is simply a sample websocket frame message.
wsBuf := &bytes.Buffer{} wsBuf := &bytes.Buffer{}
wsutil.WriteClientText(wsBuf, []byte("Hello")) wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
var tests = []struct { var tests = []struct {
desc string desc string
@ -118,7 +118,7 @@ func TestQUICServer(t *testing.T) {
}, },
{ {
desc: "test ws proxy", desc: "test ws proxy",
dest: "/ok", dest: "/ws/echo",
connectionType: quicpogs.ConnectionTypeWebsocket, connectionType: quicpogs.ConnectionTypeWebsocket,
metadata: []quicpogs.Metadata{ metadata: []quicpogs.Metadata{
quicpogs.Metadata{ quicpogs.Metadata{
@ -139,7 +139,7 @@ func TestQUICServer(t *testing.T) {
}, },
}, },
message: wsBuf.Bytes(), message: wsBuf.Bytes(),
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
}, },
{ {
desc: "test tcp proxy", desc: "test tcp proxy",
@ -278,7 +278,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
} }
if isWebsocket { if isWebsocket {
return wsEndpoint(w, r) return wsEchoEndpoint(w, r)
} }
switch r.URL.Path { switch r.URL.Path {
case "/ok": case "/ok":

View File

@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
wc.streamHandler(wsConn, wc.conn, log) wc.streamHandler(wsConn, wc.conn, log)
cancel() cancel()
// Makes sure wsConn stops sending ping before terminating the stream // Makes sure wsConn stops sending ping before terminating the stream
wsConn.WaitForShutdown() wsConn.Close()
} }
func (wc *tcpOverWSConnection) Close() { func (wc *tcpOverWSConnection) Close() {
@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
socks.StreamNetHandler(wsConn, sp.accessPolicy, log) socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
cancel() cancel()
// Makes sure wsConn stops sending ping before terminating the stream // Makes sure wsConn stops sending ping before terminating the stream
wsConn.WaitForShutdown() wsConn.Close()
} }
func (sp *socksProxyOverWSConnection) Close() { func (sp *socksProxyOverWSConnection) Close() {

View File

@ -3,8 +3,10 @@ package websocket
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"sync"
"time" "time"
gobwas "github.com/gobwas/ws" gobwas "github.com/gobwas/ws"
@ -98,15 +100,17 @@ func (c *GorillaConn) pinger(ctx context.Context) {
type Conn struct { type Conn struct {
rw io.ReadWriter rw io.ReadWriter
log *zerolog.Logger log *zerolog.Logger
// closed is a channel to indicate if Conn has been fully terminated // writeLock makes sure
shutdownC chan struct{} // 1. Only one write at a time. The pinger and Stream function can both call write.
// 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close.
writeLock sync.Mutex
done bool
} }
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
c := &Conn{ c := &Conn{
rw: rw, rw: rw,
log: log, log: log,
shutdownC: make(chan struct{}),
} }
go c.pinger(ctx) go c.pinger(ctx)
return c return c
@ -121,16 +125,22 @@ func (c *Conn) Read(reader []byte) (int, error) {
return copy(reader, data), nil return copy(reader, data), nil
} }
// Write will write messages to the websocket connection // Write will write messages to the websocket connection.
// It will not write to the connection after Close is called to fix TUN-5184
func (c *Conn) Write(p []byte) (int, error) { func (c *Conn) Write(p []byte) (int, error) {
c.writeLock.Lock()
defer c.writeLock.Unlock()
if c.done {
return 0, errors.New("Write to closed websocket connection")
}
if err := wsutil.WriteServerBinary(c.rw, p); err != nil { if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
return 0, err return 0, err
} }
return len(p), nil return len(p), nil
} }
func (c *Conn) pinger(ctx context.Context) { func (c *Conn) pinger(ctx context.Context) {
defer close(c.shutdownC)
pongMessge := wsutil.Message{ pongMessge := wsutil.Message{
OpCode: gobwas.OpPong, OpCode: gobwas.OpPong,
Payload: []byte{}, Payload: []byte{},
@ -140,11 +150,12 @@ func (c *Conn) pinger(ctx context.Context) {
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
// Ping/Pong messages will not be written after the connection is done
case <-ticker.C: case <-ticker.C:
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { if err := wsutil.WriteServerMessage(c, gobwas.OpPing, []byte{}); err != nil {
c.log.Debug().Err(err).Msgf("failed to write ping message") c.log.Debug().Err(err).Msgf("failed to write ping message")
} }
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { if err := wsutil.HandleClientControlMessage(c, pongMessge); err != nil {
c.log.Debug().Err(err).Msgf("failed to write pong message") c.log.Debug().Err(err).Msgf("failed to write pong message")
} }
case <-ctx.Done(): case <-ctx.Done():
@ -162,7 +173,9 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
return defaultPingPeriod return defaultPingPeriod
} }
// Close waits for pinger to terminate // Close waits for the current write to finish. Further writes will return error
func (c *Conn) WaitForShutdown() { func (c *Conn) Close() {
<-c.shutdownC c.writeLock.Lock()
defer c.writeLock.Unlock()
c.done = true
} }