TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown
This commit is contained in:
parent
2ca4633f89
commit
f8fbbcd806
|
@ -4,16 +4,17 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
|
|||
isWebsocket bool,
|
||||
) error {
|
||||
if isWebsocket {
|
||||
return wsEndpoint(w, req)
|
||||
switch req.URL.Path {
|
||||
case "/ws/echo":
|
||||
return wsEchoEndpoint(w, req)
|
||||
case "/ws/flaky":
|
||||
return wsFlakyEndpoint(w, req)
|
||||
default:
|
||||
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
|
||||
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
|
||||
}
|
||||
}
|
||||
switch req.URL.Path {
|
||||
case "/ok":
|
||||
|
@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
|
|||
return nil
|
||||
}
|
||||
|
||||
type nowriter struct {
|
||||
io.Reader
|
||||
type echoPipe struct {
|
||||
reader *io.PipeReader
|
||||
writer *io.PipeWriter
|
||||
}
|
||||
|
||||
func (nowriter) Write(p []byte) (int, error) {
|
||||
return 0, fmt.Errorf("Writer not implemented")
|
||||
func (ep *echoPipe) Read(p []byte) (int, error) {
|
||||
return ep.reader.Read(p)
|
||||
}
|
||||
|
||||
func wsEndpoint(w ResponseWriter, r *http.Request) error {
|
||||
func (ep *echoPipe) Write(p []byte) (int, error) {
|
||||
return ep.writer.Write(p)
|
||||
}
|
||||
|
||||
// A mock origin that echos data by streaming like a tcpOverWSConnection
|
||||
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
|
||||
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
}
|
||||
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||
clientReader := nowriter{r.Body}
|
||||
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
||||
return err
|
||||
}
|
||||
wsCtx, cancel := context.WithCancel(r.Context())
|
||||
readPipe, writePipe := io.Pipe()
|
||||
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
|
||||
go func() {
|
||||
for {
|
||||
data, err := wsutil.ReadClientText(clientReader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := wsutil.WriteServerText(w, data); err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-wsCtx.Done():
|
||||
case <-r.Context().Done():
|
||||
}
|
||||
readPipe.Close()
|
||||
writePipe.Close()
|
||||
}()
|
||||
<-r.Context().Done()
|
||||
|
||||
originConn := &echoPipe{reader: readPipe, writer: writePipe}
|
||||
websocket.Stream(wsConn, originConn, &log)
|
||||
cancel()
|
||||
wsConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type flakyConn struct {
|
||||
closeAt time.Time
|
||||
}
|
||||
|
||||
func (fc *flakyConn) Read(p []byte) (int, error) {
|
||||
if time.Now().After(fc.closeAt) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, []byte("Read from flaky connection"))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (fc *flakyConn) Write(p []byte) (int, error) {
|
||||
if time.Now().After(fc.closeAt) {
|
||||
return 0, fmt.Errorf("Flaky connection closed")
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
}
|
||||
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
||||
return err
|
||||
}
|
||||
wsCtx, cancel := context.WithCancel(r.Context())
|
||||
|
||||
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
|
||||
|
||||
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
|
||||
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
|
||||
websocket.Stream(wsConn, originConn, &log)
|
||||
cancel()
|
||||
wsConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
|
|||
headers := []h2mux.Header{
|
||||
{
|
||||
Name: ":path",
|
||||
Value: "/ws",
|
||||
Value: "/ws/echo",
|
||||
},
|
||||
{
|
||||
Name: "connection",
|
||||
|
@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
|
|||
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||
|
||||
data := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, data)
|
||||
err = wsutil.WriteClientBinary(writePipe, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBody, err := wsutil.ReadServerText(stream)
|
||||
respBody, err := wsutil.ReadServerBinary(stream)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ var (
|
|||
)
|
||||
|
||||
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||
edgeConn, originConn := net.Pipe()
|
||||
edgeConn, cfdConn := net.Pipe()
|
||||
var connIndex = uint8(0)
|
||||
log := zerolog.Nop()
|
||||
obs := NewObserver(&log, &log, false)
|
||||
|
@ -41,7 +41,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
|||
1*time.Second,
|
||||
)
|
||||
return NewHTTP2Connection(
|
||||
originConn,
|
||||
cfdConn,
|
||||
// OriginProxy is set in testConfig
|
||||
testConfig,
|
||||
&pogs.ConnectionOptions{},
|
||||
obs,
|
||||
|
@ -166,6 +167,7 @@ type wsRespWriter struct {
|
|||
*httptest.ResponseRecorder
|
||||
readPipe *io.PipeReader
|
||||
writePipe *io.PipeWriter
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newWSRespWriter() *wsRespWriter {
|
||||
|
@ -174,46 +176,58 @@ func newWSRespWriter() *wsRespWriter {
|
|||
httptest.NewRecorder(),
|
||||
readPipe,
|
||||
writePipe,
|
||||
false,
|
||||
}
|
||||
}
|
||||
|
||||
type nowriter struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (nowriter) Write(p []byte) (int, error) {
|
||||
return 0, fmt.Errorf("Writer not implemented")
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
||||
return nowriter{w.readPipe}
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
||||
if w.closed {
|
||||
// Simulate writing to http2 ResponseWriter after ServeHTTP has returned
|
||||
panic("Write to closed ResponseWriter")
|
||||
}
|
||||
return w.writePipe.Write(data)
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) close() {
|
||||
w.closed = true
|
||||
}
|
||||
|
||||
func TestServeWS(t *testing.T) {
|
||||
http2Conn, _ := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
respWriter := newWSRespWriter()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||
|
||||
wg.Add(1)
|
||||
serveDone := make(chan struct{})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer close(serveDone)
|
||||
http2Conn.ServeHTTP(respWriter, req)
|
||||
respWriter.close()
|
||||
}()
|
||||
|
||||
data := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, data)
|
||||
err = wsutil.WriteClientBinary(writePipe, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
|
||||
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||
|
||||
|
@ -223,7 +237,64 @@ func TestServeWS(t *testing.T) {
|
|||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||
|
||||
<-serveDone
|
||||
}
|
||||
|
||||
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
|
||||
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
|
||||
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
||||
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
cfdHTTP2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeTransport := http2.Transport{}
|
||||
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
|
||||
require.NoError(t, err)
|
||||
message := []byte(t.Name())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
// http2RespWriter should rewrite status 101 to 200
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
_ = wsutil.WriteClientBinary(writePipe, message)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
reqCancel()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestServeControlStream(t *testing.T) {
|
||||
|
|
|
@ -61,7 +61,7 @@ func TestQUICServer(t *testing.T) {
|
|||
|
||||
// This is simply a sample websocket frame message.
|
||||
wsBuf := &bytes.Buffer{}
|
||||
wsutil.WriteClientText(wsBuf, []byte("Hello"))
|
||||
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
|
||||
|
||||
var tests = []struct {
|
||||
desc string
|
||||
|
@ -118,7 +118,7 @@ func TestQUICServer(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "test ws proxy",
|
||||
dest: "/ok",
|
||||
dest: "/ws/echo",
|
||||
connectionType: quicpogs.ConnectionTypeWebsocket,
|
||||
metadata: []quicpogs.Metadata{
|
||||
quicpogs.Metadata{
|
||||
|
@ -139,7 +139,7 @@ func TestQUICServer(t *testing.T) {
|
|||
},
|
||||
},
|
||||
message: wsBuf.Bytes(),
|
||||
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
||||
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
||||
},
|
||||
{
|
||||
desc: "test tcp proxy",
|
||||
|
@ -278,7 +278,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
|
|||
}
|
||||
|
||||
if isWebsocket {
|
||||
return wsEndpoint(w, r)
|
||||
return wsEchoEndpoint(w, r)
|
||||
}
|
||||
switch r.URL.Path {
|
||||
case "/ok":
|
||||
|
|
|
@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
|
|||
wc.streamHandler(wsConn, wc.conn, log)
|
||||
cancel()
|
||||
// Makes sure wsConn stops sending ping before terminating the stream
|
||||
wsConn.WaitForShutdown()
|
||||
wsConn.Close()
|
||||
}
|
||||
|
||||
func (wc *tcpOverWSConnection) Close() {
|
||||
|
@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
|
|||
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
|
||||
cancel()
|
||||
// Makes sure wsConn stops sending ping before terminating the stream
|
||||
wsConn.WaitForShutdown()
|
||||
wsConn.Close()
|
||||
}
|
||||
|
||||
func (sp *socksProxyOverWSConnection) Close() {
|
||||
|
|
|
@ -3,8 +3,10 @@ package websocket
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gobwas "github.com/gobwas/ws"
|
||||
|
@ -98,15 +100,17 @@ func (c *GorillaConn) pinger(ctx context.Context) {
|
|||
type Conn struct {
|
||||
rw io.ReadWriter
|
||||
log *zerolog.Logger
|
||||
// closed is a channel to indicate if Conn has been fully terminated
|
||||
shutdownC chan struct{}
|
||||
// writeLock makes sure
|
||||
// 1. Only one write at a time. The pinger and Stream function can both call write.
|
||||
// 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close.
|
||||
writeLock sync.Mutex
|
||||
done bool
|
||||
}
|
||||
|
||||
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
||||
c := &Conn{
|
||||
rw: rw,
|
||||
log: log,
|
||||
shutdownC: make(chan struct{}),
|
||||
rw: rw,
|
||||
log: log,
|
||||
}
|
||||
go c.pinger(ctx)
|
||||
return c
|
||||
|
@ -121,16 +125,22 @@ func (c *Conn) Read(reader []byte) (int, error) {
|
|||
return copy(reader, data), nil
|
||||
}
|
||||
|
||||
// Write will write messages to the websocket connection
|
||||
// Write will write messages to the websocket connection.
|
||||
// It will not write to the connection after Close is called to fix TUN-5184
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
c.writeLock.Lock()
|
||||
defer c.writeLock.Unlock()
|
||||
if c.done {
|
||||
return 0, errors.New("Write to closed websocket connection")
|
||||
}
|
||||
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *Conn) pinger(ctx context.Context) {
|
||||
defer close(c.shutdownC)
|
||||
pongMessge := wsutil.Message{
|
||||
OpCode: gobwas.OpPong,
|
||||
Payload: []byte{},
|
||||
|
@ -140,11 +150,12 @@ func (c *Conn) pinger(ctx context.Context) {
|
|||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
// Ping/Pong messages will not be written after the connection is done
|
||||
case <-ticker.C:
|
||||
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
||||
if err := wsutil.WriteServerMessage(c, gobwas.OpPing, []byte{}); err != nil {
|
||||
c.log.Debug().Err(err).Msgf("failed to write ping message")
|
||||
}
|
||||
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
||||
if err := wsutil.HandleClientControlMessage(c, pongMessge); err != nil {
|
||||
c.log.Debug().Err(err).Msgf("failed to write pong message")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
|
@ -162,7 +173,9 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
|
|||
return defaultPingPeriod
|
||||
}
|
||||
|
||||
// Close waits for pinger to terminate
|
||||
func (c *Conn) WaitForShutdown() {
|
||||
<-c.shutdownC
|
||||
// Close waits for the current write to finish. Further writes will return error
|
||||
func (c *Conn) Close() {
|
||||
c.writeLock.Lock()
|
||||
defer c.writeLock.Unlock()
|
||||
c.done = true
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue