cloudflared-mirror/vendor/github.com/gobwas/ws/wsutil/dialer.go

147 lines
3.5 KiB
Go
Raw Permalink Normal View History

package wsutil
import (
"bufio"
"bytes"
"context"
"io"
"io/ioutil"
"net"
"net/http"
"github.com/gobwas/ws"
)
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
// handshake. That is, it gives ability to receive copied HTTP request and
// response bytes that made inside Dialer.Dial().
//
// Note that it must not be used in production applications that requires
// Dial() to be efficient.
type DebugDialer struct {
// Dialer contains WebSocket connection establishment options.
Dialer ws.Dialer
// OnRequest and OnResponse are the callbacks that will be called with the
// HTTP request and response respectively.
OnRequest, OnResponse func([]byte)
}
// Dial connects to the url host and upgrades connection to WebSocket. It makes
// it by calling d.Dialer.Dial().
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
// Need to copy Dialer to prevent original object mutation.
dialer := d.Dialer
var (
reqBuf bytes.Buffer
resBuf bytes.Buffer
resContentLength int64
)
userWrap := dialer.WrapConn
dialer.WrapConn = func(c net.Conn) net.Conn {
if userWrap != nil {
c = userWrap(c)
}
// Save the pointer to the raw connection.
conn = c
var (
r io.Reader = conn
w io.Writer = conn
)
if d.OnResponse != nil {
r = &prefetchResponseReader{
source: conn,
buffer: &resBuf,
contentLength: &resContentLength,
}
}
if d.OnRequest != nil {
w = io.MultiWriter(conn, &reqBuf)
}
return rwConn{conn, r, w}
}
_, br, hs, err = dialer.Dial(ctx, urlstr)
if onRequest := d.OnRequest; onRequest != nil {
onRequest(reqBuf.Bytes())
}
if onResponse := d.OnResponse; onResponse != nil {
// We must split response inside buffered bytes from other received
// bytes from server.
p := resBuf.Bytes()
n := bytes.Index(p, headEnd)
h := n + len(headEnd) // Head end index.
n = h + int(resContentLength) // Body end index.
onResponse(p[:n])
if br != nil {
// If br is non-nil, then it mean two things. First is that
// handshake is OK and server has sent additional bytes probably
// immediate sent frames (or weird but possible response body).
// Second, the bad one, is that br buffer's source is now rwConn
// instance from above WrapConn call. It is incorrect, so we must
// fix it.
var r io.Reader = conn
if len(p) > h {
// Buffer contains more than just HTTP headers bytes.
r = io.MultiReader(
bytes.NewReader(p[h:]),
conn,
)
}
br.Reset(r)
// Must make br.Buffered() to be non-zero.
br.Peek(len(p[h:]))
}
}
return conn, br, hs, err
}
type rwConn struct {
net.Conn
r io.Reader
w io.Writer
}
func (rwc rwConn) Read(p []byte) (int, error) {
return rwc.r.Read(p)
}
func (rwc rwConn) Write(p []byte) (int, error) {
return rwc.w.Write(p)
}
var headEnd = []byte("\r\n\r\n")
type prefetchResponseReader struct {
source io.Reader // Original connection source.
reader io.Reader // Wrapped reader used to read from by clients.
buffer *bytes.Buffer
contentLength *int64
}
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
if r.reader == nil {
resp, err := http.ReadResponse(bufio.NewReader(
io.TeeReader(r.source, r.buffer),
), nil)
if err == nil {
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
}
bts := r.buffer.Bytes()
r.reader = io.MultiReader(
bytes.NewReader(bts),
r.source,
)
}
return r.reader.Read(p)
}