cloudflared-mirror/websocket/websocket.go

162 lines
4.5 KiB
Go
Raw Permalink Normal View History

package websocket
import (
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"runtime/debug"
"sync/atomic"
"time"
"github.com/getsentry/raven-go"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// NewResponseHeader returns headers needed to return to origin for completing handshake
func NewResponseHeader(req *http.Request) http.Header {
header := http.Header{}
header.Add("Connection", "Upgrade")
header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key")))
header.Add("Upgrade", "websocket")
return header
}
type bidirectionalStreamStatus struct {
doneChan chan struct{}
anyDone uint32
}
func newBiStreamStatus() *bidirectionalStreamStatus {
return &bidirectionalStreamStatus{
doneChan: make(chan struct{}, 2),
anyDone: 0,
}
}
func (s *bidirectionalStreamStatus) markUniStreamDone() {
atomic.StoreUint32(&s.anyDone, 1)
s.doneChan <- struct{}{}
}
func (s *bidirectionalStreamStatus) waitAnyDone() {
<-s.doneChan
}
func (s *bidirectionalStreamStatus) isAnyDone() bool {
return atomic.LoadUint32(&s.anyDone) > 0
}
// Stream copies copy data to & from provided io.ReadWriters.
func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
status := newBiStreamStatus()
go unidirectionalStream(tunnelConn, originConn, "origin->tunnel", status, log)
go unidirectionalStream(originConn, tunnelConn, "tunnel->origin", status, log)
// If one side is done, we are done.
status.waitAnyDone()
}
func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) {
defer func() {
// The bidirectional streaming spawns 2 goroutines to stream each direction.
// If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will
// close. In such case, if the other direction did not stop (due to application level stopping, e.g., if a
// server/origin listens forever until closure), it may read/write from the underlying ReadWriter (backed by
// the Edge<->cloudflared transport) in an unexpected state.
// Because of this, we set this recover() logic.
if r := recover(); r != nil {
if status.isAnyDone() {
// We handle such unexpected errors only when we detect that one side of the streaming is done.
log.Debug().Msgf("Gracefully handled error %v in Streaming for %s, error %s", r, dir, debug.Stack())
} else {
// Otherwise, this is unexpected, but we prevent the program from crashing anyway.
log.Warn().Msgf("Gracefully handled unexpected error %v in Streaming for %s, error %s", r, dir, debug.Stack())
tags := make(map[string]string)
tags["root"] = "websocket.stream"
tags["dir"] = dir
switch rval := r.(type) {
case error:
raven.CaptureError(rval, tags)
default:
rvalStr := fmt.Sprint(rval)
raven.CaptureMessage(rvalStr, tags)
}
}
}
}()
_, err := copyData(dst, src, dir)
if err != nil {
log.Debug().Msgf("%s copy: %v", dir, err)
}
status.markUniStreamDone()
}
// when set to true, enables logging of content copied to/from origin and tunnel
const debugCopy = false
func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
if debugCopy {
// copyBuffer is based on stdio Copy implementation but shows copied data
copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
var buf []byte
size := 32 * 1024
buf = make([]byte, size)
for {
t := time.Now()
nr, er := src.Read(buf)
if nr > 0 {
fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
return copyBuffer(dst, src, dir)
} else {
return io.Copy(dst, src)
}
}
// from RFC-6455
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func generateAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}