package ws

import (
	"bufio"
	"bytes"
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"net"
	"net/url"
	"strconv"
	"strings"
	"time"

	"github.com/gobwas/httphead"
	"github.com/gobwas/pool/pbufio"
)

// Constants used by Dialer.
const (
	DefaultClientReadBufferSize  = 4096
	DefaultClientWriteBufferSize = 4096
)

// Handshake represents handshake result.
type Handshake struct {
	// Protocol is the subprotocol selected during handshake.
	Protocol string

	// Extensions is the list of negotiated extensions.
	Extensions []httphead.Option
}

// Errors used by the websocket client.
var (
	ErrHandshakeBadStatus      = fmt.Errorf("unexpected http status")
	ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
	ErrHandshakeBadExtensions  = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
)

// DefaultDialer is dialer that holds no options and is used by Dial function.
var DefaultDialer Dialer

// Dial is like Dialer{}.Dial().
func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
	return DefaultDialer.Dial(ctx, urlstr)
}

// Dialer contains options for establishing websocket connection to an url.
type Dialer struct {
	// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
	// They used to read and write http data while upgrading to WebSocket.
	// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
	//
	// If a size is zero then default value is used.
	ReadBufferSize, WriteBufferSize int

	// Timeout is the maximum amount of time a Dial() will wait for a connect
	// and an handshake to complete.
	//
	// The default is no timeout.
	Timeout time.Duration

	// Protocols is the list of subprotocols that the client wants to speak,
	// ordered by preference.
	//
	// See https://tools.ietf.org/html/rfc6455#section-4.1
	Protocols []string

	// Extensions is the list of extensions that client wants to speak.
	//
	// Note that if server decides to use some of this extensions, Dial() will
	// return Handshake struct containing a slice of items, which are the
	// shallow copies of the items from this list. That is, internals of
	// Extensions items are shared during Dial().
	//
	// See https://tools.ietf.org/html/rfc6455#section-4.1
	// See https://tools.ietf.org/html/rfc6455#section-9.1
	Extensions []httphead.Option

	// Header is an optional HandshakeHeader instance that could be used to
	// write additional headers to the handshake request.
	//
	// It used instead of any key-value mappings to avoid allocations in user
	// land.
	Header HandshakeHeader

	// OnStatusError is the callback that will be called after receiving non
	// "101 Continue" HTTP response status. It receives an io.Reader object
	// representing server response bytes. That is, it gives ability to parse
	// HTTP response somehow (probably with http.ReadResponse call) and make a
	// decision of further logic.
	//
	// The arguments are only valid until the callback returns.
	OnStatusError func(status int, reason []byte, resp io.Reader)

	// OnHeader is the callback that will be called after successful parsing of
	// header, that is not used during WebSocket handshake procedure. That is,
	// it will be called with non-websocket headers, which could be relevant
	// for application-level logic.
	//
	// The arguments are only valid until the callback returns.
	//
	// Returned value could be used to prevent processing response.
	OnHeader func(key, value []byte) (err error)

	// NetDial is the function that is used to get plain tcp connection.
	// If it is not nil, then it is used instead of net.Dialer.
	NetDial func(ctx context.Context, network, addr string) (net.Conn, error)

	// TLSClient is the callback that will be called after successful dial with
	// received connection and its remote host name. If it is nil, then the
	// default tls.Client() will be used.
	// If it is not nil, then TLSConfig field is ignored.
	TLSClient func(conn net.Conn, hostname string) net.Conn

	// TLSConfig is passed to tls.Client() to start TLS over established
	// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
	// non-nil and its ServerName is empty, then for every Dial() it will be
	// cloned and appropriate ServerName will be set.
	TLSConfig *tls.Config

	// WrapConn is the optional callback that will be called when connection is
	// ready for an i/o. That is, it will be called after successful dial and
	// TLS initialization (for "wss" schemes). It may be helpful for different
	// user land purposes such as end to end encryption.
	//
	// Note that for debugging purposes of an http handshake (e.g. sent request
	// and received response), there is an wsutil.DebugDialer struct.
	WrapConn func(conn net.Conn) net.Conn
}

