diff --git a/connection/header.go b/connection/header.go index f53e4483..77b69ab8 100644 --- a/connection/header.go +++ b/connection/header.go @@ -8,16 +8,14 @@ import ( ) const ( - responseMetaHeaderField = "cf-cloudflared-response-meta" - responseSourceCloudflared = "cloudflared" - responseSourceOrigin = "origin" + responseMetaHeaderField = "cf-cloudflared-response-meta" ) var ( canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField) canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(responseMetaHeaderField) - responseMetaHeaderCfd = mustInitRespMetaHeader(responseSourceCloudflared) - responseMetaHeaderOrigin = mustInitRespMetaHeader(responseSourceOrigin) + responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared") + responseMetaHeaderOrigin = mustInitRespMetaHeader("origin") ) type responseMetaHeader struct { diff --git a/go.mod b/go.mod index 8b94040c..9a1b780f 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,9 @@ require ( github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10 github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0 github.com/go-sql-driver/mysql v1.5.0 + github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/gobwas/ws v1.0.4 github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 github.com/google/go-cmp v0.5.2 // indirect github.com/google/uuid v1.1.2 diff --git a/go.sum b/go.sum index c3663f73..cd99d28e 100644 --- a/go.sum +++ b/go.sum @@ -233,6 +233,12 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ= +github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs= +github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= diff --git a/hello/hello.go b/hello/hello.go index a2b12798..63d1830e 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -18,6 +18,11 @@ import ( "github.com/cloudflare/cloudflared/tlsconfig" ) +const ( + UptimeRoute = "/uptime" + WSRoute = "/ws" +) + type templateData struct { ServerName string Request *http.Request @@ -104,8 +109,8 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow } muxer := http.NewServeMux() - muxer.HandleFunc("/uptime", uptimeHandler(time.Now())) - muxer.HandleFunc("/ws", websocketHandler(logger, upgrader)) + muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now())) + muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader)) muxer.HandleFunc("/", rootHandler(serverName)) httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer} go func() { diff --git a/origin/proxy.go b/origin/proxy.go index 638aee92..6042779e 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -2,6 +2,7 @@ package origin import ( "bufio" + "context" "crypto/tls" "io" "net/http" @@ -112,12 +113,17 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { c.setHostHeader(req) - conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig) if err != nil { return nil, err } - defer conn.Close() + + serveCtx, cancel := context.WithCancel(req.Context()) + defer cancel() + go func() { + <-serveCtx.Done() + conn.Close() + }() err = w.WriteRespHeaders(resp) if err != nil { return nil, errors.Wrap(err, "Error writing response header") @@ -125,7 +131,6 @@ func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) // Copy to/from stream to the undelying connection. Use the underlying // connection because cloudflared doesn't operate on the message themselves websocket.Stream(conn.UnderlyingConn(), w) - return resp, nil } diff --git a/origin/proxy_test.go b/origin/proxy_test.go new file mode 100644 index 00000000..3d43e997 --- /dev/null +++ b/origin/proxy_test.go @@ -0,0 +1,169 @@ +package origin + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tlsconfig" + + "github.com/gobwas/ws/wsutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockHTTPRespWriter struct { + *httptest.ResponseRecorder +} + +func newMockHTTPRespWriter() *mockHTTPRespWriter { + return &mockHTTPRespWriter{ + httptest.NewRecorder(), + } +} + +func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error { + w.WriteHeader(resp.StatusCode) + for header, val := range resp.Header { + w.Header()[header] = val + } + return nil +} + +func (w *mockHTTPRespWriter) WriteErrorResponse(err error) { + w.WriteHeader(http.StatusBadGateway) +} + +func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { + return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader") +} + +type mockWSRespWriter struct { + *mockHTTPRespWriter + writeNotification chan []byte + reader io.Reader +} + +func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter { + return &mockWSRespWriter{ + httpRespWriter, + make(chan []byte), + reader, + } +} + +func (w *mockWSRespWriter) Write(data []byte) (int, error) { + w.writeNotification <- data + return len(data), nil +} + +func (w *mockWSRespWriter) respBody() io.ReadWriter { + data := <-w.writeNotification + return bytes.NewBuffer(data) +} + +func (w *mockWSRespWriter) Read(data []byte) (int, error) { + return w.reader.Read(data) +} + +func TestProxy(t *testing.T) { + logger, err := logger.New() + require.NoError(t, err) + // let runtime pick an available port + listener, err := hello.CreateTLSListener("127.0.0.1:0") + require.NoError(t, err) + + originURL := &url.URL{ + Scheme: "https", + Host: listener.Addr().String(), + } + originCA := x509.NewCertPool() + helloCert, err := tlsconfig.GetHelloCertificateX509() + require.NoError(t, err) + originCA.AddCert(helloCert) + clientTLS := &tls.Config{ + RootCAs: originCA, + } + proxyConfig := &ProxyConfig{ + Client: &http.Transport{ + TLSClientConfig: clientTLS, + }, + URL: originURL, + TLSConfig: clientTLS, + } + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + hello.StartHelloWorldServer(logger, listener, ctx.Done()) + }() + + client := NewClient(proxyConfig, logger) + t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL)) + t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS)) + cancel() +} + +func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) { + return func(t *testing.T) { + respWriter := newMockHTTPRespWriter() + req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) + require.NoError(t, err) + + err = client.Proxy(respWriter, req, false) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, respWriter.Code) + } +} + +func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) { + return func(t *testing.T) { + // WSRoute is a websocket echo handler + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil) + + readPipe, writePipe := io.Pipe() + respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err = client.Proxy(respWriter, req, true) + require.NoError(t, err) + + require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) + }() + + msg := []byte("test websocket") + err = wsutil.WriteClientText(writePipe, msg) + require.NoError(t, err) + + // ReadServerText reads next data message from rw, considering that caller represents client side. + returnedMsg, err := wsutil.ReadServerText(respWriter.respBody()) + require.NoError(t, err) + require.Equal(t, msg, returnedMsg) + + err = wsutil.WriteClientBinary(writePipe, msg) + require.NoError(t, err) + + returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody()) + require.NoError(t, err) + require.Equal(t, msg, returnedMsg) + + cancel() + wg.Wait() + } +} diff --git a/vendor/github.com/gobwas/httphead/LICENSE b/vendor/github.com/gobwas/httphead/LICENSE new file mode 100644 index 00000000..27443176 --- /dev/null +++ b/vendor/github.com/gobwas/httphead/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017 Sergey Kamardin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/gobwas/httphead/README.md b/vendor/github.com/gobwas/httphead/README.md new file mode 100644 index 00000000..67a97fdb --- /dev/null +++ b/vendor/github.com/gobwas/httphead/README.md @@ -0,0 +1,63 @@ +# httphead.[go](https://golang.org) + +[![GoDoc][godoc-image]][godoc-url] + +> Tiny HTTP header value parsing library in go. + +## Overview + +This library contains low-level functions for scanning HTTP RFC2616 compatible header value grammars. + +## Install + +```shell + go get github.com/gobwas/httphead +``` + +## Example + +The example below shows how multiple-choise HTTP header value could be parsed with this library: + +```go + options, ok := httphead.ParseOptions([]byte(`foo;bar=1,baz`), nil) + fmt.Println(options, ok) + // Output: [{foo map[bar:1]} {baz map[]}] true +``` + +The low-level example below shows how to optimize keys skipping and selection +of some key: + +```go + // The right part of full header line like: + // X-My-Header: key;foo=bar;baz,key;baz + header := []byte(`foo;a=0,foo;a=1,foo;a=2,foo;a=3`) + + // We want to search key "foo" with an "a" parameter that equal to "2". + var ( + foo = []byte(`foo`) + a = []byte(`a`) + v = []byte(`2`) + ) + var found bool + httphead.ScanOptions(header, func(i int, key, param, value []byte) Control { + if !bytes.Equal(key, foo) { + return ControlSkip + } + if !bytes.Equal(param, a) { + if bytes.Equal(value, v) { + // Found it! + found = true + return ControlBreak + } + return ControlSkip + } + return ControlContinue + }) +``` + +For more usage examples please see [docs][godoc-url] or package tests. + +[godoc-image]: https://godoc.org/github.com/gobwas/httphead?status.svg +[godoc-url]: https://godoc.org/github.com/gobwas/httphead +[travis-image]: https://travis-ci.org/gobwas/httphead.svg?branch=master +[travis-url]: https://travis-ci.org/gobwas/httphead diff --git a/vendor/github.com/gobwas/httphead/cookie.go b/vendor/github.com/gobwas/httphead/cookie.go new file mode 100644 index 00000000..05c9a1fb --- /dev/null +++ b/vendor/github.com/gobwas/httphead/cookie.go @@ -0,0 +1,200 @@ +package httphead + +import ( + "bytes" +) + +// ScanCookie scans cookie pairs from data using DefaultCookieScanner.Scan() +// method. +func ScanCookie(data []byte, it func(key, value []byte) bool) bool { + return DefaultCookieScanner.Scan(data, it) +} + +// DefaultCookieScanner is a CookieScanner which is used by ScanCookie(). +// Note that it is intended to have the same behavior as http.Request.Cookies() +// has. +var DefaultCookieScanner = CookieScanner{} + +// CookieScanner contains options for scanning cookie pairs. +// See https://tools.ietf.org/html/rfc6265#section-4.1.1 +type CookieScanner struct { + // DisableNameValidation disables name validation of a cookie. If false, + // only RFC2616 "tokens" are accepted. + DisableNameValidation bool + + // DisableValueValidation disables value validation of a cookie. If false, + // only RFC6265 "cookie-octet" characters are accepted. + // + // Note that Strict option also affects validation of a value. + // + // If Strict is false, then scanner begins to allow space and comma + // characters inside the value for better compatibility with non standard + // cookies implementations. + DisableValueValidation bool + + // BreakOnPairError sets scanner to immediately return after first pair syntax + // validation error. + // If false, scanner will try to skip invalid pair bytes and go ahead. + BreakOnPairError bool + + // Strict enables strict RFC6265 mode scanning. It affects name and value + // validation, as also some other rules. + // If false, it is intended to bring the same behavior as + // http.Request.Cookies(). + Strict bool +} + +// Scan maps data to name and value pairs. Usually data represents value of the +// Cookie header. +func (c CookieScanner) Scan(data []byte, it func(name, value []byte) bool) bool { + lexer := &Scanner{data: data} + + const ( + statePair = iota + stateBefore + ) + + state := statePair + + for lexer.Buffered() > 0 { + switch state { + case stateBefore: + // Pairs separated by ";" and space, according to the RFC6265: + // cookie-pair *( ";" SP cookie-pair ) + // + // Cookie pairs MUST be separated by (";" SP). So our only option + // here is to fail as syntax error. + a, b := lexer.Peek2() + if a != ';' { + return false + } + + state = statePair + + advance := 1 + if b == ' ' { + advance++ + } else if c.Strict { + return false + } + + lexer.Advance(advance) + + case statePair: + if !lexer.FetchUntil(';') { + return false + } + + var value []byte + name := lexer.Bytes() + if i := bytes.IndexByte(name, '='); i != -1 { + value = name[i+1:] + name = name[:i] + } else if c.Strict { + if !c.BreakOnPairError { + goto nextPair + } + return false + } + + if !c.Strict { + trimLeft(name) + } + if !c.DisableNameValidation && !ValidCookieName(name) { + if !c.BreakOnPairError { + goto nextPair + } + return false + } + + if !c.Strict { + value = trimRight(value) + } + value = stripQuotes(value) + if !c.DisableValueValidation && !ValidCookieValue(value, c.Strict) { + if !c.BreakOnPairError { + goto nextPair + } + return false + } + + if !it(name, value) { + return true + } + + nextPair: + state = stateBefore + } + } + + return true +} + +// ValidCookieValue reports whether given value is a valid RFC6265 +// "cookie-octet" bytes. +// +// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E +// ; US-ASCII characters excluding CTLs, +// ; whitespace DQUOTE, comma, semicolon, +// ; and backslash +// +// Note that the false strict parameter disables errors on space 0x20 and comma +// 0x2c. This could be useful to bring some compatibility with non-compliant +// clients/servers in the real world. +// It acts the same as standard library cookie parser if strict is false. +func ValidCookieValue(value []byte, strict bool) bool { + if len(value) == 0 { + return true + } + for _, c := range value { + switch c { + case '"', ';', '\\': + return false + case ',', ' ': + if strict { + return false + } + default: + if c <= 0x20 { + return false + } + if c >= 0x7f { + return false + } + } + } + return true +} + +// ValidCookieName reports wheter given bytes is a valid RFC2616 "token" bytes. +func ValidCookieName(name []byte) bool { + for _, c := range name { + if !OctetTypes[c].IsToken() { + return false + } + } + return true +} + +func stripQuotes(bts []byte) []byte { + if last := len(bts) - 1; last > 0 && bts[0] == '"' && bts[last] == '"' { + return bts[1:last] + } + return bts +} + +func trimLeft(p []byte) []byte { + var i int + for i < len(p) && OctetTypes[p[i]].IsSpace() { + i++ + } + return p[i:] +} + +func trimRight(p []byte) []byte { + j := len(p) + for j > 0 && OctetTypes[p[j-1]].IsSpace() { + j-- + } + return p[:j] +} diff --git a/vendor/github.com/gobwas/httphead/head.go b/vendor/github.com/gobwas/httphead/head.go new file mode 100644 index 00000000..a50e907d --- /dev/null +++ b/vendor/github.com/gobwas/httphead/head.go @@ -0,0 +1,275 @@ +package httphead + +import ( + "bufio" + "bytes" +) + +// Version contains protocol major and minor version. +type Version struct { + Major int + Minor int +} + +// RequestLine contains parameters parsed from the first request line. +type RequestLine struct { + Method []byte + URI []byte + Version Version +} + +// ResponseLine contains parameters parsed from the first response line. +type ResponseLine struct { + Version Version + Status int + Reason []byte +} + +// SplitRequestLine splits given slice of bytes into three chunks without +// parsing. +func SplitRequestLine(line []byte) (method, uri, version []byte) { + return split3(line, ' ') +} + +// ParseRequestLine parses http request line like "GET / HTTP/1.0". +func ParseRequestLine(line []byte) (r RequestLine, ok bool) { + var i int + for i = 0; i < len(line); i++ { + c := line[i] + if !OctetTypes[c].IsToken() { + if i > 0 && c == ' ' { + break + } + return + } + } + if i == len(line) { + return + } + + var proto []byte + r.Method = line[:i] + r.URI, proto = split2(line[i+1:], ' ') + if len(r.URI) == 0 { + return + } + if major, minor, ok := ParseVersion(proto); ok { + r.Version.Major = major + r.Version.Minor = minor + return r, true + } + + return r, false +} + +// SplitResponseLine splits given slice of bytes into three chunks without +// parsing. +func SplitResponseLine(line []byte) (version, status, reason []byte) { + return split3(line, ' ') +} + +// ParseResponseLine parses first response line into ResponseLine struct. +func ParseResponseLine(line []byte) (r ResponseLine, ok bool) { + var ( + proto []byte + status []byte + ) + proto, status, r.Reason = split3(line, ' ') + if major, minor, ok := ParseVersion(proto); ok { + r.Version.Major = major + r.Version.Minor = minor + } else { + return r, false + } + if n, ok := IntFromASCII(status); ok { + r.Status = n + } else { + return r, false + } + // TODO(gobwas): parse here r.Reason fot TEXT rule: + // TEXT = + return r, true +} + +var ( + httpVersion10 = []byte("HTTP/1.0") + httpVersion11 = []byte("HTTP/1.1") + httpVersionPrefix = []byte("HTTP/") +) + +// ParseVersion parses major and minor version of HTTP protocol. +// It returns parsed values and true if parse is ok. +func ParseVersion(bts []byte) (major, minor int, ok bool) { + switch { + case bytes.Equal(bts, httpVersion11): + return 1, 1, true + case bytes.Equal(bts, httpVersion10): + return 1, 0, true + case len(bts) < 8: + return + case !bytes.Equal(bts[:5], httpVersionPrefix): + return + } + + bts = bts[5:] + + dot := bytes.IndexByte(bts, '.') + if dot == -1 { + return + } + major, ok = IntFromASCII(bts[:dot]) + if !ok { + return + } + minor, ok = IntFromASCII(bts[dot+1:]) + if !ok { + return + } + + return major, minor, true +} + +// ReadLine reads line from br. It reads until '\n' and returns bytes without +// '\n' or '\r\n' at the end. +// It returns err if and only if line does not end in '\n'. Note that read +// bytes returned in any case of error. +// +// It is much like the textproto/Reader.ReadLine() except the thing that it +// returns raw bytes, instead of string. That is, it avoids copying bytes read +// from br. +// +// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be +// safe with future I/O operations on br. +// +// We could control I/O operations on br and do not need to make additional +// copy for safety. +func ReadLine(br *bufio.Reader) ([]byte, error) { + var line []byte + for { + bts, err := br.ReadSlice('\n') + if err == bufio.ErrBufferFull { + // Copy bytes because next read will discard them. + line = append(line, bts...) + continue + } + // Avoid copy of single read. + if line == nil { + line = bts + } else { + line = append(line, bts...) + } + if err != nil { + return line, err + } + // Size of line is at least 1. + // In other case bufio.ReadSlice() returns error. + n := len(line) + // Cut '\n' or '\r\n'. + if n > 1 && line[n-2] == '\r' { + line = line[:n-2] + } else { + line = line[:n-1] + } + return line, nil + } +} + +// ParseHeaderLine parses HTTP header as key-value pair. It returns parsed +// values and true if parse is ok. +func ParseHeaderLine(line []byte) (k, v []byte, ok bool) { + colon := bytes.IndexByte(line, ':') + if colon == -1 { + return + } + k = trim(line[:colon]) + for _, c := range k { + if !OctetTypes[c].IsToken() { + return nil, nil, false + } + } + v = trim(line[colon+1:]) + return k, v, true +} + +// IntFromASCII converts ascii encoded decimal numeric value from HTTP entities +// to an integer. +func IntFromASCII(bts []byte) (ret int, ok bool) { + // ASCII numbers all start with the high-order bits 0011. + // If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those + // bits and interpret them directly as an integer. + var n int + if n = len(bts); n < 1 { + return 0, false + } + for i := 0; i < n; i++ { + if bts[i]&0xf0 != 0x30 { + return 0, false + } + ret += int(bts[i]&0xf) * pow(10, n-i-1) + } + return ret, true +} + +const ( + toLower = 'a' - 'A' // for use with OR. + toUpper = ^byte(toLower) // for use with AND. +) + +// CanonicalizeHeaderKey is like standard textproto/CanonicalMIMEHeaderKey, +// except that it operates with slice of bytes and modifies it inplace without +// copying. +func CanonicalizeHeaderKey(k []byte) { + upper := true + for i, c := range k { + if upper && 'a' <= c && c <= 'z' { + k[i] &= toUpper + } else if !upper && 'A' <= c && c <= 'Z' { + k[i] |= toLower + } + upper = c == '-' + } +} + +// pow for integers implementation. +// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3 +func pow(a, b int) int { + p := 1 + for b > 0 { + if b&1 != 0 { + p *= a + } + b >>= 1 + a *= a + } + return p +} + +func split3(p []byte, sep byte) (p1, p2, p3 []byte) { + a := bytes.IndexByte(p, sep) + b := bytes.IndexByte(p[a+1:], sep) + if a == -1 || b == -1 { + return p, nil, nil + } + b += a + 1 + return p[:a], p[a+1 : b], p[b+1:] +} + +func split2(p []byte, sep byte) (p1, p2 []byte) { + i := bytes.IndexByte(p, sep) + if i == -1 { + return p, nil + } + return p[:i], p[i+1:] +} + +func trim(p []byte) []byte { + var i, j int + for i = 0; i < len(p) && (p[i] == ' ' || p[i] == '\t'); { + i++ + } + for j = len(p); j > i && (p[j-1] == ' ' || p[j-1] == '\t'); { + j-- + } + return p[i:j] +} diff --git a/vendor/github.com/gobwas/httphead/httphead.go b/vendor/github.com/gobwas/httphead/httphead.go new file mode 100644 index 00000000..2387e803 --- /dev/null +++ b/vendor/github.com/gobwas/httphead/httphead.go @@ -0,0 +1,331 @@ +// Package httphead contains utils for parsing HTTP and HTTP-grammar compatible +// text protocols headers. +// +// That is, this package first aim is to bring ability to easily parse +// constructions, described here https://tools.ietf.org/html/rfc2616#section-2 +package httphead + +import ( + "bytes" + "strings" +) + +// ScanTokens parses data in this form: +// +// list = 1#token +// +// It returns false if data is malformed. +func ScanTokens(data []byte, it func([]byte) bool) bool { + lexer := &Scanner{data: data} + + var ok bool + for lexer.Next() { + switch lexer.Type() { + case ItemToken: + ok = true + if !it(lexer.Bytes()) { + return true + } + case ItemSeparator: + if !isComma(lexer.Bytes()) { + return false + } + default: + return false + } + } + + return ok && !lexer.err +} + +// ParseOptions parses all header options and appends it to given slice of +// Option. It returns flag of successful (wellformed input) parsing. +// +// Note that appended options are all consist of subslices of data. That is, +// mutation of data will mutate appended options. +func ParseOptions(data []byte, options []Option) ([]Option, bool) { + var i int + index := -1 + return options, ScanOptions(data, func(idx int, name, attr, val []byte) Control { + if idx != index { + index = idx + i = len(options) + options = append(options, Option{Name: name}) + } + if attr != nil { + options[i].Parameters.Set(attr, val) + } + return ControlContinue + }) +} + +// SelectFlag encodes way of options selection. +type SelectFlag byte + +// String represetns flag as string. +func (f SelectFlag) String() string { + var flags [2]string + var n int + if f&SelectCopy != 0 { + flags[n] = "copy" + n++ + } + if f&SelectUnique != 0 { + flags[n] = "unique" + n++ + } + return "[" + strings.Join(flags[:n], "|") + "]" +} + +const ( + // SelectCopy causes selector to copy selected option before appending it + // to resulting slice. + // If SelectCopy flag is not passed to selector, then appended options will + // contain sub-slices of the initial data. + SelectCopy SelectFlag = 1 << iota + + // SelectUnique causes selector to append only not yet existing option to + // resulting slice. Unique is checked by comparing option names. + SelectUnique +) + +// OptionSelector contains configuration for selecting Options from header value. +type OptionSelector struct { + // Check is a filter function that applied to every Option that possibly + // could be selected. + // If Check is nil all options will be selected. + Check func(Option) bool + + // Flags contains flags for options selection. + Flags SelectFlag + + // Alloc used to allocate slice of bytes when selector is configured with + // SelectCopy flag. It will be called with number of bytes needed for copy + // of single Option. + // If Alloc is nil make is used. + Alloc func(n int) []byte +} + +// Select parses header data and appends it to given slice of Option. +// It also returns flag of successful (wellformed input) parsing. +func (s OptionSelector) Select(data []byte, options []Option) ([]Option, bool) { + var current Option + var has bool + index := -1 + + alloc := s.Alloc + if alloc == nil { + alloc = defaultAlloc + } + check := s.Check + if check == nil { + check = defaultCheck + } + + ok := ScanOptions(data, func(idx int, name, attr, val []byte) Control { + if idx != index { + if has && check(current) { + if s.Flags&SelectCopy != 0 { + current = current.Copy(alloc(current.Size())) + } + options = append(options, current) + has = false + } + if s.Flags&SelectUnique != 0 { + for i := len(options) - 1; i >= 0; i-- { + if bytes.Equal(options[i].Name, name) { + return ControlSkip + } + } + } + index = idx + current = Option{Name: name} + has = true + } + if attr != nil { + current.Parameters.Set(attr, val) + } + + return ControlContinue + }) + if has && check(current) { + if s.Flags&SelectCopy != 0 { + current = current.Copy(alloc(current.Size())) + } + options = append(options, current) + } + + return options, ok +} + +func defaultAlloc(n int) []byte { return make([]byte, n) } +func defaultCheck(Option) bool { return true } + +// Control represents operation that scanner should perform. +type Control byte + +const ( + // ControlContinue causes scanner to continue scan tokens. + ControlContinue Control = iota + // ControlBreak causes scanner to stop scan tokens. + ControlBreak + // ControlSkip causes scanner to skip current entity. + ControlSkip +) + +// ScanOptions parses data in this form: +// +// values = 1#value +// value = token *( ";" param ) +// param = token [ "=" (token | quoted-string) ] +// +// It calls given callback with the index of the option, option itself and its +// parameter (attribute and its value, both could be nil). Index is useful when +// header contains multiple choises for the same named option. +// +// Given callback should return one of the defined Control* values. +// ControlSkip means that passed key is not in caller's interest. That is, all +// parameters of that key will be skipped. +// ControlBreak means that no more keys and parameters should be parsed. That +// is, it must break parsing immediately. +// ControlContinue means that caller want to receive next parameter and its +// value or the next key. +// +// It returns false if data is malformed. +func ScanOptions(data []byte, it func(index int, option, attribute, value []byte) Control) bool { + lexer := &Scanner{data: data} + + var ok bool + var state int + const ( + stateKey = iota + stateParamBeforeName + stateParamName + stateParamBeforeValue + stateParamValue + ) + + var ( + index int + key, param, value []byte + mustCall bool + ) + for lexer.Next() { + var ( + call bool + growIndex int + ) + + t := lexer.Type() + v := lexer.Bytes() + + switch t { + case ItemToken: + switch state { + case stateKey, stateParamBeforeName: + key = v + state = stateParamBeforeName + mustCall = true + case stateParamName: + param = v + state = stateParamBeforeValue + mustCall = true + case stateParamValue: + value = v + state = stateParamBeforeName + call = true + default: + return false + } + + case ItemString: + if state != stateParamValue { + return false + } + value = v + state = stateParamBeforeName + call = true + + case ItemSeparator: + switch { + case isComma(v) && state == stateKey: + // Nothing to do. + + case isComma(v) && state == stateParamBeforeName: + state = stateKey + // Make call only if we have not called this key yet. + call = mustCall + if !call { + // If we have already called callback with the key + // that just ended. + index++ + } else { + // Else grow the index after calling callback. + growIndex = 1 + } + + case isComma(v) && state == stateParamBeforeValue: + state = stateKey + growIndex = 1 + call = true + + case isSemicolon(v) && state == stateParamBeforeName: + state = stateParamName + + case isSemicolon(v) && state == stateParamBeforeValue: + state = stateParamName + call = true + + case isEquality(v) && state == stateParamBeforeValue: + state = stateParamValue + + default: + return false + } + + default: + return false + } + + if call { + switch it(index, key, param, value) { + case ControlBreak: + // User want to stop to parsing parameters. + return true + + case ControlSkip: + // User want to skip current param. + state = stateKey + lexer.SkipEscaped(',') + + case ControlContinue: + // User is interested in rest of parameters. + // Nothing to do. + + default: + panic("unexpected control value") + } + ok = true + param = nil + value = nil + mustCall = false + index += growIndex + } + } + if mustCall { + ok = true + it(index, key, param, value) + } + + return ok && !lexer.err +} + +func isComma(b []byte) bool { + return len(b) == 1 && b[0] == ',' +} +func isSemicolon(b []byte) bool { + return len(b) == 1 && b[0] == ';' +} +func isEquality(b []byte) bool { + return len(b) == 1 && b[0] == '=' +} diff --git a/vendor/github.com/gobwas/httphead/lexer.go b/vendor/github.com/gobwas/httphead/lexer.go new file mode 100644 index 00000000..729855ed --- /dev/null +++ b/vendor/github.com/gobwas/httphead/lexer.go @@ -0,0 +1,360 @@ +package httphead + +import ( + "bytes" +) + +// ItemType encodes type of the lexing token. +type ItemType int + +const ( + // ItemUndef reports that token is undefined. + ItemUndef ItemType = iota + // ItemToken reports that token is RFC2616 token. + ItemToken + // ItemSeparator reports that token is RFC2616 separator. + ItemSeparator + // ItemString reports that token is RFC2616 quouted string. + ItemString + // ItemComment reports that token is RFC2616 comment. + ItemComment + // ItemOctet reports that token is octet slice. + ItemOctet +) + +// Scanner represents header tokens scanner. +// See https://tools.ietf.org/html/rfc2616#section-2 +type Scanner struct { + data []byte + pos int + + itemType ItemType + itemBytes []byte + + err bool +} + +// NewScanner creates new RFC2616 data scanner. +func NewScanner(data []byte) *Scanner { + return &Scanner{data: data} +} + +// Next scans for next token. It returns true on successful scanning, and false +// on error or EOF. +func (l *Scanner) Next() bool { + c, ok := l.nextChar() + if !ok { + return false + } + switch c { + case '"': // quoted-string; + return l.fetchQuotedString() + + case '(': // comment; + return l.fetchComment() + + case '\\', ')': // unexpected chars; + l.err = true + return false + + default: + return l.fetchToken() + } +} + +// FetchUntil fetches ItemOctet from current scanner position to first +// occurence of the c or to the end of the underlying data. +func (l *Scanner) FetchUntil(c byte) bool { + l.resetItem() + if l.pos == len(l.data) { + return false + } + return l.fetchOctet(c) +} + +// Peek reads byte at current position without advancing it. On end of data it +// returns 0. +func (l *Scanner) Peek() byte { + if l.pos == len(l.data) { + return 0 + } + return l.data[l.pos] +} + +// Peek2 reads two first bytes at current position without advancing it. +// If there not enough data it returs 0. +func (l *Scanner) Peek2() (a, b byte) { + if l.pos == len(l.data) { + return 0, 0 + } + if l.pos+1 == len(l.data) { + return l.data[l.pos], 0 + } + return l.data[l.pos], l.data[l.pos+1] +} + +// Buffered reporst how many bytes there are left to scan. +func (l *Scanner) Buffered() int { + return len(l.data) - l.pos +} + +// Advance moves current position index at n bytes. It returns true on +// successful move. +func (l *Scanner) Advance(n int) bool { + l.pos += n + if l.pos > len(l.data) { + l.pos = len(l.data) + return false + } + return true +} + +// Skip skips all bytes until first occurence of c. +func (l *Scanner) Skip(c byte) { + if l.err { + return + } + // Reset scanner state. + l.resetItem() + + if i := bytes.IndexByte(l.data[l.pos:], c); i == -1 { + // Reached the end of data. + l.pos = len(l.data) + } else { + l.pos += i + 1 + } +} + +// SkipEscaped skips all bytes until first occurence of non-escaped c. +func (l *Scanner) SkipEscaped(c byte) { + if l.err { + return + } + // Reset scanner state. + l.resetItem() + + if i := ScanUntil(l.data[l.pos:], c); i == -1 { + // Reached the end of data. + l.pos = len(l.data) + } else { + l.pos += i + 1 + } +} + +// Type reports current token type. +func (l *Scanner) Type() ItemType { + return l.itemType +} + +// Bytes returns current token bytes. +func (l *Scanner) Bytes() []byte { + return l.itemBytes +} + +func (l *Scanner) nextChar() (byte, bool) { + // Reset scanner state. + l.resetItem() + + if l.err { + return 0, false + } + l.pos += SkipSpace(l.data[l.pos:]) + if l.pos == len(l.data) { + return 0, false + } + return l.data[l.pos], true +} + +func (l *Scanner) resetItem() { + l.itemType = ItemUndef + l.itemBytes = nil +} + +func (l *Scanner) fetchOctet(c byte) bool { + i := l.pos + if j := bytes.IndexByte(l.data[l.pos:], c); j == -1 { + // Reached the end of data. + l.pos = len(l.data) + } else { + l.pos += j + } + + l.itemType = ItemOctet + l.itemBytes = l.data[i:l.pos] + + return true +} + +func (l *Scanner) fetchToken() bool { + n, t := ScanToken(l.data[l.pos:]) + if n == -1 { + l.err = true + return false + } + + l.itemType = t + l.itemBytes = l.data[l.pos : l.pos+n] + l.pos += n + + return true +} + +func (l *Scanner) fetchQuotedString() (ok bool) { + l.pos++ + + n := ScanUntil(l.data[l.pos:], '"') + if n == -1 { + l.err = true + return false + } + + l.itemType = ItemString + l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\') + l.pos += n + 1 + + return true +} + +func (l *Scanner) fetchComment() (ok bool) { + l.pos++ + + n := ScanPairGreedy(l.data[l.pos:], '(', ')') + if n == -1 { + l.err = true + return false + } + + l.itemType = ItemComment + l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\') + l.pos += n + 1 + + return true +} + +// ScanUntil scans for first non-escaped character c in given data. +// It returns index of matched c and -1 if c is not found. +func ScanUntil(data []byte, c byte) (n int) { + for { + i := bytes.IndexByte(data[n:], c) + if i == -1 { + return -1 + } + n += i + if n == 0 || data[n-1] != '\\' { + break + } + n++ + } + return +} + +// ScanPairGreedy scans for complete pair of opening and closing chars in greedy manner. +// Note that first opening byte must not be present in data. +func ScanPairGreedy(data []byte, open, close byte) (n int) { + var m int + opened := 1 + for { + i := bytes.IndexByte(data[n:], close) + if i == -1 { + return -1 + } + n += i + // If found index is not escaped then it is the end. + if n == 0 || data[n-1] != '\\' { + opened-- + } + + for m < i { + j := bytes.IndexByte(data[m:i], open) + if j == -1 { + break + } + m += j + 1 + opened++ + } + + if opened == 0 { + break + } + + n++ + m = n + } + return +} + +// RemoveByte returns data without c. If c is not present in data it returns +// the same slice. If not, it copies data without c. +func RemoveByte(data []byte, c byte) []byte { + j := bytes.IndexByte(data, c) + if j == -1 { + return data + } + + n := len(data) - 1 + + // If character is present, than allocate slice with n-1 capacity. That is, + // resulting bytes could be at most n-1 length. + result := make([]byte, n) + k := copy(result, data[:j]) + + for i := j + 1; i < n; { + j = bytes.IndexByte(data[i:], c) + if j != -1 { + k += copy(result[k:], data[i:i+j]) + i = i + j + 1 + } else { + k += copy(result[k:], data[i:]) + break + } + } + + return result[:k] +} + +// SkipSpace skips spaces and lws-sequences from p. +// It returns number ob bytes skipped. +func SkipSpace(p []byte) (n int) { + for len(p) > 0 { + switch { + case len(p) >= 3 && + p[0] == '\r' && + p[1] == '\n' && + OctetTypes[p[2]].IsSpace(): + p = p[3:] + n += 3 + case OctetTypes[p[0]].IsSpace(): + p = p[1:] + n++ + default: + return + } + } + return +} + +// ScanToken scan for next token in p. It returns length of the token and its +// type. It do not trim p. +func ScanToken(p []byte) (n int, t ItemType) { + if len(p) == 0 { + return 0, ItemUndef + } + + c := p[0] + switch { + case OctetTypes[c].IsSeparator(): + return 1, ItemSeparator + + case OctetTypes[c].IsToken(): + for n = 1; n < len(p); n++ { + c := p[n] + if !OctetTypes[c].IsToken() { + break + } + } + return n, ItemToken + + default: + return -1, ItemUndef + } +} diff --git a/vendor/github.com/gobwas/httphead/octet.go b/vendor/github.com/gobwas/httphead/octet.go new file mode 100644 index 00000000..2a04cdd0 --- /dev/null +++ b/vendor/github.com/gobwas/httphead/octet.go @@ -0,0 +1,83 @@ +package httphead + +// OctetType desribes character type. +// +// From the "Basic Rules" chapter of RFC2616 +// See https://tools.ietf.org/html/rfc2616#section-2.2 +// +// OCTET = +// CHAR = +// UPALPHA = +// LOALPHA = +// ALPHA = UPALPHA | LOALPHA +// DIGIT = +// CTL = +// CR = +// LF = +// SP = +// HT = +// <"> = +// CRLF = CR LF +// LWS = [CRLF] 1*( SP | HT ) +// +// Many HTTP/1.1 header field values consist of words separated by LWS +// or special characters. These special characters MUST be in a quoted +// string to be used within a parameter value (as defined in section +// 3.6). +// +// token = 1* +// separators = "(" | ")" | "<" | ">" | "@" +// | "," | ";" | ":" | "\" | <"> +// | "/" | "[" | "]" | "?" | "=" +// | "{" | "}" | SP | HT +type OctetType byte + +// IsChar reports whether octet is CHAR. +func (t OctetType) IsChar() bool { return t&octetChar != 0 } + +// IsControl reports whether octet is CTL. +func (t OctetType) IsControl() bool { return t&octetControl != 0 } + +// IsSeparator reports whether octet is separator. +func (t OctetType) IsSeparator() bool { return t&octetSeparator != 0 } + +// IsSpace reports whether octet is space (SP or HT). +func (t OctetType) IsSpace() bool { return t&octetSpace != 0 } + +// IsToken reports whether octet is token. +func (t OctetType) IsToken() bool { return t&octetToken != 0 } + +const ( + octetChar OctetType = 1 << iota + octetControl + octetSpace + octetSeparator + octetToken +) + +// OctetTypes is a table of octets. +var OctetTypes [256]OctetType + +func init() { + for c := 32; c < 256; c++ { + var t OctetType + if c <= 127 { + t |= octetChar + } + if 0 <= c && c <= 31 || c == 127 { + t |= octetControl + } + switch c { + case '(', ')', '<', '>', '@', ',', ';', ':', '"', '/', '[', ']', '?', '=', '{', '}', '\\': + t |= octetSeparator + case ' ', '\t': + t |= octetSpace | octetSeparator + } + + if t.IsChar() && !t.IsControl() && !t.IsSeparator() && !t.IsSpace() { + t |= octetToken + } + + OctetTypes[c] = t + } +} diff --git a/vendor/github.com/gobwas/httphead/option.go b/vendor/github.com/gobwas/httphead/option.go new file mode 100644 index 00000000..0a18c7c7 --- /dev/null +++ b/vendor/github.com/gobwas/httphead/option.go @@ -0,0 +1,193 @@ +package httphead + +import ( + "bytes" + "sort" +) + +// Option represents a header option. +type Option struct { + Name []byte + Parameters Parameters +} + +// Size returns number of bytes need to be allocated for use in opt.Copy. +func (opt Option) Size() int { + return len(opt.Name) + opt.Parameters.bytes +} + +// Copy copies all underlying []byte slices into p and returns new Option. +// Note that p must be at least of opt.Size() length. +func (opt Option) Copy(p []byte) Option { + n := copy(p, opt.Name) + opt.Name = p[:n] + opt.Parameters, p = opt.Parameters.Copy(p[n:]) + return opt +} + +// Clone is a shorthand for making slice of opt.Size() sequenced with Copy() +// call. +func (opt Option) Clone() Option { + return opt.Copy(make([]byte, opt.Size())) +} + +// String represents option as a string. +func (opt Option) String() string { + return "{" + string(opt.Name) + " " + opt.Parameters.String() + "}" +} + +// NewOption creates named option with given parameters. +func NewOption(name string, params map[string]string) Option { + p := Parameters{} + for k, v := range params { + p.Set([]byte(k), []byte(v)) + } + return Option{ + Name: []byte(name), + Parameters: p, + } +} + +// Equal reports whether option is equal to b. +func (opt Option) Equal(b Option) bool { + if bytes.Equal(opt.Name, b.Name) { + return opt.Parameters.Equal(b.Parameters) + } + return false +} + +// Parameters represents option's parameters. +type Parameters struct { + pos int + bytes int + arr [8]pair + dyn []pair +} + +// Equal reports whether a equal to b. +func (p Parameters) Equal(b Parameters) bool { + switch { + case p.dyn == nil && b.dyn == nil: + case p.dyn != nil && b.dyn != nil: + default: + return false + } + + ad, bd := p.data(), b.data() + if len(ad) != len(bd) { + return false + } + + sort.Sort(pairs(ad)) + sort.Sort(pairs(bd)) + + for i := 0; i < len(ad); i++ { + av, bv := ad[i], bd[i] + if !bytes.Equal(av.key, bv.key) || !bytes.Equal(av.value, bv.value) { + return false + } + } + return true +} + +// Size returns number of bytes that needed to copy p. +func (p *Parameters) Size() int { + return p.bytes +} + +// Copy copies all underlying []byte slices into dst and returns new +// Parameters. +// Note that dst must be at least of p.Size() length. +func (p *Parameters) Copy(dst []byte) (Parameters, []byte) { + ret := Parameters{ + pos: p.pos, + bytes: p.bytes, + } + if p.dyn != nil { + ret.dyn = make([]pair, len(p.dyn)) + for i, v := range p.dyn { + ret.dyn[i], dst = v.copy(dst) + } + } else { + for i, p := range p.arr { + ret.arr[i], dst = p.copy(dst) + } + } + return ret, dst +} + +// Get returns value by key and flag about existence such value. +func (p *Parameters) Get(key string) (value []byte, ok bool) { + for _, v := range p.data() { + if string(v.key) == key { + return v.value, true + } + } + return nil, false +} + +// Set sets value by key. +func (p *Parameters) Set(key, value []byte) { + p.bytes += len(key) + len(value) + + if p.pos < len(p.arr) { + p.arr[p.pos] = pair{key, value} + p.pos++ + return + } + + if p.dyn == nil { + p.dyn = make([]pair, len(p.arr), len(p.arr)+1) + copy(p.dyn, p.arr[:]) + } + p.dyn = append(p.dyn, pair{key, value}) +} + +// ForEach iterates over parameters key-value pairs and calls cb for each one. +func (p *Parameters) ForEach(cb func(k, v []byte) bool) { + for _, v := range p.data() { + if !cb(v.key, v.value) { + break + } + } +} + +// String represents parameters as a string. +func (p *Parameters) String() (ret string) { + ret = "[" + for i, v := range p.data() { + if i > 0 { + ret += " " + } + ret += string(v.key) + ":" + string(v.value) + } + return ret + "]" +} + +func (p *Parameters) data() []pair { + if p.dyn != nil { + return p.dyn + } + return p.arr[:p.pos] +} + +type pair struct { + key, value []byte +} + +func (p pair) copy(dst []byte) (pair, []byte) { + n := copy(dst, p.key) + p.key = dst[:n] + m := n + copy(dst[n:], p.value) + p.value = dst[n:m] + + dst = dst[m:] + + return p, dst +} + +type pairs []pair + +func (p pairs) Len() int { return len(p) } +func (p pairs) Less(a, b int) bool { return bytes.Compare(p[a].key, p[b].key) == -1 } +func (p pairs) Swap(a, b int) { p[a], p[b] = p[b], p[a] } diff --git a/vendor/github.com/gobwas/httphead/writer.go b/vendor/github.com/gobwas/httphead/writer.go new file mode 100644 index 00000000..e5df3ddf --- /dev/null +++ b/vendor/github.com/gobwas/httphead/writer.go @@ -0,0 +1,101 @@ +package httphead + +import "io" + +var ( + comma = []byte{','} + equality = []byte{'='} + semicolon = []byte{';'} + quote = []byte{'"'} + escape = []byte{'\\'} +) + +// WriteOptions write options list to the dest. +// It uses the same form as {Scan,Parse}Options functions: +// values = 1#value +// value = token *( ";" param ) +// param = token [ "=" (token | quoted-string) ] +// +// It wraps valuse into the quoted-string sequence if it contains any +// non-token characters. +func WriteOptions(dest io.Writer, options []Option) (n int, err error) { + w := writer{w: dest} + for i, opt := range options { + if i > 0 { + w.write(comma) + } + + writeTokenSanitized(&w, opt.Name) + + for _, p := range opt.Parameters.data() { + w.write(semicolon) + writeTokenSanitized(&w, p.key) + if len(p.value) != 0 { + w.write(equality) + writeTokenSanitized(&w, p.value) + } + } + } + return w.result() +} + +// writeTokenSanitized writes token as is or as quouted string if it contains +// non-token characters. +// +// Note that is is not expects LWS sequnces be in s, cause LWS is used only as +// header field continuation: +// "A CRLF is allowed in the definition of TEXT only as part of a header field +// continuation. It is expected that the folding LWS will be replaced with a +// single SP before interpretation of the TEXT value." +// See https://tools.ietf.org/html/rfc2616#section-2 +// +// That is we sanitizing s for writing, so there could not be any header field +// continuation. +// That is any CRLF will be escaped as any other control characters not allowd in TEXT. +func writeTokenSanitized(bw *writer, bts []byte) { + var qt bool + var pos int + for i := 0; i < len(bts); i++ { + c := bts[i] + if !OctetTypes[c].IsToken() && !qt { + qt = true + bw.write(quote) + } + if OctetTypes[c].IsControl() || c == '"' { + if !qt { + qt = true + bw.write(quote) + } + bw.write(bts[pos:i]) + bw.write(escape) + bw.write(bts[i : i+1]) + pos = i + 1 + } + } + if !qt { + bw.write(bts) + } else { + bw.write(bts[pos:]) + bw.write(quote) + } +} + +type writer struct { + w io.Writer + n int + err error +} + +func (w *writer) write(p []byte) { + if w.err != nil { + return + } + var n int + n, w.err = w.w.Write(p) + w.n += n + return +} + +func (w *writer) result() (int, error) { + return w.n, w.err +} diff --git a/vendor/github.com/gobwas/pool/LICENSE b/vendor/github.com/gobwas/pool/LICENSE new file mode 100644 index 00000000..c41ffde6 --- /dev/null +++ b/vendor/github.com/gobwas/pool/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017-2019 Sergey Kamardin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/gobwas/pool/README.md b/vendor/github.com/gobwas/pool/README.md new file mode 100644 index 00000000..45685581 --- /dev/null +++ b/vendor/github.com/gobwas/pool/README.md @@ -0,0 +1,107 @@ +# pool + +[![GoDoc][godoc-image]][godoc-url] + +> Tiny memory reuse helpers for Go. + +## generic + +Without use of subpackages, `pool` allows to reuse any struct distinguishable +by size in generic way: + +```go +package main + +import "github.com/gobwas/pool" + +func main() { + x, n := pool.Get(100) // Returns object with size 128 or nil. + if x == nil { + // Create x somehow with knowledge that n is 128. + } + defer pool.Put(x, n) + + // Work with x. +} +``` + +Pool allows you to pass specific options for constructing custom pool: + +```go +package main + +import "github.com/gobwas/pool" + +func main() { + p := pool.Custom( + pool.WithLogSizeMapping(), // Will ceil size n passed to Get(n) to nearest power of two. + pool.WithLogSizeRange(64, 512), // Will reuse objects in logarithmic range [64, 512]. + pool.WithSize(65536), // Will reuse object with size 65536. + ) + x, n := p.Get(1000) // Returns nil and 1000 because mapped size 1000 => 1024 is not reusing by the pool. + defer pool.Put(x, n) // Will not reuse x. + + // Work with x. +} +``` + +Note that there are few non-generic pooling implementations inside subpackages. + +## pbytes + +Subpackage `pbytes` is intended for `[]byte` reuse. + +```go +package main + +import "github.com/gobwas/pool/pbytes" + +func main() { + bts := pbytes.GetCap(100) // Returns make([]byte, 0, 128). + defer pbytes.Put(bts) + + // Work with bts. +} +``` + +You can also create your own range for pooling: + +```go +package main + +import "github.com/gobwas/pool/pbytes" + +func main() { + // Reuse only slices whose capacity is 128, 256, 512 or 1024. + pool := pbytes.New(128, 1024) + + bts := pool.GetCap(100) // Returns make([]byte, 0, 128). + defer pool.Put(bts) + + // Work with bts. +} +``` + +## pbufio + +Subpackage `pbufio` is intended for `*bufio.{Reader, Writer}` reuse. + +```go +package main + +import "github.com/gobwas/pool/pbufio" + +func main() { + bw := pbufio.GetWriter(os.Stdout, 100) // Returns bufio.NewWriterSize(128). + defer pbufio.PutWriter(bw) + + // Work with bw. +} +``` + +Like with `pbytes`, you can also create pool with custom reuse bounds. + + + +[godoc-image]: https://godoc.org/github.com/gobwas/pool?status.svg +[godoc-url]: https://godoc.org/github.com/gobwas/pool diff --git a/vendor/github.com/gobwas/pool/generic.go b/vendor/github.com/gobwas/pool/generic.go new file mode 100644 index 00000000..d40b3624 --- /dev/null +++ b/vendor/github.com/gobwas/pool/generic.go @@ -0,0 +1,87 @@ +package pool + +import ( + "sync" + + "github.com/gobwas/pool/internal/pmath" +) + +var DefaultPool = New(128, 65536) + +// Get pulls object whose generic size is at least of given size. It also +// returns a real size of x for further pass to Put(). It returns -1 as real +// size for nil x. Size >-1 does not mean that x is non-nil, so checks must be +// done. +// +// Note that size could be ceiled to the next power of two. +// +// Get is a wrapper around DefaultPool.Get(). +func Get(size int) (interface{}, int) { return DefaultPool.Get(size) } + +// Put takes x and its size for future reuse. +// Put is a wrapper around DefaultPool.Put(). +func Put(x interface{}, size int) { DefaultPool.Put(x, size) } + +// Pool contains logic of reusing objects distinguishable by size in generic +// way. +type Pool struct { + pool map[int]*sync.Pool + size func(int) int +} + +// New creates new Pool that reuses objects which size is in logarithmic range +// [min, max]. +// +// Note that it is a shortcut for Custom() constructor with Options provided by +// WithLogSizeMapping() and WithLogSizeRange(min, max) calls. +func New(min, max int) *Pool { + return Custom( + WithLogSizeMapping(), + WithLogSizeRange(min, max), + ) +} + +// Custom creates new Pool with given options. +func Custom(opts ...Option) *Pool { + p := &Pool{ + pool: make(map[int]*sync.Pool), + size: pmath.Identity, + } + + c := (*poolConfig)(p) + for _, opt := range opts { + opt(c) + } + + return p +} + +// Get pulls object whose generic size is at least of given size. +// It also returns a real size of x for further pass to Put() even if x is nil. +// Note that size could be ceiled to the next power of two. +func (p *Pool) Get(size int) (interface{}, int) { + n := p.size(size) + if pool := p.pool[n]; pool != nil { + return pool.Get(), n + } + return nil, size +} + +// Put takes x and its size for future reuse. +func (p *Pool) Put(x interface{}, size int) { + if pool := p.pool[size]; pool != nil { + pool.Put(x) + } +} + +type poolConfig Pool + +// AddSize adds size n to the map. +func (p *poolConfig) AddSize(n int) { + p.pool[n] = new(sync.Pool) +} + +// SetSizeMapping sets up incoming size mapping function. +func (p *poolConfig) SetSizeMapping(size func(int) int) { + p.size = size +} diff --git a/vendor/github.com/gobwas/pool/internal/pmath/pmath.go b/vendor/github.com/gobwas/pool/internal/pmath/pmath.go new file mode 100644 index 00000000..df152ed1 --- /dev/null +++ b/vendor/github.com/gobwas/pool/internal/pmath/pmath.go @@ -0,0 +1,65 @@ +package pmath + +const ( + bitsize = 32 << (^uint(0) >> 63) + maxint = int(1<<(bitsize-1) - 1) + maxintHeadBit = 1 << (bitsize - 2) +) + +// LogarithmicRange iterates from ceiled to power of two min to max, +// calling cb on each iteration. +func LogarithmicRange(min, max int, cb func(int)) { + if min == 0 { + min = 1 + } + for n := CeilToPowerOfTwo(min); n <= max; n <<= 1 { + cb(n) + } +} + +// IsPowerOfTwo reports whether given integer is a power of two. +func IsPowerOfTwo(n int) bool { + return n&(n-1) == 0 +} + +// Identity is identity. +func Identity(n int) int { + return n +} + +// CeilToPowerOfTwo returns the least power of two integer value greater than +// or equal to n. +func CeilToPowerOfTwo(n int) int { + if n&maxintHeadBit != 0 && n > maxintHeadBit { + panic("argument is too large") + } + if n <= 2 { + return n + } + n-- + n = fillBits(n) + n++ + return n +} + +// FloorToPowerOfTwo returns the greatest power of two integer value less than +// or equal to n. +func FloorToPowerOfTwo(n int) int { + if n <= 2 { + return n + } + n = fillBits(n) + n >>= 1 + n++ + return n +} + +func fillBits(n int) int { + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + return n +} diff --git a/vendor/github.com/gobwas/pool/option.go b/vendor/github.com/gobwas/pool/option.go new file mode 100644 index 00000000..d6e42b70 --- /dev/null +++ b/vendor/github.com/gobwas/pool/option.go @@ -0,0 +1,43 @@ +package pool + +import "github.com/gobwas/pool/internal/pmath" + +// Option configures pool. +type Option func(Config) + +// Config describes generic pool configuration. +type Config interface { + AddSize(n int) + SetSizeMapping(func(int) int) +} + +// WithSizeLogRange returns an Option that will add logarithmic range of +// pooling sizes containing [min, max] values. +func WithLogSizeRange(min, max int) Option { + return func(c Config) { + pmath.LogarithmicRange(min, max, func(n int) { + c.AddSize(n) + }) + } +} + +// WithSize returns an Option that will add given pooling size to the pool. +func WithSize(n int) Option { + return func(c Config) { + c.AddSize(n) + } +} + +func WithSizeMapping(sz func(int) int) Option { + return func(c Config) { + c.SetSizeMapping(sz) + } +} + +func WithLogSizeMapping() Option { + return WithSizeMapping(pmath.CeilToPowerOfTwo) +} + +func WithIdentitySizeMapping() Option { + return WithSizeMapping(pmath.Identity) +} diff --git a/vendor/github.com/gobwas/pool/pbufio/pbufio.go b/vendor/github.com/gobwas/pool/pbufio/pbufio.go new file mode 100644 index 00000000..d526bd80 --- /dev/null +++ b/vendor/github.com/gobwas/pool/pbufio/pbufio.go @@ -0,0 +1,106 @@ +// Package pbufio contains tools for pooling bufio.Reader and bufio.Writers. +package pbufio + +import ( + "bufio" + "io" + + "github.com/gobwas/pool" +) + +var ( + DefaultWriterPool = NewWriterPool(256, 65536) + DefaultReaderPool = NewReaderPool(256, 65536) +) + +// GetWriter returns bufio.Writer whose buffer has at least size bytes. +// Note that size could be ceiled to the next power of two. +// GetWriter is a wrapper around DefaultWriterPool.Get(). +func GetWriter(w io.Writer, size int) *bufio.Writer { return DefaultWriterPool.Get(w, size) } + +// PutWriter takes bufio.Writer for future reuse. +// It does not reuse bufio.Writer which underlying buffer size is not power of +// PutWriter is a wrapper around DefaultWriterPool.Put(). +func PutWriter(bw *bufio.Writer) { DefaultWriterPool.Put(bw) } + +// GetReader returns bufio.Reader whose buffer has at least size bytes. It returns +// its capacity for further pass to Put(). +// Note that size could be ceiled to the next power of two. +// GetReader is a wrapper around DefaultReaderPool.Get(). +func GetReader(w io.Reader, size int) *bufio.Reader { return DefaultReaderPool.Get(w, size) } + +// PutReader takes bufio.Reader and its size for future reuse. +// It does not reuse bufio.Reader if size is not power of two or is out of pool +// min/max range. +// PutReader is a wrapper around DefaultReaderPool.Put(). +func PutReader(bw *bufio.Reader) { DefaultReaderPool.Put(bw) } + +// WriterPool contains logic of *bufio.Writer reuse with various size. +type WriterPool struct { + pool *pool.Pool +} + +// NewWriterPool creates new WriterPool that reuses writers which size is in +// logarithmic range [min, max]. +func NewWriterPool(min, max int) *WriterPool { + return &WriterPool{pool.New(min, max)} +} + +// CustomWriterPool creates new WriterPool with given options. +func CustomWriterPool(opts ...pool.Option) *WriterPool { + return &WriterPool{pool.Custom(opts...)} +} + +// Get returns bufio.Writer whose buffer has at least size bytes. +func (wp *WriterPool) Get(w io.Writer, size int) *bufio.Writer { + v, n := wp.pool.Get(size) + if v != nil { + bw := v.(*bufio.Writer) + bw.Reset(w) + return bw + } + return bufio.NewWriterSize(w, n) +} + +// Put takes ownership of bufio.Writer for further reuse. +func (wp *WriterPool) Put(bw *bufio.Writer) { + // Should reset even if we do Reset() inside Get(). + // This is done to prevent locking underlying io.Writer from GC. + bw.Reset(nil) + wp.pool.Put(bw, writerSize(bw)) +} + +// ReaderPool contains logic of *bufio.Reader reuse with various size. +type ReaderPool struct { + pool *pool.Pool +} + +// NewReaderPool creates new ReaderPool that reuses writers which size is in +// logarithmic range [min, max]. +func NewReaderPool(min, max int) *ReaderPool { + return &ReaderPool{pool.New(min, max)} +} + +// CustomReaderPool creates new ReaderPool with given options. +func CustomReaderPool(opts ...pool.Option) *ReaderPool { + return &ReaderPool{pool.Custom(opts...)} +} + +// Get returns bufio.Reader whose buffer has at least size bytes. +func (rp *ReaderPool) Get(r io.Reader, size int) *bufio.Reader { + v, n := rp.pool.Get(size) + if v != nil { + br := v.(*bufio.Reader) + br.Reset(r) + return br + } + return bufio.NewReaderSize(r, n) +} + +// Put takes ownership of bufio.Reader for further reuse. +func (rp *ReaderPool) Put(br *bufio.Reader) { + // Should reset even if we do Reset() inside Get(). + // This is done to prevent locking underlying io.Reader from GC. + br.Reset(nil) + rp.pool.Put(br, readerSize(br)) +} diff --git a/vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go b/vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go new file mode 100644 index 00000000..c736ae56 --- /dev/null +++ b/vendor/github.com/gobwas/pool/pbufio/pbufio_go110.go @@ -0,0 +1,13 @@ +// +build go1.10 + +package pbufio + +import "bufio" + +func writerSize(bw *bufio.Writer) int { + return bw.Size() +} + +func readerSize(br *bufio.Reader) int { + return br.Size() +} diff --git a/vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go b/vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go new file mode 100644 index 00000000..e71dd447 --- /dev/null +++ b/vendor/github.com/gobwas/pool/pbufio/pbufio_go19.go @@ -0,0 +1,27 @@ +// +build !go1.10 + +package pbufio + +import "bufio" + +func writerSize(bw *bufio.Writer) int { + return bw.Available() + bw.Buffered() +} + +// readerSize returns buffer size of the given buffered reader. +// NOTE: current workaround implementation resets underlying io.Reader. +func readerSize(br *bufio.Reader) int { + br.Reset(sizeReader) + br.ReadByte() + n := br.Buffered() + 1 + br.Reset(nil) + return n +} + +var sizeReader optimisticReader + +type optimisticReader struct{} + +func (optimisticReader) Read(p []byte) (int, error) { + return len(p), nil +} diff --git a/vendor/github.com/gobwas/pool/pbytes/pbytes.go b/vendor/github.com/gobwas/pool/pbytes/pbytes.go new file mode 100644 index 00000000..919705b1 --- /dev/null +++ b/vendor/github.com/gobwas/pool/pbytes/pbytes.go @@ -0,0 +1,24 @@ +// Package pbytes contains tools for pooling byte pool. +// Note that by default it reuse slices with capacity from 128 to 65536 bytes. +package pbytes + +// DefaultPool is used by pacakge level functions. +var DefaultPool = New(128, 65536) + +// Get returns probably reused slice of bytes with at least capacity of c and +// exactly len of n. +// Get is a wrapper around DefaultPool.Get(). +func Get(n, c int) []byte { return DefaultPool.Get(n, c) } + +// GetCap returns probably reused slice of bytes with at least capacity of n. +// GetCap is a wrapper around DefaultPool.GetCap(). +func GetCap(c int) []byte { return DefaultPool.GetCap(c) } + +// GetLen returns probably reused slice of bytes with at least capacity of n +// and exactly len of n. +// GetLen is a wrapper around DefaultPool.GetLen(). +func GetLen(n int) []byte { return DefaultPool.GetLen(n) } + +// Put returns given slice to reuse pool. +// Put is a wrapper around DefaultPool.Put(). +func Put(p []byte) { DefaultPool.Put(p) } diff --git a/vendor/github.com/gobwas/pool/pbytes/pool.go b/vendor/github.com/gobwas/pool/pbytes/pool.go new file mode 100644 index 00000000..1dde225f --- /dev/null +++ b/vendor/github.com/gobwas/pool/pbytes/pool.go @@ -0,0 +1,59 @@ +// +build !pool_sanitize + +package pbytes + +import "github.com/gobwas/pool" + +// Pool contains logic of reusing byte slices of various size. +type Pool struct { + pool *pool.Pool +} + +// New creates new Pool that reuses slices which size is in logarithmic range +// [min, max]. +// +// Note that it is a shortcut for Custom() constructor with Options provided by +// pool.WithLogSizeMapping() and pool.WithLogSizeRange(min, max) calls. +func New(min, max int) *Pool { + return &Pool{pool.New(min, max)} +} + +// New creates new Pool with given options. +func Custom(opts ...pool.Option) *Pool { + return &Pool{pool.Custom(opts...)} +} + +// Get returns probably reused slice of bytes with at least capacity of c and +// exactly len of n. +func (p *Pool) Get(n, c int) []byte { + if n > c { + panic("requested length is greater than capacity") + } + + v, x := p.pool.Get(c) + if v != nil { + bts := v.([]byte) + bts = bts[:n] + return bts + } + + return make([]byte, n, x) +} + +// Put returns given slice to reuse pool. +// It does not reuse bytes whose size is not power of two or is out of pool +// min/max range. +func (p *Pool) Put(bts []byte) { + p.pool.Put(bts, cap(bts)) +} + +// GetCap returns probably reused slice of bytes with at least capacity of n. +func (p *Pool) GetCap(c int) []byte { + return p.Get(0, c) +} + +// GetLen returns probably reused slice of bytes with at least capacity of n +// and exactly len of n. +func (p *Pool) GetLen(n int) []byte { + return p.Get(n, n) +} diff --git a/vendor/github.com/gobwas/pool/pbytes/pool_sanitize.go b/vendor/github.com/gobwas/pool/pbytes/pool_sanitize.go new file mode 100644 index 00000000..fae9af49 --- /dev/null +++ b/vendor/github.com/gobwas/pool/pbytes/pool_sanitize.go @@ -0,0 +1,121 @@ +// +build pool_sanitize + +package pbytes + +import ( + "reflect" + "runtime" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +const magic = uint64(0x777742) + +type guard struct { + magic uint64 + size int + owners int32 +} + +const guardSize = int(unsafe.Sizeof(guard{})) + +type Pool struct { + min, max int +} + +func New(min, max int) *Pool { + return &Pool{min, max} +} + +// Get returns probably reused slice of bytes with at least capacity of c and +// exactly len of n. +func (p *Pool) Get(n, c int) []byte { + if n > c { + panic("requested length is greater than capacity") + } + + pageSize := syscall.Getpagesize() + pages := (c+guardSize)/pageSize + 1 + size := pages * pageSize + + bts := alloc(size) + + g := (*guard)(unsafe.Pointer(&bts[0])) + *g = guard{ + magic: magic, + size: size, + owners: 1, + } + + return bts[guardSize : guardSize+n] +} + +func (p *Pool) GetCap(c int) []byte { return p.Get(0, c) } +func (p *Pool) GetLen(n int) []byte { return Get(n, n) } + +// Put returns given slice to reuse pool. +func (p *Pool) Put(bts []byte) { + hdr := *(*reflect.SliceHeader)(unsafe.Pointer(&bts)) + ptr := hdr.Data - uintptr(guardSize) + + g := (*guard)(unsafe.Pointer(ptr)) + if g.magic != magic { + panic("unknown slice returned to the pool") + } + if n := atomic.AddInt32(&g.owners, -1); n < 0 { + panic("multiple Put() detected") + } + + // Disable read and write on bytes memory pages. This will cause panic on + // incorrect access to returned slice. + mprotect(ptr, false, false, g.size) + + runtime.SetFinalizer(&bts, func(b *[]byte) { + mprotect(ptr, true, true, g.size) + free(*(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{ + Data: ptr, + Len: g.size, + Cap: g.size, + }))) + }) +} + +func alloc(n int) []byte { + b, err := unix.Mmap(-1, 0, n, unix.PROT_READ|unix.PROT_WRITE|unix.PROT_EXEC, unix.MAP_SHARED|unix.MAP_ANONYMOUS) + if err != nil { + panic(err.Error()) + } + return b +} + +func free(b []byte) { + if err := unix.Munmap(b); err != nil { + panic(err.Error()) + } +} + +func mprotect(ptr uintptr, r, w bool, size int) { + // Need to avoid "EINVAL addr is not a valid pointer, + // or not a multiple of PAGESIZE." + start := ptr & ^(uintptr(syscall.Getpagesize() - 1)) + + prot := uintptr(syscall.PROT_EXEC) + switch { + case r && w: + prot |= syscall.PROT_READ | syscall.PROT_WRITE + case r: + prot |= syscall.PROT_READ + case w: + prot |= syscall.PROT_WRITE + } + + _, _, err := syscall.Syscall(syscall.SYS_MPROTECT, + start, uintptr(size), prot, + ) + if err != 0 { + panic(err.Error()) + } +} diff --git a/vendor/github.com/gobwas/pool/pool.go b/vendor/github.com/gobwas/pool/pool.go new file mode 100644 index 00000000..1fe9e602 --- /dev/null +++ b/vendor/github.com/gobwas/pool/pool.go @@ -0,0 +1,25 @@ +// Package pool contains helpers for pooling structures distinguishable by +// size. +// +// Quick example: +// +// import "github.com/gobwas/pool" +// +// func main() { +// // Reuse objects in logarithmic range from 0 to 64 (0,1,2,4,6,8,16,32,64). +// p := pool.New(0, 64) +// +// buf, n := p.Get(10) // Returns buffer with 16 capacity. +// if buf == nil { +// buf = bytes.NewBuffer(make([]byte, n)) +// } +// defer p.Put(buf, n) +// +// // Work with buf. +// } +// +// There are non-generic implementations for pooling: +// - pool/pbytes for []byte reuse; +// - pool/pbufio for *bufio.Reader and *bufio.Writer reuse; +// +package pool diff --git a/vendor/github.com/gobwas/ws/.gitignore b/vendor/github.com/gobwas/ws/.gitignore new file mode 100644 index 00000000..e3e2b108 --- /dev/null +++ b/vendor/github.com/gobwas/ws/.gitignore @@ -0,0 +1,5 @@ +bin/ +reports/ +cpu.out +mem.out +ws.test diff --git a/vendor/github.com/gobwas/ws/.travis.yml b/vendor/github.com/gobwas/ws/.travis.yml new file mode 100644 index 00000000..cf74f1be --- /dev/null +++ b/vendor/github.com/gobwas/ws/.travis.yml @@ -0,0 +1,25 @@ +sudo: required + +language: go + +services: + - docker + +os: + - linux + - windows + +go: + - 1.8.x + - 1.9.x + - 1.10.x + - 1.11.x + - 1.x + +install: + - go get github.com/gobwas/pool + - go get github.com/gobwas/httphead + +script: + - if [ "$TRAVIS_OS_NAME" = "windows" ]; then go test ./...; fi + - if [ "$TRAVIS_OS_NAME" = "linux" ]; then make test autobahn; fi diff --git a/vendor/github.com/gobwas/ws/LICENSE b/vendor/github.com/gobwas/ws/LICENSE new file mode 100644 index 00000000..d2611fdd --- /dev/null +++ b/vendor/github.com/gobwas/ws/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2017-2018 Sergey Kamardin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/gobwas/ws/Makefile b/vendor/github.com/gobwas/ws/Makefile new file mode 100644 index 00000000..075e83c7 --- /dev/null +++ b/vendor/github.com/gobwas/ws/Makefile @@ -0,0 +1,47 @@ +BENCH ?=. +BENCH_BASE?=master + +clean: + rm -f bin/reporter + rm -fr autobahn/report/* + +bin/reporter: + go build -o bin/reporter ./autobahn + +bin/gocovmerge: + go build -o bin/gocovmerge github.com/wadey/gocovmerge + +.PHONY: autobahn +autobahn: clean bin/reporter + ./autobahn/script/test.sh --build + bin/reporter $(PWD)/autobahn/report/index.json + +test: + go test -coverprofile=ws.coverage . + go test -coverprofile=wsutil.coverage ./wsutil + +cover: bin/gocovmerge test autobahn + bin/gocovmerge ws.coverage wsutil.coverage autobahn/report/server.coverage > total.coverage + +benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD) +benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX) +benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX) +benchcmp: + if [ ! -z "$(shell git status -s)" ]; then\ + echo "could not compare with $(BENCH_BASE) – found unstaged changes";\ + exit 1;\ + fi;\ + if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\ + echo "comparing the same branches";\ + exit 1;\ + fi;\ + echo "benchmarking $(BENCH_BRANCH)...";\ + go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\ + echo "benchmarking $(BENCH_BASE)...";\ + git checkout -q $(BENCH_BASE);\ + go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\ + git checkout -q $(BENCH_BRANCH);\ + echo "\nresults:";\ + echo "========\n";\ + benchcmp $(BENCH_OLD) $(BENCH_NEW);\ + diff --git a/vendor/github.com/gobwas/ws/README.md b/vendor/github.com/gobwas/ws/README.md new file mode 100644 index 00000000..74acd78b --- /dev/null +++ b/vendor/github.com/gobwas/ws/README.md @@ -0,0 +1,360 @@ +# ws + +[![GoDoc][godoc-image]][godoc-url] +[![Travis][travis-image]][travis-url] + +> [RFC6455][rfc-url] WebSocket implementation in Go. + +# Features + +- Zero-copy upgrade +- No intermediate allocations during I/O +- Low-level API which allows to build your own logic of packet handling and + buffers reuse +- High-level wrappers and helpers around API in `wsutil` package, which allow + to start fast without digging the protocol internals + +# Documentation + +[GoDoc][godoc-url]. + +# Why + +Existing WebSocket implementations do not allow users to reuse I/O buffers +between connections in clear way. This library aims to export efficient +low-level interface for working with the protocol without forcing only one way +it could be used. + +By the way, if you want get the higher-level tools, you can use `wsutil` +package. + +# Status + +Library is tagged as `v1*` so its API must not be broken during some +improvements or refactoring. + +This implementation of RFC6455 passes [Autobahn Test +Suite](https://github.com/crossbario/autobahn-testsuite) and currently has +about 78% coverage. + +# Examples + +Example applications using `ws` are developed in separate repository +[ws-examples](https://github.com/gobwas/ws-examples). + +# Usage + +The higher-level example of WebSocket echo server: + +```go +package main + +import ( + "net/http" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +func main() { + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.UpgradeHTTP(r, w) + if err != nil { + // handle error + } + go func() { + defer conn.Close() + + for { + msg, op, err := wsutil.ReadClientData(conn) + if err != nil { + // handle error + } + err = wsutil.WriteServerMessage(conn, op, msg) + if err != nil { + // handle error + } + } + }() + })) +} +``` + +Lower-level, but still high-level example: + + +```go +import ( + "net/http" + "io" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +func main() { + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.UpgradeHTTP(r, w) + if err != nil { + // handle error + } + go func() { + defer conn.Close() + + var ( + state = ws.StateServerSide + reader = wsutil.NewReader(conn, state) + writer = wsutil.NewWriter(conn, state, ws.OpText) + ) + for { + header, err := reader.NextFrame() + if err != nil { + // handle error + } + + // Reset writer to write frame with right operation code. + writer.Reset(conn, state, header.OpCode) + + if _, err = io.Copy(writer, reader); err != nil { + // handle error + } + if err = writer.Flush(); err != nil { + // handle error + } + } + }() + })) +} +``` + +We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.: + +```go + ... + var ( + r = wsutil.NewReader(conn, ws.StateServerSide) + w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText) + decoder = json.NewDecoder(r) + encoder = json.NewEncoder(w) + ) + for { + hdr, err = r.NextFrame() + if err != nil { + return err + } + if hdr.OpCode == ws.OpClose { + return io.EOF + } + var req Request + if err := decoder.Decode(&req); err != nil { + return err + } + var resp Response + if err := encoder.Encode(&resp); err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + } + ... +``` + +The lower-level example without `wsutil`: + +```go +package main + +import ( + "net" + "io" + + "github.com/gobwas/ws" +) + +func main() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + log.Fatal(err) + } + + for { + conn, err := ln.Accept() + if err != nil { + // handle error + } + _, err = ws.Upgrade(conn) + if err != nil { + // handle error + } + + go func() { + defer conn.Close() + + for { + header, err := ws.ReadHeader(conn) + if err != nil { + // handle error + } + + payload := make([]byte, header.Length) + _, err = io.ReadFull(conn, payload) + if err != nil { + // handle error + } + if header.Masked { + ws.Cipher(payload, header.Mask, 0) + } + + // Reset the Masked flag, server frames must not be masked as + // RFC6455 says. + header.Masked = false + + if err := ws.WriteHeader(conn, header); err != nil { + // handle error + } + if _, err := conn.Write(payload); err != nil { + // handle error + } + + if header.OpCode == ws.OpClose { + return + } + } + }() + } +} +``` + +# Zero-copy upgrade + +Zero-copy upgrade helps to avoid unnecessary allocations and copying while +handling HTTP Upgrade request. + +Processing of all non-websocket headers is made in place with use of registered +user callbacks whose arguments are only valid until callback returns. + +The simple example looks like this: + +```go +package main + +import ( + "net" + "log" + + "github.com/gobwas/ws" +) + +func main() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + log.Fatal(err) + } + u := ws.Upgrader{ + OnHeader: func(key, value []byte) (err error) { + log.Printf("non-websocket header: %q=%q", key, value) + return + }, + } + for { + conn, err := ln.Accept() + if err != nil { + // handle error + } + + _, err = u.Upgrade(conn) + if err != nil { + // handle error + } + } +} +``` + +Usage of `ws.Upgrader` here brings ability to control incoming connections on +tcp level and simply not to accept them by some logic. + +Zero-copy upgrade is for high-load services which have to control many +resources such as connections buffers. + +The real life example could be like this: + +```go +package main + +import ( + "fmt" + "io" + "log" + "net" + "net/http" + "runtime" + + "github.com/gobwas/httphead" + "github.com/gobwas/ws" +) + +func main() { + ln, err := net.Listen("tcp", "localhost:8080") + if err != nil { + // handle error + } + + // Prepare handshake header writer from http.Header mapping. + header := ws.HandshakeHeaderHTTP(http.Header{ + "X-Go-Version": []string{runtime.Version()}, + }) + + u := ws.Upgrader{ + OnHost: func(host []byte) error { + if string(host) == "github.com" { + return nil + } + return ws.RejectConnectionError( + ws.RejectionStatus(403), + ws.RejectionHeader(ws.HandshakeHeaderString( + "X-Want-Host: github.com\r\n", + )), + ) + }, + OnHeader: func(key, value []byte) error { + if string(key) != "Cookie" { + return nil + } + ok := httphead.ScanCookie(value, func(key, value []byte) bool { + // Check session here or do some other stuff with cookies. + // Maybe copy some values for future use. + return true + }) + if ok { + return nil + } + return ws.RejectConnectionError( + ws.RejectionReason("bad cookie"), + ws.RejectionStatus(400), + ) + }, + OnBeforeUpgrade: func() (ws.HandshakeHeader, error) { + return header, nil + }, + } + for { + conn, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + _, err = u.Upgrade(conn) + if err != nil { + log.Printf("upgrade error: %s", err) + } + } +} +``` + + + +[rfc-url]: https://tools.ietf.org/html/rfc6455 +[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg +[godoc-url]: https://godoc.org/github.com/gobwas/ws +[travis-image]: https://travis-ci.org/gobwas/ws.svg?branch=master +[travis-url]: https://travis-ci.org/gobwas/ws diff --git a/vendor/github.com/gobwas/ws/check.go b/vendor/github.com/gobwas/ws/check.go new file mode 100644 index 00000000..8aa0df8c --- /dev/null +++ b/vendor/github.com/gobwas/ws/check.go @@ -0,0 +1,145 @@ +package ws + +import "unicode/utf8" + +// State represents state of websocket endpoint. +// It used by some functions to be more strict when checking compatibility with RFC6455. +type State uint8 + +const ( + // StateServerSide means that endpoint (caller) is a server. + StateServerSide State = 0x1 << iota + // StateClientSide means that endpoint (caller) is a client. + StateClientSide + // StateExtended means that extension was negotiated during handshake. + StateExtended + // StateFragmented means that endpoint (caller) has received fragmented + // frame and waits for continuation parts. + StateFragmented +) + +// Is checks whether the s has v enabled. +func (s State) Is(v State) bool { + return uint8(s)&uint8(v) != 0 +} + +// Set enables v state on s. +func (s State) Set(v State) State { + return s | v +} + +// Clear disables v state on s. +func (s State) Clear(v State) State { + return s & (^v) +} + +// ServerSide reports whether states represents server side. +func (s State) ServerSide() bool { return s.Is(StateServerSide) } + +// ClientSide reports whether state represents client side. +func (s State) ClientSide() bool { return s.Is(StateClientSide) } + +// Extended reports whether state is extended. +func (s State) Extended() bool { return s.Is(StateExtended) } + +// Fragmented reports whether state is fragmented. +func (s State) Fragmented() bool { return s.Is(StateFragmented) } + +// ProtocolError describes error during checking/parsing websocket frames or +// headers. +type ProtocolError string + +// Error implements error interface. +func (p ProtocolError) Error() string { return string(p) } + +// Errors used by the protocol checkers. +var ( + ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code") + ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded") + ErrProtocolControlNotFinal = ProtocolError("control frame is not final") + ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated") + ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked") + ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked") + ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame") + ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame") + ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use") + ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level") + ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet") + ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec") + ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason") +) + +// CheckHeader checks h to contain valid header data for given state s. +// +// Note that zero state (0) means that state is clean, +// neither server or client side, nor fragmented, nor extended. +func CheckHeader(h Header, s State) error { + if h.OpCode.IsReserved() { + return ErrProtocolOpCodeReserved + } + if h.OpCode.IsControl() { + if h.Length > MaxControlFramePayloadSize { + return ErrProtocolControlPayloadOverflow + } + if !h.Fin { + return ErrProtocolControlNotFinal + } + } + + switch { + // [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for + // non-zero values. If a nonzero value is received and none of the + // negotiated extensions defines the meaning of such a nonzero value, the + // receiving endpoint MUST _Fail the WebSocket Connection_. + case h.Rsv != 0 && !s.Extended(): + return ErrProtocolNonZeroRsv + + // [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked. + // In this case, a server MAY send a Close frame with a status code of 1002 (protocol error) + // as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client. + // A client MUST close a connection if it detects a masked frame. In this case, it MAY use the + // status code 1002 (protocol error) as defined in Section 7.4.1. + case s.ServerSide() && !h.Masked: + return ErrProtocolMaskRequired + case s.ClientSide() && h.Masked: + return ErrProtocolMaskUnexpected + + // [RFC6455]: See detailed explanation in 5.4 section. + case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation: + return ErrProtocolContinuationExpected + case !s.Fragmented() && h.OpCode == OpContinuation: + return ErrProtocolContinuationUnexpected + + default: + return nil + } +} + +// CheckCloseFrameData checks received close information +// to be valid RFC6455 compatible close info. +// +// Note that code.Empty() or code.IsAppLevel() will raise error. +// +// If endpoint sends close frame without status code (with frame.Length = 0), +// application should not check its payload. +func CheckCloseFrameData(code StatusCode, reason string) error { + switch { + case code.IsNotUsed(): + return ErrProtocolStatusCodeNotInUse + + case code.IsProtocolReserved(): + return ErrProtocolStatusCodeApplicationLevel + + case code == StatusNoMeaningYet: + return ErrProtocolStatusCodeNoMeaning + + case code.IsProtocolSpec() && !code.IsProtocolDefined(): + return ErrProtocolStatusCodeUnknown + + case !utf8.ValidString(reason): + return ErrProtocolInvalidUTF8 + + default: + return nil + } +} diff --git a/vendor/github.com/gobwas/ws/cipher.go b/vendor/github.com/gobwas/ws/cipher.go new file mode 100644 index 00000000..026f4fd0 --- /dev/null +++ b/vendor/github.com/gobwas/ws/cipher.go @@ -0,0 +1,61 @@ +package ws + +import ( + "encoding/binary" +) + +// Cipher applies XOR cipher to the payload using mask. +// Offset is used to cipher chunked data (e.g. in io.Reader implementations). +// +// To convert masked data into unmasked data, or vice versa, the following +// algorithm is applied. The same algorithm applies regardless of the +// direction of the translation, e.g., the same steps are applied to +// mask the data as to unmask the data. +func Cipher(payload []byte, mask [4]byte, offset int) { + n := len(payload) + if n < 8 { + for i := 0; i < n; i++ { + payload[i] ^= mask[(offset+i)%4] + } + return + } + + // Calculate position in mask due to previously processed bytes number. + mpos := offset % 4 + // Count number of bytes will processed one by one from the beginning of payload. + ln := remain[mpos] + // Count number of bytes will processed one by one from the end of payload. + // This is done to process payload by 8 bytes in each iteration of main loop. + rn := (n - ln) % 8 + + for i := 0; i < ln; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + for i := n - rn; i < n; i++ { + payload[i] ^= mask[(mpos+i)%4] + } + + // NOTE: we use here binary.LittleEndian regardless of what is real + // endianess on machine is. To do so, we have to use binary.LittleEndian in + // the masking loop below as well. + var ( + m = binary.LittleEndian.Uint32(mask[:]) + m2 = uint64(m)<<32 | uint64(m) + ) + // Skip already processed right part. + // Get number of uint64 parts remaining to process. + n = (n - ln - rn) >> 3 + for i := 0; i < n; i++ { + var ( + j = ln + (i << 3) + chunk = payload[j : j+8] + ) + p := binary.LittleEndian.Uint64(chunk) + p = p ^ m2 + binary.LittleEndian.PutUint64(chunk, p) + } +} + +// remain maps position in masking key [0,4) to number +// of bytes that need to be processed manually inside Cipher(). +var remain = [4]int{0, 3, 2, 1} diff --git a/vendor/github.com/gobwas/ws/dialer.go b/vendor/github.com/gobwas/ws/dialer.go new file mode 100644 index 00000000..4357be21 --- /dev/null +++ b/vendor/github.com/gobwas/ws/dialer.go @@ -0,0 +1,556 @@ +package ws + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/gobwas/httphead" + "github.com/gobwas/pool/pbufio" +) + +// Constants used by Dialer. +const ( + DefaultClientReadBufferSize = 4096 + DefaultClientWriteBufferSize = 4096 +) + +// Handshake represents handshake result. +type Handshake struct { + // Protocol is the subprotocol selected during handshake. + Protocol string + + // Extensions is the list of negotiated extensions. + Extensions []httphead.Option +} + +// Errors used by the websocket client. +var ( + ErrHandshakeBadStatus = fmt.Errorf("unexpected http status") + ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol) + ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol) +) + +// DefaultDialer is dialer that holds no options and is used by Dial function. +var DefaultDialer Dialer + +// Dial is like Dialer{}.Dial(). +func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) { + return DefaultDialer.Dial(ctx, urlstr) +} + +// Dialer contains options for establishing websocket connection to an url. +type Dialer struct { + // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. + // They used to read and write http data while upgrading to WebSocket. + // Allocated buffers are pooled with sync.Pool to avoid extra allocations. + // + // If a size is zero then default value is used. + ReadBufferSize, WriteBufferSize int + + // Timeout is the maximum amount of time a Dial() will wait for a connect + // and an handshake to complete. + // + // The default is no timeout. + Timeout time.Duration + + // Protocols is the list of subprotocols that the client wants to speak, + // ordered by preference. + // + // See https://tools.ietf.org/html/rfc6455#section-4.1 + Protocols []string + + // Extensions is the list of extensions that client wants to speak. + // + // Note that if server decides to use some of this extensions, Dial() will + // return Handshake struct containing a slice of items, which are the + // shallow copies of the items from this list. That is, internals of + // Extensions items are shared during Dial(). + // + // See https://tools.ietf.org/html/rfc6455#section-4.1 + // See https://tools.ietf.org/html/rfc6455#section-9.1 + Extensions []httphead.Option + + // Header is an optional HandshakeHeader instance that could be used to + // write additional headers to the handshake request. + // + // It used instead of any key-value mappings to avoid allocations in user + // land. + Header HandshakeHeader + + // OnStatusError is the callback that will be called after receiving non + // "101 Continue" HTTP response status. It receives an io.Reader object + // representing server response bytes. That is, it gives ability to parse + // HTTP response somehow (probably with http.ReadResponse call) and make a + // decision of further logic. + // + // The arguments are only valid until the callback returns. + OnStatusError func(status int, reason []byte, resp io.Reader) + + // OnHeader is the callback that will be called after successful parsing of + // header, that is not used during WebSocket handshake procedure. That is, + // it will be called with non-websocket headers, which could be relevant + // for application-level logic. + // + // The arguments are only valid until the callback returns. + // + // Returned value could be used to prevent processing response. + OnHeader func(key, value []byte) (err error) + + // NetDial is the function that is used to get plain tcp connection. + // If it is not nil, then it is used instead of net.Dialer. + NetDial func(ctx context.Context, network, addr string) (net.Conn, error) + + // TLSClient is the callback that will be called after successful dial with + // received connection and its remote host name. If it is nil, then the + // default tls.Client() will be used. + // If it is not nil, then TLSConfig field is ignored. + TLSClient func(conn net.Conn, hostname string) net.Conn + + // TLSConfig is passed to tls.Client() to start TLS over established + // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is + // non-nil and its ServerName is empty, then for every Dial() it will be + // cloned and appropriate ServerName will be set. + TLSConfig *tls.Config + + // WrapConn is the optional callback that will be called when connection is + // ready for an i/o. That is, it will be called after successful dial and + // TLS initialization (for "wss" schemes). It may be helpful for different + // user land purposes such as end to end encryption. + // + // Note that for debugging purposes of an http handshake (e.g. sent request + // and received response), there is an wsutil.DebugDialer struct. + WrapConn func(conn net.Conn) net.Conn +} + +// Dial connects to the url host and upgrades connection to WebSocket. +// +// If server has sent frames right after successful handshake then returned +// buffer will be non-nil. In other cases buffer is always nil. For better +// memory efficiency received non-nil bufio.Reader should be returned to the +// inner pool with PutReader() function after use. +// +// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does. +// If you want to dial non-ascii host name, take care of its name serialization +// avoiding bad request issues. For more info see net/http Request.Write() +// implementation, especially cleanHost() function. +func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) { + u, err := url.ParseRequestURI(urlstr) + if err != nil { + return + } + + // Prepare context to dial with. Initially it is the same as original, but + // if d.Timeout is non-zero and points to time that is before ctx.Deadline, + // we use more shorter context for dial. + dialctx := ctx + + var deadline time.Time + if t := d.Timeout; t != 0 { + deadline = time.Now().Add(t) + if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { + var cancel context.CancelFunc + dialctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + } + } + if conn, err = d.dial(dialctx, u); err != nil { + return + } + defer func() { + if err != nil { + conn.Close() + } + }() + if ctx == context.Background() { + // No need to start I/O interrupter goroutine which is not zero-cost. + conn.SetDeadline(deadline) + defer conn.SetDeadline(noDeadline) + } else { + // Context could be canceled or its deadline could be exceeded. + // Start the interrupter goroutine to handle context cancelation. + done := setupContextDeadliner(ctx, conn) + defer func() { + // Map Upgrade() error to a possible context expiration error. That + // is, even if Upgrade() err is nil, context could be already + // expired and connection be "poisoned" by SetDeadline() call. + // In that case we must not return ctx.Err() error. + done(&err) + }() + } + + br, hs, err = d.Upgrade(conn, u) + + return +} + +var ( + // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if + // Dialer.NetDial is not provided. + netEmptyDialer net.Dialer + // tlsEmptyConfig is an empty tls.Config used as default one. + tlsEmptyConfig tls.Config +) + +func tlsDefaultConfig() *tls.Config { + return &tlsEmptyConfig +} + +func hostport(host string, defaultPort string) (hostname, addr string) { + var ( + colon = strings.LastIndexByte(host, ':') + bracket = strings.IndexByte(host, ']') + ) + if colon > bracket { + return host[:colon], host + } + return host, host + defaultPort +} + +func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) { + dial := d.NetDial + if dial == nil { + dial = netEmptyDialer.DialContext + } + switch u.Scheme { + case "ws": + _, addr := hostport(u.Host, ":80") + conn, err = dial(ctx, "tcp", addr) + case "wss": + hostname, addr := hostport(u.Host, ":443") + conn, err = dial(ctx, "tcp", addr) + if err != nil { + return + } + tlsClient := d.TLSClient + if tlsClient == nil { + tlsClient = d.tlsClient + } + conn = tlsClient(conn, hostname) + default: + return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme) + } + if wrap := d.WrapConn; wrap != nil { + conn = wrap(conn) + } + return +} + +func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn { + config := d.TLSConfig + if config == nil { + config = tlsDefaultConfig() + } + if config.ServerName == "" { + config = tlsCloneConfig(config) + config.ServerName = hostname + } + // Do not make conn.Handshake() here because downstairs we will prepare + // i/o on this conn with proper context's timeout handling. + return tls.Client(conn, config) +} + +var ( + // This variables are set like in net/net.go. + // noDeadline is just zero value for readability. + noDeadline = time.Time{} + // aLongTimeAgo is a non-zero time, far in the past, used for immediate + // cancelation of dials. + aLongTimeAgo = time.Unix(42, 0) +) + +// Upgrade writes an upgrade request to the given io.ReadWriter conn at given +// url u and reads a response from it. +// +// It is a caller responsibility to manage I/O deadlines on conn. +// +// It returns handshake info and some bytes which could be written by the peer +// right after response and be caught by us during buffered read. +func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) { + // headerSeen constants helps to report whether or not some header was seen + // during reading request bytes. + const ( + headerSeenUpgrade = 1 << iota + headerSeenConnection + headerSeenSecAccept + + // headerSeenAll is the value that we expect to receive at the end of + // headers read/parse loop. + headerSeenAll = 0 | + headerSeenUpgrade | + headerSeenConnection | + headerSeenSecAccept + ) + + br = pbufio.GetReader(conn, + nonZero(d.ReadBufferSize, DefaultClientReadBufferSize), + ) + bw := pbufio.GetWriter(conn, + nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize), + ) + defer func() { + pbufio.PutWriter(bw) + if br.Buffered() == 0 || err != nil { + // Server does not wrote additional bytes to the connection or + // error occurred. That is, no reason to return buffer. + pbufio.PutReader(br) + br = nil + } + }() + + nonce := make([]byte, nonceSize) + initNonce(nonce) + + httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header) + if err = bw.Flush(); err != nil { + return + } + + // Read HTTP status line like "HTTP/1.1 101 Switching Protocols". + sl, err := readLine(br) + if err != nil { + return + } + // Begin validation of the response. + // See https://tools.ietf.org/html/rfc6455#section-4.2.2 + // Parse request line data like HTTP version, uri and method. + resp, err := httpParseResponseLine(sl) + if err != nil { + return + } + // Even if RFC says "1.1 or higher" without mentioning the part of the + // version, we apply it only to minor part. + if resp.major != 1 || resp.minor < 1 { + err = ErrHandshakeBadProtocol + return + } + if resp.status != 101 { + err = StatusError(resp.status) + if onStatusError := d.OnStatusError; onStatusError != nil { + // Invoke callback with multireader of status-line bytes br. + onStatusError(resp.status, resp.reason, + io.MultiReader( + bytes.NewReader(sl), + strings.NewReader(crlf), + br, + ), + ) + } + return + } + // If response status is 101 then we expect all technical headers to be + // valid. If not, then we stop processing response without giving user + // ability to read non-technical headers. That is, we do not distinguish + // technical errors (such as parsing error) and protocol errors. + var headerSeen byte + for { + line, e := readLine(br) + if e != nil { + err = e + return + } + if len(line) == 0 { + // Blank line, no more lines to read. + break + } + + k, v, ok := httpParseHeaderLine(line) + if !ok { + err = ErrMalformedResponse + return + } + + switch btsToString(k) { + case headerUpgradeCanonical: + headerSeen |= headerSeenUpgrade + if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { + err = ErrHandshakeBadUpgrade + return + } + + case headerConnectionCanonical: + headerSeen |= headerSeenConnection + // Note that as RFC6455 says: + // > A |Connection| header field with value "Upgrade". + // That is, in server side, "Connection" header could contain + // multiple token. But in response it must contains exactly one. + if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) { + err = ErrHandshakeBadConnection + return + } + + case headerSecAcceptCanonical: + headerSeen |= headerSeenSecAccept + if !checkAcceptFromNonce(v, nonce) { + err = ErrHandshakeBadSecAccept + return + } + + case headerSecProtocolCanonical: + // RFC6455 1.3: + // "The server selects one or none of the acceptable protocols + // and echoes that value in its handshake to indicate that it has + // selected that protocol." + for _, want := range d.Protocols { + if string(v) == want { + hs.Protocol = want + break + } + } + if hs.Protocol == "" { + // Server echoed subprotocol that is not present in client + // requested protocols. + err = ErrHandshakeBadSubProtocol + return + } + + case headerSecExtensionsCanonical: + hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions) + if err != nil { + return + } + + default: + if onHeader := d.OnHeader; onHeader != nil { + if e := onHeader(k, v); e != nil { + err = e + return + } + } + } + } + if err == nil && headerSeen != headerSeenAll { + switch { + case headerSeen&headerSeenUpgrade == 0: + err = ErrHandshakeBadUpgrade + case headerSeen&headerSeenConnection == 0: + err = ErrHandshakeBadConnection + case headerSeen&headerSeenSecAccept == 0: + err = ErrHandshakeBadSecAccept + default: + panic("unknown headers state") + } + } + return +} + +// PutReader returns bufio.Reader instance to the inner reuse pool. +// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which +// contains unprocessed buffered data, that was sent by the server quickly +// right after handshake. +func PutReader(br *bufio.Reader) { + pbufio.PutReader(br) +} + +// StatusError contains an unexpected status-line code from the server. +type StatusError int + +func (s StatusError) Error() string { + return "unexpected HTTP response status: " + strconv.Itoa(int(s)) +} + +func isTimeoutError(err error) bool { + t, ok := err.(net.Error) + return ok && t.Timeout() +} + +func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) { + if len(selected) == 0 { + return received, nil + } + var ( + index int + option httphead.Option + err error + ) + index = -1 + match := func() (ok bool) { + for _, want := range wanted { + if option.Equal(want) { + // Check parsed extension to be present in client + // requested extensions. We move matched extension + // from client list to avoid allocation. + received = append(received, want) + return true + } + } + return false + } + ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control { + if i != index { + // Met next option. + index = i + if i != 0 && !match() { + // Server returned non-requested extension. + err = ErrHandshakeBadExtensions + return httphead.ControlBreak + } + option = httphead.Option{Name: name} + } + if attr != nil { + option.Parameters.Set(attr, val) + } + return httphead.ControlContinue + }) + if !ok { + err = ErrMalformedResponse + return received, err + } + if !match() { + return received, ErrHandshakeBadExtensions + } + return received, err +} + +// setupContextDeadliner is a helper function that starts connection I/O +// interrupter goroutine. +// +// Started goroutine calls SetDeadline() with long time ago value when context +// become expired to make any I/O operations failed. It returns done function +// that stops started goroutine and maps error received from conn I/O methods +// to possible context expiration error. +// +// In concern with possible SetDeadline() call inside interrupter goroutine, +// caller passes pointer to its I/O error (even if it is nil) to done(&err). +// That is, even if I/O error is nil, context could be already expired and +// connection "poisoned" by SetDeadline() call. In that case done(&err) will +// store at *err ctx.Err() result. If err is caused not by timeout, it will +// leaved untouched. +func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) { + var ( + quit = make(chan struct{}) + interrupt = make(chan error, 1) + ) + go func() { + select { + case <-quit: + interrupt <- nil + case <-ctx.Done(): + // Cancel i/o immediately. + conn.SetDeadline(aLongTimeAgo) + interrupt <- ctx.Err() + } + }() + return func(err *error) { + close(quit) + // If ctx.Err() is non-nil and the original err is net.Error with + // Timeout() == true, then it means that I/O was canceled by us by + // SetDeadline(aLongTimeAgo) call, or by somebody else previously + // by conn.SetDeadline(x). + // + // Even on race condition when both deadlines are expired + // (SetDeadline() made not by us and context's), we prefer ctx.Err() to + // be returned. + if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) { + *err = ctxErr + } + } +} diff --git a/vendor/github.com/gobwas/ws/dialer_tls_go17.go b/vendor/github.com/gobwas/ws/dialer_tls_go17.go new file mode 100644 index 00000000..b606e0ad --- /dev/null +++ b/vendor/github.com/gobwas/ws/dialer_tls_go17.go @@ -0,0 +1,35 @@ +// +build !go1.8 + +package ws + +import "crypto/tls" + +func tlsCloneConfig(c *tls.Config) *tls.Config { + // NOTE: we copying SessionTicketsDisabled and SessionTicketKey here + // without calling inner c.initOnceServer somehow because we only could get + // here from the ws.Dialer code, which is obviously a client and makes + // tls.Client() when it gets new net.Conn. + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/vendor/github.com/gobwas/ws/dialer_tls_go18.go b/vendor/github.com/gobwas/ws/dialer_tls_go18.go new file mode 100644 index 00000000..a6704d51 --- /dev/null +++ b/vendor/github.com/gobwas/ws/dialer_tls_go18.go @@ -0,0 +1,9 @@ +// +build go1.8 + +package ws + +import "crypto/tls" + +func tlsCloneConfig(c *tls.Config) *tls.Config { + return c.Clone() +} diff --git a/vendor/github.com/gobwas/ws/doc.go b/vendor/github.com/gobwas/ws/doc.go new file mode 100644 index 00000000..c9d57915 --- /dev/null +++ b/vendor/github.com/gobwas/ws/doc.go @@ -0,0 +1,81 @@ +/* +Package ws implements a client and server for the WebSocket protocol as +specified in RFC 6455. + +The main purpose of this package is to provide simple low-level API for +efficient work with protocol. + +Overview. + +Upgrade to WebSocket (or WebSocket handshake) can be done in two ways. + +The first way is to use `net/http` server: + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + conn, _, _, err := ws.UpgradeHTTP(r, w) + }) + +The second and much more efficient way is so-called "zero-copy upgrade". It +avoids redundant allocations and copying of not used headers or other request +data. User decides by himself which data should be copied. + + ln, err := net.Listen("tcp", ":8080") + if err != nil { + // handle error + } + + conn, err := ln.Accept() + if err != nil { + // handle error + } + + handshake, err := ws.Upgrade(conn) + if err != nil { + // handle error + } + +For customization details see `ws.Upgrader` documentation. + +After WebSocket handshake you can work with connection in multiple ways. +That is, `ws` does not force the only one way of how to work with WebSocket: + + header, err := ws.ReadHeader(conn) + if err != nil { + // handle err + } + + buf := make([]byte, header.Length) + _, err := io.ReadFull(conn, buf) + if err != nil { + // handle err + } + + resp := ws.NewBinaryFrame([]byte("hello, world!")) + if err := ws.WriteFrame(conn, frame); err != nil { + // handle err + } + +As you can see, it stream friendly: + + const N = 42 + + ws.WriteHeader(ws.Header{ + Fin: true, + Length: N, + OpCode: ws.OpBinary, + }) + + io.CopyN(conn, rand.Reader, N) + +Or: + + header, err := ws.ReadHeader(conn) + if err != nil { + // handle err + } + + io.CopyN(ioutil.Discard, conn, header.Length) + +For more info see the documentation. +*/ +package ws diff --git a/vendor/github.com/gobwas/ws/errors.go b/vendor/github.com/gobwas/ws/errors.go new file mode 100644 index 00000000..48fce3b7 --- /dev/null +++ b/vendor/github.com/gobwas/ws/errors.go @@ -0,0 +1,54 @@ +package ws + +// RejectOption represents an option used to control the way connection is +// rejected. +type RejectOption func(*rejectConnectionError) + +// RejectionReason returns an option that makes connection to be rejected with +// given reason. +func RejectionReason(reason string) RejectOption { + return func(err *rejectConnectionError) { + err.reason = reason + } +} + +// RejectionStatus returns an option that makes connection to be rejected with +// given HTTP status code. +func RejectionStatus(code int) RejectOption { + return func(err *rejectConnectionError) { + err.code = code + } +} + +// RejectionHeader returns an option that makes connection to be rejected with +// given HTTP headers. +func RejectionHeader(h HandshakeHeader) RejectOption { + return func(err *rejectConnectionError) { + err.header = h + } +} + +// RejectConnectionError constructs an error that could be used to control the way +// handshake is rejected by Upgrader. +func RejectConnectionError(options ...RejectOption) error { + err := new(rejectConnectionError) + for _, opt := range options { + opt(err) + } + return err +} + +// rejectConnectionError represents a rejection of upgrade error. +// +// It can be returned by Upgrader's On* hooks to control the way WebSocket +// handshake is rejected. +type rejectConnectionError struct { + reason string + code int + header HandshakeHeader +} + +// Error implements error interface. +func (r *rejectConnectionError) Error() string { + return r.reason +} diff --git a/vendor/github.com/gobwas/ws/frame.go b/vendor/github.com/gobwas/ws/frame.go new file mode 100644 index 00000000..f157ee3e --- /dev/null +++ b/vendor/github.com/gobwas/ws/frame.go @@ -0,0 +1,389 @@ +package ws + +import ( + "bytes" + "encoding/binary" + "math/rand" +) + +// Constants defined by specification. +const ( + // All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented. + MaxControlFramePayloadSize = 125 +) + +// OpCode represents operation code. +type OpCode byte + +// Operation codes defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +const ( + OpContinuation OpCode = 0x0 + OpText OpCode = 0x1 + OpBinary OpCode = 0x2 + OpClose OpCode = 0x8 + OpPing OpCode = 0x9 + OpPong OpCode = 0xa +) + +// IsControl checks whether the c is control operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func (c OpCode) IsControl() bool { + // RFC6455: Control frames are identified by opcodes where + // the most significant bit of the opcode is 1. + // + // Note that OpCode is only 4 bit length. + return c&0x8 != 0 +} + +// IsData checks whether the c is data operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +func (c OpCode) IsData() bool { + // RFC6455: Data frames (e.g., non-control frames) are identified by opcodes + // where the most significant bit of the opcode is 0. + // + // Note that OpCode is only 4 bit length. + return c&0x8 == 0 +} + +// IsReserved checks whether the c is reserved operation code. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func (c OpCode) IsReserved() bool { + // RFC6455: + // %x3-7 are reserved for further non-control frames + // %xB-F are reserved for further control frames + return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf) +} + +// StatusCode represents the encoded reason for closure of websocket connection. +// +// There are few helper methods on StatusCode that helps to define a range in +// which given code is lay in. accordingly to ranges defined in specification. +// +// See https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode uint16 + +// StatusCodeRange describes range of StatusCode values. +type StatusCodeRange struct { + Min, Max StatusCode +} + +// Status code ranges defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-7.4.2 +var ( + StatusRangeNotInUse = StatusCodeRange{0, 999} + StatusRangeProtocol = StatusCodeRange{1000, 2999} + StatusRangeApplication = StatusCodeRange{3000, 3999} + StatusRangePrivate = StatusCodeRange{4000, 4999} +) + +// Status codes defined by specification. +// See https://tools.ietf.org/html/rfc6455#section-7.4.1 +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + StatusNoMeaningYet StatusCode = 1004 + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExt StatusCode = 1010 + StatusInternalServerError StatusCode = 1011 + StatusTLSHandshake StatusCode = 1015 + + // StatusAbnormalClosure is a special code designated for use in + // applications. + StatusAbnormalClosure StatusCode = 1006 + + // StatusNoStatusRcvd is a special code designated for use in applications. + StatusNoStatusRcvd StatusCode = 1005 +) + +// In reports whether the code is defined in given range. +func (s StatusCode) In(r StatusCodeRange) bool { + return r.Min <= s && s <= r.Max +} + +// Empty reports whether the code is empty. +// Empty code has no any meaning neither app level codes nor other. +// This method is useful just to check that code is golang default value 0. +func (s StatusCode) Empty() bool { + return s == 0 +} + +// IsNotUsed reports whether the code is predefined in not used range. +func (s StatusCode) IsNotUsed() bool { + return s.In(StatusRangeNotInUse) +} + +// IsApplicationSpec reports whether the code should be defined by +// application, framework or libraries specification. +func (s StatusCode) IsApplicationSpec() bool { + return s.In(StatusRangeApplication) +} + +// IsPrivateSpec reports whether the code should be defined privately. +func (s StatusCode) IsPrivateSpec() bool { + return s.In(StatusRangePrivate) +} + +// IsProtocolSpec reports whether the code should be defined by protocol specification. +func (s StatusCode) IsProtocolSpec() bool { + return s.In(StatusRangeProtocol) +} + +// IsProtocolDefined reports whether the code is already defined by protocol specification. +func (s StatusCode) IsProtocolDefined() bool { + switch s { + case StatusNormalClosure, + StatusGoingAway, + StatusProtocolError, + StatusUnsupportedData, + StatusInvalidFramePayloadData, + StatusPolicyViolation, + StatusMessageTooBig, + StatusMandatoryExt, + StatusInternalServerError, + StatusNoStatusRcvd, + StatusAbnormalClosure, + StatusTLSHandshake: + return true + } + return false +} + +// IsProtocolReserved reports whether the code is defined by protocol specification +// to be reserved only for application usage purpose. +func (s StatusCode) IsProtocolReserved() bool { + switch s { + // [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a + // Close control frame by an endpoint. + case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return true + default: + return false + } +} + +// Compiled control frames for common use cases. +// For construct-serialize optimizations. +var ( + CompiledPing = MustCompileFrame(NewPingFrame(nil)) + CompiledPong = MustCompileFrame(NewPongFrame(nil)) + CompiledClose = MustCompileFrame(NewCloseFrame(nil)) + + CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure) + CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway) + CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError) + CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData) + CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet) + CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData) + CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation) + CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig) + CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt) + CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError) + CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake) +) + +// Header represents websocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + Rsv byte + OpCode OpCode + Masked bool + Mask [4]byte + Length int64 +} + +// Rsv1 reports whether the header has first rsv bit set. +func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 } + +// Rsv2 reports whether the header has second rsv bit set. +func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 } + +// Rsv3 reports whether the header has third rsv bit set. +func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 } + +// Frame represents websocket frame. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Frame struct { + Header Header + Payload []byte +} + +// NewFrame creates frame with given operation code, +// flag of completeness and payload bytes. +func NewFrame(op OpCode, fin bool, p []byte) Frame { + return Frame{ + Header: Header{ + Fin: fin, + OpCode: op, + Length: int64(len(p)), + }, + Payload: p, + } +} + +// NewTextFrame creates text frame with p as payload. +// Note that p is not copied. +func NewTextFrame(p []byte) Frame { + return NewFrame(OpText, true, p) +} + +// NewBinaryFrame creates binary frame with p as payload. +// Note that p is not copied. +func NewBinaryFrame(p []byte) Frame { + return NewFrame(OpBinary, true, p) +} + +// NewPingFrame creates ping frame with p as payload. +// Note that p is not copied. +// Note that p must have length of MaxControlFramePayloadSize bytes or less due +// to RFC. +func NewPingFrame(p []byte) Frame { + return NewFrame(OpPing, true, p) +} + +// NewPongFrame creates pong frame with p as payload. +// Note that p is not copied. +// Note that p must have length of MaxControlFramePayloadSize bytes or less due +// to RFC. +func NewPongFrame(p []byte) Frame { + return NewFrame(OpPong, true, p) +} + +// NewCloseFrame creates close frame with given close body. +// Note that p is not copied. +// Note that p must have length of MaxControlFramePayloadSize bytes or less due +// to RFC. +func NewCloseFrame(p []byte) Frame { + return NewFrame(OpClose, true, p) +} + +// NewCloseFrameBody encodes a closure code and a reason into a binary +// representation. +// +// It returns slice which is at most MaxControlFramePayloadSize bytes length. +// If the reason is too big it will be cropped to fit the limit defined by the +// spec. +// +// See https://tools.ietf.org/html/rfc6455#section-5.5 +func NewCloseFrameBody(code StatusCode, reason string) []byte { + n := min(2+len(reason), MaxControlFramePayloadSize) + p := make([]byte, n) + + crop := min(MaxControlFramePayloadSize-2, len(reason)) + PutCloseFrameBody(p, code, reason[:crop]) + + return p +} + +// PutCloseFrameBody encodes code and reason into buf. +// +// It will panic if the buffer is too small to accommodate a code or a reason. +// +// PutCloseFrameBody does not check buffer to be RFC compliant, but note that +// by RFC it must be at most MaxControlFramePayloadSize. +func PutCloseFrameBody(p []byte, code StatusCode, reason string) { + _ = p[1+len(reason)] + binary.BigEndian.PutUint16(p, uint16(code)) + copy(p[2:], reason) +} + +// MaskFrame masks frame and returns frame with masked payload and Mask header's field set. +// Note that it copies f payload to prevent collisions. +// For less allocations you could use MaskFrameInPlace or construct frame manually. +func MaskFrame(f Frame) Frame { + return MaskFrameWith(f, NewMask()) +} + +// MaskFrameWith masks frame with given mask and returns frame +// with masked payload and Mask header's field set. +// Note that it copies f payload to prevent collisions. +// For less allocations you could use MaskFrameInPlaceWith or construct frame manually. +func MaskFrameWith(f Frame, mask [4]byte) Frame { + // TODO(gobwas): check CopyCipher ws copy() Cipher(). + p := make([]byte, len(f.Payload)) + copy(p, f.Payload) + f.Payload = p + return MaskFrameInPlaceWith(f, mask) +} + +// MaskFrameInPlace masks frame and returns frame with masked payload and Mask +// header's field set. +// Note that it applies xor cipher to f.Payload without copying, that is, it +// modifies f.Payload inplace. +func MaskFrameInPlace(f Frame) Frame { + return MaskFrameInPlaceWith(f, NewMask()) +} + +// MaskFrameInPlaceWith masks frame with given mask and returns frame +// with masked payload and Mask header's field set. +// Note that it applies xor cipher to f.Payload without copying, that is, it +// modifies f.Payload inplace. +func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame { + f.Header.Masked = true + f.Header.Mask = m + Cipher(f.Payload, m, 0) + return f +} + +// NewMask creates new random mask. +func NewMask() (ret [4]byte) { + binary.BigEndian.PutUint32(ret[:], rand.Uint32()) + return +} + +// CompileFrame returns byte representation of given frame. +// In terms of memory consumption it is useful to precompile static frames +// which are often used. +func CompileFrame(f Frame) (bts []byte, err error) { + buf := bytes.NewBuffer(make([]byte, 0, 16)) + err = WriteFrame(buf, f) + bts = buf.Bytes() + return +} + +// MustCompileFrame is like CompileFrame but panics if frame can not be +// encoded. +func MustCompileFrame(f Frame) []byte { + bts, err := CompileFrame(f) + if err != nil { + panic(err) + } + return bts +} + +// Rsv creates rsv byte representation. +func Rsv(r1, r2, r3 bool) (rsv byte) { + if r1 { + rsv |= bit5 + } + if r2 { + rsv |= bit6 + } + if r3 { + rsv |= bit7 + } + return rsv +} + +func makeCloseFrame(code StatusCode) Frame { + return NewCloseFrame(NewCloseFrameBody(code, "")) +} + +var ( + closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure) + closeFrameGoingAway = makeCloseFrame(StatusGoingAway) + closeFrameProtocolError = makeCloseFrame(StatusProtocolError) + closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData) + closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet) + closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData) + closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation) + closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig) + closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt) + closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError) + closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake) +) diff --git a/vendor/github.com/gobwas/ws/http.go b/vendor/github.com/gobwas/ws/http.go new file mode 100644 index 00000000..e18df441 --- /dev/null +++ b/vendor/github.com/gobwas/ws/http.go @@ -0,0 +1,468 @@ +package ws + +import ( + "bufio" + "bytes" + "io" + "net/http" + "net/textproto" + "net/url" + "strconv" + + "github.com/gobwas/httphead" +) + +const ( + crlf = "\r\n" + colonAndSpace = ": " + commaAndSpace = ", " +) + +const ( + textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n" +) + +var ( + textHeadBadRequest = statusText(http.StatusBadRequest) + textHeadInternalServerError = statusText(http.StatusInternalServerError) + textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired) + + textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol) + textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod) + textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost) + textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade) + textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection) + textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept) + textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey) + textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion) + textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired) +) + +var ( + headerHost = "Host" + headerUpgrade = "Upgrade" + headerConnection = "Connection" + headerSecVersion = "Sec-WebSocket-Version" + headerSecProtocol = "Sec-WebSocket-Protocol" + headerSecExtensions = "Sec-WebSocket-Extensions" + headerSecKey = "Sec-WebSocket-Key" + headerSecAccept = "Sec-WebSocket-Accept" + + headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost) + headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade) + headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection) + headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion) + headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol) + headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions) + headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey) + headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept) +) + +var ( + specHeaderValueUpgrade = []byte("websocket") + specHeaderValueConnection = []byte("Upgrade") + specHeaderValueConnectionLower = []byte("upgrade") + specHeaderValueSecVersion = []byte("13") +) + +var ( + httpVersion1_0 = []byte("HTTP/1.0") + httpVersion1_1 = []byte("HTTP/1.1") + httpVersionPrefix = []byte("HTTP/") +) + +type httpRequestLine struct { + method, uri []byte + major, minor int +} + +type httpResponseLine struct { + major, minor int + status int + reason []byte +} + +// httpParseRequestLine parses http request line like "GET / HTTP/1.0". +func httpParseRequestLine(line []byte) (req httpRequestLine, err error) { + var proto []byte + req.method, req.uri, proto = bsplit3(line, ' ') + + var ok bool + req.major, req.minor, ok = httpParseVersion(proto) + if !ok { + err = ErrMalformedRequest + return + } + + return +} + +func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) { + var ( + proto []byte + status []byte + ) + proto, status, resp.reason = bsplit3(line, ' ') + + var ok bool + resp.major, resp.minor, ok = httpParseVersion(proto) + if !ok { + return resp, ErrMalformedResponse + } + + var convErr error + resp.status, convErr = asciiToInt(status) + if convErr != nil { + return resp, ErrMalformedResponse + } + + return resp, nil +} + +// httpParseVersion parses major and minor version of HTTP protocol. It returns +// parsed values and true if parse is ok. +func httpParseVersion(bts []byte) (major, minor int, ok bool) { + switch { + case bytes.Equal(bts, httpVersion1_0): + return 1, 0, true + case bytes.Equal(bts, httpVersion1_1): + return 1, 1, true + case len(bts) < 8: + return + case !bytes.Equal(bts[:5], httpVersionPrefix): + return + } + + bts = bts[5:] + + dot := bytes.IndexByte(bts, '.') + if dot == -1 { + return + } + var err error + major, err = asciiToInt(bts[:dot]) + if err != nil { + return + } + minor, err = asciiToInt(bts[dot+1:]) + if err != nil { + return + } + + return major, minor, true +} + +// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed +// values and true if parse is ok. +func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) { + colon := bytes.IndexByte(line, ':') + if colon == -1 { + return + } + + k = btrim(line[:colon]) + // TODO(gobwas): maybe use just lower here? + canonicalizeHeaderKey(k) + + v = btrim(line[colon+1:]) + + return k, v, true +} + +// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing, +// that key is already canonical. This helps to increase performance. +func httpGetHeader(h http.Header, key string) string { + if h == nil { + return "" + } + v := h[key] + if len(v) == 0 { + return "" + } + return v[0] +} + +// The request MAY include a header field with the name +// |Sec-WebSocket-Protocol|. If present, this value indicates one or more +// comma-separated subprotocol the client wishes to speak, ordered by +// preference. The elements that comprise this value MUST be non-empty strings +// with characters in the range U+0021 to U+007E not including separator +// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF +// for the value of this header field is 1#token, where the definitions of +// constructs and rules are as given in [RFC2616]. +func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) { + ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool { + if check(btsToString(v)) { + ret = string(v) + return false + } + return true + }) + return +} +func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) { + var selected []byte + ok = httphead.ScanTokens(h, func(v []byte) bool { + if check(v) { + selected = v + return false + } + return true + }) + if ok && selected != nil { + return string(selected), true + } + return +} + +func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { + return btsSelectExtensions(strToBytes(h), selected, check) +} + +func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) { + s := httphead.OptionSelector{ + Flags: httphead.SelectUnique | httphead.SelectCopy, + Check: check, + } + return s.Select(h, selected) +} + +func httpWriteHeader(bw *bufio.Writer, key, value string) { + httpWriteHeaderKey(bw, key) + bw.WriteString(value) + bw.WriteString(crlf) +} + +func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) { + httpWriteHeaderKey(bw, key) + bw.Write(value) + bw.WriteString(crlf) +} + +func httpWriteHeaderKey(bw *bufio.Writer, key string) { + bw.WriteString(key) + bw.WriteString(colonAndSpace) +} + +func httpWriteUpgradeRequest( + bw *bufio.Writer, + u *url.URL, + nonce []byte, + protocols []string, + extensions []httphead.Option, + header HandshakeHeader, +) { + bw.WriteString("GET ") + bw.WriteString(u.RequestURI()) + bw.WriteString(" HTTP/1.1\r\n") + + httpWriteHeader(bw, headerHost, u.Host) + + httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade) + httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection) + httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion) + + // NOTE: write nonce bytes as a string to prevent heap allocation – + // WriteString() copy given string into its inner buffer, unlike Write() + // which may write p directly to the underlying io.Writer – which in turn + // will lead to p escape. + httpWriteHeader(bw, headerSecKey, btsToString(nonce)) + + if len(protocols) > 0 { + httpWriteHeaderKey(bw, headerSecProtocol) + for i, p := range protocols { + if i > 0 { + bw.WriteString(commaAndSpace) + } + bw.WriteString(p) + } + bw.WriteString(crlf) + } + + if len(extensions) > 0 { + httpWriteHeaderKey(bw, headerSecExtensions) + httphead.WriteOptions(bw, extensions) + bw.WriteString(crlf) + } + + if header != nil { + header.WriteTo(bw) + } + + bw.WriteString(crlf) +} + +func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) { + bw.WriteString(textHeadUpgrade) + + httpWriteHeaderKey(bw, headerSecAccept) + writeAccept(bw, nonce) + bw.WriteString(crlf) + + if hs.Protocol != "" { + httpWriteHeader(bw, headerSecProtocol, hs.Protocol) + } + if len(hs.Extensions) > 0 { + httpWriteHeaderKey(bw, headerSecExtensions) + httphead.WriteOptions(bw, hs.Extensions) + bw.WriteString(crlf) + } + if header != nil { + header(bw) + } + + bw.WriteString(crlf) +} + +func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) { + switch code { + case http.StatusBadRequest: + bw.WriteString(textHeadBadRequest) + case http.StatusInternalServerError: + bw.WriteString(textHeadInternalServerError) + case http.StatusUpgradeRequired: + bw.WriteString(textHeadUpgradeRequired) + default: + writeStatusText(bw, code) + } + + // Write custom headers. + if header != nil { + header(bw) + } + + switch err { + case ErrHandshakeBadProtocol: + bw.WriteString(textTailErrHandshakeBadProtocol) + case ErrHandshakeBadMethod: + bw.WriteString(textTailErrHandshakeBadMethod) + case ErrHandshakeBadHost: + bw.WriteString(textTailErrHandshakeBadHost) + case ErrHandshakeBadUpgrade: + bw.WriteString(textTailErrHandshakeBadUpgrade) + case ErrHandshakeBadConnection: + bw.WriteString(textTailErrHandshakeBadConnection) + case ErrHandshakeBadSecAccept: + bw.WriteString(textTailErrHandshakeBadSecAccept) + case ErrHandshakeBadSecKey: + bw.WriteString(textTailErrHandshakeBadSecKey) + case ErrHandshakeBadSecVersion: + bw.WriteString(textTailErrHandshakeBadSecVersion) + case ErrHandshakeUpgradeRequired: + bw.WriteString(textTailErrUpgradeRequired) + case nil: + bw.WriteString(crlf) + default: + writeErrorText(bw, err) + } +} + +func writeStatusText(bw *bufio.Writer, code int) { + bw.WriteString("HTTP/1.1 ") + bw.WriteString(strconv.Itoa(code)) + bw.WriteByte(' ') + bw.WriteString(http.StatusText(code)) + bw.WriteString(crlf) + bw.WriteString("Content-Type: text/plain; charset=utf-8") + bw.WriteString(crlf) +} + +func writeErrorText(bw *bufio.Writer, err error) { + body := err.Error() + bw.WriteString("Content-Length: ") + bw.WriteString(strconv.Itoa(len(body))) + bw.WriteString(crlf) + bw.WriteString(crlf) + bw.WriteString(body) +} + +// httpError is like the http.Error with WebSocket context exception. +func httpError(w http.ResponseWriter, body string, code int) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + w.WriteHeader(code) + w.Write([]byte(body)) +} + +// statusText is a non-performant status text generator. +// NOTE: Used only to generate constants. +func statusText(code int) string { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writeStatusText(bw, code) + bw.Flush() + return buf.String() +} + +// errorText is a non-performant error text generator. +// NOTE: Used only to generate constants. +func errorText(err error) string { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writeErrorText(bw, err) + bw.Flush() + return buf.String() +} + +// HandshakeHeader is the interface that writes both upgrade request or +// response headers into a given io.Writer. +type HandshakeHeader interface { + io.WriterTo +} + +// HandshakeHeaderString is an adapter to allow the use of headers represented +// by ordinary string as HandshakeHeader. +type HandshakeHeaderString string + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) { + n, err := io.WriteString(w, string(s)) + return int64(n), err +} + +// HandshakeHeaderBytes is an adapter to allow the use of headers represented +// by ordinary slice of bytes as HandshakeHeader. +type HandshakeHeaderBytes []byte + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(b) + return int64(n), err +} + +// HandshakeHeaderFunc is an adapter to allow the use of headers represented by +// ordinary function as HandshakeHeader. +type HandshakeHeaderFunc func(io.Writer) (int64, error) + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) { + return f(w) +} + +// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as +// HandshakeHeader. +type HandshakeHeaderHTTP http.Header + +// WriteTo implements HandshakeHeader (and io.WriterTo) interface. +func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) { + wr := writer{w: w} + err := http.Header(h).Write(&wr) + return wr.n, err +} + +type writer struct { + n int64 + w io.Writer +} + +func (w *writer) WriteString(s string) (int, error) { + n, err := io.WriteString(w.w, s) + w.n += int64(n) + return n, err +} + +func (w *writer) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + w.n += int64(n) + return n, err +} diff --git a/vendor/github.com/gobwas/ws/nonce.go b/vendor/github.com/gobwas/ws/nonce.go new file mode 100644 index 00000000..e694da7c --- /dev/null +++ b/vendor/github.com/gobwas/ws/nonce.go @@ -0,0 +1,80 @@ +package ws + +import ( + "bufio" + "bytes" + "crypto/sha1" + "encoding/base64" + "fmt" + "math/rand" +) + +const ( + // RFC6455: The value of this header field MUST be a nonce consisting of a + // randomly selected 16-byte value that has been base64-encoded (see + // Section 4 of [RFC4648]). The nonce MUST be selected randomly for each + // connection. + nonceKeySize = 16 + nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) + + // RFC6455: The value of this header field is constructed by concatenating + // /key/, defined above in step 4 in Section 4.2.2, with the string + // "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this + // concatenated value to obtain a 20-byte value and base64- encoding (see + // Section 4 of [RFC4648]) this 20-byte hash. + acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size) +) + +// initNonce fills given slice with random base64-encoded nonce bytes. +func initNonce(dst []byte) { + // NOTE: bts does not escape. + bts := make([]byte, nonceKeySize) + if _, err := rand.Read(bts); err != nil { + panic(fmt.Sprintf("rand read error: %s", err)) + } + base64.StdEncoding.Encode(dst, bts) +} + +// checkAcceptFromNonce reports whether given accept bytes are valid for given +// nonce bytes. +func checkAcceptFromNonce(accept, nonce []byte) bool { + if len(accept) != acceptSize { + return false + } + // NOTE: expect does not escape. + expect := make([]byte, acceptSize) + initAcceptFromNonce(expect, nonce) + return bytes.Equal(expect, accept) +} + +// initAcceptFromNonce fills given slice with accept bytes generated from given +// nonce bytes. Given buffer should be exactly acceptSize bytes. +func initAcceptFromNonce(accept, nonce []byte) { + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + if len(accept) != acceptSize { + panic("accept buffer is invalid") + } + if len(nonce) != nonceSize { + panic("nonce is invalid") + } + + p := make([]byte, nonceSize+len(magic)) + copy(p[:nonceSize], nonce) + copy(p[nonceSize:], magic) + + sum := sha1.Sum(p) + base64.StdEncoding.Encode(accept, sum[:]) + + return +} + +func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) { + accept := make([]byte, acceptSize) + initAcceptFromNonce(accept, nonce) + // NOTE: write accept bytes as a string to prevent heap allocation – + // WriteString() copy given string into its inner buffer, unlike Write() + // which may write p directly to the underlying io.Writer – which in turn + // will lead to p escape. + return bw.WriteString(btsToString(accept)) +} diff --git a/vendor/github.com/gobwas/ws/read.go b/vendor/github.com/gobwas/ws/read.go new file mode 100644 index 00000000..bc653e46 --- /dev/null +++ b/vendor/github.com/gobwas/ws/read.go @@ -0,0 +1,147 @@ +package ws + +import ( + "encoding/binary" + "fmt" + "io" +) + +// Errors used by frame reader. +var ( + ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0") + ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits") +) + +// ReadHeader reads a frame header from r. +func ReadHeader(r io.Reader) (h Header, err error) { + // Make slice of bytes with capacity 12 that could hold any header. + // + // The maximum header size is 14, but due to the 2 hop reads, + // after first hop that reads first 2 constant bytes, we could reuse 2 bytes. + // So 14 - 2 = 12. + bts := make([]byte, 2, MaxHeaderSize-2) + + // Prepare to hold first 2 bytes to choose size of next read. + _, err = io.ReadFull(r, bts) + if err != nil { + return + } + + h.Fin = bts[0]&bit0 != 0 + h.Rsv = (bts[0] & 0x70) >> 4 + h.OpCode = OpCode(bts[0] & 0x0f) + + var extra int + + if bts[1]&bit0 != 0 { + h.Masked = true + extra += 4 + } + + length := bts[1] & 0x7f + switch { + case length < 126: + h.Length = int64(length) + + case length == 126: + extra += 2 + + case length == 127: + extra += 8 + + default: + err = ErrHeaderLengthUnexpected + return + } + + if extra == 0 { + return + } + + // Increase len of bts to extra bytes need to read. + // Overwrite first 2 bytes that was read before. + bts = bts[:extra] + _, err = io.ReadFull(r, bts) + if err != nil { + return + } + + switch { + case length == 126: + h.Length = int64(binary.BigEndian.Uint16(bts[:2])) + bts = bts[2:] + + case length == 127: + if bts[0]&0x80 != 0 { + err = ErrHeaderLengthMSB + return + } + h.Length = int64(binary.BigEndian.Uint64(bts[:8])) + bts = bts[8:] + } + + if h.Masked { + copy(h.Mask[:], bts) + } + + return +} + +// ReadFrame reads a frame from r. +// It is not designed for high optimized use case cause it makes allocation +// for frame.Header.Length size inside to read frame payload into. +// +// Note that ReadFrame does not unmask payload. +func ReadFrame(r io.Reader) (f Frame, err error) { + f.Header, err = ReadHeader(r) + if err != nil { + return + } + + if f.Header.Length > 0 { + // int(f.Header.Length) is safe here cause we have + // checked it for overflow above in ReadHeader. + f.Payload = make([]byte, int(f.Header.Length)) + _, err = io.ReadFull(r, f.Payload) + } + + return +} + +// MustReadFrame is like ReadFrame but panics if frame can not be read. +func MustReadFrame(r io.Reader) Frame { + f, err := ReadFrame(r) + if err != nil { + panic(err) + } + return f +} + +// ParseCloseFrameData parses close frame status code and closure reason if any provided. +// If there is no status code in the payload +// the empty status code is returned (code.Empty()) with empty string as a reason. +func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) { + if len(payload) < 2 { + // We returning empty StatusCode here, preventing the situation + // when endpoint really sent code 1005 and we should return ProtocolError on that. + // + // In other words, we ignoring this rule [RFC6455:7.1.5]: + // If this Close control frame contains no status code, _The WebSocket + // Connection Close Code_ is considered to be 1005. + return + } + code = StatusCode(binary.BigEndian.Uint16(payload)) + reason = string(payload[2:]) + return +} + +// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing +// that it does not copies payload bytes into reason, but prepares unsafe cast. +func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) { + if len(payload) < 2 { + return + } + code = StatusCode(binary.BigEndian.Uint16(payload)) + reason = btsToString(payload[2:]) + return +} diff --git a/vendor/github.com/gobwas/ws/server.go b/vendor/github.com/gobwas/ws/server.go new file mode 100644 index 00000000..62ad9c7f --- /dev/null +++ b/vendor/github.com/gobwas/ws/server.go @@ -0,0 +1,607 @@ +package ws + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/gobwas/httphead" + "github.com/gobwas/pool/pbufio" +) + +// Constants used by ConnUpgrader. +const ( + DefaultServerReadBufferSize = 4096 + DefaultServerWriteBufferSize = 512 +) + +// Errors used by both client and server when preparing WebSocket handshake. +var ( + ErrHandshakeBadProtocol = RejectConnectionError( + RejectionStatus(http.StatusHTTPVersionNotSupported), + RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")), + ) + ErrHandshakeBadMethod = RejectConnectionError( + RejectionStatus(http.StatusMethodNotAllowed), + RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")), + ) + ErrHandshakeBadHost = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)), + ) + ErrHandshakeBadUpgrade = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)), + ) + ErrHandshakeBadConnection = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)), + ) + ErrHandshakeBadSecAccept = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)), + ) + ErrHandshakeBadSecKey = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)), + ) + ErrHandshakeBadSecVersion = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)), + ) +) + +// ErrMalformedResponse is returned by Dialer to indicate that server response +// can not be parsed. +var ErrMalformedResponse = fmt.Errorf("malformed HTTP response") + +// ErrMalformedRequest is returned when HTTP request can not be parsed. +var ErrMalformedRequest = RejectConnectionError( + RejectionStatus(http.StatusBadRequest), + RejectionReason("malformed HTTP request"), +) + +// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that +// connection is rejected because given WebSocket version is malformed. +// +// According to RFC6455: +// If this version does not match a version understood by the server, the +// server MUST abort the WebSocket handshake described in this section and +// instead send an appropriate HTTP error code (such as 426 Upgrade Required) +// and a |Sec-WebSocket-Version| header field indicating the version(s) the +// server is capable of understanding. +var ErrHandshakeUpgradeRequired = RejectConnectionError( + RejectionStatus(http.StatusUpgradeRequired), + RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")), + RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)), +) + +// ErrNotHijacker is an error returned when http.ResponseWriter does not +// implement http.Hijacker interface. +var ErrNotHijacker = RejectConnectionError( + RejectionStatus(http.StatusInternalServerError), + RejectionReason("given http.ResponseWriter is not a http.Hijacker"), +) + +// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by +// UpgradeHTTP function. +var DefaultHTTPUpgrader HTTPUpgrader + +// UpgradeHTTP is like HTTPUpgrader{}.Upgrade(). +func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) { + return DefaultHTTPUpgrader.Upgrade(r, w) +} + +// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade +// function. +var DefaultUpgrader Upgrader + +// Upgrade is like Upgrader{}.Upgrade(). +func Upgrade(conn io.ReadWriter) (Handshake, error) { + return DefaultUpgrader.Upgrade(conn) +} + +// HTTPUpgrader contains options for upgrading connection to websocket from +// net/http Handler arguments. +type HTTPUpgrader struct { + // Timeout is the maximum amount of time an Upgrade() will spent while + // writing handshake response. + // + // The default is no timeout. + Timeout time.Duration + + // Header is an optional http.Header mapping that could be used to + // write additional headers to the handshake response. + // + // Note that if present, it will be written in any result of handshake. + Header http.Header + + // Protocol is the select function that is used to select subprotocol from + // list requested by client. If this field is set, then the first matched + // protocol is sent to a client as negotiated. + Protocol func(string) bool + + // Extension is the select function that is used to select extensions from + // list requested by client. If this field is set, then the all matched + // extensions are sent to a client as negotiated. + Extension func(httphead.Option) bool +} + +// Upgrade upgrades http connection to the websocket connection. +// +// It hijacks net.Conn from w and returns received net.Conn and +// bufio.ReadWriter. On successful handshake it returns Handshake struct +// describing handshake info. +func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) { + // Hijack connection first to get the ability to write rejection errors the + // same way as in Upgrader. + hj, ok := w.(http.Hijacker) + if ok { + conn, rw, err = hj.Hijack() + } else { + err = ErrNotHijacker + } + if err != nil { + httpError(w, err.Error(), http.StatusInternalServerError) + return + } + + // See https://tools.ietf.org/html/rfc6455#section-4.1 + // The method of the request MUST be GET, and the HTTP version MUST be at least 1.1. + var nonce string + if r.Method != http.MethodGet { + err = ErrHandshakeBadMethod + } else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) { + err = ErrHandshakeBadProtocol + } else if r.Host == "" { + err = ErrHandshakeBadHost + } else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") { + err = ErrHandshakeBadUpgrade + } else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") { + err = ErrHandshakeBadConnection + } else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize { + err = ErrHandshakeBadSecKey + } else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" { + // According to RFC6455: + // + // If this version does not match a version understood by the server, + // the server MUST abort the WebSocket handshake described in this + // section and instead send an appropriate HTTP error code (such as 426 + // Upgrade Required) and a |Sec-WebSocket-Version| header field + // indicating the version(s) the server is capable of understanding. + // + // So we branching here cause empty or not present version does not + // meet the ABNF rules of RFC6455: + // + // version = DIGIT | (NZDIGIT DIGIT) | + // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT) + // ; Limited to 0-255 range, with no leading zeros + // + // That is, if version is really invalid – we sent 426 status, if it + // not present or empty – it is 400. + if v != "" { + err = ErrHandshakeUpgradeRequired + } else { + err = ErrHandshakeBadSecVersion + } + } + if check := u.Protocol; err == nil && check != nil { + ps := r.Header[headerSecProtocolCanonical] + for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ { + var ok bool + hs.Protocol, ok = strSelectProtocol(ps[i], check) + if !ok { + err = ErrMalformedRequest + } + } + } + if check := u.Extension; err == nil && check != nil { + xs := r.Header[headerSecExtensionsCanonical] + for i := 0; i < len(xs) && err == nil; i++ { + var ok bool + hs.Extensions, ok = strSelectExtensions(xs[i], hs.Extensions, check) + if !ok { + err = ErrMalformedRequest + } + } + } + + // Clear deadlines set by server. + conn.SetDeadline(noDeadline) + if t := u.Timeout; t != 0 { + conn.SetWriteDeadline(time.Now().Add(t)) + defer conn.SetWriteDeadline(noDeadline) + } + + var header handshakeHeader + if h := u.Header; h != nil { + header[0] = HandshakeHeaderHTTP(h) + } + if err == nil { + httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo) + err = rw.Writer.Flush() + } else { + var code int + if rej, ok := err.(*rejectConnectionError); ok { + code = rej.code + header[1] = rej.header + } + if code == 0 { + code = http.StatusInternalServerError + } + httpWriteResponseError(rw.Writer, err, code, header.WriteTo) + // Do not store Flush() error to not override already existing one. + rw.Writer.Flush() + } + return +} + +// Upgrader contains options for upgrading connection to websocket. +type Upgrader struct { + // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. + // They used to read and write http data while upgrading to WebSocket. + // Allocated buffers are pooled with sync.Pool to avoid extra allocations. + // + // If a size is zero then default value is used. + // + // Usually it is useful to set read buffer size bigger than write buffer + // size because incoming request could contain long header values, such as + // Cookie. Response, in other way, could be big only if user write multiple + // custom headers. Usually response takes less than 256 bytes. + ReadBufferSize, WriteBufferSize int + + // Protocol is a select function that is used to select subprotocol + // from list requested by client. If this field is set, then the first matched + // protocol is sent to a client as negotiated. + // + // The argument is only valid until the callback returns. + Protocol func([]byte) bool + + // ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually. + // Note that returned bytes must be valid until Upgrade returns. + // If ProtocolCustom is set, it used instead of Protocol function. + ProtocolCustom func([]byte) (string, bool) + + // Extension is a select function that is used to select extensions + // from list requested by client. If this field is set, then the all matched + // extensions are sent to a client as negotiated. + // + // The argument is only valid until the callback returns. + // + // According to the RFC6455 order of extensions passed by a client is + // significant. That is, returning true from this function means that no + // other extension with the same name should be checked because server + // accepted the most preferable extension right now: + // "Note that the order of extensions is significant. Any interactions between + // multiple extensions MAY be defined in the documents defining the extensions. + // In the absence of such definitions, the interpretation is that the header + // fields listed by the client in its request represent a preference of the + // header fields it wishes to use, with the first options listed being most + // preferable." + Extension func(httphead.Option) bool + + // ExtensionCustom allow user to parse Sec-WebSocket-Extensions header manually. + // Note that returned options should be valid until Upgrade returns. + // If ExtensionCustom is set, it used instead of Extension function. + ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool) + + // Header is an optional HandshakeHeader instance that could be used to + // write additional headers to the handshake response. + // + // It used instead of any key-value mappings to avoid allocations in user + // land. + // + // Note that if present, it will be written in any result of handshake. + Header HandshakeHeader + + // OnRequest is a callback that will be called after request line + // successful parsing. + // + // The arguments are only valid until the callback returns. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnRequest func(uri []byte) error + + // OnHost is a callback that will be called after "Host" header successful + // parsing. + // + // It is separated from OnHeader callback because the Host header must be + // present in each request since HTTP/1.1. Thus Host header is non-optional + // and required for every WebSocket handshake. + // + // The arguments are only valid until the callback returns. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnHost func(host []byte) error + + // OnHeader is a callback that will be called after successful parsing of + // header, that is not used during WebSocket handshake procedure. That is, + // it will be called with non-websocket headers, which could be relevant + // for application-level logic. + // + // The arguments are only valid until the callback returns. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnHeader func(key, value []byte) error + + // OnBeforeUpgrade is a callback that will be called before sending + // successful upgrade response. + // + // Setting OnBeforeUpgrade allows user to make final application-level + // checks and decide whether this connection is allowed to successfully + // upgrade to WebSocket. + // + // It must return non-nil either HandshakeHeader or error and never both. + // + // If returned error is non-nil then connection is rejected and response is + // sent with appropriate HTTP error code and body set to error message. + // + // RejectConnectionError could be used to get more control on response. + OnBeforeUpgrade func() (header HandshakeHeader, err error) +} + +// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn +// as connection with incoming HTTP Upgrade request. +// +// It is a caller responsibility to manage i/o timeouts on conn. +// +// Non-nil error means that request for the WebSocket upgrade is invalid or +// malformed and usually connection should be closed. +// Even when error is non-nil Upgrade will write appropriate response into +// connection in compliance with RFC. +func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) { + // headerSeen constants helps to report whether or not some header was seen + // during reading request bytes. + const ( + headerSeenHost = 1 << iota + headerSeenUpgrade + headerSeenConnection + headerSeenSecVersion + headerSeenSecKey + + // headerSeenAll is the value that we expect to receive at the end of + // headers read/parse loop. + headerSeenAll = 0 | + headerSeenHost | + headerSeenUpgrade | + headerSeenConnection | + headerSeenSecVersion | + headerSeenSecKey + ) + + // Prepare I/O buffers. + // TODO(gobwas): make it configurable. + br := pbufio.GetReader(conn, + nonZero(u.ReadBufferSize, DefaultServerReadBufferSize), + ) + bw := pbufio.GetWriter(conn, + nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize), + ) + defer func() { + pbufio.PutReader(br) + pbufio.PutWriter(bw) + }() + + // Read HTTP request line like "GET /ws HTTP/1.1". + rl, err := readLine(br) + if err != nil { + return + } + // Parse request line data like HTTP version, uri and method. + req, err := httpParseRequestLine(rl) + if err != nil { + return + } + + // Prepare stack-based handshake header list. + header := handshakeHeader{ + 0: u.Header, + } + + // Parse and check HTTP request. + // As RFC6455 says: + // The client's opening handshake consists of the following parts. If the + // server, while reading the handshake, finds that the client did not + // send a handshake that matches the description below (note that as per + // [RFC2616], the order of the header fields is not important), including + // but not limited to any violations of the ABNF grammar specified for + // the components of the handshake, the server MUST stop processing the + // client's handshake and return an HTTP response with an appropriate + // error code (such as 400 Bad Request). + // + // See https://tools.ietf.org/html/rfc6455#section-4.2.1 + + // An HTTP/1.1 or higher GET request, including a "Request-URI". + // + // Even if RFC says "1.1 or higher" without mentioning the part of the + // version, we apply it only to minor part. + switch { + case req.major != 1 || req.minor < 1: + // Abort processing the whole request because we do not even know how + // to actually parse it. + err = ErrHandshakeBadProtocol + + case btsToString(req.method) != http.MethodGet: + err = ErrHandshakeBadMethod + + default: + if onRequest := u.OnRequest; onRequest != nil { + err = onRequest(req.uri) + } + } + // Start headers read/parse loop. + var ( + // headerSeen reports which header was seen by setting corresponding + // bit on. + headerSeen byte + + nonce = make([]byte, nonceSize) + ) + for err == nil { + line, e := readLine(br) + if e != nil { + return hs, e + } + if len(line) == 0 { + // Blank line, no more lines to read. + break + } + + k, v, ok := httpParseHeaderLine(line) + if !ok { + err = ErrMalformedRequest + break + } + + switch btsToString(k) { + case headerHostCanonical: + headerSeen |= headerSeenHost + if onHost := u.OnHost; onHost != nil { + err = onHost(v) + } + + case headerUpgradeCanonical: + headerSeen |= headerSeenUpgrade + if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { + err = ErrHandshakeBadUpgrade + } + + case headerConnectionCanonical: + headerSeen |= headerSeenConnection + if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) { + err = ErrHandshakeBadConnection + } + + case headerSecVersionCanonical: + headerSeen |= headerSeenSecVersion + if !bytes.Equal(v, specHeaderValueSecVersion) { + err = ErrHandshakeUpgradeRequired + } + + case headerSecKeyCanonical: + headerSeen |= headerSeenSecKey + if len(v) != nonceSize { + err = ErrHandshakeBadSecKey + } else { + copy(nonce[:], v) + } + + case headerSecProtocolCanonical: + if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) { + var ok bool + if custom != nil { + hs.Protocol, ok = custom(v) + } else { + hs.Protocol, ok = btsSelectProtocol(v, check) + } + if !ok { + err = ErrMalformedRequest + } + } + + case headerSecExtensionsCanonical: + if custom, check := u.ExtensionCustom, u.Extension; custom != nil || check != nil { + var ok bool + if custom != nil { + hs.Extensions, ok = custom(v, hs.Extensions) + } else { + hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check) + } + if !ok { + err = ErrMalformedRequest + } + } + + default: + if onHeader := u.OnHeader; onHeader != nil { + err = onHeader(k, v) + } + } + } + switch { + case err == nil && headerSeen != headerSeenAll: + switch { + case headerSeen&headerSeenHost == 0: + // As RFC2616 says: + // A client MUST include a Host header field in all HTTP/1.1 + // request messages. If the requested URI does not include an + // Internet host name for the service being requested, then the + // Host header field MUST be given with an empty value. An + // HTTP/1.1 proxy MUST ensure that any request message it + // forwards does contain an appropriate Host header field that + // identifies the service being requested by the proxy. All + // Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad + // Request) status code to any HTTP/1.1 request message which + // lacks a Host header field. + err = ErrHandshakeBadHost + case headerSeen&headerSeenUpgrade == 0: + err = ErrHandshakeBadUpgrade + case headerSeen&headerSeenConnection == 0: + err = ErrHandshakeBadConnection + case headerSeen&headerSeenSecVersion == 0: + // In case of empty or not present version we do not send 426 status, + // because it does not meet the ABNF rules of RFC6455: + // + // version = DIGIT | (NZDIGIT DIGIT) | + // ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT) + // ; Limited to 0-255 range, with no leading zeros + // + // That is, if version is really invalid – we sent 426 status as above, if it + // not present – it is 400. + err = ErrHandshakeBadSecVersion + case headerSeen&headerSeenSecKey == 0: + err = ErrHandshakeBadSecKey + default: + panic("unknown headers state") + } + + case err == nil && u.OnBeforeUpgrade != nil: + header[1], err = u.OnBeforeUpgrade() + } + if err != nil { + var code int + if rej, ok := err.(*rejectConnectionError); ok { + code = rej.code + header[1] = rej.header + } + if code == 0 { + code = http.StatusInternalServerError + } + httpWriteResponseError(bw, err, code, header.WriteTo) + // Do not store Flush() error to not override already existing one. + bw.Flush() + return + } + + httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo) + err = bw.Flush() + + return +} + +type handshakeHeader [2]HandshakeHeader + +func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) { + for i := 0; i < len(hs) && err == nil; i++ { + if h := hs[i]; h != nil { + var m int64 + m, err = h.WriteTo(w) + n += m + } + } + return n, err +} diff --git a/vendor/github.com/gobwas/ws/server_test.s b/vendor/github.com/gobwas/ws/server_test.s new file mode 100644 index 00000000..e69de29b diff --git a/vendor/github.com/gobwas/ws/util.go b/vendor/github.com/gobwas/ws/util.go new file mode 100644 index 00000000..67ad906e --- /dev/null +++ b/vendor/github.com/gobwas/ws/util.go @@ -0,0 +1,214 @@ +package ws + +import ( + "bufio" + "bytes" + "fmt" + "reflect" + "unsafe" + + "github.com/gobwas/httphead" +) + +// SelectFromSlice creates accept function that could be used as Protocol/Extension +// select during upgrade. +func SelectFromSlice(accept []string) func(string) bool { + if len(accept) > 16 { + mp := make(map[string]struct{}, len(accept)) + for _, p := range accept { + mp[p] = struct{}{} + } + return func(p string) bool { + _, ok := mp[p] + return ok + } + } + return func(p string) bool { + for _, ok := range accept { + if p == ok { + return true + } + } + return false + } +} + +// SelectEqual creates accept function that could be used as Protocol/Extension +// select during upgrade. +func SelectEqual(v string) func(string) bool { + return func(p string) bool { + return v == p + } +} + +func strToBytes(str string) (bts []byte) { + s := (*reflect.StringHeader)(unsafe.Pointer(&str)) + b := (*reflect.SliceHeader)(unsafe.Pointer(&bts)) + b.Data = s.Data + b.Len = s.Len + b.Cap = s.Len + return +} + +func btsToString(bts []byte) (str string) { + return *(*string)(unsafe.Pointer(&bts)) +} + +// asciiToInt converts bytes to int. +func asciiToInt(bts []byte) (ret int, err error) { + // ASCII numbers all start with the high-order bits 0011. + // If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those + // bits and interpret them directly as an integer. + var n int + if n = len(bts); n < 1 { + return 0, fmt.Errorf("converting empty bytes to int") + } + for i := 0; i < n; i++ { + if bts[i]&0xf0 != 0x30 { + return 0, fmt.Errorf("%s is not a numeric character", string(bts[i])) + } + ret += int(bts[i]&0xf) * pow(10, n-i-1) + } + return ret, nil +} + +// pow for integers implementation. +// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3 +func pow(a, b int) int { + p := 1 + for b > 0 { + if b&1 != 0 { + p *= a + } + b >>= 1 + a *= a + } + return p +} + +func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) { + a := bytes.IndexByte(bts, sep) + b := bytes.IndexByte(bts[a+1:], sep) + if a == -1 || b == -1 { + return bts, nil, nil + } + b += a + 1 + return bts[:a], bts[a+1 : b], bts[b+1:] +} + +func btrim(bts []byte) []byte { + var i, j int + for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); { + i++ + } + for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); { + j-- + } + return bts[i:j] +} + +func strHasToken(header, token string) (has bool) { + return btsHasToken(strToBytes(header), strToBytes(token)) +} + +func btsHasToken(header, token []byte) (has bool) { + httphead.ScanTokens(header, func(v []byte) bool { + has = bytes.EqualFold(v, token) + return !has + }) + return +} + +const ( + toLower = 'a' - 'A' // for use with OR. + toUpper = ^byte(toLower) // for use with AND. + toLower8 = uint64(toLower) | + uint64(toLower)<<8 | + uint64(toLower)<<16 | + uint64(toLower)<<24 | + uint64(toLower)<<32 | + uint64(toLower)<<40 | + uint64(toLower)<<48 | + uint64(toLower)<<56 +) + +// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except +// that it operates with slice of bytes and modifies it inplace without copying. +func canonicalizeHeaderKey(k []byte) { + upper := true + for i, c := range k { + if upper && 'a' <= c && c <= 'z' { + k[i] &= toUpper + } else if !upper && 'A' <= c && c <= 'Z' { + k[i] |= toLower + } + upper = c == '-' + } +} + +// readLine reads line from br. It reads until '\n' and returns bytes without +// '\n' or '\r\n' at the end. +// It returns err if and only if line does not end in '\n'. Note that read +// bytes returned in any case of error. +// +// It is much like the textproto/Reader.ReadLine() except the thing that it +// returns raw bytes, instead of string. That is, it avoids copying bytes read +// from br. +// +// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be +// safe with future I/O operations on br. +// +// We could control I/O operations on br and do not need to make additional +// copy for safety. +// +// NOTE: it may return copied flag to notify that returned buffer is safe to +// use. +func readLine(br *bufio.Reader) ([]byte, error) { + var line []byte + for { + bts, err := br.ReadSlice('\n') + if err == bufio.ErrBufferFull { + // Copy bytes because next read will discard them. + line = append(line, bts...) + continue + } + + // Avoid copy of single read. + if line == nil { + line = bts + } else { + line = append(line, bts...) + } + + if err != nil { + return line, err + } + + // Size of line is at least 1. + // In other case bufio.ReadSlice() returns error. + n := len(line) + + // Cut '\n' or '\r\n'. + if n > 1 && line[n-2] == '\r' { + line = line[:n-2] + } else { + line = line[:n-1] + } + + return line, nil + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func nonZero(a, b int) int { + if a != 0 { + return a + } + return b +} diff --git a/vendor/github.com/gobwas/ws/write.go b/vendor/github.com/gobwas/ws/write.go new file mode 100644 index 00000000..94557c69 --- /dev/null +++ b/vendor/github.com/gobwas/ws/write.go @@ -0,0 +1,104 @@ +package ws + +import ( + "encoding/binary" + "io" +) + +// Header size length bounds in bytes. +const ( + MaxHeaderSize = 14 + MinHeaderSize = 2 +) + +const ( + bit0 = 0x80 + bit1 = 0x40 + bit2 = 0x20 + bit3 = 0x10 + bit4 = 0x08 + bit5 = 0x04 + bit6 = 0x02 + bit7 = 0x01 + + len7 = int64(125) + len16 = int64(^(uint16(0))) + len64 = int64(^(uint64(0)) >> 1) +) + +// HeaderSize returns number of bytes that are needed to encode given header. +// It returns -1 if header is malformed. +func HeaderSize(h Header) (n int) { + switch { + case h.Length < 126: + n = 2 + case h.Length <= len16: + n = 4 + case h.Length <= len64: + n = 10 + default: + return -1 + } + if h.Masked { + n += len(h.Mask) + } + return n +} + +// WriteHeader writes header binary representation into w. +func WriteHeader(w io.Writer, h Header) error { + // Make slice of bytes with capacity 14 that could hold any header. + bts := make([]byte, MaxHeaderSize) + + if h.Fin { + bts[0] |= bit0 + } + bts[0] |= h.Rsv << 4 + bts[0] |= byte(h.OpCode) + + var n int + switch { + case h.Length <= len7: + bts[1] = byte(h.Length) + n = 2 + + case h.Length <= len16: + bts[1] = 126 + binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length)) + n = 4 + + case h.Length <= len64: + bts[1] = 127 + binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length)) + n = 10 + + default: + return ErrHeaderLengthUnexpected + } + + if h.Masked { + bts[1] |= bit0 + n += copy(bts[n:], h.Mask[:]) + } + + _, err := w.Write(bts[:n]) + + return err +} + +// WriteFrame writes frame binary representation into w. +func WriteFrame(w io.Writer, f Frame) error { + err := WriteHeader(w, f.Header) + if err != nil { + return err + } + _, err = w.Write(f.Payload) + return err +} + +// MustWriteFrame is like WriteFrame but panics if frame can not be read. +func MustWriteFrame(w io.Writer, f Frame) { + if err := WriteFrame(w, f); err != nil { + panic(err) + } +} diff --git a/vendor/github.com/gobwas/ws/wsutil/cipher.go b/vendor/github.com/gobwas/ws/wsutil/cipher.go new file mode 100644 index 00000000..f234be73 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/cipher.go @@ -0,0 +1,72 @@ +package wsutil + +import ( + "io" + + "github.com/gobwas/pool/pbytes" + "github.com/gobwas/ws" +) + +// CipherReader implements io.Reader that applies xor-cipher to the bytes read +// from source. +// It could help to unmask WebSocket frame payload on the fly. +type CipherReader struct { + r io.Reader + mask [4]byte + pos int +} + +// NewCipherReader creates xor-cipher reader from r with given mask. +func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader { + return &CipherReader{r, mask, 0} +} + +// Reset resets CipherReader to read from r with given mask. +func (c *CipherReader) Reset(r io.Reader, mask [4]byte) { + c.r = r + c.mask = mask + c.pos = 0 +} + +// Read implements io.Reader interface. It applies mask given during +// initialization to every read byte. +func (c *CipherReader) Read(p []byte) (n int, err error) { + n, err = c.r.Read(p) + ws.Cipher(p[:n], c.mask, c.pos) + c.pos += n + return +} + +// CipherWriter implements io.Writer that applies xor-cipher to the bytes +// written to the destination writer. It does not modify the original bytes. +type CipherWriter struct { + w io.Writer + mask [4]byte + pos int +} + +// NewCipherWriter creates xor-cipher writer to w with given mask. +func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter { + return &CipherWriter{w, mask, 0} +} + +// Reset reset CipherWriter to write to w with given mask. +func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) { + c.w = w + c.mask = mask + c.pos = 0 +} + +// Write implements io.Writer interface. It applies masking during +// initialization to every sent byte. It does not modify original slice. +func (c *CipherWriter) Write(p []byte) (n int, err error) { + cp := pbytes.GetLen(len(p)) + defer pbytes.Put(cp) + + copy(cp, p) + ws.Cipher(cp, c.mask, c.pos) + n, err = c.w.Write(cp) + c.pos += n + + return +} diff --git a/vendor/github.com/gobwas/ws/wsutil/dialer.go b/vendor/github.com/gobwas/ws/wsutil/dialer.go new file mode 100644 index 00000000..91c03d51 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/dialer.go @@ -0,0 +1,146 @@ +package wsutil + +import ( + "bufio" + "bytes" + "context" + "io" + "io/ioutil" + "net" + "net/http" + + "github.com/gobwas/ws" +) + +// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket +// handshake. That is, it gives ability to receive copied HTTP request and +// response bytes that made inside Dialer.Dial(). +// +// Note that it must not be used in production applications that requires +// Dial() to be efficient. +type DebugDialer struct { + // Dialer contains WebSocket connection establishment options. + Dialer ws.Dialer + + // OnRequest and OnResponse are the callbacks that will be called with the + // HTTP request and response respectively. + OnRequest, OnResponse func([]byte) +} + +// Dial connects to the url host and upgrades connection to WebSocket. It makes +// it by calling d.Dialer.Dial(). +func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) { + // Need to copy Dialer to prevent original object mutation. + dialer := d.Dialer + var ( + reqBuf bytes.Buffer + resBuf bytes.Buffer + + resContentLength int64 + ) + userWrap := dialer.WrapConn + dialer.WrapConn = func(c net.Conn) net.Conn { + if userWrap != nil { + c = userWrap(c) + } + + // Save the pointer to the raw connection. + conn = c + + var ( + r io.Reader = conn + w io.Writer = conn + ) + if d.OnResponse != nil { + r = &prefetchResponseReader{ + source: conn, + buffer: &resBuf, + contentLength: &resContentLength, + } + } + if d.OnRequest != nil { + w = io.MultiWriter(conn, &reqBuf) + } + return rwConn{conn, r, w} + } + + _, br, hs, err = dialer.Dial(ctx, urlstr) + + if onRequest := d.OnRequest; onRequest != nil { + onRequest(reqBuf.Bytes()) + } + if onResponse := d.OnResponse; onResponse != nil { + // We must split response inside buffered bytes from other received + // bytes from server. + p := resBuf.Bytes() + n := bytes.Index(p, headEnd) + h := n + len(headEnd) // Head end index. + n = h + int(resContentLength) // Body end index. + + onResponse(p[:n]) + + if br != nil { + // If br is non-nil, then it mean two things. First is that + // handshake is OK and server has sent additional bytes – probably + // immediate sent frames (or weird but possible response body). + // Second, the bad one, is that br buffer's source is now rwConn + // instance from above WrapConn call. It is incorrect, so we must + // fix it. + var r io.Reader = conn + if len(p) > h { + // Buffer contains more than just HTTP headers bytes. + r = io.MultiReader( + bytes.NewReader(p[h:]), + conn, + ) + } + br.Reset(r) + // Must make br.Buffered() to be non-zero. + br.Peek(len(p[h:])) + } + } + + return conn, br, hs, err +} + +type rwConn struct { + net.Conn + + r io.Reader + w io.Writer +} + +func (rwc rwConn) Read(p []byte) (int, error) { + return rwc.r.Read(p) +} +func (rwc rwConn) Write(p []byte) (int, error) { + return rwc.w.Write(p) +} + +var headEnd = []byte("\r\n\r\n") + +type prefetchResponseReader struct { + source io.Reader // Original connection source. + reader io.Reader // Wrapped reader used to read from by clients. + buffer *bytes.Buffer + + contentLength *int64 +} + +func (r *prefetchResponseReader) Read(p []byte) (int, error) { + if r.reader == nil { + resp, err := http.ReadResponse(bufio.NewReader( + io.TeeReader(r.source, r.buffer), + ), nil) + if err == nil { + *r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() + } + bts := r.buffer.Bytes() + r.reader = io.MultiReader( + bytes.NewReader(bts), + r.source, + ) + } + return r.reader.Read(p) +} diff --git a/vendor/github.com/gobwas/ws/wsutil/handler.go b/vendor/github.com/gobwas/ws/wsutil/handler.go new file mode 100644 index 00000000..abb7cb73 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/handler.go @@ -0,0 +1,219 @@ +package wsutil + +import ( + "errors" + "io" + "io/ioutil" + "strconv" + + "github.com/gobwas/pool/pbytes" + "github.com/gobwas/ws" +) + +// ClosedError returned when peer has closed the connection with appropriate +// code and a textual reason. +type ClosedError struct { + Code ws.StatusCode + Reason string +} + +// Error implements error interface. +func (err ClosedError) Error() string { + return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason +} + +// ControlHandler contains logic of handling control frames. +// +// The intentional way to use it is to read the next frame header from the +// connection, optionally check its validity via ws.CheckHeader() and if it is +// not a ws.OpText of ws.OpBinary (or ws.OpContinuation) – pass it to Handle() +// method. +// +// That is, passed header should be checked to get rid of unexpected errors. +// +// The Handle() method will read out all control frame payload (if any) and +// write necessary bytes as a rfc compatible response. +type ControlHandler struct { + Src io.Reader + Dst io.Writer + State ws.State + + // DisableSrcCiphering disables unmasking payload data read from Src. + // It is useful when wsutil.Reader is used or when frame payload already + // pulled and ciphered out from the connection (and introduced by + // bytes.Reader, for example). + DisableSrcCiphering bool +} + +// ErrNotControlFrame is returned by ControlHandler to indicate that given +// header could not be handled. +var ErrNotControlFrame = errors.New("not a control frame") + +// Handle handles control frames regarding to the c.State and writes responses +// to the c.Dst when needed. +// +// It returns ErrNotControlFrame when given header is not of ws.OpClose, +// ws.OpPing or ws.OpPong operation code. +func (c ControlHandler) Handle(h ws.Header) error { + switch h.OpCode { + case ws.OpPing: + return c.HandlePing(h) + case ws.OpPong: + return c.HandlePong(h) + case ws.OpClose: + return c.HandleClose(h) + } + return ErrNotControlFrame +} + +// HandlePing handles ping frame and writes specification compatible response +// to the c.Dst. +func (c ControlHandler) HandlePing(h ws.Header) error { + if h.Length == 0 { + // The most common case when ping is empty. + // Note that when sending masked frame the mask for empty payload is + // just four zero bytes. + return ws.WriteHeader(c.Dst, ws.Header{ + Fin: true, + OpCode: ws.OpPong, + Masked: c.State.ClientSide(), + }) + } + + // In other way reply with Pong frame with copied payload. + p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{ + Length: h.Length, + Masked: c.State.ClientSide(), + })) + defer pbytes.Put(p) + + // Deal with ciphering i/o: + // Masking key is used to mask the "Payload data" defined in the same + // section as frame-payload-data, which includes "Extension data" and + // "Application data". + // + // See https://tools.ietf.org/html/rfc6455#section-5.3 + // + // NOTE: We prefer ControlWriter with preallocated buffer to + // ws.WriteHeader because it performs one syscall instead of two. + w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p) + r := c.Src + if c.State.ServerSide() && !c.DisableSrcCiphering { + r = NewCipherReader(r, h.Mask) + } + + _, err := io.Copy(w, r) + if err == nil { + err = w.Flush() + } + + return err +} + +// HandlePong handles pong frame by discarding it. +func (c ControlHandler) HandlePong(h ws.Header) error { + if h.Length == 0 { + return nil + } + + buf := pbytes.GetLen(int(h.Length)) + defer pbytes.Put(buf) + + // Discard pong message according to the RFC6455: + // A Pong frame MAY be sent unsolicited. This serves as a + // unidirectional heartbeat. A response to an unsolicited Pong frame + // is not expected. + _, err := io.CopyBuffer(ioutil.Discard, c.Src, buf) + + return err +} + +// HandleClose handles close frame, makes protocol validity checks and writes +// specification compatible response to the c.Dst. +func (c ControlHandler) HandleClose(h ws.Header) error { + if h.Length == 0 { + err := ws.WriteHeader(c.Dst, ws.Header{ + Fin: true, + OpCode: ws.OpClose, + Masked: c.State.ClientSide(), + }) + if err != nil { + return err + } + + // Due to RFC, we should interpret the code as no status code + // received: + // If this Close control frame contains no status code, _The WebSocket + // Connection Close Code_ is considered to be 1005. + // + // See https://tools.ietf.org/html/rfc6455#section-7.1.5 + return ClosedError{ + Code: ws.StatusNoStatusRcvd, + } + } + + // Prepare bytes both for reading reason and sending response. + p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{ + Length: h.Length, + Masked: c.State.ClientSide(), + })) + defer pbytes.Put(p) + + // Get the subslice to read the frame payload out. + subp := p[:h.Length] + + r := c.Src + if c.State.ServerSide() && !c.DisableSrcCiphering { + r = NewCipherReader(r, h.Mask) + } + if _, err := io.ReadFull(r, subp); err != nil { + return err + } + + code, reason := ws.ParseCloseFrameData(subp) + if err := ws.CheckCloseFrameData(code, reason); err != nil { + // Here we could not use the prepared bytes because there is no + // guarantee that it may fit our protocol error closure code and a + // reason. + c.closeWithProtocolError(err) + return err + } + + // Deal with ciphering i/o: + // Masking key is used to mask the "Payload data" defined in the same + // section as frame-payload-data, which includes "Extension data" and + // "Application data". + // + // See https://tools.ietf.org/html/rfc6455#section-5.3 + // + // NOTE: We prefer ControlWriter with preallocated buffer to + // ws.WriteHeader because it performs one syscall instead of two. + w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p) + + // RFC6455#5.5.1: + // If an endpoint receives a Close frame and did not previously + // send a Close frame, the endpoint MUST send a Close frame in + // response. (When sending a Close frame in response, the endpoint + // typically echoes the status code it received.) + _, err := w.Write(p[:2]) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + return ClosedError{ + Code: code, + Reason: reason, + } +} + +func (c ControlHandler) closeWithProtocolError(reason error) error { + f := ws.NewCloseFrame(ws.NewCloseFrameBody( + ws.StatusProtocolError, reason.Error(), + )) + if c.State.ClientSide() { + ws.MaskFrameInPlace(f) + } + return ws.WriteFrame(c.Dst, f) +} diff --git a/vendor/github.com/gobwas/ws/wsutil/helper.go b/vendor/github.com/gobwas/ws/wsutil/helper.go new file mode 100644 index 00000000..001e9d9e --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/helper.go @@ -0,0 +1,279 @@ +package wsutil + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/gobwas/ws" +) + +// Message represents a message from peer, that could be presented in one or +// more frames. That is, it contains payload of all message fragments and +// operation code of initial frame for this message. +type Message struct { + OpCode ws.OpCode + Payload []byte +} + +// ReadMessage is a helper function that reads next message from r. It appends +// received message(s) to the third argument and returns the result of it and +// an error if some failure happened. That is, it probably could receive more +// than one message when peer sending fragmented message in multiple frames and +// want to send some control frame between fragments. Then returned slice will +// contain those control frames at first, and then result of gluing fragments. +// +// TODO(gobwas): add DefaultReader with buffer size options. +func ReadMessage(r io.Reader, s ws.State, m []Message) ([]Message, error) { + rd := Reader{ + Source: r, + State: s, + CheckUTF8: true, + OnIntermediate: func(hdr ws.Header, src io.Reader) error { + bts, err := ioutil.ReadAll(src) + if err != nil { + return err + } + m = append(m, Message{hdr.OpCode, bts}) + return nil + }, + } + h, err := rd.NextFrame() + if err != nil { + return m, err + } + var p []byte + if h.Fin { + // No more frames will be read. Use fixed sized buffer to read payload. + p = make([]byte, h.Length) + // It is not possible to receive io.EOF here because Reader does not + // return EOF if frame payload was successfully fetched. + // Thus we consistent here with io.Reader behavior. + _, err = io.ReadFull(&rd, p) + } else { + // Frame is fragmented, thus use ioutil.ReadAll behavior. + var buf bytes.Buffer + _, err = buf.ReadFrom(&rd) + p = buf.Bytes() + } + if err != nil { + return m, err + } + return append(m, Message{h.OpCode, p}), nil +} + +// ReadClientMessage reads next message from r, considering that caller +// represents server side. +// It is a shortcut for ReadMessage(r, ws.StateServerSide, m) +func ReadClientMessage(r io.Reader, m []Message) ([]Message, error) { + return ReadMessage(r, ws.StateServerSide, m) +} + +// ReadServerMessage reads next message from r, considering that caller +// represents client side. +// It is a shortcut for ReadMessage(r, ws.StateClientSide, m) +func ReadServerMessage(r io.Reader, m []Message) ([]Message, error) { + return ReadMessage(r, ws.StateClientSide, m) +} + +// ReadData is a helper function that reads next data (non-control) message +// from rw. +// It takes care on handling all control frames. It will write response on +// control frames to the write part of rw. It blocks until some data frame +// will be received. +// +// Note this may handle and write control frames into the writer part of a +// given io.ReadWriter. +func ReadData(rw io.ReadWriter, s ws.State) ([]byte, ws.OpCode, error) { + return readData(rw, s, ws.OpText|ws.OpBinary) +} + +// ReadClientData reads next data message from rw, considering that caller +// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide). +// +// Note this may handle and write control frames into the writer part of a +// given io.ReadWriter. +func ReadClientData(rw io.ReadWriter) ([]byte, ws.OpCode, error) { + return ReadData(rw, ws.StateServerSide) +} + +// ReadClientText reads next text message from rw, considering that caller +// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide). +// It discards received binary messages. +// +// Note this may handle and write control frames into the writer part of a +// given io.ReadWriter. +func ReadClientText(rw io.ReadWriter) ([]byte, error) { + p, _, err := readData(rw, ws.StateServerSide, ws.OpText) + return p, err +} + +// ReadClientBinary reads next binary message from rw, considering that caller +// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide). +// It discards received text messages. +// +// Note this may handle and write control frames into the writer part of a given +// io.ReadWriter. +func ReadClientBinary(rw io.ReadWriter) ([]byte, error) { + p, _, err := readData(rw, ws.StateServerSide, ws.OpBinary) + return p, err +} + +// ReadServerData reads next data message from rw, considering that caller +// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide). +// +// Note this may handle and write control frames into the writer part of a +// given io.ReadWriter. +func ReadServerData(rw io.ReadWriter) ([]byte, ws.OpCode, error) { + return ReadData(rw, ws.StateClientSide) +} + +// ReadServerText reads next text message from rw, considering that caller +// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide). +// It discards received binary messages. +// +// Note this may handle and write control frames into the writer part of a given +// io.ReadWriter. +func ReadServerText(rw io.ReadWriter) ([]byte, error) { + p, _, err := readData(rw, ws.StateClientSide, ws.OpText) + return p, err +} + +// ReadServerBinary reads next binary message from rw, considering that caller +// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide). +// It discards received text messages. +// +// Note this may handle and write control frames into the writer part of a +// given io.ReadWriter. +func ReadServerBinary(rw io.ReadWriter) ([]byte, error) { + p, _, err := readData(rw, ws.StateClientSide, ws.OpBinary) + return p, err +} + +// WriteMessage is a helper function that writes message to the w. It +// constructs single frame with given operation code and payload. +// It uses given state to prepare side-dependent things, like cipher +// payload bytes from client to server. It will not mutate p bytes if +// cipher must be made. +// +// If you want to write message in fragmented frames, use Writer instead. +func WriteMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error { + return writeFrame(w, s, op, true, p) +} + +// WriteServerMessage writes message to w, considering that caller +// represents server side. +func WriteServerMessage(w io.Writer, op ws.OpCode, p []byte) error { + return WriteMessage(w, ws.StateServerSide, op, p) +} + +// WriteServerText is the same as WriteServerMessage with +// ws.OpText. +func WriteServerText(w io.Writer, p []byte) error { + return WriteServerMessage(w, ws.OpText, p) +} + +// WriteServerBinary is the same as WriteServerMessage with +// ws.OpBinary. +func WriteServerBinary(w io.Writer, p []byte) error { + return WriteServerMessage(w, ws.OpBinary, p) +} + +// WriteClientMessage writes message to w, considering that caller +// represents client side. +func WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error { + return WriteMessage(w, ws.StateClientSide, op, p) +} + +// WriteClientText is the same as WriteClientMessage with +// ws.OpText. +func WriteClientText(w io.Writer, p []byte) error { + return WriteClientMessage(w, ws.OpText, p) +} + +// WriteClientBinary is the same as WriteClientMessage with +// ws.OpBinary. +func WriteClientBinary(w io.Writer, p []byte) error { + return WriteClientMessage(w, ws.OpBinary, p) +} + +// HandleClientControlMessage handles control frame from conn and writes +// response when needed. +// +// It considers that caller represents server side. +func HandleClientControlMessage(conn io.Writer, msg Message) error { + return HandleControlMessage(conn, ws.StateServerSide, msg) +} + +// HandleServerControlMessage handles control frame from conn and writes +// response when needed. +// +// It considers that caller represents client side. +func HandleServerControlMessage(conn io.Writer, msg Message) error { + return HandleControlMessage(conn, ws.StateClientSide, msg) +} + +// HandleControlMessage handles message which was read by ReadMessage() +// functions. +// +// That is, it is expected, that payload is already unmasked and frame header +// were checked by ws.CheckHeader() call. +func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error { + return (ControlHandler{ + DisableSrcCiphering: true, + Src: bytes.NewReader(msg.Payload), + Dst: conn, + State: state, + }).Handle(ws.Header{ + Length: int64(len(msg.Payload)), + OpCode: msg.OpCode, + Fin: true, + Masked: state.ServerSide(), + }) +} + +// ControlFrameHandler returns FrameHandlerFunc for handling control frames. +// For more info see ControlHandler docs. +func ControlFrameHandler(w io.Writer, state ws.State) FrameHandlerFunc { + return func(h ws.Header, r io.Reader) error { + return (ControlHandler{ + DisableSrcCiphering: true, + Src: r, + Dst: w, + State: state, + }).Handle(h) + } +} + +func readData(rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, ws.OpCode, error) { + controlHandler := ControlFrameHandler(rw, s) + rd := Reader{ + Source: rw, + State: s, + CheckUTF8: true, + SkipHeaderCheck: false, + OnIntermediate: controlHandler, + } + for { + hdr, err := rd.NextFrame() + if err != nil { + return nil, 0, err + } + if hdr.OpCode.IsControl() { + if err := controlHandler(hdr, &rd); err != nil { + return nil, 0, err + } + continue + } + if hdr.OpCode&want == 0 { + if err := rd.Discard(); err != nil { + return nil, 0, err + } + continue + } + + bts, err := ioutil.ReadAll(&rd) + + return bts, hdr.OpCode, err + } +} diff --git a/vendor/github.com/gobwas/ws/wsutil/reader.go b/vendor/github.com/gobwas/ws/wsutil/reader.go new file mode 100644 index 00000000..5f64c632 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/reader.go @@ -0,0 +1,257 @@ +package wsutil + +import ( + "errors" + "io" + "io/ioutil" + + "github.com/gobwas/ws" +) + +// ErrNoFrameAdvance means that Reader's Read() method was called without +// preceding NextFrame() call. +var ErrNoFrameAdvance = errors.New("no frame advance") + +// FrameHandlerFunc handles parsed frame header and its body represented by +// io.Reader. +// +// Note that reader represents already unmasked body. +type FrameHandlerFunc func(ws.Header, io.Reader) error + +// Reader is a wrapper around source io.Reader which represents WebSocket +// connection. It contains options for reading messages from source. +// +// Reader implements io.Reader, which Read() method reads payload of incoming +// WebSocket frames. It also takes care on fragmented frames and possibly +// intermediate control frames between them. +// +// Note that Reader's methods are not goroutine safe. +type Reader struct { + Source io.Reader + State ws.State + + // SkipHeaderCheck disables checking header bits to be RFC6455 compliant. + SkipHeaderCheck bool + + // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming + // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned. + CheckUTF8 bool + + // TODO(gobwas): add max frame size limit here. + + OnContinuation FrameHandlerFunc + OnIntermediate FrameHandlerFunc + + opCode ws.OpCode // Used to store message op code on fragmentation. + frame io.Reader // Used to as frame reader. + raw io.LimitedReader // Used to discard frames without cipher. + utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true. +} + +// NewReader creates new frame reader that reads from r keeping given state to +// make some protocol validity checks when it needed. +func NewReader(r io.Reader, s ws.State) *Reader { + return &Reader{ + Source: r, + State: s, + } +} + +// NewClientSideReader is a helper function that calls NewReader with r and +// ws.StateClientSide. +func NewClientSideReader(r io.Reader) *Reader { + return NewReader(r, ws.StateClientSide) +} + +// NewServerSideReader is a helper function that calls NewReader with r and +// ws.StateServerSide. +func NewServerSideReader(r io.Reader) *Reader { + return NewReader(r, ws.StateServerSide) +} + +// Read implements io.Reader. It reads the next message payload into p. +// It takes care on fragmented messages. +// +// The error is io.EOF only if all of message bytes were read. +// If an io.EOF happens during reading some but not all the message bytes +// Read() returns io.ErrUnexpectedEOF. +// +// The error is ErrNoFrameAdvance if no NextFrame() call was made before +// reading next message bytes. +func (r *Reader) Read(p []byte) (n int, err error) { + if r.frame == nil { + if !r.fragmented() { + // Every new Read() must be preceded by NextFrame() call. + return 0, ErrNoFrameAdvance + } + // Read next continuation or intermediate control frame. + _, err := r.NextFrame() + if err != nil { + return 0, err + } + if r.frame == nil { + // We handled intermediate control and now got nothing to read. + return 0, nil + } + } + + n, err = r.frame.Read(p) + if err != nil && err != io.EOF { + return + } + if err == nil && r.raw.N != 0 { + return + } + + switch { + case r.raw.N != 0: + err = io.ErrUnexpectedEOF + + case r.fragmented(): + err = nil + r.resetFragment() + + case r.CheckUTF8 && !r.utf8.Valid(): + n = r.utf8.Accepted() + err = ErrInvalidUTF8 + + default: + r.reset() + err = io.EOF + } + + return +} + +// Discard discards current message unread bytes. +// It discards all frames of fragmented message. +func (r *Reader) Discard() (err error) { + for { + _, err = io.Copy(ioutil.Discard, &r.raw) + if err != nil { + break + } + if !r.fragmented() { + break + } + if _, err = r.NextFrame(); err != nil { + break + } + } + r.reset() + return err +} + +// NextFrame prepares r to read next message. It returns received frame header +// and non-nil error on failure. +// +// Note that next NextFrame() call must be done after receiving or discarding +// all current message bytes. +func (r *Reader) NextFrame() (hdr ws.Header, err error) { + hdr, err = ws.ReadHeader(r.Source) + if err == io.EOF && r.fragmented() { + // If we are in fragmented state EOF means that is was totally + // unexpected. + // + // NOTE: This is necessary to prevent callers such that + // ioutil.ReadAll to receive some amount of bytes without an error. + // ReadAll() ignores an io.EOF error, thus caller may think that + // whole message fetched, but actually only part of it. + err = io.ErrUnexpectedEOF + } + if err == nil && !r.SkipHeaderCheck { + err = ws.CheckHeader(hdr, r.State) + } + if err != nil { + return hdr, err + } + + // Save raw reader to use it on discarding frame without ciphering and + // other streaming checks. + r.raw = io.LimitedReader{r.Source, hdr.Length} + + frame := io.Reader(&r.raw) + if hdr.Masked { + frame = NewCipherReader(frame, hdr.Mask) + } + if r.fragmented() { + if hdr.OpCode.IsControl() { + if cb := r.OnIntermediate; cb != nil { + err = cb(hdr, frame) + } + if err == nil { + // Ensure that src is empty. + _, err = io.Copy(ioutil.Discard, &r.raw) + } + return + } + } else { + r.opCode = hdr.OpCode + } + if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) { + r.utf8.Source = frame + frame = &r.utf8 + } + + // Save reader with ciphering and other streaming checks. + r.frame = frame + + if hdr.OpCode == ws.OpContinuation { + if cb := r.OnContinuation; cb != nil { + err = cb(hdr, frame) + } + } + + if hdr.Fin { + r.State = r.State.Clear(ws.StateFragmented) + } else { + r.State = r.State.Set(ws.StateFragmented) + } + + return +} + +func (r *Reader) fragmented() bool { + return r.State.Fragmented() +} + +func (r *Reader) resetFragment() { + r.raw = io.LimitedReader{} + r.frame = nil + // Reset source of the UTF8Reader, but not the state. + r.utf8.Source = nil +} + +func (r *Reader) reset() { + r.raw = io.LimitedReader{} + r.frame = nil + r.utf8 = UTF8Reader{} + r.opCode = 0 +} + +// NextReader prepares next message read from r. It returns header that +// describes the message and io.Reader to read message's payload. It returns +// non-nil error when it is not possible to read message's initial frame. +// +// Note that next NextReader() on the same r should be done after reading all +// bytes from previously returned io.Reader. For more performant way to discard +// message use Reader and its Discard() method. +// +// Note that it will not handle any "intermediate" frames, that possibly could +// be received between text/binary continuation frames. That is, if peer sent +// text/binary frame with fin flag "false", then it could send ping frame, and +// eventually remaining part of text/binary frame with fin "true" – with +// NextReader() the ping frame will be dropped without any notice. To handle +// this rare, but possible situation (and if you do not know exactly which +// frames peer could send), you could use Reader with OnIntermediate field set. +func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) { + rd := &Reader{ + Source: r, + State: s, + } + header, err := rd.NextFrame() + if err != nil { + return header, nil, err + } + return header, rd, nil +} diff --git a/vendor/github.com/gobwas/ws/wsutil/upgrader.go b/vendor/github.com/gobwas/ws/wsutil/upgrader.go new file mode 100644 index 00000000..2ed351e0 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/upgrader.go @@ -0,0 +1,68 @@ +package wsutil + +import ( + "bufio" + "bytes" + "io" + "io/ioutil" + "net/http" + + "github.com/gobwas/ws" +) + +// DebugUpgrader is a wrapper around ws.Upgrader. It tracks I/O of a +// WebSocket handshake. +// +// Note that it must not be used in production applications that requires +// Upgrade() to be efficient. +type DebugUpgrader struct { + // Upgrader contains upgrade to WebSocket options. + Upgrader ws.Upgrader + + // OnRequest and OnResponse are the callbacks that will be called with the + // HTTP request and response respectively. + OnRequest, OnResponse func([]byte) +} + +// Upgrade calls Upgrade() on underlying ws.Upgrader and tracks I/O on conn. +func (d *DebugUpgrader) Upgrade(conn io.ReadWriter) (hs ws.Handshake, err error) { + var ( + // Take the Reader and Writer parts from conn to be probably replaced + // below. + r io.Reader = conn + w io.Writer = conn + ) + if onRequest := d.OnRequest; onRequest != nil { + var buf bytes.Buffer + // First, we must read the entire request. + req, err := http.ReadRequest(bufio.NewReader( + io.TeeReader(conn, &buf), + )) + if err == nil { + // Fulfill the buffer with the response body. + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + } + onRequest(buf.Bytes()) + + r = io.MultiReader( + &buf, conn, + ) + } + + if onResponse := d.OnResponse; onResponse != nil { + var buf bytes.Buffer + // Intercept the response stream written by the Upgrade(). + w = io.MultiWriter( + conn, &buf, + ) + defer func() { + onResponse(buf.Bytes()) + }() + } + + return d.Upgrader.Upgrade(struct { + io.Reader + io.Writer + }{r, w}) +} diff --git a/vendor/github.com/gobwas/ws/wsutil/utf8.go b/vendor/github.com/gobwas/ws/wsutil/utf8.go new file mode 100644 index 00000000..d877be0b --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/utf8.go @@ -0,0 +1,140 @@ +package wsutil + +import ( + "fmt" + "io" +) + +// ErrInvalidUTF8 is returned by UTF8 reader on invalid utf8 sequence. +var ErrInvalidUTF8 = fmt.Errorf("invalid utf8") + +// UTF8Reader implements io.Reader that calculates utf8 validity state after +// every read byte from Source. +// +// Note that in some cases client must call r.Valid() after all bytes are read +// to ensure that all of them are valid utf8 sequences. That is, some io helper +// functions such io.ReadAtLeast or io.ReadFull could discard the error +// information returned by the reader when they receive all of requested bytes. +// For example, the last read sequence is invalid and UTF8Reader returns number +// of bytes read and an error. But helper function decides to discard received +// error due to all requested bytes are completely read from the source. +// +// Another possible case is when some valid sequence become split by the read +// bound. Then UTF8Reader can not make decision about validity of the last +// sequence cause it is not fully read yet. And if the read stops, Valid() will +// return false, even if Read() by itself dit not. +type UTF8Reader struct { + Source io.Reader + + accepted int + + state uint32 + codep uint32 +} + +// NewUTF8Reader creates utf8 reader that reads from r. +func NewUTF8Reader(r io.Reader) *UTF8Reader { + return &UTF8Reader{ + Source: r, + } +} + +// Reset resets utf8 reader to read from r. +func (u *UTF8Reader) Reset(r io.Reader) { + u.Source = r + u.state = 0 + u.codep = 0 +} + +// Read implements io.Reader. +func (u *UTF8Reader) Read(p []byte) (n int, err error) { + n, err = u.Source.Read(p) + + accepted := 0 + s, c := u.state, u.codep + for i := 0; i < n; i++ { + c, s = decode(s, c, p[i]) + if s == utf8Reject { + u.state = s + return accepted, ErrInvalidUTF8 + } + if s == utf8Accept { + accepted = i + 1 + } + } + u.state, u.codep = s, c + u.accepted = accepted + + return +} + +// Valid checks current reader state. It returns true if all read bytes are +// valid UTF-8 sequences, and false if not. +func (u *UTF8Reader) Valid() bool { + return u.state == utf8Accept +} + +// Accepted returns number of valid bytes in last Read(). +func (u *UTF8Reader) Accepted() int { + return u.accepted +} + +// Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ +// +// Copyright (c) 2008-2009 Bjoern Hoehrmann +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. + +const ( + utf8Accept = 0 + utf8Reject = 12 +) + +var utf8d = [...]byte{ + // The first part of the table maps bytes to character classes that + // to reduce the size of the transition table and create bitmasks. + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + + // The second part is a transition table that maps a combination + // of a state of the automaton and a character class to a state. + 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, + 12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, +} + +func decode(state, codep uint32, b byte) (uint32, uint32) { + t := uint32(utf8d[b]) + + if state != utf8Accept { + codep = (uint32(b) & 0x3f) | (codep << 6) + } else { + codep = (0xff >> t) & uint32(b) + } + + return codep, uint32(utf8d[256+state+t]) +} diff --git a/vendor/github.com/gobwas/ws/wsutil/writer.go b/vendor/github.com/gobwas/ws/wsutil/writer.go new file mode 100644 index 00000000..c76b0b42 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/writer.go @@ -0,0 +1,450 @@ +package wsutil + +import ( + "fmt" + "io" + + "github.com/gobwas/pool" + "github.com/gobwas/pool/pbytes" + "github.com/gobwas/ws" +) + +// DefaultWriteBuffer contains size of Writer's default buffer. It used by +// Writer constructor functions. +var DefaultWriteBuffer = 4096 + +var ( + // ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is + // not empty and write through could not be done. That is, caller should call + // Writer.FlushFragment() to make buffer empty. + ErrNotEmpty = fmt.Errorf("writer not empty") + + // ErrControlOverflow is returned by ControlWriter.Write() to indicate that + // no more data could be written to the underlying io.Writer because + // MaxControlFramePayloadSize limit is reached. + ErrControlOverflow = fmt.Errorf("control frame payload overflow") +) + +// Constants which are represent frame length ranges. +const ( + len7 = int64(125) // 126 and 127 are reserved values + len16 = int64(^uint16(0)) + len64 = int64((^uint64(0)) >> 1) +) + +// ControlWriter is a wrapper around Writer that contains some guards for +// buffered writes of control frames. +type ControlWriter struct { + w *Writer + limit int + n int +} + +// NewControlWriter contains ControlWriter with Writer inside whose buffer size +// is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize. +func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter { + return &ControlWriter{ + w: NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize), + limit: ws.MaxControlFramePayloadSize, + } +} + +// NewControlWriterBuffer returns a new ControlWriter with buf as a buffer. +// +// Note that it reserves x bytes of buf for header data, where x could be +// ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most +// (ws.MaxControlFramePayloadSize + x) bytes of buf will be used. +// +// It panics if len(buf) <= ws.MinHeaderSize + x. +func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter { + max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize) + if len(buf) > max { + buf = buf[:max] + } + + w := NewWriterBuffer(dest, state, op, buf) + + return &ControlWriter{ + w: w, + limit: len(w.buf), + } +} + +// Write implements io.Writer. It writes to the underlying Writer until it +// returns error or until ControlWriter write limit will be exceeded. +func (c *ControlWriter) Write(p []byte) (n int, err error) { + if c.n+len(p) > c.limit { + return 0, ErrControlOverflow + } + return c.w.Write(p) +} + +// Flush flushes all buffered data to the underlying io.Writer. +func (c *ControlWriter) Flush() error { + return c.w.Flush() +} + +// Writer contains logic of buffering output data into a WebSocket fragments. +// It is much the same as bufio.Writer, except the thing that it works with +// WebSocket frames, not the raw data. +// +// Writer writes frames with specified OpCode. +// It uses ws.State to decide whether the output frames must be masked. +// +// Note that it does not check control frame size or other RFC rules. +// That is, it must be used with special care to write control frames without +// violation of RFC. You could use ControlWriter that wraps Writer and contains +// some guards for writing control frames. +// +// If an error occurs writing to a Writer, no more data will be accepted and +// all subsequent writes will return the error. +// After all data has been written, the client should call the Flush() method +// to guarantee all data has been forwarded to the underlying io.Writer. +type Writer struct { + dest io.Writer + + n int // Buffered bytes counter. + raw []byte // Raw representation of buffer, including reserved header bytes. + buf []byte // Writeable part of buffer, without reserved header bytes. + + op ws.OpCode + state ws.State + + dirty bool + fragmented bool + + err error +} + +var writers = pool.New(128, 65536) + +// GetWriter tries to reuse Writer getting it from the pool. +// +// This function is intended for memory consumption optimizations, because +// NewWriter*() functions make allocations for inner buffer. +// +// Note the it ceils n to the power of two. +// +// If you have your own bytes buffer pool you could use NewWriterBuffer to use +// pooled bytes in writer. +func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer { + x, m := writers.Get(n) + if x != nil { + w := x.(*Writer) + w.Reset(dest, state, op) + return w + } + // NOTE: we use m instead of n, because m is an attempt to reuse w of such + // size in the future. + return NewWriterBufferSize(dest, state, op, m) +} + +// PutWriter puts w for future reuse by GetWriter(). +func PutWriter(w *Writer) { + w.Reset(nil, 0, 0) + writers.Put(w, w.Size()) +} + +// NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size. +func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer { + return NewWriterBufferSize(dest, state, op, 0) +} + +// NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize. +// That is, output frames payload length could be up to n, except the case when +// Write() is called on empty Writer with len(p) > n. +// +// If n <= 0 then the default buffer size is used as Writer's buffer size. +func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer { + if n > 0 { + n += headerSize(state, n) + } + return NewWriterBufferSize(dest, state, op, n) +} + +// NewWriterBufferSize returns a new Writer whose buffer size is equal to n. +// If n <= ws.MinHeaderSize then the default buffer size is used. +// +// Note that Writer will reserve x bytes for header data, where x is in range +// [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer +// will not have payload length equal to n, except the case when Write() is +// called on empty Writer with len(p) > n. +func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer { + if n <= ws.MinHeaderSize { + n = DefaultWriteBuffer + } + return NewWriterBuffer(dest, state, op, make([]byte, n)) +} + +// NewWriterBuffer returns a new Writer with buf as a buffer. +// +// Note that it reserves x bytes of buf for header data, where x is in range +// [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size). +// +// You could use ws.HeaderSize() to calculate number of bytes needed to store +// header data. +// +// It panics if len(buf) is too small to fit header and payload data. +func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer { + offset := reserve(state, len(buf)) + if len(buf) <= offset { + panic("buffer too small") + } + + return &Writer{ + dest: dest, + raw: buf, + buf: buf[offset:], + state: state, + op: op, + } +} + +func reserve(state ws.State, n int) (offset int) { + var mask int + if state.ClientSide() { + mask = 4 + } + + switch { + case n <= int(len7)+mask+2: + return mask + 2 + case n <= int(len16)+mask+4: + return mask + 4 + default: + return mask + 10 + } +} + +// headerSize returns number of bytes needed to encode header of a frame with +// given state and length. +func headerSize(s ws.State, n int) int { + return ws.HeaderSize(ws.Header{ + Length: int64(n), + Masked: s.ClientSide(), + }) +} + +// Reset discards any buffered data, clears error, and resets w to have given +// state and write frames with given OpCode to dest. +func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) { + w.n = 0 + w.dirty = false + w.fragmented = false + w.dest = dest + w.state = state + w.op = op +} + +// Size returns the size of the underlying buffer in bytes. +func (w *Writer) Size() int { + return len(w.buf) +} + +// Available returns how many bytes are unused in the buffer. +func (w *Writer) Available() int { + return len(w.buf) - w.n +} + +// Buffered returns the number of bytes that have been written into the current +// buffer. +func (w *Writer) Buffered() int { + return w.n +} + +// Write implements io.Writer. +// +// Note that even if the Writer was created to have N-sized buffer, Write() +// with payload of N bytes will not fit into that buffer. Writer reserves some +// space to fit WebSocket header data. +func (w *Writer) Write(p []byte) (n int, err error) { + // Even empty p may make a sense. + w.dirty = true + + var nn int + for len(p) > w.Available() && w.err == nil { + if w.Buffered() == 0 { + // Large write, empty buffer. Write directly from p to avoid copy. + // Trade off here is that we make additional Write() to underlying + // io.Writer when writing frame header. + // + // On large buffers additional write is better than copying. + nn, _ = w.WriteThrough(p) + } else { + nn = copy(w.buf[w.n:], p) + w.n += nn + w.FlushFragment() + } + n += nn + p = p[nn:] + } + if w.err != nil { + return n, w.err + } + nn = copy(w.buf[w.n:], p) + w.n += nn + n += nn + + // Even if w.Available() == 0 we will not flush buffer preventively because + // this could bring unwanted fragmentation. That is, user could create + // buffer with size that fits exactly all further Write() call, and then + // call Flush(), excepting that single and not fragmented frame will be + // sent. With preemptive flush this case will produce two frames – last one + // will be empty and just to set fin = true. + + return n, w.err +} + +// WriteThrough writes data bypassing the buffer. +// Note that Writer's buffer must be empty before calling WriteThrough(). +func (w *Writer) WriteThrough(p []byte) (n int, err error) { + if w.err != nil { + return 0, w.err + } + if w.Buffered() != 0 { + return 0, ErrNotEmpty + } + + w.err = writeFrame(w.dest, w.state, w.opCode(), false, p) + if w.err == nil { + n = len(p) + } + + w.dirty = true + w.fragmented = true + + return n, w.err +} + +// ReadFrom implements io.ReaderFrom. +func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) { + var nn int + for err == nil { + if w.Available() == 0 { + err = w.FlushFragment() + continue + } + + // We copy the behavior of bufio.Writer here. + // Also, from the docs on io.ReaderFrom: + // ReadFrom reads data from r until EOF or error. + // + // See https://codereview.appspot.com/76400048/#ps1 + const maxEmptyReads = 100 + var nr int + for nr < maxEmptyReads { + nn, err = src.Read(w.buf[w.n:]) + if nn != 0 || err != nil { + break + } + nr++ + } + if nr == maxEmptyReads { + return n, io.ErrNoProgress + } + + w.n += nn + n += int64(nn) + } + if err == io.EOF { + // NOTE: Do not flush preemptively. + // See the Write() sources for more info. + err = nil + w.dirty = true + } + return n, err +} + +// Flush writes any buffered data to the underlying io.Writer. +// It sends the frame with "fin" flag set to true. +// +// If no Write() or ReadFrom() was made, then Flush() does nothing. +func (w *Writer) Flush() error { + if (!w.dirty && w.Buffered() == 0) || w.err != nil { + return w.err + } + + w.err = w.flushFragment(true) + w.n = 0 + w.dirty = false + w.fragmented = false + + return w.err +} + +// FlushFragment writes any buffered data to the underlying io.Writer. +// It sends the frame with "fin" flag set to false. +func (w *Writer) FlushFragment() error { + if w.Buffered() == 0 || w.err != nil { + return w.err + } + + w.err = w.flushFragment(false) + w.n = 0 + w.fragmented = true + + return w.err +} + +func (w *Writer) flushFragment(fin bool) error { + frame := ws.NewFrame(w.opCode(), fin, w.buf[:w.n]) + if w.state.ClientSide() { + frame = ws.MaskFrameInPlace(frame) + } + + // Write header to the header segment of the raw buffer. + head := len(w.raw) - len(w.buf) + offset := head - ws.HeaderSize(frame.Header) + buf := bytesWriter{ + buf: w.raw[offset:head], + } + if err := ws.WriteHeader(&buf, frame.Header); err != nil { + // Must never be reached. + panic("dump header error: " + err.Error()) + } + + _, err := w.dest.Write(w.raw[offset : head+w.n]) + + return err +} + +func (w *Writer) opCode() ws.OpCode { + if w.fragmented { + return ws.OpContinuation + } + return w.op +} + +var errNoSpace = fmt.Errorf("not enough buffer space") + +type bytesWriter struct { + buf []byte + pos int +} + +func (w *bytesWriter) Write(p []byte) (int, error) { + n := copy(w.buf[w.pos:], p) + w.pos += n + if n != len(p) { + return n, errNoSpace + } + return n, nil +} + +func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error { + var frame ws.Frame + if s.ClientSide() { + // Should copy bytes to prevent corruption of caller data. + payload := pbytes.GetLen(len(p)) + defer pbytes.Put(payload) + + copy(payload, p) + + frame = ws.NewFrame(op, fin, payload) + frame = ws.MaskFrameInPlace(frame) + } else { + frame = ws.NewFrame(op, fin, p) + } + + return ws.WriteFrame(w, frame) +} diff --git a/vendor/github.com/gobwas/ws/wsutil/wsutil.go b/vendor/github.com/gobwas/ws/wsutil/wsutil.go new file mode 100644 index 00000000..ffd43367 --- /dev/null +++ b/vendor/github.com/gobwas/ws/wsutil/wsutil.go @@ -0,0 +1,57 @@ +/* +Package wsutil provides utilities for working with WebSocket protocol. + +Overview: + + // Read masked text message from peer and check utf8 encoding. + header, err := ws.ReadHeader(conn) + if err != nil { + // handle err + } + + // Prepare to read payload. + r := io.LimitReader(conn, header.Length) + r = wsutil.NewCipherReader(r, header.Mask) + r = wsutil.NewUTF8Reader(r) + + payload, err := ioutil.ReadAll(r) + if err != nil { + // handle err + } + +You could get the same behavior using just `wsutil.Reader`: + + r := wsutil.Reader{ + Source: conn, + CheckUTF8: true, + } + + payload, err := ioutil.ReadAll(r) + if err != nil { + // handle err + } + +Or even simplest: + + payload, err := wsutil.ReadClientText(conn) + if err != nil { + // handle err + } + +Package is also exports tools for buffered writing: + + // Create buffered writer, that will buffer output bytes and send them as + // 128-length fragments (with exception on large writes, see the doc). + writer := wsutil.NewWriterSize(conn, ws.StateServerSide, ws.OpText, 128) + + _, err := io.CopyN(writer, rand.Reader, 100) + if err == nil { + err = writer.Flush() + } + if err != nil { + // handle error + } + +For more utils and helpers see the documentation. +*/ +package wsutil diff --git a/vendor/modules.txt b/vendor/modules.txt index dc2b5198..dd2f9a53 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -207,6 +207,19 @@ github.com/gliderlabs/ssh # github.com/go-sql-driver/mysql v1.5.0 ## explicit github.com/go-sql-driver/mysql +# github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 +## explicit +github.com/gobwas/httphead +# github.com/gobwas/pool v0.2.1 +## explicit +github.com/gobwas/pool +github.com/gobwas/pool/internal/pmath +github.com/gobwas/pool/pbufio +github.com/gobwas/pool/pbytes +# github.com/gobwas/ws v1.0.4 +## explicit +github.com/gobwas/ws +github.com/gobwas/ws/wsutil # github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 ## explicit github.com/golang-collections/collections/queue