147 lines
3.5 KiB
Go
147 lines
3.5 KiB
Go
|
package wsutil
|
|||
|
|
|||
|
import (
|
|||
|
"bufio"
|
|||
|
"bytes"
|
|||
|
"context"
|
|||
|
"io"
|
|||
|
"io/ioutil"
|
|||
|
"net"
|
|||
|
"net/http"
|
|||
|
|
|||
|
"github.com/gobwas/ws"
|
|||
|
)
|
|||
|
|
|||
|
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
|
|||
|
// handshake. That is, it gives ability to receive copied HTTP request and
|
|||
|
// response bytes that made inside Dialer.Dial().
|
|||
|
//
|
|||
|
// Note that it must not be used in production applications that requires
|
|||
|
// Dial() to be efficient.
|
|||
|
type DebugDialer struct {
|
|||
|
// Dialer contains WebSocket connection establishment options.
|
|||
|
Dialer ws.Dialer
|
|||
|
|
|||
|
// OnRequest and OnResponse are the callbacks that will be called with the
|
|||
|
// HTTP request and response respectively.
|
|||
|
OnRequest, OnResponse func([]byte)
|
|||
|
}
|
|||
|
|
|||
|
// Dial connects to the url host and upgrades connection to WebSocket. It makes
|
|||
|
// it by calling d.Dialer.Dial().
|
|||
|
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
|
|||
|
// Need to copy Dialer to prevent original object mutation.
|
|||
|
dialer := d.Dialer
|
|||
|
var (
|
|||
|
reqBuf bytes.Buffer
|
|||
|
resBuf bytes.Buffer
|
|||
|
|
|||
|
resContentLength int64
|
|||
|
)
|
|||
|
userWrap := dialer.WrapConn
|
|||
|
dialer.WrapConn = func(c net.Conn) net.Conn {
|
|||
|
if userWrap != nil {
|
|||
|
c = userWrap(c)
|
|||
|
}
|
|||
|
|
|||
|
// Save the pointer to the raw connection.
|
|||
|
conn = c
|
|||
|
|
|||
|
var (
|
|||
|
r io.Reader = conn
|
|||
|
w io.Writer = conn
|
|||
|
)
|
|||
|
if d.OnResponse != nil {
|
|||
|
r = &prefetchResponseReader{
|
|||
|
source: conn,
|
|||
|
buffer: &resBuf,
|
|||
|
contentLength: &resContentLength,
|
|||
|
}
|
|||
|
}
|
|||
|
if d.OnRequest != nil {
|
|||
|
w = io.MultiWriter(conn, &reqBuf)
|
|||
|
}
|
|||
|
return rwConn{conn, r, w}
|
|||
|
}
|
|||
|
|
|||
|
_, br, hs, err = dialer.Dial(ctx, urlstr)
|
|||
|
|
|||
|
if onRequest := d.OnRequest; onRequest != nil {
|
|||
|
onRequest(reqBuf.Bytes())
|
|||
|
}
|
|||
|
if onResponse := d.OnResponse; onResponse != nil {
|
|||
|
// We must split response inside buffered bytes from other received
|
|||
|
// bytes from server.
|
|||
|
p := resBuf.Bytes()
|
|||
|
n := bytes.Index(p, headEnd)
|
|||
|
h := n + len(headEnd) // Head end index.
|
|||
|
n = h + int(resContentLength) // Body end index.
|
|||
|
|
|||
|
onResponse(p[:n])
|
|||
|
|
|||
|
if br != nil {
|
|||
|
// If br is non-nil, then it mean two things. First is that
|
|||
|
// handshake is OK and server has sent additional bytes – probably
|
|||
|
// immediate sent frames (or weird but possible response body).
|
|||
|
// Second, the bad one, is that br buffer's source is now rwConn
|
|||
|
// instance from above WrapConn call. It is incorrect, so we must
|
|||
|
// fix it.
|
|||
|
var r io.Reader = conn
|
|||
|
if len(p) > h {
|
|||
|
// Buffer contains more than just HTTP headers bytes.
|
|||
|
r = io.MultiReader(
|
|||
|
bytes.NewReader(p[h:]),
|
|||
|
conn,
|
|||
|
)
|
|||
|
}
|
|||
|
br.Reset(r)
|
|||
|
// Must make br.Buffered() to be non-zero.
|
|||
|
br.Peek(len(p[h:]))
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return conn, br, hs, err
|
|||
|
}
|
|||
|
|
|||
|
type rwConn struct {
|
|||
|
net.Conn
|
|||
|
|
|||
|
r io.Reader
|
|||
|
w io.Writer
|
|||
|
}
|
|||
|
|
|||
|
func (rwc rwConn) Read(p []byte) (int, error) {
|
|||
|
return rwc.r.Read(p)
|
|||
|
}
|
|||
|
func (rwc rwConn) Write(p []byte) (int, error) {
|
|||
|
return rwc.w.Write(p)
|
|||
|
}
|
|||
|
|
|||
|
var headEnd = []byte("\r\n\r\n")
|
|||
|
|
|||
|
type prefetchResponseReader struct {
|
|||
|
source io.Reader // Original connection source.
|
|||
|
reader io.Reader // Wrapped reader used to read from by clients.
|
|||
|
buffer *bytes.Buffer
|
|||
|
|
|||
|
contentLength *int64
|
|||
|
}
|
|||
|
|
|||
|
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
|
|||
|
if r.reader == nil {
|
|||
|
resp, err := http.ReadResponse(bufio.NewReader(
|
|||
|
io.TeeReader(r.source, r.buffer),
|
|||
|
), nil)
|
|||
|
if err == nil {
|
|||
|
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
|
|||
|
resp.Body.Close()
|
|||
|
}
|
|||
|
bts := r.buffer.Bytes()
|
|||
|
r.reader = io.MultiReader(
|
|||
|
bytes.NewReader(bts),
|
|||
|
r.source,
|
|||
|
)
|
|||
|
}
|
|||
|
return r.reader.Read(p)
|
|||
|
}
|