cloudflared-mirror/websocket/connection.go

182 lines
4.7 KiB
Go
Raw Normal View History

package websocket
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"sync"
"time"
gobwas "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
defaultPongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
defaultPingPeriod = (defaultPongWait * 9) / 10
PingPeriodContextKey = PingPeriodContext("pingPeriod")
)
type PingPeriodContext string
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
// This is still used by access carrier
type GorillaConn struct {
*websocket.Conn
log *zerolog.Logger
readBuf bytes.Buffer
}
// Read will read messages from the websocket connection
func (c *GorillaConn) Read(p []byte) (int, error) {
// Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame
if c.readBuf.Len() > 0 {
return c.readBuf.Read(p)
}
_, message, err := c.Conn.ReadMessage()
if err != nil {
return 0, err
}
copied := copy(p, message)
// Write unread bytes to readBuf; if everything was read this is a no-op
// Write returns a nil error always and grows the buffer; everything is always written or panic
c.readBuf.Write(message[copied:])
return copied, nil
}
// Write will write messages to the websocket connection
func (c *GorillaConn) Write(p []byte) (int, error) {
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
return 0, err
}
return len(p), nil
}
// SetDeadline sets both read and write deadlines, as per net.Conn interface docs:
// "It is equivalent to calling both SetReadDeadline and SetWriteDeadline."
// Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway
func (c *GorillaConn) SetDeadline(t time.Time) error {
if err := c.Conn.SetReadDeadline(t); err != nil {
return fmt.Errorf("error setting read deadline: %w", err)
}
if err := c.Conn.SetWriteDeadline(t); err != nil {
return fmt.Errorf("error setting write deadline: %w", err)
}
return nil
}
// pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) {
ticker := time.NewTicker(defaultPingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
c.log.Debug().Msgf("failed to send ping message: %s", err)
}
case <-ctx.Done():
return
}
}
}
type Conn struct {
rw io.ReadWriter
log *zerolog.Logger
// writeLock makes sure
// 1. Only one write at a time. The pinger and Stream function can both call write.
// 2. Close only returns after in progress Write is finished, and no more Write will succeed after calling Close.
writeLock sync.Mutex
done bool
}
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
c := &Conn{
rw: rw,
log: log,
}
go c.pinger(ctx)
return c
}
// Read will read messages from the websocket connection
func (c *Conn) Read(reader []byte) (int, error) {
data, err := wsutil.ReadClientBinary(c.rw)
if err != nil {
return 0, err
}
return copy(reader, data), nil
}
// Write will write messages to the websocket connection.
// It will not write to the connection after Close is called to fix TUN-5184
func (c *Conn) Write(p []byte) (int, error) {
c.writeLock.Lock()
defer c.writeLock.Unlock()
if c.done {
return 0, errors.New("Write to closed websocket connection")
}
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
return 0, err
}
return len(p), nil
}
func (c *Conn) pinger(ctx context.Context) {
pongMessge := wsutil.Message{
OpCode: gobwas.OpPong,
Payload: []byte{},
}
ticker := time.NewTicker(c.pingPeriod(ctx))
defer ticker.Stop()
for {
select {
// Ping/Pong messages will not be written after the connection is done
case <-ticker.C:
if err := wsutil.WriteServerMessage(c, gobwas.OpPing, []byte{}); err != nil {
c.log.Debug().Err(err).Msgf("failed to write ping message")
}
if err := wsutil.HandleClientControlMessage(c, pongMessge); err != nil {
c.log.Debug().Err(err).Msgf("failed to write pong message")
}
case <-ctx.Done():
return
}
}
}
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
if val := ctx.Value(PingPeriodContextKey); val != nil {
if period, ok := val.(time.Duration); ok {
return period
}
}
return defaultPingPeriod
}
// Close waits for the current write to finish. Further writes will return error
func (c *Conn) Close() {
c.writeLock.Lock()
defer c.writeLock.Unlock()
c.done = true
}