169 lines
4.2 KiB
Go
169 lines
4.2 KiB
Go
package websocket
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
gobwas "github.com/gobwas/ws"
|
|
"github.com/gobwas/ws/wsutil"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
const (
|
|
// Time allowed to write a message to the peer.
|
|
writeWait = 10 * time.Second
|
|
|
|
// Time allowed to read the next pong message from the peer.
|
|
defaultPongWait = 60 * time.Second
|
|
|
|
// Send pings to peer with this period. Must be less than pongWait.
|
|
defaultPingPeriod = (defaultPongWait * 9) / 10
|
|
|
|
PingPeriodContextKey = PingPeriodContext("pingPeriod")
|
|
)
|
|
|
|
type PingPeriodContext string
|
|
|
|
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
|
|
// This is still used by access carrier
|
|
type GorillaConn struct {
|
|
*websocket.Conn
|
|
log *zerolog.Logger
|
|
readBuf bytes.Buffer
|
|
}
|
|
|
|
// Read will read messages from the websocket connection
|
|
func (c *GorillaConn) Read(p []byte) (int, error) {
|
|
// Intermediate buffer may contain unread bytes from the last read, start there before blocking on a new frame
|
|
if c.readBuf.Len() > 0 {
|
|
return c.readBuf.Read(p)
|
|
}
|
|
|
|
_, message, err := c.Conn.ReadMessage()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
copied := copy(p, message)
|
|
|
|
// Write unread bytes to readBuf; if everything was read this is a no-op
|
|
// Write returns a nil error always and grows the buffer; everything is always written or panic
|
|
c.readBuf.Write(message[copied:])
|
|
|
|
return copied, nil
|
|
}
|
|
|
|
// Write will write messages to the websocket connection
|
|
func (c *GorillaConn) Write(p []byte) (int, error) {
|
|
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
// SetDeadline sets both read and write deadlines, as per net.Conn interface docs:
|
|
// "It is equivalent to calling both SetReadDeadline and SetWriteDeadline."
|
|
// Note there is no synchronization here, but the gorilla implementation isn't thread safe anyway
|
|
func (c *GorillaConn) SetDeadline(t time.Time) error {
|
|
if err := c.Conn.SetReadDeadline(t); err != nil {
|
|
return fmt.Errorf("error setting read deadline: %w", err)
|
|
}
|
|
if err := c.Conn.SetWriteDeadline(t); err != nil {
|
|
return fmt.Errorf("error setting write deadline: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// pinger simulates the websocket connection to keep it alive
|
|
func (c *GorillaConn) pinger(ctx context.Context) {
|
|
ticker := time.NewTicker(defaultPingPeriod)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
|
c.log.Debug().Msgf("failed to send ping message: %s", err)
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type Conn struct {
|
|
rw io.ReadWriter
|
|
log *zerolog.Logger
|
|
// closed is a channel to indicate if Conn has been fully terminated
|
|
shutdownC chan struct{}
|
|
}
|
|
|
|
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
|
c := &Conn{
|
|
rw: rw,
|
|
log: log,
|
|
shutdownC: make(chan struct{}),
|
|
}
|
|
go c.pinger(ctx)
|
|
return c
|
|
}
|
|
|
|
// Read will read messages from the websocket connection
|
|
func (c *Conn) Read(reader []byte) (int, error) {
|
|
data, err := wsutil.ReadClientBinary(c.rw)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return copy(reader, data), nil
|
|
}
|
|
|
|
// Write will write messages to the websocket connection
|
|
func (c *Conn) Write(p []byte) (int, error) {
|
|
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
|
return 0, err
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
func (c *Conn) pinger(ctx context.Context) {
|
|
defer close(c.shutdownC)
|
|
pongMessge := wsutil.Message{
|
|
OpCode: gobwas.OpPong,
|
|
Payload: []byte{},
|
|
}
|
|
|
|
ticker := time.NewTicker(c.pingPeriod(ctx))
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
|
c.log.Debug().Err(err).Msgf("failed to write ping message")
|
|
}
|
|
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
|
c.log.Debug().Err(err).Msgf("failed to write pong message")
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
|
|
if val := ctx.Value(PingPeriodContextKey); val != nil {
|
|
if period, ok := val.(time.Duration); ok {
|
|
return period
|
|
}
|
|
}
|
|
return defaultPingPeriod
|
|
}
|
|
|
|
// Close waits for pinger to terminate
|
|
func (c *Conn) WaitForShutdown() {
|
|
<-c.shutdownC
|
|
}
|