// Dial connects to the url host and upgrades connection to WebSocket.
//
// If server has sent frames right after successful handshake then returned
// buffer will be non-nil. In other cases buffer is always nil. For better
// memory efficiency received non-nil bufio.Reader should be returned to the
// inner pool with PutReader() function after use.
//
// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
// If you want to dial non-ascii host name, take care of its name serialization
// avoiding bad request issues. For more info see net/http Request.Write()
// implementation, especially cleanHost() function.
func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
	u, err := url.ParseRequestURI(urlstr)
	if err != nil {
		return
	}

	// Prepare context to dial with. Initially it is the same as original, but
	// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
	// we use more shorter context for dial.
	dialctx := ctx

	var deadline time.Time
	if t := d.Timeout; t != 0 {
		deadline = time.Now().Add(t)
		if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
			var cancel context.CancelFunc
			dialctx, cancel = context.WithDeadline(ctx, deadline)
			defer cancel()
		}
	}
	if conn, err = d.dial(dialctx, u); err != nil {
		return
	}
	defer func() {
		if err != nil {
			conn.Close()
		}
	}()
	if ctx == context.Background() {
		// No need to start I/O interrupter goroutine which is not zero-cost.
		conn.SetDeadline(deadline)
		defer conn.SetDeadline(noDeadline)
	} else {
		// Context could be canceled or its deadline could be exceeded.
		// Start the interrupter goroutine to handle context cancelation.
		done := setupContextDeadliner(ctx, conn)
		defer func() {
			// Map Upgrade() error to a possible context expiration error. That
			// is, even if Upgrade() err is nil, context could be already
			// expired and connection be "poisoned" by SetDeadline() call.
			// In that case we must not return ctx.Err() error.
			done(&err)
		}()
	}

	br, hs, err = d.Upgrade(conn, u)

	return
}

var (
	// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
	// Dialer.NetDial is not provided.
	netEmptyDialer net.Dialer
	// tlsEmptyConfig is an empty tls.Config used as default one.
	tlsEmptyConfig tls.Config
)

func tlsDefaultConfig() *tls.Config {
	return &tlsEmptyConfig
}

func hostport(host string, defaultPort string) (hostname, addr string) {
	var (
		colon   = strings.LastIndexByte(host, ':')
		bracket = strings.IndexByte(host, ']')
	)
	if colon > bracket {
		return host[:colon], host
	}
	return host, host + defaultPort
}

func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
	dial := d.NetDial
	if dial == nil {
		dial = netEmptyDialer.DialContext
	}
	switch u.Scheme {
	case "ws":
		_, addr := hostport(u.Host, ":80")
		conn, err = dial(ctx, "tcp", addr)
	case "wss":
		hostname, addr := hostport(u.Host, ":443")
		conn, err = dial(ctx, "tcp", addr)
		if err != nil {
			return
		}
		tlsClient := d.TLSClient
		if tlsClient == nil {
			tlsClient = d.tlsClient
		}
		conn = tlsClient(conn, hostname)
	default:
		return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
	}
	if wrap := d.WrapConn; wrap != nil {
		conn = wrap(conn)
	}
	return
}

func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
	config := d.TLSConfig
	if config == nil {
		config = tlsDefaultConfig()
	}
	if config.ServerName == "" {
		config = tlsCloneConfig(config)
		config.ServerName = hostname
	}
	// Do not make conn.Handshake() here because downstairs we will prepare
	// i/o on this conn with proper context's timeout handling.
	return tls.Client(conn, config)
}

