469 lines
12 KiB
Go
469 lines
12 KiB
Go
|
package ws
|
|||
|
|
|||
|
import (
|
|||
|
"bufio"
|
|||
|
"bytes"
|
|||
|
"io"
|
|||
|
"net/http"
|
|||
|
"net/textproto"
|
|||
|
"net/url"
|
|||
|
"strconv"
|
|||
|
|
|||
|
"github.com/gobwas/httphead"
|
|||
|
)
|
|||
|
|
|||
|
const (
|
|||
|
crlf = "\r\n"
|
|||
|
colonAndSpace = ": "
|
|||
|
commaAndSpace = ", "
|
|||
|
)
|
|||
|
|
|||
|
const (
|
|||
|
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
|||
|
)
|
|||
|
|
|||
|
var (
|
|||
|
textHeadBadRequest = statusText(http.StatusBadRequest)
|
|||
|
textHeadInternalServerError = statusText(http.StatusInternalServerError)
|
|||
|
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
|
|||
|
|
|||
|
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
|
|||
|
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
|
|||
|
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
|
|||
|
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
|
|||
|
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
|
|||
|
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
|
|||
|
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
|
|||
|
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
|
|||
|
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
|
|||
|
)
|
|||
|
|
|||
|
var (
|
|||
|
headerHost = "Host"
|
|||
|
headerUpgrade = "Upgrade"
|
|||
|
headerConnection = "Connection"
|
|||
|
headerSecVersion = "Sec-WebSocket-Version"
|
|||
|
headerSecProtocol = "Sec-WebSocket-Protocol"
|
|||
|
headerSecExtensions = "Sec-WebSocket-Extensions"
|
|||
|
headerSecKey = "Sec-WebSocket-Key"
|
|||
|
headerSecAccept = "Sec-WebSocket-Accept"
|
|||
|
|
|||
|
headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
|
|||
|
headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
|
|||
|
headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
|
|||
|
headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
|
|||
|
headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
|
|||
|
headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
|
|||
|
headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
|
|||
|
headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
|
|||
|
)
|
|||
|
|
|||
|
var (
|
|||
|
specHeaderValueUpgrade = []byte("websocket")
|
|||
|
specHeaderValueConnection = []byte("Upgrade")
|
|||
|
specHeaderValueConnectionLower = []byte("upgrade")
|
|||
|
specHeaderValueSecVersion = []byte("13")
|
|||
|
)
|
|||
|
|
|||
|
var (
|
|||
|
httpVersion1_0 = []byte("HTTP/1.0")
|
|||
|
httpVersion1_1 = []byte("HTTP/1.1")
|
|||
|
httpVersionPrefix = []byte("HTTP/")
|
|||
|
)
|
|||
|
|
|||
|
type httpRequestLine struct {
|
|||
|
method, uri []byte
|
|||
|
major, minor int
|
|||
|
}
|
|||
|
|
|||
|
type httpResponseLine struct {
|
|||
|
major, minor int
|
|||
|
status int
|
|||
|
reason []byte
|
|||
|
}
|
|||
|
|
|||
|
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
|
|||
|
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
|
|||
|
var proto []byte
|
|||
|
req.method, req.uri, proto = bsplit3(line, ' ')
|
|||
|
|
|||
|
var ok bool
|
|||
|
req.major, req.minor, ok = httpParseVersion(proto)
|
|||
|
if !ok {
|
|||
|
err = ErrMalformedRequest
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
|
|||
|
var (
|
|||
|
proto []byte
|
|||
|
status []byte
|
|||
|
)
|
|||
|
proto, status, resp.reason = bsplit3(line, ' ')
|
|||
|
|
|||
|
var ok bool
|
|||
|
resp.major, resp.minor, ok = httpParseVersion(proto)
|
|||
|
if !ok {
|
|||
|
return resp, ErrMalformedResponse
|
|||
|
}
|
|||
|
|
|||
|
var convErr error
|
|||
|
resp.status, convErr = asciiToInt(status)
|
|||
|
if convErr != nil {
|
|||
|
return resp, ErrMalformedResponse
|
|||
|
}
|
|||
|
|
|||
|
return resp, nil
|
|||
|
}
|
|||
|
|
|||
|
// httpParseVersion parses major and minor version of HTTP protocol. It returns
|
|||
|
// parsed values and true if parse is ok.
|
|||
|
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
|
|||
|
switch {
|
|||
|
case bytes.Equal(bts, httpVersion1_0):
|
|||
|
return 1, 0, true
|
|||
|
case bytes.Equal(bts, httpVersion1_1):
|
|||
|
return 1, 1, true
|
|||
|
case len(bts) < 8:
|
|||
|
return
|
|||
|
case !bytes.Equal(bts[:5], httpVersionPrefix):
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
bts = bts[5:]
|
|||
|
|
|||
|
dot := bytes.IndexByte(bts, '.')
|
|||
|
if dot == -1 {
|
|||
|
return
|
|||
|
}
|
|||
|
var err error
|
|||
|
major, err = asciiToInt(bts[:dot])
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
minor, err = asciiToInt(bts[dot+1:])
|
|||
|
if err != nil {
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
return major, minor, true
|
|||
|
}
|
|||
|
|
|||
|
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
|
|||
|
// values and true if parse is ok.
|
|||
|
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
|
|||
|
colon := bytes.IndexByte(line, ':')
|
|||
|
if colon == -1 {
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
k = btrim(line[:colon])
|
|||
|
// TODO(gobwas): maybe use just lower here?
|
|||
|
canonicalizeHeaderKey(k)
|
|||
|
|
|||
|
v = btrim(line[colon+1:])
|
|||
|
|
|||
|
return k, v, true
|
|||
|
}
|
|||
|
|
|||
|
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
|
|||
|
// that key is already canonical. This helps to increase performance.
|
|||
|
func httpGetHeader(h http.Header, key string) string {
|
|||
|
if h == nil {
|
|||
|
return ""
|
|||
|
}
|
|||
|
v := h[key]
|
|||
|
if len(v) == 0 {
|
|||
|
return ""
|
|||
|
}
|
|||
|
return v[0]
|
|||
|
}
|
|||
|
|
|||
|
// The request MAY include a header field with the name
|
|||
|
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
|
|||
|
// comma-separated subprotocol the client wishes to speak, ordered by
|
|||
|
// preference. The elements that comprise this value MUST be non-empty strings
|
|||
|
// with characters in the range U+0021 to U+007E not including separator
|
|||
|
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
|
|||
|
// for the value of this header field is 1#token, where the definitions of
|
|||
|
// constructs and rules are as given in [RFC2616].
|
|||
|
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
|
|||
|
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
|
|||
|
if check(btsToString(v)) {
|
|||
|
ret = string(v)
|
|||
|
return false
|
|||
|
}
|
|||
|
return true
|
|||
|
})
|
|||
|
return
|
|||
|
}
|
|||
|
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
|
|||
|
var selected []byte
|
|||
|
ok = httphead.ScanTokens(h, func(v []byte) bool {
|
|||
|
if check(v) {
|
|||
|
selected = v
|
|||
|
return false
|
|||
|
}
|
|||
|
return true
|
|||
|
})
|
|||
|
if ok && selected != nil {
|
|||
|
return string(selected), true
|
|||
|
}
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
|||
|
return btsSelectExtensions(strToBytes(h), selected, check)
|
|||
|
}
|
|||
|
|
|||
|
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
|||
|
s := httphead.OptionSelector{
|
|||
|
Flags: httphead.SelectUnique | httphead.SelectCopy,
|
|||
|
Check: check,
|
|||
|
}
|
|||
|
return s.Select(h, selected)
|
|||
|
}
|
|||
|
|
|||
|
func httpWriteHeader(bw *bufio.Writer, key, value string) {
|
|||
|
httpWriteHeaderKey(bw, key)
|
|||
|
bw.WriteString(value)
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
|
|||
|
httpWriteHeaderKey(bw, key)
|
|||
|
bw.Write(value)
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
|
|||
|
bw.WriteString(key)
|
|||
|
bw.WriteString(colonAndSpace)
|
|||
|
}
|
|||
|
|
|||
|
func httpWriteUpgradeRequest(
|
|||
|
bw *bufio.Writer,
|
|||
|
u *url.URL,
|
|||
|
nonce []byte,
|
|||
|
protocols []string,
|
|||
|
extensions []httphead.Option,
|
|||
|
header HandshakeHeader,
|
|||
|
) {
|
|||
|
bw.WriteString("GET ")
|
|||
|
bw.WriteString(u.RequestURI())
|
|||
|
bw.WriteString(" HTTP/1.1\r\n")
|
|||
|
|
|||
|
httpWriteHeader(bw, headerHost, u.Host)
|
|||
|
|
|||
|
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
|
|||
|
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
|
|||
|
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
|
|||
|
|
|||
|
// NOTE: write nonce bytes as a string to prevent heap allocation –
|
|||
|
// WriteString() copy given string into its inner buffer, unlike Write()
|
|||
|
// which may write p directly to the underlying io.Writer – which in turn
|
|||
|
// will lead to p escape.
|
|||
|
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
|
|||
|
|
|||
|
if len(protocols) > 0 {
|
|||
|
httpWriteHeaderKey(bw, headerSecProtocol)
|
|||
|
for i, p := range protocols {
|
|||
|
if i > 0 {
|
|||
|
bw.WriteString(commaAndSpace)
|
|||
|
}
|
|||
|
bw.WriteString(p)
|
|||
|
}
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
if len(extensions) > 0 {
|
|||
|
httpWriteHeaderKey(bw, headerSecExtensions)
|
|||
|
httphead.WriteOptions(bw, extensions)
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
if header != nil {
|
|||
|
header.WriteTo(bw)
|
|||
|
}
|
|||
|
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
|
|||
|
bw.WriteString(textHeadUpgrade)
|
|||
|
|
|||
|
httpWriteHeaderKey(bw, headerSecAccept)
|
|||
|
writeAccept(bw, nonce)
|
|||
|
bw.WriteString(crlf)
|
|||
|
|
|||
|
if hs.Protocol != "" {
|
|||
|
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
|
|||
|
}
|
|||
|
if len(hs.Extensions) > 0 {
|
|||
|
httpWriteHeaderKey(bw, headerSecExtensions)
|
|||
|
httphead.WriteOptions(bw, hs.Extensions)
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
if header != nil {
|
|||
|
header(bw)
|
|||
|
}
|
|||
|
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
|
|||
|
switch code {
|
|||
|
case http.StatusBadRequest:
|
|||
|
bw.WriteString(textHeadBadRequest)
|
|||
|
case http.StatusInternalServerError:
|
|||
|
bw.WriteString(textHeadInternalServerError)
|
|||
|
case http.StatusUpgradeRequired:
|
|||
|
bw.WriteString(textHeadUpgradeRequired)
|
|||
|
default:
|
|||
|
writeStatusText(bw, code)
|
|||
|
}
|
|||
|
|
|||
|
// Write custom headers.
|
|||
|
if header != nil {
|
|||
|
header(bw)
|
|||
|
}
|
|||
|
|
|||
|
switch err {
|
|||
|
case ErrHandshakeBadProtocol:
|
|||
|
bw.WriteString(textTailErrHandshakeBadProtocol)
|
|||
|
case ErrHandshakeBadMethod:
|
|||
|
bw.WriteString(textTailErrHandshakeBadMethod)
|
|||
|
case ErrHandshakeBadHost:
|
|||
|
bw.WriteString(textTailErrHandshakeBadHost)
|
|||
|
case ErrHandshakeBadUpgrade:
|
|||
|
bw.WriteString(textTailErrHandshakeBadUpgrade)
|
|||
|
case ErrHandshakeBadConnection:
|
|||
|
bw.WriteString(textTailErrHandshakeBadConnection)
|
|||
|
case ErrHandshakeBadSecAccept:
|
|||
|
bw.WriteString(textTailErrHandshakeBadSecAccept)
|
|||
|
case ErrHandshakeBadSecKey:
|
|||
|
bw.WriteString(textTailErrHandshakeBadSecKey)
|
|||
|
case ErrHandshakeBadSecVersion:
|
|||
|
bw.WriteString(textTailErrHandshakeBadSecVersion)
|
|||
|
case ErrHandshakeUpgradeRequired:
|
|||
|
bw.WriteString(textTailErrUpgradeRequired)
|
|||
|
case nil:
|
|||
|
bw.WriteString(crlf)
|
|||
|
default:
|
|||
|
writeErrorText(bw, err)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
func writeStatusText(bw *bufio.Writer, code int) {
|
|||
|
bw.WriteString("HTTP/1.1 ")
|
|||
|
bw.WriteString(strconv.Itoa(code))
|
|||
|
bw.WriteByte(' ')
|
|||
|
bw.WriteString(http.StatusText(code))
|
|||
|
bw.WriteString(crlf)
|
|||
|
bw.WriteString("Content-Type: text/plain; charset=utf-8")
|
|||
|
bw.WriteString(crlf)
|
|||
|
}
|
|||
|
|
|||
|
func writeErrorText(bw *bufio.Writer, err error) {
|
|||
|
body := err.Error()
|
|||
|
bw.WriteString("Content-Length: ")
|
|||
|
bw.WriteString(strconv.Itoa(len(body)))
|
|||
|
bw.WriteString(crlf)
|
|||
|
bw.WriteString(crlf)
|
|||
|
bw.WriteString(body)
|
|||
|
}
|
|||
|
|
|||
|
// httpError is like the http.Error with WebSocket context exception.
|
|||
|
func httpError(w http.ResponseWriter, body string, code int) {
|
|||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|||
|
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
|||
|
w.WriteHeader(code)
|
|||
|
w.Write([]byte(body))
|
|||
|
}
|
|||
|
|
|||
|
// statusText is a non-performant status text generator.
|
|||
|
// NOTE: Used only to generate constants.
|
|||
|
func statusText(code int) string {
|
|||
|
var buf bytes.Buffer
|
|||
|
bw := bufio.NewWriter(&buf)
|
|||
|
writeStatusText(bw, code)
|
|||
|
bw.Flush()
|
|||
|
return buf.String()
|
|||
|
}
|
|||
|
|
|||
|
// errorText is a non-performant error text generator.
|
|||
|
// NOTE: Used only to generate constants.
|
|||
|
func errorText(err error) string {
|
|||
|
var buf bytes.Buffer
|
|||
|
bw := bufio.NewWriter(&buf)
|
|||
|
writeErrorText(bw, err)
|
|||
|
bw.Flush()
|
|||
|
return buf.String()
|
|||
|
}
|
|||
|
|
|||
|
// HandshakeHeader is the interface that writes both upgrade request or
|
|||
|
// response headers into a given io.Writer.
|
|||
|
type HandshakeHeader interface {
|
|||
|
io.WriterTo
|
|||
|
}
|
|||
|
|
|||
|
// HandshakeHeaderString is an adapter to allow the use of headers represented
|
|||
|
// by ordinary string as HandshakeHeader.
|
|||
|
type HandshakeHeaderString string
|
|||
|
|
|||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
|||
|
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
|
|||
|
n, err := io.WriteString(w, string(s))
|
|||
|
return int64(n), err
|
|||
|
}
|
|||
|
|
|||
|
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
|
|||
|
// by ordinary slice of bytes as HandshakeHeader.
|
|||
|
type HandshakeHeaderBytes []byte
|
|||
|
|
|||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
|||
|
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
|
|||
|
n, err := w.Write(b)
|
|||
|
return int64(n), err
|
|||
|
}
|
|||
|
|
|||
|
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
|
|||
|
// ordinary function as HandshakeHeader.
|
|||
|
type HandshakeHeaderFunc func(io.Writer) (int64, error)
|
|||
|
|
|||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
|||
|
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
|
|||
|
return f(w)
|
|||
|
}
|
|||
|
|
|||
|
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
|
|||
|
// HandshakeHeader.
|
|||
|
type HandshakeHeaderHTTP http.Header
|
|||
|
|
|||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
|||
|
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
|
|||
|
wr := writer{w: w}
|
|||
|
err := http.Header(h).Write(&wr)
|
|||
|
return wr.n, err
|
|||
|
}
|
|||
|
|
|||
|
type writer struct {
|
|||
|
n int64
|
|||
|
w io.Writer
|
|||
|
}
|
|||
|
|
|||
|
func (w *writer) WriteString(s string) (int, error) {
|
|||
|
n, err := io.WriteString(w.w, s)
|
|||
|
w.n += int64(n)
|
|||
|
return n, err
|
|||
|
}
|
|||
|
|
|||
|
func (w *writer) Write(p []byte) (int, error) {
|
|||
|
n, err := w.w.Write(p)
|
|||
|
w.n += int64(n)
|
|||
|
return n, err
|
|||
|
}
|