Allow partial reads from a GorillaConn; add SetDeadline (from net.Conn) (#330)

* Allow partial reads from a GorillaConn; add SetDeadline (from net.Conn)

The current implementation of GorillaConn will drop data if the
websocket frame isn't read 100%. For example, if a websocket frame is
size=3, and Read() is called with a []byte of len=1, the 2 other bytes
in the frame are lost forever.

This is currently masked by the fact that this is used primarily in
io.Copy to another socket (in ingress.Stream) - as long as the read buffer
used by io.Copy is big enough (it is 32*1024, so in theory we could see
this today?) then data is copied over to the other socket.

The client then can do partial reads just fine as the kernel will take
care of the buffer from here on out.

I hit this by trying to create my own tunnel and avoiding
ingress.Stream, but this could be a real bug today I think if a
websocket frame bigger than 32*1024 was received, although it is also
possible that we are lucky and the upstream size which I haven't checked
uses a smaller buffer than that always.

The test I added hangs before my change, succeeds after.

Also add SetDeadline so that GorillaConn fully implements net.Conn

* Comment formatting; fast path

* Avoid intermediate buffer for first len(p) bytes; import order
This commit is contained in:
Benjamin Buzbee 2021-03-09 07:57:04 -08:00 committed by GitHub
parent 39065377b5
commit 452f8cef79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 6 deletions

View File

@ -1,15 +1,16 @@
package websocket package websocket
import ( import (
"bytes"
"context" "context"
"fmt"
"io" "io"
"time" "time"
"github.com/rs/zerolog"
gobwas "github.com/gobwas/ws" gobwas "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/rs/zerolog"
) )
const ( const (
@ -28,16 +29,28 @@ const (
type GorillaConn struct { type GorillaConn struct {
*websocket.Conn *websocket.Conn
log *zerolog.Logger log *zerolog.Logger
readBuf bytes.Buffer
} }
// Read will read messages from the websocket connection // Read will read messages from the websocket connection
func (c *GorillaConn) Read(p []byte) (int, error) { func (c *GorillaConn) Read(p []byte) (int, error) {
// Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame
if c.readBuf.Len() > 0 {
return c.readBuf.Read(p)
}
_, message, err := c.Conn.ReadMessage() _, message, err := c.Conn.ReadMessage()
if err != nil { if err != nil {
return 0, err return 0, err
} }
return copy(p, message), nil
copied := copy(p, message)
// Write unread bytes to readBuf; if everything was read this is a no-op
// Write returns a nil error always and grows the buffer; everything is always written or panic
c.readBuf.Write(message[copied:])
return copied, nil
} }
// Write will write messages to the websocket connection // Write will write messages to the websocket connection
@ -49,6 +62,19 @@ func (c *GorillaConn) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
// SetDeadline sets both read and write deadlines, as per net.Conn interface docs:
// "It is equivalent to calling both SetReadDeadline and SetWriteDeadline."
// Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway
func (c *GorillaConn) SetDeadline(t time.Time) error {
if err := c.Conn.SetReadDeadline(t); err != nil {
return fmt.Errorf("error setting read deadline: %w", err)
}
if err := c.Conn.SetWriteDeadline(t); err != nil {
return fmt.Errorf("error setting write deadline: %w", err)
}
return nil
}
// pinger simulates the websocket connection to keep it alive // pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) { func (c *GorillaConn) pinger(ctx context.Context) {
ticker := time.NewTicker(pingPeriod) ticker := time.NewTicker(pingPeriod)

View File

@ -114,7 +114,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
_ = conn.SetReadDeadline(time.Now().Add(pongWait)) _ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
gorillaConn := &GorillaConn{conn, h.log} gorillaConn := &GorillaConn{Conn: conn, log: h.log}
go gorillaConn.pinger(r.Context()) go gorillaConn.pinger(r.Context())
defer conn.Close() defer conn.Close()

View File

@ -1,18 +1,23 @@
package websocket package websocket
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"github.com/rs/zerolog" "fmt"
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
gws "github.com/gorilla/websocket" gws "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -102,6 +107,51 @@ func TestServe(t *testing.T) {
<-errC <-errC
} }
func TestWebsocketWrapper(t *testing.T) {
listener, err := hello.CreateTLSListener("localhost:0")
require.NoError(t, err)
serverErrorChan := make(chan error)
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
defer func() { <-serverErrorChan }()
defer cancelHelloSvr()
go func() {
log := zerolog.Nop()
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
}()
tlsConfig := websocketClientTLSConfig(t)
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
req := testRequest(t, testAddr, nil)
conn, resp, err := ClientConnect(req, &d)
require.NoError(t, err)
require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
// Websocket now connected to test server so lets check our wrapper
wrapper := GorillaConn{Conn: conn}
buf := make([]byte, 100)
wrapper.Write([]byte("abc"))
n, err := wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 3)
require.Equal(t, "abc", string(buf[:n]))
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
wrapper.Write([]byte("abc"))
buf = buf[:1]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 1)
require.Equal(t, "a", string(buf[:n]))
buf = buf[:cap(buf)]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 2)
require.Equal(t, "bc", string(buf[:n]))
}
// func TestStartProxyServer(t *testing.T) { // func TestStartProxyServer(t *testing.T) {
// var wg sync.WaitGroup // var wg sync.WaitGroup
// remoteAddress := "localhost:1113" // remoteAddress := "localhost:1113"