var (
	// This variables are set like in net/net.go.
	// noDeadline is just zero value for readability.
	noDeadline = time.Time{}
	// aLongTimeAgo is a non-zero time, far in the past, used for immediate
	// cancelation of dials.
	aLongTimeAgo = time.Unix(42, 0)
)

// Upgrade writes an upgrade request to the given io.ReadWriter conn at given
// url u and reads a response from it.
//
// It is a caller responsibility to manage I/O deadlines on conn.
//
// It returns handshake info and some bytes which could be written by the peer
// right after response and be caught by us during buffered read.
func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
	// headerSeen constants helps to report whether or not some header was seen
	// during reading request bytes.
	const (
		headerSeenUpgrade = 1 << iota
		headerSeenConnection
		headerSeenSecAccept

		// headerSeenAll is the value that we expect to receive at the end of
		// headers read/parse loop.
		headerSeenAll = 0 |
			headerSeenUpgrade |
			headerSeenConnection |
			headerSeenSecAccept
	)

	br = pbufio.GetReader(conn,
		nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
	)
	bw := pbufio.GetWriter(conn,
		nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
	)
	defer func() {
		pbufio.PutWriter(bw)
		if br.Buffered() == 0 || err != nil {
			// Server does not wrote additional bytes to the connection or
			// error occurred. That is, no reason to return buffer.
			pbufio.PutReader(br)
			br = nil
		}
	}()

	nonce := make([]byte, nonceSize)
	initNonce(nonce)

	httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
	if err = bw.Flush(); err != nil {
		return
	}

	// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
	sl, err := readLine(br)
	if err != nil {
		return
	}
	// Begin validation of the response.
	// See https://tools.ietf.org/html/rfc6455#section-4.2.2
	// Parse request line data like HTTP version, uri and method.
	resp, err := httpParseResponseLine(sl)
	if err != nil {
		return
	}
	// Even if RFC says "1.1 or higher" without mentioning the part of the
	// version, we apply it only to minor part.
	if resp.major != 1 || resp.minor < 1 {
		err = ErrHandshakeBadProtocol
		return
	}
	if resp.status != 101 {
		err = StatusError(resp.status)
		if onStatusError := d.OnStatusError; onStatusError != nil {
			// Invoke callback with multireader of status-line bytes br.
			onStatusError(resp.status, resp.reason,
				io.MultiReader(
					bytes.NewReader(sl),
					strings.NewReader(crlf),
					br,
				),
			)
		}
		return
	}
	// If response status is 101 then we expect all technical headers to be
	// valid. If not, then we stop processing response without giving user
	// ability to read non-technical headers. That is, we do not distinguish
	// technical errors (such as parsing error) and protocol errors.
	var headerSeen byte
	for {
		line, e := readLine(br)
		if e != nil {
			err = e
			return
		}
		if len(line) == 0 {
			// Blank line, no more lines to read.
			break
		}

		k, v, ok := httpParseHeaderLine(line)
		if !ok {
			err = ErrMalformedResponse
			return
		}

		switch btsToString(k) {
		case headerUpgradeCanonical:
			headerSeen |= headerSeenUpgrade
			if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
				err = ErrHandshakeBadUpgrade
				return
			}

		case headerConnectionCanonical:
			headerSeen |= headerSeenConnection
			// Note that as RFC6455 says:
			//   > A |Connection| header field with value "Upgrade".
			// That is, in server side, "Connection" header could contain
			// multiple token. But in response it must contains exactly one.
			if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
				err = ErrHandshakeBadConnection
				return
			}

		case headerSecAcceptCanonical:
			headerSeen |= headerSeenSecAccept
			if !checkAcceptFromNonce(v, nonce) {
				err = ErrHandshakeBadSecAccept
				return
			}

		case headerSecProtocolCanonical:
			// RFC6455 1.3:
			//   "The server selects one or none of the acceptable protocols
			//   and echoes that value in its handshake to indicate that it has
			//   selected that protocol."
			for _, want := range d.Protocols {
				if string(v) == want {
					hs.Protocol = want
					break
				}
			}
			if hs.Protocol == "" {
				// Server echoed subprotocol that is not present in client
				// requested protocols.
				err = ErrHandshakeBadSubProtocol
				return
			}

		case headerSecExtensionsCanonical:
			hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
			if err != nil {
				return
			}

		default:
			if onHeader := d.OnHeader; onHeader != nil {
				if e := onHeader(k, v); e != nil {
					err = e
					return
				}
			}
		}
	}
	if err == nil && headerSeen != headerSeenAll {
		switch {
		case headerSeen&headerSeenUpgrade == 0:
			err = ErrHandshakeBadUpgrade
		case headerSeen&headerSeenConnection == 0:
			err = ErrHandshakeBadConnection
		case headerSeen&headerSeenSecAccept == 0:
			err = ErrHandshakeBadSecAccept
		default:
			panic("unknown headers state")
		}
	}
	return
}

