111 lines
2.6 KiB
Go
111 lines
2.6 KiB
Go
package websocket
|
|
|
|
import (
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
|
func IsWebSocketUpgrade(req *http.Request) bool {
|
|
return websocket.IsWebSocketUpgrade(req)
|
|
}
|
|
|
|
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
|
func NewResponseHeader(req *http.Request) http.Header {
|
|
header := http.Header{}
|
|
header.Add("Connection", "Upgrade")
|
|
header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key")))
|
|
header.Add("Upgrade", "websocket")
|
|
return header
|
|
}
|
|
|
|
// Stream copies copy data to & from provided io.ReadWriters.
|
|
func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
|
|
proxyDone := make(chan struct{}, 2)
|
|
|
|
go func() {
|
|
_, err := copyData(tunnelConn, originConn, "origin->tunnel")
|
|
if err != nil {
|
|
log.Debug().Msgf("origin to tunnel copy: %v", err)
|
|
}
|
|
proxyDone <- struct{}{}
|
|
}()
|
|
|
|
go func() {
|
|
_, err := copyData(originConn, tunnelConn, "tunnel->origin")
|
|
if err != nil {
|
|
log.Debug().Msgf("tunnel to origin copy: %v", err)
|
|
}
|
|
proxyDone <- struct{}{}
|
|
}()
|
|
|
|
// If one side is done, we are done.
|
|
<-proxyDone
|
|
}
|
|
|
|
// when set to true, enables logging of content copied to/from origin and tunnel
|
|
const debugCopy = false
|
|
|
|
func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
|
|
if debugCopy {
|
|
// copyBuffer is based on stdio Copy implementation but shows copied data
|
|
copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
|
|
var buf []byte
|
|
size := 32 * 1024
|
|
buf = make([]byte, size)
|
|
for {
|
|
t := time.Now()
|
|
nr, er := src.Read(buf)
|
|
if nr > 0 {
|
|
fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
|
|
nw, ew := dst.Write(buf[0:nr])
|
|
if nw < 0 || nr < nw {
|
|
nw = 0
|
|
if ew == nil {
|
|
ew = errors.New("invalid write")
|
|
}
|
|
}
|
|
written += int64(nw)
|
|
if ew != nil {
|
|
err = ew
|
|
break
|
|
}
|
|
if nr != nw {
|
|
err = io.ErrShortWrite
|
|
break
|
|
}
|
|
}
|
|
if er != nil {
|
|
if er != io.EOF {
|
|
err = er
|
|
}
|
|
break
|
|
}
|
|
}
|
|
return written, err
|
|
}
|
|
return copyBuffer(dst, src, dir)
|
|
} else {
|
|
return io.Copy(dst, src)
|
|
}
|
|
}
|
|
|
|
// from RFC-6455
|
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
|
|
|
func generateAcceptKey(challengeKey string) string {
|
|
h := sha1.New()
|
|
h.Write([]byte(challengeKey))
|
|
h.Write(keyGUID)
|
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
|
}
|