diff --git a/websocket/connection.go b/websocket/connection.go index ec327804..40a68fd3 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -1,15 +1,16 @@ package websocket import ( + "bytes" "context" + "fmt" "io" "time" - "github.com/rs/zerolog" - gobwas "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/gorilla/websocket" + "github.com/rs/zerolog" ) const ( @@ -27,17 +28,29 @@ const ( // This is still used by access carrier type GorillaConn struct { *websocket.Conn - log *zerolog.Logger + log *zerolog.Logger + readBuf bytes.Buffer } // Read will read messages from the websocket connection func (c *GorillaConn) Read(p []byte) (int, error) { + // Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame + if c.readBuf.Len() > 0 { + return c.readBuf.Read(p) + } + _, message, err := c.Conn.ReadMessage() if err != nil { return 0, err } - return copy(p, message), nil + copied := copy(p, message) + + // Write unread bytes to readBuf; if everything was read this is a no-op + // Write returns a nil error always and grows the buffer; everything is always written or panic + c.readBuf.Write(message[copied:]) + + return copied, nil } // Write will write messages to the websocket connection @@ -49,6 +62,19 @@ func (c *GorillaConn) Write(p []byte) (int, error) { return len(p), nil } +// SetDeadline sets both read and write deadlines, as per net.Conn interface docs: +// "It is equivalent to calling both SetReadDeadline and SetWriteDeadline." +// Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway +func (c *GorillaConn) SetDeadline(t time.Time) error { + if err := c.Conn.SetReadDeadline(t); err != nil { + return fmt.Errorf("error setting read deadline: %w", err) + } + if err := c.Conn.SetWriteDeadline(t); err != nil { + return fmt.Errorf("error setting write deadline: %w", err) + } + return nil +} + // pinger simulates the websocket connection to keep it alive func (c *GorillaConn) pinger(ctx context.Context) { ticker := time.NewTicker(pingPeriod) diff --git a/websocket/websocket.go b/websocket/websocket.go index 9103da28..60e6edfd 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -114,7 +114,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } _ = conn.SetReadDeadline(time.Now().Add(pongWait)) conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - gorillaConn := &GorillaConn{conn, h.log} + gorillaConn := &GorillaConn{Conn: conn, log: h.log} go gorillaConn.pinger(r.Context()) defer conn.Close() diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 179098d6..035d83a5 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -1,18 +1,23 @@ package websocket import ( + "context" "crypto/tls" "crypto/x509" - "github.com/rs/zerolog" + "fmt" "io" "math/rand" "net/http" "testing" + "time" + + "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/tlsconfig" gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/websocket" ) @@ -102,6 +107,51 @@ func TestServe(t *testing.T) { <-errC } +func TestWebsocketWrapper(t *testing.T) { + + listener, err := hello.CreateTLSListener("localhost:0") + require.NoError(t, err) + + serverErrorChan := make(chan error) + helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background()) + defer func() { <-serverErrorChan }() + defer cancelHelloSvr() + go func() { + log := zerolog.Nop() + serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done()) + }() + + tlsConfig := websocketClientTLSConfig(t) + d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute} + testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String()) + req := testRequest(t, testAddr, nil) + conn, resp, err := ClientConnect(req, &d) + require.NoError(t, err) + require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept")) + + // Websocket now connected to test server so lets check our wrapper + wrapper := GorillaConn{Conn: conn} + buf := make([]byte, 100) + wrapper.Write([]byte("abc")) + n, err := wrapper.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 3) + require.Equal(t, "abc", string(buf[:n])) + + // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read + wrapper.Write([]byte("abc")) + buf = buf[:1] + n, err = wrapper.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 1) + require.Equal(t, "a", string(buf[:n])) + buf = buf[:cap(buf)] + n, err = wrapper.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 2) + require.Equal(t, "bc", string(buf[:n])) +} + // func TestStartProxyServer(t *testing.T) { // var wg sync.WaitGroup // remoteAddress := "localhost:1113"