// PutReader returns bufio.Reader instance to the inner reuse pool.
// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
// contains unprocessed buffered data, that was sent by the server quickly
// right after handshake.
func PutReader(br *bufio.Reader) {
	pbufio.PutReader(br)
}

// StatusError contains an unexpected status-line code from the server.
type StatusError int

func (s StatusError) Error() string {
	return "unexpected HTTP response status: " + strconv.Itoa(int(s))
}

func isTimeoutError(err error) bool {
	t, ok := err.(net.Error)
	return ok && t.Timeout()
}

func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
	if len(selected) == 0 {
		return received, nil
	}
	var (
		index  int
		option httphead.Option
		err    error
	)
	index = -1
	match := func() (ok bool) {
		for _, want := range wanted {
			if option.Equal(want) {
				// Check parsed extension to be present in client
				// requested extensions. We move matched extension
				// from client list to avoid allocation.
				received = append(received, want)
				return true
			}
		}
		return false
	}
	ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
		if i != index {
			// Met next option.
			index = i
			if i != 0 && !match() {
				// Server returned non-requested extension.
				err = ErrHandshakeBadExtensions
				return httphead.ControlBreak
			}
			option = httphead.Option{Name: name}
		}
		if attr != nil {
			option.Parameters.Set(attr, val)
		}
		return httphead.ControlContinue
	})
	if !ok {
		err = ErrMalformedResponse
		return received, err
	}
	if !match() {
		return received, ErrHandshakeBadExtensions
	}
	return received, err
}

// setupContextDeadliner is a helper function that starts connection I/O
// interrupter goroutine.
//
// Started goroutine calls SetDeadline() with long time ago value when context
// become expired to make any I/O operations failed. It returns done function
// that stops started goroutine and maps error received from conn I/O methods
// to possible context expiration error.
//
// In concern with possible SetDeadline() call inside interrupter goroutine,
// caller passes pointer to its I/O error (even if it is nil) to done(&err).
// That is, even if I/O error is nil, context could be already expired and
// connection "poisoned" by SetDeadline() call. In that case done(&err) will
// store at *err ctx.Err() result. If err is caused not by timeout, it will
// leaved untouched.
func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
	var (
		quit      = make(chan struct{})
		interrupt = make(chan error, 1)
	)
	go func() {
		select {
		case <-quit:
			interrupt <- nil
		case <-ctx.Done():
			// Cancel i/o immediately.
			conn.SetDeadline(aLongTimeAgo)
			interrupt <- ctx.Err()
		}
	}()
	return func(err *error) {
		close(quit)
		// If ctx.Err() is non-nil and the original err is net.Error with
		// Timeout() == true, then it means that I/O was canceled by us by
		// SetDeadline(aLongTimeAgo) call, or by somebody else previously
		// by conn.SetDeadline(x).
		//
		// Even on race condition when both deadlines are expired
		// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
		// be returned.
		if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
			*err = ctxErr
		}
	}
}