package wsutil

import (
	"bufio"
	"bytes"
	"context"
	"io"
	"io/ioutil"
	"net"
	"net/http"

	"github.com/gobwas/ws"
)

// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
// handshake. That is, it gives ability to receive copied HTTP request and
// response bytes that made inside Dialer.Dial().
//
// Note that it must not be used in production applications that requires
// Dial() to be efficient.
type DebugDialer struct {
	// Dialer contains WebSocket connection establishment options.
	Dialer ws.Dialer

	// OnRequest and OnResponse are the callbacks that will be called with the
	// HTTP request and response respectively.
	OnRequest, OnResponse func([]byte)
}

// Dial connects to the url host and upgrades connection to WebSocket. It makes
// it by calling d.Dialer.Dial().
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
	// Need to copy Dialer to prevent original object mutation.
	dialer := d.Dialer
	var (
		reqBuf bytes.Buffer
		resBuf bytes.Buffer

		resContentLength int64
	)
	userWrap := dialer.WrapConn
	dialer.WrapConn = func(c net.Conn) net.Conn {
		if userWrap != nil {
			c = userWrap(c)
		}

		// Save the pointer to the raw connection.
		conn = c

		var (
			r io.Reader = conn
			w io.Writer = conn
		)
		if d.OnResponse != nil {
			r = &prefetchResponseReader{
				source:        conn,
				buffer:        &resBuf,
				contentLength: &resContentLength,
			}
		}
		if d.OnRequest != nil {
			w = io.MultiWriter(conn, &reqBuf)
		}
		return rwConn{conn, r, w}
	}

	_, br, hs, err = dialer.Dial(ctx, urlstr)

	if onRequest := d.OnRequest; onRequest != nil {
		onRequest(reqBuf.Bytes())
	}
	if onResponse := d.OnResponse; onResponse != nil {
		// We must split response inside buffered bytes from other received
		// bytes from server.
		p := resBuf.Bytes()
		n := bytes.Index(p, headEnd)
		h := n + len(headEnd)         // Head end index.
		n = h + int(resContentLength) // Body end index.

		onResponse(p[:n])

		if br != nil {
			// If br is non-nil, then it mean two things. First is that
			// handshake is OK and server has sent additional bytes – probably
			// immediate sent frames (or weird but possible response body).
			// Second, the bad one, is that br buffer's source is now rwConn
			// instance from above WrapConn call. It is incorrect, so we must
			// fix it.
			var r io.Reader = conn
			if len(p) > h {
				// Buffer contains more than just HTTP headers bytes.
				r = io.MultiReader(
					bytes.NewReader(p[h:]),
					conn,
				)
			}
			br.Reset(r)
			// Must make br.Buffered() to be non-zero.
			br.Peek(len(p[h:]))
		}
	}

	return conn, br, hs, err
}

type rwConn struct {
	net.Conn

	r io.Reader
	w io.Writer
}

func (rwc rwConn) Read(p []byte) (int, error) {
	return rwc.r.Read(p)
}

func (rwc rwConn) Write(p []byte) (int, error) {
	return rwc.w.Write(p)
}

var headEnd = []byte("\r\n\r\n")

type prefetchResponseReader struct {
	source io.Reader // Original connection source.
	reader io.Reader // Wrapped reader used to read from by clients.
	buffer *bytes.Buffer

	contentLength *int64
}

func (r *prefetchResponseReader) Read(p []byte) (int, error) {
	if r.reader == nil {
		resp, err := http.ReadResponse(bufio.NewReader(
			io.TeeReader(r.source, r.buffer),
		), nil)
		if err == nil {
			*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
			resp.Body.Close()
		}
		bts := r.buffer.Bytes()
		r.reader = io.MultiReader(
			bytes.NewReader(bts),
			r.source,
		)
	}
	return r.reader.Read(p)
}