TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown
This commit is contained in:
parent
1ff5fd3fdc
commit
db01127191
|
@ -4,16 +4,17 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gobwas/ws/wsutil"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
|
||||||
isWebsocket bool,
|
isWebsocket bool,
|
||||||
) error {
|
) error {
|
||||||
if isWebsocket {
|
if isWebsocket {
|
||||||
return wsEndpoint(w, req)
|
switch req.URL.Path {
|
||||||
|
case "/ws/echo":
|
||||||
|
return wsEchoEndpoint(w, req)
|
||||||
|
case "/ws/flaky":
|
||||||
|
return wsFlakyEndpoint(w, req)
|
||||||
|
default:
|
||||||
|
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
|
||||||
|
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
switch req.URL.Path {
|
switch req.URL.Path {
|
||||||
case "/ok":
|
case "/ok":
|
||||||
|
@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type nowriter struct {
|
type echoPipe struct {
|
||||||
io.Reader
|
reader *io.PipeReader
|
||||||
|
writer *io.PipeWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nowriter) Write(p []byte) (int, error) {
|
func (ep *echoPipe) Read(p []byte) (int, error) {
|
||||||
return 0, fmt.Errorf("Writer not implemented")
|
return ep.reader.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func wsEndpoint(w ResponseWriter, r *http.Request) error {
|
func (ep *echoPipe) Write(p []byte) (int, error) {
|
||||||
|
return ep.writer.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A mock origin that echos data by streaming like a tcpOverWSConnection
|
||||||
|
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
|
||||||
|
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
}
|
}
|
||||||
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
||||||
clientReader := nowriter{r.Body}
|
return err
|
||||||
|
}
|
||||||
|
wsCtx, cancel := context.WithCancel(r.Context())
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
select {
|
||||||
data, err := wsutil.ReadClientText(clientReader)
|
case <-wsCtx.Done():
|
||||||
if err != nil {
|
case <-r.Context().Done():
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := wsutil.WriteServerText(w, data); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
readPipe.Close()
|
||||||
|
writePipe.Close()
|
||||||
}()
|
}()
|
||||||
<-r.Context().Done()
|
|
||||||
|
originConn := &echoPipe{reader: readPipe, writer: writePipe}
|
||||||
|
websocket.Stream(wsConn, originConn, &log)
|
||||||
|
cancel()
|
||||||
|
wsConn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type flakyConn struct {
|
||||||
|
closeAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *flakyConn) Read(p []byte) (int, error) {
|
||||||
|
if time.Now().After(fc.closeAt) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := copy(p, "Read from flaky connection")
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *flakyConn) Write(p []byte) (int, error) {
|
||||||
|
if time.Now().After(fc.closeAt) {
|
||||||
|
return 0, fmt.Errorf("flaky connection closed")
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
}
|
||||||
|
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
wsCtx, cancel := context.WithCancel(r.Context())
|
||||||
|
|
||||||
|
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
|
||||||
|
|
||||||
|
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
|
||||||
|
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
|
||||||
|
websocket.Stream(wsConn, originConn, &log)
|
||||||
|
cancel()
|
||||||
|
wsConn.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
|
||||||
headers := []h2mux.Header{
|
headers := []h2mux.Header{
|
||||||
{
|
{
|
||||||
Name: ":path",
|
Name: ":path",
|
||||||
Value: "/ws",
|
Value: "/ws/echo",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "connection",
|
Name: "connection",
|
||||||
|
@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
|
||||||
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||||
|
|
||||||
data := []byte("test websocket")
|
data := []byte("test websocket")
|
||||||
err = wsutil.WriteClientText(writePipe, data)
|
err = wsutil.WriteClientBinary(writePipe, data)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
respBody, err := wsutil.ReadServerText(stream)
|
respBody, err := wsutil.ReadServerBinary(stream)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -27,7 +28,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||||
edgeConn, originConn := net.Pipe()
|
edgeConn, cfdConn := net.Pipe()
|
||||||
var connIndex = uint8(0)
|
var connIndex = uint8(0)
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
obs := NewObserver(&log, &log, false)
|
obs := NewObserver(&log, &log, false)
|
||||||
|
@ -41,7 +42,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||||
1*time.Second,
|
1*time.Second,
|
||||||
)
|
)
|
||||||
return NewHTTP2Connection(
|
return NewHTTP2Connection(
|
||||||
originConn,
|
cfdConn,
|
||||||
|
// OriginProxy is set in testConfig
|
||||||
testConfig,
|
testConfig,
|
||||||
&pogs.ConnectionOptions{},
|
&pogs.ConnectionOptions{},
|
||||||
obs,
|
obs,
|
||||||
|
@ -166,6 +168,8 @@ type wsRespWriter struct {
|
||||||
*httptest.ResponseRecorder
|
*httptest.ResponseRecorder
|
||||||
readPipe *io.PipeReader
|
readPipe *io.PipeReader
|
||||||
writePipe *io.PipeWriter
|
writePipe *io.PipeWriter
|
||||||
|
closed bool
|
||||||
|
panicked bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWSRespWriter() *wsRespWriter {
|
func newWSRespWriter() *wsRespWriter {
|
||||||
|
@ -174,46 +178,59 @@ func newWSRespWriter() *wsRespWriter {
|
||||||
httptest.NewRecorder(),
|
httptest.NewRecorder(),
|
||||||
readPipe,
|
readPipe,
|
||||||
writePipe,
|
writePipe,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type nowriter struct {
|
||||||
|
io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nowriter) Write(_ []byte) (int, error) {
|
||||||
|
return 0, fmt.Errorf("writer not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
||||||
return nowriter{w.readPipe}
|
return nowriter{w.readPipe}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
||||||
|
if w.closed {
|
||||||
|
w.panicked = true
|
||||||
|
return 0, errors.New("wsRespWriter panicked")
|
||||||
|
}
|
||||||
return w.writePipe.Write(data)
|
return w.writePipe.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *wsRespWriter) close() {
|
||||||
|
w.closed = true
|
||||||
|
}
|
||||||
|
|
||||||
func TestServeWS(t *testing.T) {
|
func TestServeWS(t *testing.T) {
|
||||||
http2Conn, _ := newTestHTTP2Connection()
|
http2Conn, _ := newTestHTTP2Connection()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
http2Conn.Serve(ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
respWriter := newWSRespWriter()
|
respWriter := newWSRespWriter()
|
||||||
readPipe, writePipe := io.Pipe()
|
readPipe, writePipe := io.Pipe()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||||
|
|
||||||
wg.Add(1)
|
serveDone := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer close(serveDone)
|
||||||
http2Conn.ServeHTTP(respWriter, req)
|
http2Conn.ServeHTTP(respWriter, req)
|
||||||
|
respWriter.close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
data := []byte("test websocket")
|
data := []byte("test websocket")
|
||||||
err = wsutil.WriteClientText(writePipe, data)
|
err = wsutil.WriteClientBinary(writePipe, data)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
|
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||||
|
|
||||||
|
@ -223,7 +240,65 @@ func TestServeWS(t *testing.T) {
|
||||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||||
|
|
||||||
|
<-serveDone
|
||||||
|
require.False(t, respWriter.panicked)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
|
||||||
|
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
|
||||||
|
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
||||||
|
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
serverDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
|
cfdHTTP2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
edgeTransport := http2.Transport{}
|
||||||
|
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
message := []byte(t.Name())
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||||
|
|
||||||
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// http2RespWriter should rewrite status 101 to 200
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-reqCtx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
_ = wsutil.WriteClientBinary(writePipe, message)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
reqCancel()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
cancel()
|
||||||
|
<-serverDone
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServeControlStream(t *testing.T) {
|
func TestServeControlStream(t *testing.T) {
|
||||||
|
|
|
@ -47,7 +47,7 @@ func TestQUICServer(t *testing.T) {
|
||||||
|
|
||||||
// This is simply a sample websocket frame message.
|
// This is simply a sample websocket frame message.
|
||||||
wsBuf := &bytes.Buffer{}
|
wsBuf := &bytes.Buffer{}
|
||||||
wsutil.WriteClientText(wsBuf, []byte("Hello"))
|
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
|
||||||
|
|
||||||
var tests = []struct {
|
var tests = []struct {
|
||||||
desc string
|
desc string
|
||||||
|
@ -104,7 +104,7 @@ func TestQUICServer(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "test ws proxy",
|
desc: "test ws proxy",
|
||||||
dest: "/ok",
|
dest: "/ws/echo",
|
||||||
connectionType: quicpogs.ConnectionTypeWebsocket,
|
connectionType: quicpogs.ConnectionTypeWebsocket,
|
||||||
metadata: []quicpogs.Metadata{
|
metadata: []quicpogs.Metadata{
|
||||||
{
|
{
|
||||||
|
@ -125,7 +125,7 @@ func TestQUICServer(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
message: wsBuf.Bytes(),
|
message: wsBuf.Bytes(),
|
||||||
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
desc: "test tcp proxy",
|
desc: "test tcp proxy",
|
||||||
|
@ -233,7 +233,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
|
||||||
}
|
}
|
||||||
|
|
||||||
if isWebsocket {
|
if isWebsocket {
|
||||||
return wsEndpoint(w, r)
|
return wsEchoEndpoint(w, r)
|
||||||
}
|
}
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/ok":
|
case "/ok":
|
||||||
|
|
|
@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
|
||||||
wc.streamHandler(wsConn, wc.conn, log)
|
wc.streamHandler(wsConn, wc.conn, log)
|
||||||
cancel()
|
cancel()
|
||||||
// Makes sure wsConn stops sending ping before terminating the stream
|
// Makes sure wsConn stops sending ping before terminating the stream
|
||||||
wsConn.WaitForShutdown()
|
wsConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *tcpOverWSConnection) Close() {
|
func (wc *tcpOverWSConnection) Close() {
|
||||||
|
@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
|
||||||
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
|
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
|
||||||
cancel()
|
cancel()
|
||||||
// Makes sure wsConn stops sending ping before terminating the stream
|
// Makes sure wsConn stops sending ping before terminating the stream
|
||||||
wsConn.WaitForShutdown()
|
wsConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sp *socksProxyOverWSConnection) Close() {
|
func (sp *socksProxyOverWSConnection) Close() {
|
||||||
|
|
|
@ -3,8 +3,10 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
gobwas "github.com/gobwas/ws"
|
gobwas "github.com/gobwas/ws"
|
||||||
|
@ -14,9 +16,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Time allowed to write a message to the peer.
|
|
||||||
writeWait = 10 * time.Second
|
|
||||||
|
|
||||||
// Time allowed to read the next pong message from the peer.
|
// Time allowed to read the next pong message from the peer.
|
||||||
defaultPongWait = 60 * time.Second
|
defaultPongWait = 60 * time.Second
|
||||||
|
|
||||||
|
@ -79,34 +78,20 @@ func (c *GorillaConn) SetDeadline(t time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// pinger simulates the websocket connection to keep it alive
|
|
||||||
func (c *GorillaConn) pinger(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(defaultPingPeriod)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
|
||||||
c.log.Debug().Msgf("failed to send ping message: %s", err)
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
rw io.ReadWriter
|
rw io.ReadWriter
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
// closed is a channel to indicate if Conn has been fully terminated
|
// writeLock makes sure
|
||||||
shutdownC chan struct{}
|
// 1. Only one write at a time. The pinger and Stream function can both call write.
|
||||||
|
// 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close.
|
||||||
|
writeLock sync.Mutex
|
||||||
|
done bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
rw: rw,
|
rw: rw,
|
||||||
log: log,
|
log: log,
|
||||||
shutdownC: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
go c.pinger(ctx)
|
go c.pinger(ctx)
|
||||||
return c
|
return c
|
||||||
|
@ -121,16 +106,22 @@ func (c *Conn) Read(reader []byte) (int, error) {
|
||||||
return copy(reader, data), nil
|
return copy(reader, data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write will write messages to the websocket connection
|
// Write will write messages to the websocket connection.
|
||||||
|
// It will not write to the connection after Close is called to fix TUN-5184
|
||||||
func (c *Conn) Write(p []byte) (int, error) {
|
func (c *Conn) Write(p []byte) (int, error) {
|
||||||
|
c.writeLock.Lock()
|
||||||
|
defer c.writeLock.Unlock()
|
||||||
|
if c.done {
|
||||||
|
return 0, errors.New("write to closed websocket connection")
|
||||||
|
}
|
||||||
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) pinger(ctx context.Context) {
|
func (c *Conn) pinger(ctx context.Context) {
|
||||||
defer close(c.shutdownC)
|
|
||||||
pongMessge := wsutil.Message{
|
pongMessge := wsutil.Message{
|
||||||
OpCode: gobwas.OpPong,
|
OpCode: gobwas.OpPong,
|
||||||
Payload: []byte{},
|
Payload: []byte{},
|
||||||
|
@ -141,7 +132,11 @@ func (c *Conn) pinger(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
done, err := c.ping()
|
||||||
|
if done {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
c.log.Debug().Err(err).Msgf("failed to write ping message")
|
c.log.Debug().Err(err).Msgf("failed to write ping message")
|
||||||
}
|
}
|
||||||
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
||||||
|
@ -153,6 +148,17 @@ func (c *Conn) pinger(ctx context.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ping() (bool, error) {
|
||||||
|
c.writeLock.Lock()
|
||||||
|
defer c.writeLock.Unlock()
|
||||||
|
|
||||||
|
if c.done {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
|
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
|
||||||
if val := ctx.Value(PingPeriodContextKey); val != nil {
|
if val := ctx.Value(PingPeriodContextKey); val != nil {
|
||||||
if period, ok := val.(time.Duration); ok {
|
if period, ok := val.(time.Duration); ok {
|
||||||
|
@ -162,7 +168,9 @@ func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
|
||||||
return defaultPingPeriod
|
return defaultPingPeriod
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close waits for pinger to terminate
|
// Close waits for the current write to finish. Further writes will return error
|
||||||
func (c *Conn) WaitForShutdown() {
|
func (c *Conn) Close() {
|
||||||
<-c.shutdownC
|
c.writeLock.Lock()
|
||||||
|
defer c.writeLock.Unlock()
|
||||||
|
c.done = true
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue