TUN-3403: Unit test for origin/proxy to test serving HTTP and Websocket
This commit is contained in:
parent
a490443630
commit
6b86f81c4a
|
@ -8,16 +8,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
responseMetaHeaderField = "cf-cloudflared-response-meta"
|
responseMetaHeaderField = "cf-cloudflared-response-meta"
|
||||||
responseSourceCloudflared = "cloudflared"
|
|
||||||
responseSourceOrigin = "origin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField)
|
canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField)
|
||||||
canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(responseMetaHeaderField)
|
canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(responseMetaHeaderField)
|
||||||
responseMetaHeaderCfd = mustInitRespMetaHeader(responseSourceCloudflared)
|
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
||||||
responseMetaHeaderOrigin = mustInitRespMetaHeader(responseSourceOrigin)
|
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||||
)
|
)
|
||||||
|
|
||||||
type responseMetaHeader struct {
|
type responseMetaHeader struct {
|
||||||
|
|
3
go.mod
3
go.mod
|
@ -27,6 +27,9 @@ require (
|
||||||
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10
|
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10
|
||||||
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0
|
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0
|
||||||
github.com/go-sql-driver/mysql v1.5.0
|
github.com/go-sql-driver/mysql v1.5.0
|
||||||
|
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect
|
||||||
|
github.com/gobwas/pool v0.2.1 // indirect
|
||||||
|
github.com/gobwas/ws v1.0.4
|
||||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
||||||
github.com/google/go-cmp v0.5.2 // indirect
|
github.com/google/go-cmp v0.5.2 // indirect
|
||||||
github.com/google/uuid v1.1.2
|
github.com/google/uuid v1.1.2
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -233,6 +233,12 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
|
||||||
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
||||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
|
||||||
|
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||||
|
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||||
|
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||||
|
github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs=
|
||||||
|
github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||||
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||||
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
||||||
|
|
|
@ -18,6 +18,11 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UptimeRoute = "/uptime"
|
||||||
|
WSRoute = "/ws"
|
||||||
|
)
|
||||||
|
|
||||||
type templateData struct {
|
type templateData struct {
|
||||||
ServerName string
|
ServerName string
|
||||||
Request *http.Request
|
Request *http.Request
|
||||||
|
@ -104,8 +109,8 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
|
||||||
}
|
}
|
||||||
|
|
||||||
muxer := http.NewServeMux()
|
muxer := http.NewServeMux()
|
||||||
muxer.HandleFunc("/uptime", uptimeHandler(time.Now()))
|
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
|
||||||
muxer.HandleFunc("/ws", websocketHandler(logger, upgrader))
|
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
|
||||||
muxer.HandleFunc("/", rootHandler(serverName))
|
muxer.HandleFunc("/", rootHandler(serverName))
|
||||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
|
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
|
||||||
go func() {
|
go func() {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package origin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -112,12 +113,17 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
|
||||||
|
|
||||||
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
||||||
c.setHostHeader(req)
|
c.setHostHeader(req)
|
||||||
|
|
||||||
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
|
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
|
serveCtx, cancel := context.WithCancel(req.Context())
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-serveCtx.Done()
|
||||||
|
conn.Close()
|
||||||
|
}()
|
||||||
err = w.WriteRespHeaders(resp)
|
err = w.WriteRespHeaders(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error writing response header")
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
|
@ -125,7 +131,6 @@ func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request)
|
||||||
// Copy to/from stream to the undelying connection. Use the underlying
|
// Copy to/from stream to the undelying connection. Use the underlying
|
||||||
// connection because cloudflared doesn't operate on the message themselves
|
// connection because cloudflared doesn't operate on the message themselves
|
||||||
websocket.Stream(conn.UnderlyingConn(), w)
|
websocket.Stream(conn.UnderlyingConn(), w)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,169 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockHTTPRespWriter struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockHTTPRespWriter() *mockHTTPRespWriter {
|
||||||
|
return &mockHTTPRespWriter{
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
for header, val := range resp.Header {
|
||||||
|
w.Header()[header] = val
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockHTTPRespWriter) WriteErrorResponse(err error) {
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||||
|
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockWSRespWriter struct {
|
||||||
|
*mockHTTPRespWriter
|
||||||
|
writeNotification chan []byte
|
||||||
|
reader io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter {
|
||||||
|
return &mockWSRespWriter{
|
||||||
|
httpRespWriter,
|
||||||
|
make(chan []byte),
|
||||||
|
reader,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockWSRespWriter) Write(data []byte) (int, error) {
|
||||||
|
w.writeNotification <- data
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockWSRespWriter) respBody() io.ReadWriter {
|
||||||
|
data := <-w.writeNotification
|
||||||
|
return bytes.NewBuffer(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||||
|
return w.reader.Read(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxy(t *testing.T) {
|
||||||
|
logger, err := logger.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
// let runtime pick an available port
|
||||||
|
listener, err := hello.CreateTLSListener("127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
originURL := &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: listener.Addr().String(),
|
||||||
|
}
|
||||||
|
originCA := x509.NewCertPool()
|
||||||
|
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||||
|
require.NoError(t, err)
|
||||||
|
originCA.AddCert(helloCert)
|
||||||
|
clientTLS := &tls.Config{
|
||||||
|
RootCAs: originCA,
|
||||||
|
}
|
||||||
|
proxyConfig := &ProxyConfig{
|
||||||
|
Client: &http.Transport{
|
||||||
|
TLSClientConfig: clientTLS,
|
||||||
|
},
|
||||||
|
URL: originURL,
|
||||||
|
TLSConfig: clientTLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
hello.StartHelloWorldServer(logger, listener, ctx.Done())
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := NewClient(proxyConfig, logger)
|
||||||
|
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
|
||||||
|
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
respWriter := newMockHTTPRespWriter()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = client.Proxy(respWriter, req, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, respWriter.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
// WSRoute is a websocket echo handler
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
|
||||||
|
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err = client.Proxy(respWriter, req, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
||||||
|
}()
|
||||||
|
|
||||||
|
msg := []byte("test websocket")
|
||||||
|
err = wsutil.WriteClientText(writePipe, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// ReadServerText reads next data message from rw, considering that caller represents client side.
|
||||||
|
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msg, returnedMsg)
|
||||||
|
|
||||||
|
err = wsutil.WriteClientBinary(writePipe, msg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msg, returnedMsg)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2017 Sergey Kamardin
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,63 @@
|
||||||
|
# httphead.[go](https://golang.org)
|
||||||
|
|
||||||
|
[![GoDoc][godoc-image]][godoc-url]
|
||||||
|
|
||||||
|
> Tiny HTTP header value parsing library in go.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This library contains low-level functions for scanning HTTP RFC2616 compatible header value grammars.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
go get github.com/gobwas/httphead
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
The example below shows how multiple-choise HTTP header value could be parsed with this library:
|
||||||
|
|
||||||
|
```go
|
||||||
|
options, ok := httphead.ParseOptions([]byte(`foo;bar=1,baz`), nil)
|
||||||
|
fmt.Println(options, ok)
|
||||||
|
// Output: [{foo map[bar:1]} {baz map[]}] true
|
||||||
|
```
|
||||||
|
|
||||||
|
The low-level example below shows how to optimize keys skipping and selection
|
||||||
|
of some key:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// The right part of full header line like:
|
||||||
|
// X-My-Header: key;foo=bar;baz,key;baz
|
||||||
|
header := []byte(`foo;a=0,foo;a=1,foo;a=2,foo;a=3`)
|
||||||
|
|
||||||
|
// We want to search key "foo" with an "a" parameter that equal to "2".
|
||||||
|
var (
|
||||||
|
foo = []byte(`foo`)
|
||||||
|
a = []byte(`a`)
|
||||||
|
v = []byte(`2`)
|
||||||
|
)
|
||||||
|
var found bool
|
||||||
|
httphead.ScanOptions(header, func(i int, key, param, value []byte) Control {
|
||||||
|
if !bytes.Equal(key, foo) {
|
||||||
|
return ControlSkip
|
||||||
|
}
|
||||||
|
if !bytes.Equal(param, a) {
|
||||||
|
if bytes.Equal(value, v) {
|
||||||
|
// Found it!
|
||||||
|
found = true
|
||||||
|
return ControlBreak
|
||||||
|
}
|
||||||
|
return ControlSkip
|
||||||
|
}
|
||||||
|
return ControlContinue
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
For more usage examples please see [docs][godoc-url] or package tests.
|
||||||
|
|
||||||
|
[godoc-image]: https://godoc.org/github.com/gobwas/httphead?status.svg
|
||||||
|
[godoc-url]: https://godoc.org/github.com/gobwas/httphead
|
||||||
|
[travis-image]: https://travis-ci.org/gobwas/httphead.svg?branch=master
|
||||||
|
[travis-url]: https://travis-ci.org/gobwas/httphead
|
|
@ -0,0 +1,200 @@
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScanCookie scans cookie pairs from data using DefaultCookieScanner.Scan()
|
||||||
|
// method.
|
||||||
|
func ScanCookie(data []byte, it func(key, value []byte) bool) bool {
|
||||||
|
return DefaultCookieScanner.Scan(data, it)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCookieScanner is a CookieScanner which is used by ScanCookie().
|
||||||
|
// Note that it is intended to have the same behavior as http.Request.Cookies()
|
||||||
|
// has.
|
||||||
|
var DefaultCookieScanner = CookieScanner{}
|
||||||
|
|
||||||
|
// CookieScanner contains options for scanning cookie pairs.
|
||||||
|
// See https://tools.ietf.org/html/rfc6265#section-4.1.1
|
||||||
|
type CookieScanner struct {
|
||||||
|
// DisableNameValidation disables name validation of a cookie. If false,
|
||||||
|
// only RFC2616 "tokens" are accepted.
|
||||||
|
DisableNameValidation bool
|
||||||
|
|
||||||
|
// DisableValueValidation disables value validation of a cookie. If false,
|
||||||
|
// only RFC6265 "cookie-octet" characters are accepted.
|
||||||
|
//
|
||||||
|
// Note that Strict option also affects validation of a value.
|
||||||
|
//
|
||||||
|
// If Strict is false, then scanner begins to allow space and comma
|
||||||
|
// characters inside the value for better compatibility with non standard
|
||||||
|
// cookies implementations.
|
||||||
|
DisableValueValidation bool
|
||||||
|
|
||||||
|
// BreakOnPairError sets scanner to immediately return after first pair syntax
|
||||||
|
// validation error.
|
||||||
|
// If false, scanner will try to skip invalid pair bytes and go ahead.
|
||||||
|
BreakOnPairError bool
|
||||||
|
|
||||||
|
// Strict enables strict RFC6265 mode scanning. It affects name and value
|
||||||
|
// validation, as also some other rules.
|
||||||
|
// If false, it is intended to bring the same behavior as
|
||||||
|
// http.Request.Cookies().
|
||||||
|
Strict bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan maps data to name and value pairs. Usually data represents value of the
|
||||||
|
// Cookie header.
|
||||||
|
func (c CookieScanner) Scan(data []byte, it func(name, value []byte) bool) bool {
|
||||||
|
lexer := &Scanner{data: data}
|
||||||
|
|
||||||
|
const (
|
||||||
|
statePair = iota
|
||||||
|
stateBefore
|
||||||
|
)
|
||||||
|
|
||||||
|
state := statePair
|
||||||
|
|
||||||
|
for lexer.Buffered() > 0 {
|
||||||
|
switch state {
|
||||||
|
case stateBefore:
|
||||||
|
// Pairs separated by ";" and space, according to the RFC6265:
|
||||||
|
// cookie-pair *( ";" SP cookie-pair )
|
||||||
|
//
|
||||||
|
// Cookie pairs MUST be separated by (";" SP). So our only option
|
||||||
|
// here is to fail as syntax error.
|
||||||
|
a, b := lexer.Peek2()
|
||||||
|
if a != ';' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
state = statePair
|
||||||
|
|
||||||
|
advance := 1
|
||||||
|
if b == ' ' {
|
||||||
|
advance++
|
||||||
|
} else if c.Strict {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
lexer.Advance(advance)
|
||||||
|
|
||||||
|
case statePair:
|
||||||
|
if !lexer.FetchUntil(';') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var value []byte
|
||||||
|
name := lexer.Bytes()
|
||||||
|
if i := bytes.IndexByte(name, '='); i != -1 {
|
||||||
|
value = name[i+1:]
|
||||||
|
name = name[:i]
|
||||||
|
} else if c.Strict {
|
||||||
|
if !c.BreakOnPairError {
|
||||||
|
goto nextPair
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.Strict {
|
||||||
|
trimLeft(name)
|
||||||
|
}
|
||||||
|
if !c.DisableNameValidation && !ValidCookieName(name) {
|
||||||
|
if !c.BreakOnPairError {
|
||||||
|
goto nextPair
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.Strict {
|
||||||
|
value = trimRight(value)
|
||||||
|
}
|
||||||
|
value = stripQuotes(value)
|
||||||
|
if !c.DisableValueValidation && !ValidCookieValue(value, c.Strict) {
|
||||||
|
if !c.BreakOnPairError {
|
||||||
|
goto nextPair
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !it(name, value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
nextPair:
|
||||||
|
state = stateBefore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidCookieValue reports whether given value is a valid RFC6265
|
||||||
|
// "cookie-octet" bytes.
|
||||||
|
//
|
||||||
|
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
|
||||||
|
// ; US-ASCII characters excluding CTLs,
|
||||||
|
// ; whitespace DQUOTE, comma, semicolon,
|
||||||
|
// ; and backslash
|
||||||
|
//
|
||||||
|
// Note that the false strict parameter disables errors on space 0x20 and comma
|
||||||
|
// 0x2c. This could be useful to bring some compatibility with non-compliant
|
||||||
|
// clients/servers in the real world.
|
||||||
|
// It acts the same as standard library cookie parser if strict is false.
|
||||||
|
func ValidCookieValue(value []byte, strict bool) bool {
|
||||||
|
if len(value) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, c := range value {
|
||||||
|
switch c {
|
||||||
|
case '"', ';', '\\':
|
||||||
|
return false
|
||||||
|
case ',', ' ':
|
||||||
|
if strict {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if c <= 0x20 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c >= 0x7f {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidCookieName reports wheter given bytes is a valid RFC2616 "token" bytes.
|
||||||
|
func ValidCookieName(name []byte) bool {
|
||||||
|
for _, c := range name {
|
||||||
|
if !OctetTypes[c].IsToken() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripQuotes(bts []byte) []byte {
|
||||||
|
if last := len(bts) - 1; last > 0 && bts[0] == '"' && bts[last] == '"' {
|
||||||
|
return bts[1:last]
|
||||||
|
}
|
||||||
|
return bts
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimLeft(p []byte) []byte {
|
||||||
|
var i int
|
||||||
|
for i < len(p) && OctetTypes[p[i]].IsSpace() {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return p[i:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimRight(p []byte) []byte {
|
||||||
|
j := len(p)
|
||||||
|
for j > 0 && OctetTypes[p[j-1]].IsSpace() {
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
return p[:j]
|
||||||
|
}
|
|
@ -0,0 +1,275 @@
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Version contains protocol major and minor version.
|
||||||
|
type Version struct {
|
||||||
|
Major int
|
||||||
|
Minor int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestLine contains parameters parsed from the first request line.
|
||||||
|
type RequestLine struct {
|
||||||
|
Method []byte
|
||||||
|
URI []byte
|
||||||
|
Version Version
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseLine contains parameters parsed from the first response line.
|
||||||
|
type ResponseLine struct {
|
||||||
|
Version Version
|
||||||
|
Status int
|
||||||
|
Reason []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitRequestLine splits given slice of bytes into three chunks without
|
||||||
|
// parsing.
|
||||||
|
func SplitRequestLine(line []byte) (method, uri, version []byte) {
|
||||||
|
return split3(line, ' ')
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRequestLine parses http request line like "GET / HTTP/1.0".
|
||||||
|
func ParseRequestLine(line []byte) (r RequestLine, ok bool) {
|
||||||
|
var i int
|
||||||
|
for i = 0; i < len(line); i++ {
|
||||||
|
c := line[i]
|
||||||
|
if !OctetTypes[c].IsToken() {
|
||||||
|
if i > 0 && c == ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i == len(line) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var proto []byte
|
||||||
|
r.Method = line[:i]
|
||||||
|
r.URI, proto = split2(line[i+1:], ' ')
|
||||||
|
if len(r.URI) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if major, minor, ok := ParseVersion(proto); ok {
|
||||||
|
r.Version.Major = major
|
||||||
|
r.Version.Minor = minor
|
||||||
|
return r, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitResponseLine splits given slice of bytes into three chunks without
|
||||||
|
// parsing.
|
||||||
|
func SplitResponseLine(line []byte) (version, status, reason []byte) {
|
||||||
|
return split3(line, ' ')
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseResponseLine parses first response line into ResponseLine struct.
|
||||||
|
func ParseResponseLine(line []byte) (r ResponseLine, ok bool) {
|
||||||
|
var (
|
||||||
|
proto []byte
|
||||||
|
status []byte
|
||||||
|
)
|
||||||
|
proto, status, r.Reason = split3(line, ' ')
|
||||||
|
if major, minor, ok := ParseVersion(proto); ok {
|
||||||
|
r.Version.Major = major
|
||||||
|
r.Version.Minor = minor
|
||||||
|
} else {
|
||||||
|
return r, false
|
||||||
|
}
|
||||||
|
if n, ok := IntFromASCII(status); ok {
|
||||||
|
r.Status = n
|
||||||
|
} else {
|
||||||
|
return r, false
|
||||||
|
}
|
||||||
|
// TODO(gobwas): parse here r.Reason fot TEXT rule:
|
||||||
|
// TEXT = <any OCTET except CTLs,
|
||||||
|
// but including LWS>
|
||||||
|
return r, true
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
httpVersion10 = []byte("HTTP/1.0")
|
||||||
|
httpVersion11 = []byte("HTTP/1.1")
|
||||||
|
httpVersionPrefix = []byte("HTTP/")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseVersion parses major and minor version of HTTP protocol.
|
||||||
|
// It returns parsed values and true if parse is ok.
|
||||||
|
func ParseVersion(bts []byte) (major, minor int, ok bool) {
|
||||||
|
switch {
|
||||||
|
case bytes.Equal(bts, httpVersion11):
|
||||||
|
return 1, 1, true
|
||||||
|
case bytes.Equal(bts, httpVersion10):
|
||||||
|
return 1, 0, true
|
||||||
|
case len(bts) < 8:
|
||||||
|
return
|
||||||
|
case !bytes.Equal(bts[:5], httpVersionPrefix):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bts = bts[5:]
|
||||||
|
|
||||||
|
dot := bytes.IndexByte(bts, '.')
|
||||||
|
if dot == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
major, ok = IntFromASCII(bts[:dot])
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
minor, ok = IntFromASCII(bts[dot+1:])
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return major, minor, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLine reads line from br. It reads until '\n' and returns bytes without
|
||||||
|
// '\n' or '\r\n' at the end.
|
||||||
|
// It returns err if and only if line does not end in '\n'. Note that read
|
||||||
|
// bytes returned in any case of error.
|
||||||
|
//
|
||||||
|
// It is much like the textproto/Reader.ReadLine() except the thing that it
|
||||||
|
// returns raw bytes, instead of string. That is, it avoids copying bytes read
|
||||||
|
// from br.
|
||||||
|
//
|
||||||
|
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
|
||||||
|
// safe with future I/O operations on br.
|
||||||
|
//
|
||||||
|
// We could control I/O operations on br and do not need to make additional
|
||||||
|
// copy for safety.
|
||||||
|
func ReadLine(br *bufio.Reader) ([]byte, error) {
|
||||||
|
var line []byte
|
||||||
|
for {
|
||||||
|
bts, err := br.ReadSlice('\n')
|
||||||
|
if err == bufio.ErrBufferFull {
|
||||||
|
// Copy bytes because next read will discard them.
|
||||||
|
line = append(line, bts...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Avoid copy of single read.
|
||||||
|
if line == nil {
|
||||||
|
line = bts
|
||||||
|
} else {
|
||||||
|
line = append(line, bts...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return line, err
|
||||||
|
}
|
||||||
|
// Size of line is at least 1.
|
||||||
|
// In other case bufio.ReadSlice() returns error.
|
||||||
|
n := len(line)
|
||||||
|
// Cut '\n' or '\r\n'.
|
||||||
|
if n > 1 && line[n-2] == '\r' {
|
||||||
|
line = line[:n-2]
|
||||||
|
} else {
|
||||||
|
line = line[:n-1]
|
||||||
|
}
|
||||||
|
return line, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseHeaderLine parses HTTP header as key-value pair. It returns parsed
|
||||||
|
// values and true if parse is ok.
|
||||||
|
func ParseHeaderLine(line []byte) (k, v []byte, ok bool) {
|
||||||
|
colon := bytes.IndexByte(line, ':')
|
||||||
|
if colon == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
k = trim(line[:colon])
|
||||||
|
for _, c := range k {
|
||||||
|
if !OctetTypes[c].IsToken() {
|
||||||
|
return nil, nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v = trim(line[colon+1:])
|
||||||
|
return k, v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntFromASCII converts ascii encoded decimal numeric value from HTTP entities
|
||||||
|
// to an integer.
|
||||||
|
func IntFromASCII(bts []byte) (ret int, ok bool) {
|
||||||
|
// ASCII numbers all start with the high-order bits 0011.
|
||||||
|
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
|
||||||
|
// bits and interpret them directly as an integer.
|
||||||
|
var n int
|
||||||
|
if n = len(bts); n < 1 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if bts[i]&0xf0 != 0x30 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
ret += int(bts[i]&0xf) * pow(10, n-i-1)
|
||||||
|
}
|
||||||
|
return ret, true
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
toLower = 'a' - 'A' // for use with OR.
|
||||||
|
toUpper = ^byte(toLower) // for use with AND.
|
||||||
|
)
|
||||||
|
|
||||||
|
// CanonicalizeHeaderKey is like standard textproto/CanonicalMIMEHeaderKey,
|
||||||
|
// except that it operates with slice of bytes and modifies it inplace without
|
||||||
|
// copying.
|
||||||
|
func CanonicalizeHeaderKey(k []byte) {
|
||||||
|
upper := true
|
||||||
|
for i, c := range k {
|
||||||
|
if upper && 'a' <= c && c <= 'z' {
|
||||||
|
k[i] &= toUpper
|
||||||
|
} else if !upper && 'A' <= c && c <= 'Z' {
|
||||||
|
k[i] |= toLower
|
||||||
|
}
|
||||||
|
upper = c == '-'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pow for integers implementation.
|
||||||
|
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
|
||||||
|
func pow(a, b int) int {
|
||||||
|
p := 1
|
||||||
|
for b > 0 {
|
||||||
|
if b&1 != 0 {
|
||||||
|
p *= a
|
||||||
|
}
|
||||||
|
b >>= 1
|
||||||
|
a *= a
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func split3(p []byte, sep byte) (p1, p2, p3 []byte) {
|
||||||
|
a := bytes.IndexByte(p, sep)
|
||||||
|
b := bytes.IndexByte(p[a+1:], sep)
|
||||||
|
if a == -1 || b == -1 {
|
||||||
|
return p, nil, nil
|
||||||
|
}
|
||||||
|
b += a + 1
|
||||||
|
return p[:a], p[a+1 : b], p[b+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func split2(p []byte, sep byte) (p1, p2 []byte) {
|
||||||
|
i := bytes.IndexByte(p, sep)
|
||||||
|
if i == -1 {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
return p[:i], p[i+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func trim(p []byte) []byte {
|
||||||
|
var i, j int
|
||||||
|
for i = 0; i < len(p) && (p[i] == ' ' || p[i] == '\t'); {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
for j = len(p); j > i && (p[j-1] == ' ' || p[j-1] == '\t'); {
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
return p[i:j]
|
||||||
|
}
|
|
@ -0,0 +1,331 @@
|
||||||
|
// Package httphead contains utils for parsing HTTP and HTTP-grammar compatible
|
||||||
|
// text protocols headers.
|
||||||
|
//
|
||||||
|
// That is, this package first aim is to bring ability to easily parse
|
||||||
|
// constructions, described here https://tools.ietf.org/html/rfc2616#section-2
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScanTokens parses data in this form:
|
||||||
|
//
|
||||||
|
// list = 1#token
|
||||||
|
//
|
||||||
|
// It returns false if data is malformed.
|
||||||
|
func ScanTokens(data []byte, it func([]byte) bool) bool {
|
||||||
|
lexer := &Scanner{data: data}
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
for lexer.Next() {
|
||||||
|
switch lexer.Type() {
|
||||||
|
case ItemToken:
|
||||||
|
ok = true
|
||||||
|
if !it(lexer.Bytes()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case ItemSeparator:
|
||||||
|
if !isComma(lexer.Bytes()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok && !lexer.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseOptions parses all header options and appends it to given slice of
|
||||||
|
// Option. It returns flag of successful (wellformed input) parsing.
|
||||||
|
//
|
||||||
|
// Note that appended options are all consist of subslices of data. That is,
|
||||||
|
// mutation of data will mutate appended options.
|
||||||
|
func ParseOptions(data []byte, options []Option) ([]Option, bool) {
|
||||||
|
var i int
|
||||||
|
index := -1
|
||||||
|
return options, ScanOptions(data, func(idx int, name, attr, val []byte) Control {
|
||||||
|
if idx != index {
|
||||||
|
index = idx
|
||||||
|
i = len(options)
|
||||||
|
options = append(options, Option{Name: name})
|
||||||
|
}
|
||||||
|
if attr != nil {
|
||||||
|
options[i].Parameters.Set(attr, val)
|
||||||
|
}
|
||||||
|
return ControlContinue
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectFlag encodes way of options selection.
|
||||||
|
type SelectFlag byte
|
||||||
|
|
||||||
|
// String represetns flag as string.
|
||||||
|
func (f SelectFlag) String() string {
|
||||||
|
var flags [2]string
|
||||||
|
var n int
|
||||||
|
if f&SelectCopy != 0 {
|
||||||
|
flags[n] = "copy"
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if f&SelectUnique != 0 {
|
||||||
|
flags[n] = "unique"
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
return "[" + strings.Join(flags[:n], "|") + "]"
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SelectCopy causes selector to copy selected option before appending it
|
||||||
|
// to resulting slice.
|
||||||
|
// If SelectCopy flag is not passed to selector, then appended options will
|
||||||
|
// contain sub-slices of the initial data.
|
||||||
|
SelectCopy SelectFlag = 1 << iota
|
||||||
|
|
||||||
|
// SelectUnique causes selector to append only not yet existing option to
|
||||||
|
// resulting slice. Unique is checked by comparing option names.
|
||||||
|
SelectUnique
|
||||||
|
)
|
||||||
|
|
||||||
|
// OptionSelector contains configuration for selecting Options from header value.
|
||||||
|
type OptionSelector struct {
|
||||||
|
// Check is a filter function that applied to every Option that possibly
|
||||||
|
// could be selected.
|
||||||
|
// If Check is nil all options will be selected.
|
||||||
|
Check func(Option) bool
|
||||||
|
|
||||||
|
// Flags contains flags for options selection.
|
||||||
|
Flags SelectFlag
|
||||||
|
|
||||||
|
// Alloc used to allocate slice of bytes when selector is configured with
|
||||||
|
// SelectCopy flag. It will be called with number of bytes needed for copy
|
||||||
|
// of single Option.
|
||||||
|
// If Alloc is nil make is used.
|
||||||
|
Alloc func(n int) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select parses header data and appends it to given slice of Option.
|
||||||
|
// It also returns flag of successful (wellformed input) parsing.
|
||||||
|
func (s OptionSelector) Select(data []byte, options []Option) ([]Option, bool) {
|
||||||
|
var current Option
|
||||||
|
var has bool
|
||||||
|
index := -1
|
||||||
|
|
||||||
|
alloc := s.Alloc
|
||||||
|
if alloc == nil {
|
||||||
|
alloc = defaultAlloc
|
||||||
|
}
|
||||||
|
check := s.Check
|
||||||
|
if check == nil {
|
||||||
|
check = defaultCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := ScanOptions(data, func(idx int, name, attr, val []byte) Control {
|
||||||
|
if idx != index {
|
||||||
|
if has && check(current) {
|
||||||
|
if s.Flags&SelectCopy != 0 {
|
||||||
|
current = current.Copy(alloc(current.Size()))
|
||||||
|
}
|
||||||
|
options = append(options, current)
|
||||||
|
has = false
|
||||||
|
}
|
||||||
|
if s.Flags&SelectUnique != 0 {
|
||||||
|
for i := len(options) - 1; i >= 0; i-- {
|
||||||
|
if bytes.Equal(options[i].Name, name) {
|
||||||
|
return ControlSkip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index = idx
|
||||||
|
current = Option{Name: name}
|
||||||
|
has = true
|
||||||
|
}
|
||||||
|
if attr != nil {
|
||||||
|
current.Parameters.Set(attr, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ControlContinue
|
||||||
|
})
|
||||||
|
if has && check(current) {
|
||||||
|
if s.Flags&SelectCopy != 0 {
|
||||||
|
current = current.Copy(alloc(current.Size()))
|
||||||
|
}
|
||||||
|
options = append(options, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
return options, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultAlloc(n int) []byte { return make([]byte, n) }
|
||||||
|
func defaultCheck(Option) bool { return true }
|
||||||
|
|
||||||
|
// Control represents operation that scanner should perform.
|
||||||
|
type Control byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ControlContinue causes scanner to continue scan tokens.
|
||||||
|
ControlContinue Control = iota
|
||||||
|
// ControlBreak causes scanner to stop scan tokens.
|
||||||
|
ControlBreak
|
||||||
|
// ControlSkip causes scanner to skip current entity.
|
||||||
|
ControlSkip
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScanOptions parses data in this form:
|
||||||
|
//
|
||||||
|
// values = 1#value
|
||||||
|
// value = token *( ";" param )
|
||||||
|
// param = token [ "=" (token | quoted-string) ]
|
||||||
|
//
|
||||||
|
// It calls given callback with the index of the option, option itself and its
|
||||||
|
// parameter (attribute and its value, both could be nil). Index is useful when
|
||||||
|
// header contains multiple choises for the same named option.
|
||||||
|
//
|
||||||
|
// Given callback should return one of the defined Control* values.
|
||||||
|
// ControlSkip means that passed key is not in caller's interest. That is, all
|
||||||
|
// parameters of that key will be skipped.
|
||||||
|
// ControlBreak means that no more keys and parameters should be parsed. That
|
||||||
|
// is, it must break parsing immediately.
|
||||||
|
// ControlContinue means that caller want to receive next parameter and its
|
||||||
|
// value or the next key.
|
||||||
|
//
|
||||||
|
// It returns false if data is malformed.
|
||||||
|
func ScanOptions(data []byte, it func(index int, option, attribute, value []byte) Control) bool {
|
||||||
|
lexer := &Scanner{data: data}
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
var state int
|
||||||
|
const (
|
||||||
|
stateKey = iota
|
||||||
|
stateParamBeforeName
|
||||||
|
stateParamName
|
||||||
|
stateParamBeforeValue
|
||||||
|
stateParamValue
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
index int
|
||||||
|
key, param, value []byte
|
||||||
|
mustCall bool
|
||||||
|
)
|
||||||
|
for lexer.Next() {
|
||||||
|
var (
|
||||||
|
call bool
|
||||||
|
growIndex int
|
||||||
|
)
|
||||||
|
|
||||||
|
t := lexer.Type()
|
||||||
|
v := lexer.Bytes()
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case ItemToken:
|
||||||
|
switch state {
|
||||||
|
case stateKey, stateParamBeforeName:
|
||||||
|
key = v
|
||||||
|
state = stateParamBeforeName
|
||||||
|
mustCall = true
|
||||||
|
case stateParamName:
|
||||||
|
param = v
|
||||||
|
state = stateParamBeforeValue
|
||||||
|
mustCall = true
|
||||||
|
case stateParamValue:
|
||||||
|
value = v
|
||||||
|
state = stateParamBeforeName
|
||||||
|
call = true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
case ItemString:
|
||||||
|
if state != stateParamValue {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
value = v
|
||||||
|
state = stateParamBeforeName
|
||||||
|
call = true
|
||||||
|
|
||||||
|
case ItemSeparator:
|
||||||
|
switch {
|
||||||
|
case isComma(v) && state == stateKey:
|
||||||
|
// Nothing to do.
|
||||||
|
|
||||||
|
case isComma(v) && state == stateParamBeforeName:
|
||||||
|
state = stateKey
|
||||||
|
// Make call only if we have not called this key yet.
|
||||||
|
call = mustCall
|
||||||
|
if !call {
|
||||||
|
// If we have already called callback with the key
|
||||||
|
// that just ended.
|
||||||
|
index++
|
||||||
|
} else {
|
||||||
|
// Else grow the index after calling callback.
|
||||||
|
growIndex = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
case isComma(v) && state == stateParamBeforeValue:
|
||||||
|
state = stateKey
|
||||||
|
growIndex = 1
|
||||||
|
call = true
|
||||||
|
|
||||||
|
case isSemicolon(v) && state == stateParamBeforeName:
|
||||||
|
state = stateParamName
|
||||||
|
|
||||||
|
case isSemicolon(v) && state == stateParamBeforeValue:
|
||||||
|
state = stateParamName
|
||||||
|
call = true
|
||||||
|
|
||||||
|
case isEquality(v) && state == stateParamBeforeValue:
|
||||||
|
state = stateParamValue
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if call {
|
||||||
|
switch it(index, key, param, value) {
|
||||||
|
case ControlBreak:
|
||||||
|
// User want to stop to parsing parameters.
|
||||||
|
return true
|
||||||
|
|
||||||
|
case ControlSkip:
|
||||||
|
// User want to skip current param.
|
||||||
|
state = stateKey
|
||||||
|
lexer.SkipEscaped(',')
|
||||||
|
|
||||||
|
case ControlContinue:
|
||||||
|
// User is interested in rest of parameters.
|
||||||
|
// Nothing to do.
|
||||||
|
|
||||||
|
default:
|
||||||
|
panic("unexpected control value")
|
||||||
|
}
|
||||||
|
ok = true
|
||||||
|
param = nil
|
||||||
|
value = nil
|
||||||
|
mustCall = false
|
||||||
|
index += growIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mustCall {
|
||||||
|
ok = true
|
||||||
|
it(index, key, param, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok && !lexer.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func isComma(b []byte) bool {
|
||||||
|
return len(b) == 1 && b[0] == ','
|
||||||
|
}
|
||||||
|
func isSemicolon(b []byte) bool {
|
||||||
|
return len(b) == 1 && b[0] == ';'
|
||||||
|
}
|
||||||
|
func isEquality(b []byte) bool {
|
||||||
|
return len(b) == 1 && b[0] == '='
|
||||||
|
}
|
|
@ -0,0 +1,360 @@
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ItemType encodes type of the lexing token.
|
||||||
|
type ItemType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ItemUndef reports that token is undefined.
|
||||||
|
ItemUndef ItemType = iota
|
||||||
|
// ItemToken reports that token is RFC2616 token.
|
||||||
|
ItemToken
|
||||||
|
// ItemSeparator reports that token is RFC2616 separator.
|
||||||
|
ItemSeparator
|
||||||
|
// ItemString reports that token is RFC2616 quouted string.
|
||||||
|
ItemString
|
||||||
|
// ItemComment reports that token is RFC2616 comment.
|
||||||
|
ItemComment
|
||||||
|
// ItemOctet reports that token is octet slice.
|
||||||
|
ItemOctet
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scanner represents header tokens scanner.
|
||||||
|
// See https://tools.ietf.org/html/rfc2616#section-2
|
||||||
|
type Scanner struct {
|
||||||
|
data []byte
|
||||||
|
pos int
|
||||||
|
|
||||||
|
itemType ItemType
|
||||||
|
itemBytes []byte
|
||||||
|
|
||||||
|
err bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScanner creates new RFC2616 data scanner.
|
||||||
|
func NewScanner(data []byte) *Scanner {
|
||||||
|
return &Scanner{data: data}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next scans for next token. It returns true on successful scanning, and false
|
||||||
|
// on error or EOF.
|
||||||
|
func (l *Scanner) Next() bool {
|
||||||
|
c, ok := l.nextChar()
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch c {
|
||||||
|
case '"': // quoted-string;
|
||||||
|
return l.fetchQuotedString()
|
||||||
|
|
||||||
|
case '(': // comment;
|
||||||
|
return l.fetchComment()
|
||||||
|
|
||||||
|
case '\\', ')': // unexpected chars;
|
||||||
|
l.err = true
|
||||||
|
return false
|
||||||
|
|
||||||
|
default:
|
||||||
|
return l.fetchToken()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUntil fetches ItemOctet from current scanner position to first
|
||||||
|
// occurence of the c or to the end of the underlying data.
|
||||||
|
func (l *Scanner) FetchUntil(c byte) bool {
|
||||||
|
l.resetItem()
|
||||||
|
if l.pos == len(l.data) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return l.fetchOctet(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek reads byte at current position without advancing it. On end of data it
|
||||||
|
// returns 0.
|
||||||
|
func (l *Scanner) Peek() byte {
|
||||||
|
if l.pos == len(l.data) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return l.data[l.pos]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek2 reads two first bytes at current position without advancing it.
|
||||||
|
// If there not enough data it returs 0.
|
||||||
|
func (l *Scanner) Peek2() (a, b byte) {
|
||||||
|
if l.pos == len(l.data) {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
if l.pos+1 == len(l.data) {
|
||||||
|
return l.data[l.pos], 0
|
||||||
|
}
|
||||||
|
return l.data[l.pos], l.data[l.pos+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffered reporst how many bytes there are left to scan.
|
||||||
|
func (l *Scanner) Buffered() int {
|
||||||
|
return len(l.data) - l.pos
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance moves current position index at n bytes. It returns true on
|
||||||
|
// successful move.
|
||||||
|
func (l *Scanner) Advance(n int) bool {
|
||||||
|
l.pos += n
|
||||||
|
if l.pos > len(l.data) {
|
||||||
|
l.pos = len(l.data)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip skips all bytes until first occurence of c.
|
||||||
|
func (l *Scanner) Skip(c byte) {
|
||||||
|
if l.err {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Reset scanner state.
|
||||||
|
l.resetItem()
|
||||||
|
|
||||||
|
if i := bytes.IndexByte(l.data[l.pos:], c); i == -1 {
|
||||||
|
// Reached the end of data.
|
||||||
|
l.pos = len(l.data)
|
||||||
|
} else {
|
||||||
|
l.pos += i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipEscaped skips all bytes until first occurence of non-escaped c.
|
||||||
|
func (l *Scanner) SkipEscaped(c byte) {
|
||||||
|
if l.err {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Reset scanner state.
|
||||||
|
l.resetItem()
|
||||||
|
|
||||||
|
if i := ScanUntil(l.data[l.pos:], c); i == -1 {
|
||||||
|
// Reached the end of data.
|
||||||
|
l.pos = len(l.data)
|
||||||
|
} else {
|
||||||
|
l.pos += i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type reports current token type.
|
||||||
|
func (l *Scanner) Type() ItemType {
|
||||||
|
return l.itemType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns current token bytes.
|
||||||
|
func (l *Scanner) Bytes() []byte {
|
||||||
|
return l.itemBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Scanner) nextChar() (byte, bool) {
|
||||||
|
// Reset scanner state.
|
||||||
|
l.resetItem()
|
||||||
|
|
||||||
|
if l.err {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
l.pos += SkipSpace(l.data[l.pos:])
|
||||||
|
if l.pos == len(l.data) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return l.data[l.pos], true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Scanner) resetItem() {
|
||||||
|
l.itemType = ItemUndef
|
||||||
|
l.itemBytes = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Scanner) fetchOctet(c byte) bool {
|
||||||
|
i := l.pos
|
||||||
|
if j := bytes.IndexByte(l.data[l.pos:], c); j == -1 {
|
||||||
|
// Reached the end of data.
|
||||||
|
l.pos = len(l.data)
|
||||||
|
} else {
|
||||||
|
l.pos += j
|
||||||
|
}
|
||||||
|
|
||||||
|
l.itemType = ItemOctet
|
||||||
|
l.itemBytes = l.data[i:l.pos]
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Scanner) fetchToken() bool {
|
||||||
|
n, t := ScanToken(l.data[l.pos:])
|
||||||
|
if n == -1 {
|
||||||
|
l.err = true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
l.itemType = t
|
||||||
|
l.itemBytes = l.data[l.pos : l.pos+n]
|
||||||
|
l.pos += n
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Scanner) fetchQuotedString() (ok bool) {
|
||||||
|
l.pos++
|
||||||
|
|
||||||
|
n := ScanUntil(l.data[l.pos:], '"')
|
||||||
|
if n == -1 {
|
||||||
|
l.err = true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
l.itemType = ItemString
|
||||||
|
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
|
||||||
|
l.pos += n + 1
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Scanner) fetchComment() (ok bool) {
|
||||||
|
l.pos++
|
||||||
|
|
||||||
|
n := ScanPairGreedy(l.data[l.pos:], '(', ')')
|
||||||
|
if n == -1 {
|
||||||
|
l.err = true
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
l.itemType = ItemComment
|
||||||
|
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
|
||||||
|
l.pos += n + 1
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanUntil scans for first non-escaped character c in given data.
|
||||||
|
// It returns index of matched c and -1 if c is not found.
|
||||||
|
func ScanUntil(data []byte, c byte) (n int) {
|
||||||
|
for {
|
||||||
|
i := bytes.IndexByte(data[n:], c)
|
||||||
|
if i == -1 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
n += i
|
||||||
|
if n == 0 || data[n-1] != '\\' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanPairGreedy scans for complete pair of opening and closing chars in greedy manner.
|
||||||
|
// Note that first opening byte must not be present in data.
|
||||||
|
func ScanPairGreedy(data []byte, open, close byte) (n int) {
|
||||||
|
var m int
|
||||||
|
opened := 1
|
||||||
|
for {
|
||||||
|
i := bytes.IndexByte(data[n:], close)
|
||||||
|
if i == -1 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
n += i
|
||||||
|
// If found index is not escaped then it is the end.
|
||||||
|
if n == 0 || data[n-1] != '\\' {
|
||||||
|
opened--
|
||||||
|
}
|
||||||
|
|
||||||
|
for m < i {
|
||||||
|
j := bytes.IndexByte(data[m:i], open)
|
||||||
|
if j == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
m += j + 1
|
||||||
|
opened++
|
||||||
|
}
|
||||||
|
|
||||||
|
if opened == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
n++
|
||||||
|
m = n
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveByte returns data without c. If c is not present in data it returns
|
||||||
|
// the same slice. If not, it copies data without c.
|
||||||
|
func RemoveByte(data []byte, c byte) []byte {
|
||||||
|
j := bytes.IndexByte(data, c)
|
||||||
|
if j == -1 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
n := len(data) - 1
|
||||||
|
|
||||||
|
// If character is present, than allocate slice with n-1 capacity. That is,
|
||||||
|
// resulting bytes could be at most n-1 length.
|
||||||
|
result := make([]byte, n)
|
||||||
|
k := copy(result, data[:j])
|
||||||
|
|
||||||
|
for i := j + 1; i < n; {
|
||||||
|
j = bytes.IndexByte(data[i:], c)
|
||||||
|
if j != -1 {
|
||||||
|
k += copy(result[k:], data[i:i+j])
|
||||||
|
i = i + j + 1
|
||||||
|
} else {
|
||||||
|
k += copy(result[k:], data[i:])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result[:k]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipSpace skips spaces and lws-sequences from p.
|
||||||
|
// It returns number ob bytes skipped.
|
||||||
|
func SkipSpace(p []byte) (n int) {
|
||||||
|
for len(p) > 0 {
|
||||||
|
switch {
|
||||||
|
case len(p) >= 3 &&
|
||||||
|
p[0] == '\r' &&
|
||||||
|
p[1] == '\n' &&
|
||||||
|
OctetTypes[p[2]].IsSpace():
|
||||||
|
p = p[3:]
|
||||||
|
n += 3
|
||||||
|
case OctetTypes[p[0]].IsSpace():
|
||||||
|
p = p[1:]
|
||||||
|
n++
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanToken scan for next token in p. It returns length of the token and its
|
||||||
|
// type. It do not trim p.
|
||||||
|
func ScanToken(p []byte) (n int, t ItemType) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, ItemUndef
|
||||||
|
}
|
||||||
|
|
||||||
|
c := p[0]
|
||||||
|
switch {
|
||||||
|
case OctetTypes[c].IsSeparator():
|
||||||
|
return 1, ItemSeparator
|
||||||
|
|
||||||
|
case OctetTypes[c].IsToken():
|
||||||
|
for n = 1; n < len(p); n++ {
|
||||||
|
c := p[n]
|
||||||
|
if !OctetTypes[c].IsToken() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, ItemToken
|
||||||
|
|
||||||
|
default:
|
||||||
|
return -1, ItemUndef
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,83 @@
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
// OctetType desribes character type.
|
||||||
|
//
|
||||||
|
// From the "Basic Rules" chapter of RFC2616
|
||||||
|
// See https://tools.ietf.org/html/rfc2616#section-2.2
|
||||||
|
//
|
||||||
|
// OCTET = <any 8-bit sequence of data>
|
||||||
|
// CHAR = <any US-ASCII character (octets 0 - 127)>
|
||||||
|
// UPALPHA = <any US-ASCII uppercase letter "A".."Z">
|
||||||
|
// LOALPHA = <any US-ASCII lowercase letter "a".."z">
|
||||||
|
// ALPHA = UPALPHA | LOALPHA
|
||||||
|
// DIGIT = <any US-ASCII digit "0".."9">
|
||||||
|
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
|
||||||
|
// CR = <US-ASCII CR, carriage return (13)>
|
||||||
|
// LF = <US-ASCII LF, linefeed (10)>
|
||||||
|
// SP = <US-ASCII SP, space (32)>
|
||||||
|
// HT = <US-ASCII HT, horizontal-tab (9)>
|
||||||
|
// <"> = <US-ASCII double-quote mark (34)>
|
||||||
|
// CRLF = CR LF
|
||||||
|
// LWS = [CRLF] 1*( SP | HT )
|
||||||
|
//
|
||||||
|
// Many HTTP/1.1 header field values consist of words separated by LWS
|
||||||
|
// or special characters. These special characters MUST be in a quoted
|
||||||
|
// string to be used within a parameter value (as defined in section
|
||||||
|
// 3.6).
|
||||||
|
//
|
||||||
|
// token = 1*<any CHAR except CTLs or separators>
|
||||||
|
// separators = "(" | ")" | "<" | ">" | "@"
|
||||||
|
// | "," | ";" | ":" | "\" | <">
|
||||||
|
// | "/" | "[" | "]" | "?" | "="
|
||||||
|
// | "{" | "}" | SP | HT
|
||||||
|
type OctetType byte
|
||||||
|
|
||||||
|
// IsChar reports whether octet is CHAR.
|
||||||
|
func (t OctetType) IsChar() bool { return t&octetChar != 0 }
|
||||||
|
|
||||||
|
// IsControl reports whether octet is CTL.
|
||||||
|
func (t OctetType) IsControl() bool { return t&octetControl != 0 }
|
||||||
|
|
||||||
|
// IsSeparator reports whether octet is separator.
|
||||||
|
func (t OctetType) IsSeparator() bool { return t&octetSeparator != 0 }
|
||||||
|
|
||||||
|
// IsSpace reports whether octet is space (SP or HT).
|
||||||
|
func (t OctetType) IsSpace() bool { return t&octetSpace != 0 }
|
||||||
|
|
||||||
|
// IsToken reports whether octet is token.
|
||||||
|
func (t OctetType) IsToken() bool { return t&octetToken != 0 }
|
||||||
|
|
||||||
|
const (
|
||||||
|
octetChar OctetType = 1 << iota
|
||||||
|
octetControl
|
||||||
|
octetSpace
|
||||||
|
octetSeparator
|
||||||
|
octetToken
|
||||||
|
)
|
||||||
|
|
||||||
|
// OctetTypes is a table of octets.
|
||||||
|
var OctetTypes [256]OctetType
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for c := 32; c < 256; c++ {
|
||||||
|
var t OctetType
|
||||||
|
if c <= 127 {
|
||||||
|
t |= octetChar
|
||||||
|
}
|
||||||
|
if 0 <= c && c <= 31 || c == 127 {
|
||||||
|
t |= octetControl
|
||||||
|
}
|
||||||
|
switch c {
|
||||||
|
case '(', ')', '<', '>', '@', ',', ';', ':', '"', '/', '[', ']', '?', '=', '{', '}', '\\':
|
||||||
|
t |= octetSeparator
|
||||||
|
case ' ', '\t':
|
||||||
|
t |= octetSpace | octetSeparator
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.IsChar() && !t.IsControl() && !t.IsSeparator() && !t.IsSpace() {
|
||||||
|
t |= octetToken
|
||||||
|
}
|
||||||
|
|
||||||
|
OctetTypes[c] = t
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,193 @@
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option represents a header option.
|
||||||
|
type Option struct {
|
||||||
|
Name []byte
|
||||||
|
Parameters Parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns number of bytes need to be allocated for use in opt.Copy.
|
||||||
|
func (opt Option) Size() int {
|
||||||
|
return len(opt.Name) + opt.Parameters.bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy copies all underlying []byte slices into p and returns new Option.
|
||||||
|
// Note that p must be at least of opt.Size() length.
|
||||||
|
func (opt Option) Copy(p []byte) Option {
|
||||||
|
n := copy(p, opt.Name)
|
||||||
|
opt.Name = p[:n]
|
||||||
|
opt.Parameters, p = opt.Parameters.Copy(p[n:])
|
||||||
|
return opt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone is a shorthand for making slice of opt.Size() sequenced with Copy()
|
||||||
|
// call.
|
||||||
|
func (opt Option) Clone() Option {
|
||||||
|
return opt.Copy(make([]byte, opt.Size()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String represents option as a string.
|
||||||
|
func (opt Option) String() string {
|
||||||
|
return "{" + string(opt.Name) + " " + opt.Parameters.String() + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOption creates named option with given parameters.
|
||||||
|
func NewOption(name string, params map[string]string) Option {
|
||||||
|
p := Parameters{}
|
||||||
|
for k, v := range params {
|
||||||
|
p.Set([]byte(k), []byte(v))
|
||||||
|
}
|
||||||
|
return Option{
|
||||||
|
Name: []byte(name),
|
||||||
|
Parameters: p,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal reports whether option is equal to b.
|
||||||
|
func (opt Option) Equal(b Option) bool {
|
||||||
|
if bytes.Equal(opt.Name, b.Name) {
|
||||||
|
return opt.Parameters.Equal(b.Parameters)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameters represents option's parameters.
|
||||||
|
type Parameters struct {
|
||||||
|
pos int
|
||||||
|
bytes int
|
||||||
|
arr [8]pair
|
||||||
|
dyn []pair
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal reports whether a equal to b.
|
||||||
|
func (p Parameters) Equal(b Parameters) bool {
|
||||||
|
switch {
|
||||||
|
case p.dyn == nil && b.dyn == nil:
|
||||||
|
case p.dyn != nil && b.dyn != nil:
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ad, bd := p.data(), b.data()
|
||||||
|
if len(ad) != len(bd) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Sort(pairs(ad))
|
||||||
|
sort.Sort(pairs(bd))
|
||||||
|
|
||||||
|
for i := 0; i < len(ad); i++ {
|
||||||
|
av, bv := ad[i], bd[i]
|
||||||
|
if !bytes.Equal(av.key, bv.key) || !bytes.Equal(av.value, bv.value) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns number of bytes that needed to copy p.
|
||||||
|
func (p *Parameters) Size() int {
|
||||||
|
return p.bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy copies all underlying []byte slices into dst and returns new
|
||||||
|
// Parameters.
|
||||||
|
// Note that dst must be at least of p.Size() length.
|
||||||
|
func (p *Parameters) Copy(dst []byte) (Parameters, []byte) {
|
||||||
|
ret := Parameters{
|
||||||
|
pos: p.pos,
|
||||||
|
bytes: p.bytes,
|
||||||
|
}
|
||||||
|
if p.dyn != nil {
|
||||||
|
ret.dyn = make([]pair, len(p.dyn))
|
||||||
|
for i, v := range p.dyn {
|
||||||
|
ret.dyn[i], dst = v.copy(dst)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i, p := range p.arr {
|
||||||
|
ret.arr[i], dst = p.copy(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret, dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns value by key and flag about existence such value.
|
||||||
|
func (p *Parameters) Get(key string) (value []byte, ok bool) {
|
||||||
|
for _, v := range p.data() {
|
||||||
|
if string(v.key) == key {
|
||||||
|
return v.value, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets value by key.
|
||||||
|
func (p *Parameters) Set(key, value []byte) {
|
||||||
|
p.bytes += len(key) + len(value)
|
||||||
|
|
||||||
|
if p.pos < len(p.arr) {
|
||||||
|
p.arr[p.pos] = pair{key, value}
|
||||||
|
p.pos++
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.dyn == nil {
|
||||||
|
p.dyn = make([]pair, len(p.arr), len(p.arr)+1)
|
||||||
|
copy(p.dyn, p.arr[:])
|
||||||
|
}
|
||||||
|
p.dyn = append(p.dyn, pair{key, value})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEach iterates over parameters key-value pairs and calls cb for each one.
|
||||||
|
func (p *Parameters) ForEach(cb func(k, v []byte) bool) {
|
||||||
|
for _, v := range p.data() {
|
||||||
|
if !cb(v.key, v.value) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String represents parameters as a string.
|
||||||
|
func (p *Parameters) String() (ret string) {
|
||||||
|
ret = "["
|
||||||
|
for i, v := range p.data() {
|
||||||
|
if i > 0 {
|
||||||
|
ret += " "
|
||||||
|
}
|
||||||
|
ret += string(v.key) + ":" + string(v.value)
|
||||||
|
}
|
||||||
|
return ret + "]"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Parameters) data() []pair {
|
||||||
|
if p.dyn != nil {
|
||||||
|
return p.dyn
|
||||||
|
}
|
||||||
|
return p.arr[:p.pos]
|
||||||
|
}
|
||||||
|
|
||||||
|
type pair struct {
|
||||||
|
key, value []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p pair) copy(dst []byte) (pair, []byte) {
|
||||||
|
n := copy(dst, p.key)
|
||||||
|
p.key = dst[:n]
|
||||||
|
m := n + copy(dst[n:], p.value)
|
||||||
|
p.value = dst[n:m]
|
||||||
|
|
||||||
|
dst = dst[m:]
|
||||||
|
|
||||||
|
return p, dst
|
||||||
|
}
|
||||||
|
|
||||||
|
type pairs []pair
|
||||||
|
|
||||||
|
func (p pairs) Len() int { return len(p) }
|
||||||
|
func (p pairs) Less(a, b int) bool { return bytes.Compare(p[a].key, p[b].key) == -1 }
|
||||||
|
func (p pairs) Swap(a, b int) { p[a], p[b] = p[b], p[a] }
|
|
@ -0,0 +1,101 @@
|
||||||
|
package httphead
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
var (
|
||||||
|
comma = []byte{','}
|
||||||
|
equality = []byte{'='}
|
||||||
|
semicolon = []byte{';'}
|
||||||
|
quote = []byte{'"'}
|
||||||
|
escape = []byte{'\\'}
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteOptions write options list to the dest.
|
||||||
|
// It uses the same form as {Scan,Parse}Options functions:
|
||||||
|
// values = 1#value
|
||||||
|
// value = token *( ";" param )
|
||||||
|
// param = token [ "=" (token | quoted-string) ]
|
||||||
|
//
|
||||||
|
// It wraps valuse into the quoted-string sequence if it contains any
|
||||||
|
// non-token characters.
|
||||||
|
func WriteOptions(dest io.Writer, options []Option) (n int, err error) {
|
||||||
|
w := writer{w: dest}
|
||||||
|
for i, opt := range options {
|
||||||
|
if i > 0 {
|
||||||
|
w.write(comma)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeTokenSanitized(&w, opt.Name)
|
||||||
|
|
||||||
|
for _, p := range opt.Parameters.data() {
|
||||||
|
w.write(semicolon)
|
||||||
|
writeTokenSanitized(&w, p.key)
|
||||||
|
if len(p.value) != 0 {
|
||||||
|
w.write(equality)
|
||||||
|
writeTokenSanitized(&w, p.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w.result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTokenSanitized writes token as is or as quouted string if it contains
|
||||||
|
// non-token characters.
|
||||||
|
//
|
||||||
|
// Note that is is not expects LWS sequnces be in s, cause LWS is used only as
|
||||||
|
// header field continuation:
|
||||||
|
// "A CRLF is allowed in the definition of TEXT only as part of a header field
|
||||||
|
// continuation. It is expected that the folding LWS will be replaced with a
|
||||||
|
// single SP before interpretation of the TEXT value."
|
||||||
|
// See https://tools.ietf.org/html/rfc2616#section-2
|
||||||
|
//
|
||||||
|
// That is we sanitizing s for writing, so there could not be any header field
|
||||||
|
// continuation.
|
||||||
|
// That is any CRLF will be escaped as any other control characters not allowd in TEXT.
|
||||||
|
func writeTokenSanitized(bw *writer, bts []byte) {
|
||||||
|
var qt bool
|
||||||
|
var pos int
|
||||||
|
for i := 0; i < len(bts); i++ {
|
||||||
|
c := bts[i]
|
||||||
|
if !OctetTypes[c].IsToken() && !qt {
|
||||||
|
qt = true
|
||||||
|
bw.write(quote)
|
||||||
|
}
|
||||||
|
if OctetTypes[c].IsControl() || c == '"' {
|
||||||
|
if !qt {
|
||||||
|
qt = true
|
||||||
|
bw.write(quote)
|
||||||
|
}
|
||||||
|
bw.write(bts[pos:i])
|
||||||
|
bw.write(escape)
|
||||||
|
bw.write(bts[i : i+1])
|
||||||
|
pos = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !qt {
|
||||||
|
bw.write(bts)
|
||||||
|
} else {
|
||||||
|
bw.write(bts[pos:])
|
||||||
|
bw.write(quote)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type writer struct {
|
||||||
|
w io.Writer
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) write(p []byte) {
|
||||||
|
if w.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var n int
|
||||||
|
n, w.err = w.w.Write(p)
|
||||||
|
w.n += n
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) result() (int, error) {
|
||||||
|
return w.n, w.err
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2017-2019 Sergey Kamardin <gobwas@gmail.com>
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,107 @@
|
||||||
|
# pool
|
||||||
|
|
||||||
|
[![GoDoc][godoc-image]][godoc-url]
|
||||||
|
|
||||||
|
> Tiny memory reuse helpers for Go.
|
||||||
|
|
||||||
|
## generic
|
||||||
|
|
||||||
|
Without use of subpackages, `pool` allows to reuse any struct distinguishable
|
||||||
|
by size in generic way:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
x, n := pool.Get(100) // Returns object with size 128 or nil.
|
||||||
|
if x == nil {
|
||||||
|
// Create x somehow with knowledge that n is 128.
|
||||||
|
}
|
||||||
|
defer pool.Put(x, n)
|
||||||
|
|
||||||
|
// Work with x.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Pool allows you to pass specific options for constructing custom pool:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
p := pool.Custom(
|
||||||
|
pool.WithLogSizeMapping(), // Will ceil size n passed to Get(n) to nearest power of two.
|
||||||
|
pool.WithLogSizeRange(64, 512), // Will reuse objects in logarithmic range [64, 512].
|
||||||
|
pool.WithSize(65536), // Will reuse object with size 65536.
|
||||||
|
)
|
||||||
|
x, n := p.Get(1000) // Returns nil and 1000 because mapped size 1000 => 1024 is not reusing by the pool.
|
||||||
|
defer pool.Put(x, n) // Will not reuse x.
|
||||||
|
|
||||||
|
// Work with x.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that there are few non-generic pooling implementations inside subpackages.
|
||||||
|
|
||||||
|
## pbytes
|
||||||
|
|
||||||
|
Subpackage `pbytes` is intended for `[]byte` reuse.
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool/pbytes"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
bts := pbytes.GetCap(100) // Returns make([]byte, 0, 128).
|
||||||
|
defer pbytes.Put(bts)
|
||||||
|
|
||||||
|
// Work with bts.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also create your own range for pooling:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool/pbytes"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Reuse only slices whose capacity is 128, 256, 512 or 1024.
|
||||||
|
pool := pbytes.New(128, 1024)
|
||||||
|
|
||||||
|
bts := pool.GetCap(100) // Returns make([]byte, 0, 128).
|
||||||
|
defer pool.Put(bts)
|
||||||
|
|
||||||
|
// Work with bts.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## pbufio
|
||||||
|
|
||||||
|
Subpackage `pbufio` is intended for `*bufio.{Reader, Writer}` reuse.
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool/pbufio"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
bw := pbufio.GetWriter(os.Stdout, 100) // Returns bufio.NewWriterSize(128).
|
||||||
|
defer pbufio.PutWriter(bw)
|
||||||
|
|
||||||
|
// Work with bw.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Like with `pbytes`, you can also create pool with custom reuse bounds.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[godoc-image]: https://godoc.org/github.com/gobwas/pool?status.svg
|
||||||
|
[godoc-url]: https://godoc.org/github.com/gobwas/pool
|
|
@ -0,0 +1,87 @@
|
||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gobwas/pool/internal/pmath"
|
||||||
|
)
|
||||||
|
|
||||||
|
var DefaultPool = New(128, 65536)
|
||||||
|
|
||||||
|
// Get pulls object whose generic size is at least of given size. It also
|
||||||
|
// returns a real size of x for further pass to Put(). It returns -1 as real
|
||||||
|
// size for nil x. Size >-1 does not mean that x is non-nil, so checks must be
|
||||||
|
// done.
|
||||||
|
//
|
||||||
|
// Note that size could be ceiled to the next power of two.
|
||||||
|
//
|
||||||
|
// Get is a wrapper around DefaultPool.Get().
|
||||||
|
func Get(size int) (interface{}, int) { return DefaultPool.Get(size) }
|
||||||
|
|
||||||
|
// Put takes x and its size for future reuse.
|
||||||
|
// Put is a wrapper around DefaultPool.Put().
|
||||||
|
func Put(x interface{}, size int) { DefaultPool.Put(x, size) }
|
||||||
|
|
||||||
|
// Pool contains logic of reusing objects distinguishable by size in generic
|
||||||
|
// way.
|
||||||
|
type Pool struct {
|
||||||
|
pool map[int]*sync.Pool
|
||||||
|
size func(int) int
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates new Pool that reuses objects which size is in logarithmic range
|
||||||
|
// [min, max].
|
||||||
|
//
|
||||||
|
// Note that it is a shortcut for Custom() constructor with Options provided by
|
||||||
|
// WithLogSizeMapping() and WithLogSizeRange(min, max) calls.
|
||||||
|
func New(min, max int) *Pool {
|
||||||
|
return Custom(
|
||||||
|
WithLogSizeMapping(),
|
||||||
|
WithLogSizeRange(min, max),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom creates new Pool with given options.
|
||||||
|
func Custom(opts ...Option) *Pool {
|
||||||
|
p := &Pool{
|
||||||
|
pool: make(map[int]*sync.Pool),
|
||||||
|
size: pmath.Identity,
|
||||||
|
}
|
||||||
|
|
||||||
|
c := (*poolConfig)(p)
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pulls object whose generic size is at least of given size.
|
||||||
|
// It also returns a real size of x for further pass to Put() even if x is nil.
|
||||||
|
// Note that size could be ceiled to the next power of two.
|
||||||
|
func (p *Pool) Get(size int) (interface{}, int) {
|
||||||
|
n := p.size(size)
|
||||||
|
if pool := p.pool[n]; pool != nil {
|
||||||
|
return pool.Get(), n
|
||||||
|
}
|
||||||
|
return nil, size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put takes x and its size for future reuse.
|
||||||
|
func (p *Pool) Put(x interface{}, size int) {
|
||||||
|
if pool := p.pool[size]; pool != nil {
|
||||||
|
pool.Put(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type poolConfig Pool
|
||||||
|
|
||||||
|
// AddSize adds size n to the map.
|
||||||
|
func (p *poolConfig) AddSize(n int) {
|
||||||
|
p.pool[n] = new(sync.Pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSizeMapping sets up incoming size mapping function.
|
||||||
|
func (p *poolConfig) SetSizeMapping(size func(int) int) {
|
||||||
|
p.size = size
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package pmath
|
||||||
|
|
||||||
|
const (
|
||||||
|
bitsize = 32 << (^uint(0) >> 63)
|
||||||
|
maxint = int(1<<(bitsize-1) - 1)
|
||||||
|
maxintHeadBit = 1 << (bitsize - 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogarithmicRange iterates from ceiled to power of two min to max,
|
||||||
|
// calling cb on each iteration.
|
||||||
|
func LogarithmicRange(min, max int, cb func(int)) {
|
||||||
|
if min == 0 {
|
||||||
|
min = 1
|
||||||
|
}
|
||||||
|
for n := CeilToPowerOfTwo(min); n <= max; n <<= 1 {
|
||||||
|
cb(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPowerOfTwo reports whether given integer is a power of two.
|
||||||
|
func IsPowerOfTwo(n int) bool {
|
||||||
|
return n&(n-1) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identity is identity.
|
||||||
|
func Identity(n int) int {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// CeilToPowerOfTwo returns the least power of two integer value greater than
|
||||||
|
// or equal to n.
|
||||||
|
func CeilToPowerOfTwo(n int) int {
|
||||||
|
if n&maxintHeadBit != 0 && n > maxintHeadBit {
|
||||||
|
panic("argument is too large")
|
||||||
|
}
|
||||||
|
if n <= 2 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
n--
|
||||||
|
n = fillBits(n)
|
||||||
|
n++
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// FloorToPowerOfTwo returns the greatest power of two integer value less than
|
||||||
|
// or equal to n.
|
||||||
|
func FloorToPowerOfTwo(n int) int {
|
||||||
|
if n <= 2 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
n = fillBits(n)
|
||||||
|
n >>= 1
|
||||||
|
n++
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillBits(n int) int {
|
||||||
|
n |= n >> 1
|
||||||
|
n |= n >> 2
|
||||||
|
n |= n >> 4
|
||||||
|
n |= n >> 8
|
||||||
|
n |= n >> 16
|
||||||
|
n |= n >> 32
|
||||||
|
return n
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package pool
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool/internal/pmath"
|
||||||
|
|
||||||
|
// Option configures pool.
|
||||||
|
type Option func(Config)
|
||||||
|
|
||||||
|
// Config describes generic pool configuration.
|
||||||
|
type Config interface {
|
||||||
|
AddSize(n int)
|
||||||
|
SetSizeMapping(func(int) int)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSizeLogRange returns an Option that will add logarithmic range of
|
||||||
|
// pooling sizes containing [min, max] values.
|
||||||
|
func WithLogSizeRange(min, max int) Option {
|
||||||
|
return func(c Config) {
|
||||||
|
pmath.LogarithmicRange(min, max, func(n int) {
|
||||||
|
c.AddSize(n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSize returns an Option that will add given pooling size to the pool.
|
||||||
|
func WithSize(n int) Option {
|
||||||
|
return func(c Config) {
|
||||||
|
c.AddSize(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSizeMapping(sz func(int) int) Option {
|
||||||
|
return func(c Config) {
|
||||||
|
c.SetSizeMapping(sz)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithLogSizeMapping() Option {
|
||||||
|
return WithSizeMapping(pmath.CeilToPowerOfTwo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithIdentitySizeMapping() Option {
|
||||||
|
return WithSizeMapping(pmath.Identity)
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
// Package pbufio contains tools for pooling bufio.Reader and bufio.Writers.
|
||||||
|
package pbufio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gobwas/pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
DefaultWriterPool = NewWriterPool(256, 65536)
|
||||||
|
DefaultReaderPool = NewReaderPool(256, 65536)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetWriter returns bufio.Writer whose buffer has at least size bytes.
|
||||||
|
// Note that size could be ceiled to the next power of two.
|
||||||
|
// GetWriter is a wrapper around DefaultWriterPool.Get().
|
||||||
|
func GetWriter(w io.Writer, size int) *bufio.Writer { return DefaultWriterPool.Get(w, size) }
|
||||||
|
|
||||||
|
// PutWriter takes bufio.Writer for future reuse.
|
||||||
|
// It does not reuse bufio.Writer which underlying buffer size is not power of
|
||||||
|
// PutWriter is a wrapper around DefaultWriterPool.Put().
|
||||||
|
func PutWriter(bw *bufio.Writer) { DefaultWriterPool.Put(bw) }
|
||||||
|
|
||||||
|
// GetReader returns bufio.Reader whose buffer has at least size bytes. It returns
|
||||||
|
// its capacity for further pass to Put().
|
||||||
|
// Note that size could be ceiled to the next power of two.
|
||||||
|
// GetReader is a wrapper around DefaultReaderPool.Get().
|
||||||
|
func GetReader(w io.Reader, size int) *bufio.Reader { return DefaultReaderPool.Get(w, size) }
|
||||||
|
|
||||||
|
// PutReader takes bufio.Reader and its size for future reuse.
|
||||||
|
// It does not reuse bufio.Reader if size is not power of two or is out of pool
|
||||||
|
// min/max range.
|
||||||
|
// PutReader is a wrapper around DefaultReaderPool.Put().
|
||||||
|
func PutReader(bw *bufio.Reader) { DefaultReaderPool.Put(bw) }
|
||||||
|
|
||||||
|
// WriterPool contains logic of *bufio.Writer reuse with various size.
|
||||||
|
type WriterPool struct {
|
||||||
|
pool *pool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriterPool creates new WriterPool that reuses writers which size is in
|
||||||
|
// logarithmic range [min, max].
|
||||||
|
func NewWriterPool(min, max int) *WriterPool {
|
||||||
|
return &WriterPool{pool.New(min, max)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomWriterPool creates new WriterPool with given options.
|
||||||
|
func CustomWriterPool(opts ...pool.Option) *WriterPool {
|
||||||
|
return &WriterPool{pool.Custom(opts...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns bufio.Writer whose buffer has at least size bytes.
|
||||||
|
func (wp *WriterPool) Get(w io.Writer, size int) *bufio.Writer {
|
||||||
|
v, n := wp.pool.Get(size)
|
||||||
|
if v != nil {
|
||||||
|
bw := v.(*bufio.Writer)
|
||||||
|
bw.Reset(w)
|
||||||
|
return bw
|
||||||
|
}
|
||||||
|
return bufio.NewWriterSize(w, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put takes ownership of bufio.Writer for further reuse.
|
||||||
|
func (wp *WriterPool) Put(bw *bufio.Writer) {
|
||||||
|
// Should reset even if we do Reset() inside Get().
|
||||||
|
// This is done to prevent locking underlying io.Writer from GC.
|
||||||
|
bw.Reset(nil)
|
||||||
|
wp.pool.Put(bw, writerSize(bw))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReaderPool contains logic of *bufio.Reader reuse with various size.
|
||||||
|
type ReaderPool struct {
|
||||||
|
pool *pool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReaderPool creates new ReaderPool that reuses writers which size is in
|
||||||
|
// logarithmic range [min, max].
|
||||||
|
func NewReaderPool(min, max int) *ReaderPool {
|
||||||
|
return &ReaderPool{pool.New(min, max)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomReaderPool creates new ReaderPool with given options.
|
||||||
|
func CustomReaderPool(opts ...pool.Option) *ReaderPool {
|
||||||
|
return &ReaderPool{pool.Custom(opts...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns bufio.Reader whose buffer has at least size bytes.
|
||||||
|
func (rp *ReaderPool) Get(r io.Reader, size int) *bufio.Reader {
|
||||||
|
v, n := rp.pool.Get(size)
|
||||||
|
if v != nil {
|
||||||
|
br := v.(*bufio.Reader)
|
||||||
|
br.Reset(r)
|
||||||
|
return br
|
||||||
|
}
|
||||||
|
return bufio.NewReaderSize(r, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put takes ownership of bufio.Reader for further reuse.
|
||||||
|
func (rp *ReaderPool) Put(br *bufio.Reader) {
|
||||||
|
// Should reset even if we do Reset() inside Get().
|
||||||
|
// This is done to prevent locking underlying io.Reader from GC.
|
||||||
|
br.Reset(nil)
|
||||||
|
rp.pool.Put(br, readerSize(br))
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
// +build go1.10
|
||||||
|
|
||||||
|
package pbufio
|
||||||
|
|
||||||
|
import "bufio"
|
||||||
|
|
||||||
|
func writerSize(bw *bufio.Writer) int {
|
||||||
|
return bw.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
func readerSize(br *bufio.Reader) int {
|
||||||
|
return br.Size()
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
// +build !go1.10
|
||||||
|
|
||||||
|
package pbufio
|
||||||
|
|
||||||
|
import "bufio"
|
||||||
|
|
||||||
|
func writerSize(bw *bufio.Writer) int {
|
||||||
|
return bw.Available() + bw.Buffered()
|
||||||
|
}
|
||||||
|
|
||||||
|
// readerSize returns buffer size of the given buffered reader.
|
||||||
|
// NOTE: current workaround implementation resets underlying io.Reader.
|
||||||
|
func readerSize(br *bufio.Reader) int {
|
||||||
|
br.Reset(sizeReader)
|
||||||
|
br.ReadByte()
|
||||||
|
n := br.Buffered() + 1
|
||||||
|
br.Reset(nil)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
var sizeReader optimisticReader
|
||||||
|
|
||||||
|
type optimisticReader struct{}
|
||||||
|
|
||||||
|
func (optimisticReader) Read(p []byte) (int, error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Package pbytes contains tools for pooling byte pool.
|
||||||
|
// Note that by default it reuse slices with capacity from 128 to 65536 bytes.
|
||||||
|
package pbytes
|
||||||
|
|
||||||
|
// DefaultPool is used by pacakge level functions.
|
||||||
|
var DefaultPool = New(128, 65536)
|
||||||
|
|
||||||
|
// Get returns probably reused slice of bytes with at least capacity of c and
|
||||||
|
// exactly len of n.
|
||||||
|
// Get is a wrapper around DefaultPool.Get().
|
||||||
|
func Get(n, c int) []byte { return DefaultPool.Get(n, c) }
|
||||||
|
|
||||||
|
// GetCap returns probably reused slice of bytes with at least capacity of n.
|
||||||
|
// GetCap is a wrapper around DefaultPool.GetCap().
|
||||||
|
func GetCap(c int) []byte { return DefaultPool.GetCap(c) }
|
||||||
|
|
||||||
|
// GetLen returns probably reused slice of bytes with at least capacity of n
|
||||||
|
// and exactly len of n.
|
||||||
|
// GetLen is a wrapper around DefaultPool.GetLen().
|
||||||
|
func GetLen(n int) []byte { return DefaultPool.GetLen(n) }
|
||||||
|
|
||||||
|
// Put returns given slice to reuse pool.
|
||||||
|
// Put is a wrapper around DefaultPool.Put().
|
||||||
|
func Put(p []byte) { DefaultPool.Put(p) }
|
|
@ -0,0 +1,59 @@
|
||||||
|
// +build !pool_sanitize
|
||||||
|
|
||||||
|
package pbytes
|
||||||
|
|
||||||
|
import "github.com/gobwas/pool"
|
||||||
|
|
||||||
|
// Pool contains logic of reusing byte slices of various size.
|
||||||
|
type Pool struct {
|
||||||
|
pool *pool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates new Pool that reuses slices which size is in logarithmic range
|
||||||
|
// [min, max].
|
||||||
|
//
|
||||||
|
// Note that it is a shortcut for Custom() constructor with Options provided by
|
||||||
|
// pool.WithLogSizeMapping() and pool.WithLogSizeRange(min, max) calls.
|
||||||
|
func New(min, max int) *Pool {
|
||||||
|
return &Pool{pool.New(min, max)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates new Pool with given options.
|
||||||
|
func Custom(opts ...pool.Option) *Pool {
|
||||||
|
return &Pool{pool.Custom(opts...)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns probably reused slice of bytes with at least capacity of c and
|
||||||
|
// exactly len of n.
|
||||||
|
func (p *Pool) Get(n, c int) []byte {
|
||||||
|
if n > c {
|
||||||
|
panic("requested length is greater than capacity")
|
||||||
|
}
|
||||||
|
|
||||||
|
v, x := p.pool.Get(c)
|
||||||
|
if v != nil {
|
||||||
|
bts := v.([]byte)
|
||||||
|
bts = bts[:n]
|
||||||
|
return bts
|
||||||
|
}
|
||||||
|
|
||||||
|
return make([]byte, n, x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put returns given slice to reuse pool.
|
||||||
|
// It does not reuse bytes whose size is not power of two or is out of pool
|
||||||
|
// min/max range.
|
||||||
|
func (p *Pool) Put(bts []byte) {
|
||||||
|
p.pool.Put(bts, cap(bts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCap returns probably reused slice of bytes with at least capacity of n.
|
||||||
|
func (p *Pool) GetCap(c int) []byte {
|
||||||
|
return p.Get(0, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLen returns probably reused slice of bytes with at least capacity of n
|
||||||
|
// and exactly len of n.
|
||||||
|
func (p *Pool) GetLen(n int) []byte {
|
||||||
|
return p.Get(n, n)
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
// +build pool_sanitize
|
||||||
|
|
||||||
|
package pbytes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const magic = uint64(0x777742)
|
||||||
|
|
||||||
|
type guard struct {
|
||||||
|
magic uint64
|
||||||
|
size int
|
||||||
|
owners int32
|
||||||
|
}
|
||||||
|
|
||||||
|
const guardSize = int(unsafe.Sizeof(guard{}))
|
||||||
|
|
||||||
|
type Pool struct {
|
||||||
|
min, max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(min, max int) *Pool {
|
||||||
|
return &Pool{min, max}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns probably reused slice of bytes with at least capacity of c and
|
||||||
|
// exactly len of n.
|
||||||
|
func (p *Pool) Get(n, c int) []byte {
|
||||||
|
if n > c {
|
||||||
|
panic("requested length is greater than capacity")
|
||||||
|
}
|
||||||
|
|
||||||
|
pageSize := syscall.Getpagesize()
|
||||||
|
pages := (c+guardSize)/pageSize + 1
|
||||||
|
size := pages * pageSize
|
||||||
|
|
||||||
|
bts := alloc(size)
|
||||||
|
|
||||||
|
g := (*guard)(unsafe.Pointer(&bts[0]))
|
||||||
|
*g = guard{
|
||||||
|
magic: magic,
|
||||||
|
size: size,
|
||||||
|
owners: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return bts[guardSize : guardSize+n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pool) GetCap(c int) []byte { return p.Get(0, c) }
|
||||||
|
func (p *Pool) GetLen(n int) []byte { return Get(n, n) }
|
||||||
|
|
||||||
|
// Put returns given slice to reuse pool.
|
||||||
|
func (p *Pool) Put(bts []byte) {
|
||||||
|
hdr := *(*reflect.SliceHeader)(unsafe.Pointer(&bts))
|
||||||
|
ptr := hdr.Data - uintptr(guardSize)
|
||||||
|
|
||||||
|
g := (*guard)(unsafe.Pointer(ptr))
|
||||||
|
if g.magic != magic {
|
||||||
|
panic("unknown slice returned to the pool")
|
||||||
|
}
|
||||||
|
if n := atomic.AddInt32(&g.owners, -1); n < 0 {
|
||||||
|
panic("multiple Put() detected")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable read and write on bytes memory pages. This will cause panic on
|
||||||
|
// incorrect access to returned slice.
|
||||||
|
mprotect(ptr, false, false, g.size)
|
||||||
|
|
||||||
|
runtime.SetFinalizer(&bts, func(b *[]byte) {
|
||||||
|
mprotect(ptr, true, true, g.size)
|
||||||
|
free(*(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
|
||||||
|
Data: ptr,
|
||||||
|
Len: g.size,
|
||||||
|
Cap: g.size,
|
||||||
|
})))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func alloc(n int) []byte {
|
||||||
|
b, err := unix.Mmap(-1, 0, n, unix.PROT_READ|unix.PROT_WRITE|unix.PROT_EXEC, unix.MAP_SHARED|unix.MAP_ANONYMOUS)
|
||||||
|
if err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func free(b []byte) {
|
||||||
|
if err := unix.Munmap(b); err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mprotect(ptr uintptr, r, w bool, size int) {
|
||||||
|
// Need to avoid "EINVAL addr is not a valid pointer,
|
||||||
|
// or not a multiple of PAGESIZE."
|
||||||
|
start := ptr & ^(uintptr(syscall.Getpagesize() - 1))
|
||||||
|
|
||||||
|
prot := uintptr(syscall.PROT_EXEC)
|
||||||
|
switch {
|
||||||
|
case r && w:
|
||||||
|
prot |= syscall.PROT_READ | syscall.PROT_WRITE
|
||||||
|
case r:
|
||||||
|
prot |= syscall.PROT_READ
|
||||||
|
case w:
|
||||||
|
prot |= syscall.PROT_WRITE
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := syscall.Syscall(syscall.SYS_MPROTECT,
|
||||||
|
start, uintptr(size), prot,
|
||||||
|
)
|
||||||
|
if err != 0 {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
// Package pool contains helpers for pooling structures distinguishable by
|
||||||
|
// size.
|
||||||
|
//
|
||||||
|
// Quick example:
|
||||||
|
//
|
||||||
|
// import "github.com/gobwas/pool"
|
||||||
|
//
|
||||||
|
// func main() {
|
||||||
|
// // Reuse objects in logarithmic range from 0 to 64 (0,1,2,4,6,8,16,32,64).
|
||||||
|
// p := pool.New(0, 64)
|
||||||
|
//
|
||||||
|
// buf, n := p.Get(10) // Returns buffer with 16 capacity.
|
||||||
|
// if buf == nil {
|
||||||
|
// buf = bytes.NewBuffer(make([]byte, n))
|
||||||
|
// }
|
||||||
|
// defer p.Put(buf, n)
|
||||||
|
//
|
||||||
|
// // Work with buf.
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// There are non-generic implementations for pooling:
|
||||||
|
// - pool/pbytes for []byte reuse;
|
||||||
|
// - pool/pbufio for *bufio.Reader and *bufio.Writer reuse;
|
||||||
|
//
|
||||||
|
package pool
|
|
@ -0,0 +1,5 @@
|
||||||
|
bin/
|
||||||
|
reports/
|
||||||
|
cpu.out
|
||||||
|
mem.out
|
||||||
|
ws.test
|
|
@ -0,0 +1,25 @@
|
||||||
|
sudo: required
|
||||||
|
|
||||||
|
language: go
|
||||||
|
|
||||||
|
services:
|
||||||
|
- docker
|
||||||
|
|
||||||
|
os:
|
||||||
|
- linux
|
||||||
|
- windows
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.8.x
|
||||||
|
- 1.9.x
|
||||||
|
- 1.10.x
|
||||||
|
- 1.11.x
|
||||||
|
- 1.x
|
||||||
|
|
||||||
|
install:
|
||||||
|
- go get github.com/gobwas/pool
|
||||||
|
- go get github.com/gobwas/httphead
|
||||||
|
|
||||||
|
script:
|
||||||
|
- if [ "$TRAVIS_OS_NAME" = "windows" ]; then go test ./...; fi
|
||||||
|
- if [ "$TRAVIS_OS_NAME" = "linux" ]; then make test autobahn; fi
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2017-2018 Sergey Kamardin <gobwas@gmail.com>
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,47 @@
|
||||||
|
BENCH ?=.
|
||||||
|
BENCH_BASE?=master
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f bin/reporter
|
||||||
|
rm -fr autobahn/report/*
|
||||||
|
|
||||||
|
bin/reporter:
|
||||||
|
go build -o bin/reporter ./autobahn
|
||||||
|
|
||||||
|
bin/gocovmerge:
|
||||||
|
go build -o bin/gocovmerge github.com/wadey/gocovmerge
|
||||||
|
|
||||||
|
.PHONY: autobahn
|
||||||
|
autobahn: clean bin/reporter
|
||||||
|
./autobahn/script/test.sh --build
|
||||||
|
bin/reporter $(PWD)/autobahn/report/index.json
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test -coverprofile=ws.coverage .
|
||||||
|
go test -coverprofile=wsutil.coverage ./wsutil
|
||||||
|
|
||||||
|
cover: bin/gocovmerge test autobahn
|
||||||
|
bin/gocovmerge ws.coverage wsutil.coverage autobahn/report/server.coverage > total.coverage
|
||||||
|
|
||||||
|
benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
|
||||||
|
benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX)
|
||||||
|
benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX)
|
||||||
|
benchcmp:
|
||||||
|
if [ ! -z "$(shell git status -s)" ]; then\
|
||||||
|
echo "could not compare with $(BENCH_BASE) – found unstaged changes";\
|
||||||
|
exit 1;\
|
||||||
|
fi;\
|
||||||
|
if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\
|
||||||
|
echo "comparing the same branches";\
|
||||||
|
exit 1;\
|
||||||
|
fi;\
|
||||||
|
echo "benchmarking $(BENCH_BRANCH)...";\
|
||||||
|
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\
|
||||||
|
echo "benchmarking $(BENCH_BASE)...";\
|
||||||
|
git checkout -q $(BENCH_BASE);\
|
||||||
|
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\
|
||||||
|
git checkout -q $(BENCH_BRANCH);\
|
||||||
|
echo "\nresults:";\
|
||||||
|
echo "========\n";\
|
||||||
|
benchcmp $(BENCH_OLD) $(BENCH_NEW);\
|
||||||
|
|
|
@ -0,0 +1,360 @@
|
||||||
|
# ws
|
||||||
|
|
||||||
|
[![GoDoc][godoc-image]][godoc-url]
|
||||||
|
[![Travis][travis-image]][travis-url]
|
||||||
|
|
||||||
|
> [RFC6455][rfc-url] WebSocket implementation in Go.
|
||||||
|
|
||||||
|
# Features
|
||||||
|
|
||||||
|
- Zero-copy upgrade
|
||||||
|
- No intermediate allocations during I/O
|
||||||
|
- Low-level API which allows to build your own logic of packet handling and
|
||||||
|
buffers reuse
|
||||||
|
- High-level wrappers and helpers around API in `wsutil` package, which allow
|
||||||
|
to start fast without digging the protocol internals
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
|
||||||
|
[GoDoc][godoc-url].
|
||||||
|
|
||||||
|
# Why
|
||||||
|
|
||||||
|
Existing WebSocket implementations do not allow users to reuse I/O buffers
|
||||||
|
between connections in clear way. This library aims to export efficient
|
||||||
|
low-level interface for working with the protocol without forcing only one way
|
||||||
|
it could be used.
|
||||||
|
|
||||||
|
By the way, if you want get the higher-level tools, you can use `wsutil`
|
||||||
|
package.
|
||||||
|
|
||||||
|
# Status
|
||||||
|
|
||||||
|
Library is tagged as `v1*` so its API must not be broken during some
|
||||||
|
improvements or refactoring.
|
||||||
|
|
||||||
|
This implementation of RFC6455 passes [Autobahn Test
|
||||||
|
Suite](https://github.com/crossbario/autobahn-testsuite) and currently has
|
||||||
|
about 78% coverage.
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
|
||||||
|
Example applications using `ws` are developed in separate repository
|
||||||
|
[ws-examples](https://github.com/gobwas/ws-examples).
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
The higher-level example of WebSocket echo server:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, op, err := wsutil.ReadClientData(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
err = wsutil.WriteServerMessage(conn, op, msg)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Lower-level, but still high-level example:
|
||||||
|
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
var (
|
||||||
|
state = ws.StateServerSide
|
||||||
|
reader = wsutil.NewReader(conn, state)
|
||||||
|
writer = wsutil.NewWriter(conn, state, ws.OpText)
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
header, err := reader.NextFrame()
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset writer to write frame with right operation code.
|
||||||
|
writer.Reset(conn, state, header.OpCode)
|
||||||
|
|
||||||
|
if _, err = io.Copy(writer, reader); err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
if err = writer.Flush(); err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.:
|
||||||
|
|
||||||
|
```go
|
||||||
|
...
|
||||||
|
var (
|
||||||
|
r = wsutil.NewReader(conn, ws.StateServerSide)
|
||||||
|
w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText)
|
||||||
|
decoder = json.NewDecoder(r)
|
||||||
|
encoder = json.NewEncoder(w)
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
hdr, err = r.NextFrame()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hdr.OpCode == ws.OpClose {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
var req Request
|
||||||
|
if err := decoder.Decode(&req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var resp Response
|
||||||
|
if err := encoder.Encode(&resp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = w.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The lower-level example without `wsutil`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ln, err := net.Listen("tcp", "localhost:8080")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
_, err = ws.Upgrade(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
header, err := ws.ReadHeader(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, header.Length)
|
||||||
|
_, err = io.ReadFull(conn, payload)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
if header.Masked {
|
||||||
|
ws.Cipher(payload, header.Mask, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the Masked flag, server frames must not be masked as
|
||||||
|
// RFC6455 says.
|
||||||
|
header.Masked = false
|
||||||
|
|
||||||
|
if err := ws.WriteHeader(conn, header); err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
if _, err := conn.Write(payload); err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.OpCode == ws.OpClose {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# Zero-copy upgrade
|
||||||
|
|
||||||
|
Zero-copy upgrade helps to avoid unnecessary allocations and copying while
|
||||||
|
handling HTTP Upgrade request.
|
||||||
|
|
||||||
|
Processing of all non-websocket headers is made in place with use of registered
|
||||||
|
user callbacks whose arguments are only valid until callback returns.
|
||||||
|
|
||||||
|
The simple example looks like this:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ln, err := net.Listen("tcp", "localhost:8080")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
u := ws.Upgrader{
|
||||||
|
OnHeader: func(key, value []byte) (err error) {
|
||||||
|
log.Printf("non-websocket header: %q=%q", key, value)
|
||||||
|
return
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = u.Upgrade(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Usage of `ws.Upgrader` here brings ability to control incoming connections on
|
||||||
|
tcp level and simply not to accept them by some logic.
|
||||||
|
|
||||||
|
Zero-copy upgrade is for high-load services which have to control many
|
||||||
|
resources such as connections buffers.
|
||||||
|
|
||||||
|
The real life example could be like this:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/gobwas/httphead"
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ln, err := net.Listen("tcp", "localhost:8080")
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare handshake header writer from http.Header mapping.
|
||||||
|
header := ws.HandshakeHeaderHTTP(http.Header{
|
||||||
|
"X-Go-Version": []string{runtime.Version()},
|
||||||
|
})
|
||||||
|
|
||||||
|
u := ws.Upgrader{
|
||||||
|
OnHost: func(host []byte) error {
|
||||||
|
if string(host) == "github.com" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ws.RejectConnectionError(
|
||||||
|
ws.RejectionStatus(403),
|
||||||
|
ws.RejectionHeader(ws.HandshakeHeaderString(
|
||||||
|
"X-Want-Host: github.com\r\n",
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
OnHeader: func(key, value []byte) error {
|
||||||
|
if string(key) != "Cookie" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ok := httphead.ScanCookie(value, func(key, value []byte) bool {
|
||||||
|
// Check session here or do some other stuff with cookies.
|
||||||
|
// Maybe copy some values for future use.
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ws.RejectConnectionError(
|
||||||
|
ws.RejectionReason("bad cookie"),
|
||||||
|
ws.RejectionStatus(400),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
OnBeforeUpgrade: func() (ws.HandshakeHeader, error) {
|
||||||
|
return header, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = u.Upgrade(conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("upgrade error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[rfc-url]: https://tools.ietf.org/html/rfc6455
|
||||||
|
[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg
|
||||||
|
[godoc-url]: https://godoc.org/github.com/gobwas/ws
|
||||||
|
[travis-image]: https://travis-ci.org/gobwas/ws.svg?branch=master
|
||||||
|
[travis-url]: https://travis-ci.org/gobwas/ws
|
|
@ -0,0 +1,145 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import "unicode/utf8"
|
||||||
|
|
||||||
|
// State represents state of websocket endpoint.
|
||||||
|
// It used by some functions to be more strict when checking compatibility with RFC6455.
|
||||||
|
type State uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// StateServerSide means that endpoint (caller) is a server.
|
||||||
|
StateServerSide State = 0x1 << iota
|
||||||
|
// StateClientSide means that endpoint (caller) is a client.
|
||||||
|
StateClientSide
|
||||||
|
// StateExtended means that extension was negotiated during handshake.
|
||||||
|
StateExtended
|
||||||
|
// StateFragmented means that endpoint (caller) has received fragmented
|
||||||
|
// frame and waits for continuation parts.
|
||||||
|
StateFragmented
|
||||||
|
)
|
||||||
|
|
||||||
|
// Is checks whether the s has v enabled.
|
||||||
|
func (s State) Is(v State) bool {
|
||||||
|
return uint8(s)&uint8(v) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set enables v state on s.
|
||||||
|
func (s State) Set(v State) State {
|
||||||
|
return s | v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear disables v state on s.
|
||||||
|
func (s State) Clear(v State) State {
|
||||||
|
return s & (^v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerSide reports whether states represents server side.
|
||||||
|
func (s State) ServerSide() bool { return s.Is(StateServerSide) }
|
||||||
|
|
||||||
|
// ClientSide reports whether state represents client side.
|
||||||
|
func (s State) ClientSide() bool { return s.Is(StateClientSide) }
|
||||||
|
|
||||||
|
// Extended reports whether state is extended.
|
||||||
|
func (s State) Extended() bool { return s.Is(StateExtended) }
|
||||||
|
|
||||||
|
// Fragmented reports whether state is fragmented.
|
||||||
|
func (s State) Fragmented() bool { return s.Is(StateFragmented) }
|
||||||
|
|
||||||
|
// ProtocolError describes error during checking/parsing websocket frames or
|
||||||
|
// headers.
|
||||||
|
type ProtocolError string
|
||||||
|
|
||||||
|
// Error implements error interface.
|
||||||
|
func (p ProtocolError) Error() string { return string(p) }
|
||||||
|
|
||||||
|
// Errors used by the protocol checkers.
|
||||||
|
var (
|
||||||
|
ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code")
|
||||||
|
ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded")
|
||||||
|
ErrProtocolControlNotFinal = ProtocolError("control frame is not final")
|
||||||
|
ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated")
|
||||||
|
ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked")
|
||||||
|
ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked")
|
||||||
|
ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame")
|
||||||
|
ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame")
|
||||||
|
ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use")
|
||||||
|
ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
|
||||||
|
ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet")
|
||||||
|
ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec")
|
||||||
|
ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason")
|
||||||
|
)
|
||||||
|
|
||||||
|
// CheckHeader checks h to contain valid header data for given state s.
|
||||||
|
//
|
||||||
|
// Note that zero state (0) means that state is clean,
|
||||||
|
// neither server or client side, nor fragmented, nor extended.
|
||||||
|
func CheckHeader(h Header, s State) error {
|
||||||
|
if h.OpCode.IsReserved() {
|
||||||
|
return ErrProtocolOpCodeReserved
|
||||||
|
}
|
||||||
|
if h.OpCode.IsControl() {
|
||||||
|
if h.Length > MaxControlFramePayloadSize {
|
||||||
|
return ErrProtocolControlPayloadOverflow
|
||||||
|
}
|
||||||
|
if !h.Fin {
|
||||||
|
return ErrProtocolControlNotFinal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
// [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
|
||||||
|
// non-zero values. If a nonzero value is received and none of the
|
||||||
|
// negotiated extensions defines the meaning of such a nonzero value, the
|
||||||
|
// receiving endpoint MUST _Fail the WebSocket Connection_.
|
||||||
|
case h.Rsv != 0 && !s.Extended():
|
||||||
|
return ErrProtocolNonZeroRsv
|
||||||
|
|
||||||
|
// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
|
||||||
|
// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
|
||||||
|
// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
|
||||||
|
// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
|
||||||
|
// status code 1002 (protocol error) as defined in Section 7.4.1.
|
||||||
|
case s.ServerSide() && !h.Masked:
|
||||||
|
return ErrProtocolMaskRequired
|
||||||
|
case s.ClientSide() && h.Masked:
|
||||||
|
return ErrProtocolMaskUnexpected
|
||||||
|
|
||||||
|
// [RFC6455]: See detailed explanation in 5.4 section.
|
||||||
|
case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
|
||||||
|
return ErrProtocolContinuationExpected
|
||||||
|
case !s.Fragmented() && h.OpCode == OpContinuation:
|
||||||
|
return ErrProtocolContinuationUnexpected
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckCloseFrameData checks received close information
|
||||||
|
// to be valid RFC6455 compatible close info.
|
||||||
|
//
|
||||||
|
// Note that code.Empty() or code.IsAppLevel() will raise error.
|
||||||
|
//
|
||||||
|
// If endpoint sends close frame without status code (with frame.Length = 0),
|
||||||
|
// application should not check its payload.
|
||||||
|
func CheckCloseFrameData(code StatusCode, reason string) error {
|
||||||
|
switch {
|
||||||
|
case code.IsNotUsed():
|
||||||
|
return ErrProtocolStatusCodeNotInUse
|
||||||
|
|
||||||
|
case code.IsProtocolReserved():
|
||||||
|
return ErrProtocolStatusCodeApplicationLevel
|
||||||
|
|
||||||
|
case code == StatusNoMeaningYet:
|
||||||
|
return ErrProtocolStatusCodeNoMeaning
|
||||||
|
|
||||||
|
case code.IsProtocolSpec() && !code.IsProtocolDefined():
|
||||||
|
return ErrProtocolStatusCodeUnknown
|
||||||
|
|
||||||
|
case !utf8.ValidString(reason):
|
||||||
|
return ErrProtocolInvalidUTF8
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cipher applies XOR cipher to the payload using mask.
|
||||||
|
// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
|
||||||
|
//
|
||||||
|
// To convert masked data into unmasked data, or vice versa, the following
|
||||||
|
// algorithm is applied. The same algorithm applies regardless of the
|
||||||
|
// direction of the translation, e.g., the same steps are applied to
|
||||||
|
// mask the data as to unmask the data.
|
||||||
|
func Cipher(payload []byte, mask [4]byte, offset int) {
|
||||||
|
n := len(payload)
|
||||||
|
if n < 8 {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
payload[i] ^= mask[(offset+i)%4]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate position in mask due to previously processed bytes number.
|
||||||
|
mpos := offset % 4
|
||||||
|
// Count number of bytes will processed one by one from the beginning of payload.
|
||||||
|
ln := remain[mpos]
|
||||||
|
// Count number of bytes will processed one by one from the end of payload.
|
||||||
|
// This is done to process payload by 8 bytes in each iteration of main loop.
|
||||||
|
rn := (n - ln) % 8
|
||||||
|
|
||||||
|
for i := 0; i < ln; i++ {
|
||||||
|
payload[i] ^= mask[(mpos+i)%4]
|
||||||
|
}
|
||||||
|
for i := n - rn; i < n; i++ {
|
||||||
|
payload[i] ^= mask[(mpos+i)%4]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: we use here binary.LittleEndian regardless of what is real
|
||||||
|
// endianess on machine is. To do so, we have to use binary.LittleEndian in
|
||||||
|
// the masking loop below as well.
|
||||||
|
var (
|
||||||
|
m = binary.LittleEndian.Uint32(mask[:])
|
||||||
|
m2 = uint64(m)<<32 | uint64(m)
|
||||||
|
)
|
||||||
|
// Skip already processed right part.
|
||||||
|
// Get number of uint64 parts remaining to process.
|
||||||
|
n = (n - ln - rn) >> 3
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
var (
|
||||||
|
j = ln + (i << 3)
|
||||||
|
chunk = payload[j : j+8]
|
||||||
|
)
|
||||||
|
p := binary.LittleEndian.Uint64(chunk)
|
||||||
|
p = p ^ m2
|
||||||
|
binary.LittleEndian.PutUint64(chunk, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remain maps position in masking key [0,4) to number
|
||||||
|
// of bytes that need to be processed manually inside Cipher().
|
||||||
|
var remain = [4]int{0, 3, 2, 1}
|
|
@ -0,0 +1,556 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gobwas/httphead"
|
||||||
|
"github.com/gobwas/pool/pbufio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants used by Dialer.
|
||||||
|
const (
|
||||||
|
DefaultClientReadBufferSize = 4096
|
||||||
|
DefaultClientWriteBufferSize = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handshake represents handshake result.
|
||||||
|
type Handshake struct {
|
||||||
|
// Protocol is the subprotocol selected during handshake.
|
||||||
|
Protocol string
|
||||||
|
|
||||||
|
// Extensions is the list of negotiated extensions.
|
||||||
|
Extensions []httphead.Option
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errors used by the websocket client.
|
||||||
|
var (
|
||||||
|
ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
|
||||||
|
ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
|
||||||
|
ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultDialer is dialer that holds no options and is used by Dial function.
|
||||||
|
var DefaultDialer Dialer
|
||||||
|
|
||||||
|
// Dial is like Dialer{}.Dial().
|
||||||
|
func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
|
||||||
|
return DefaultDialer.Dial(ctx, urlstr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dialer contains options for establishing websocket connection to an url.
|
||||||
|
type Dialer struct {
|
||||||
|
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
|
||||||
|
// They used to read and write http data while upgrading to WebSocket.
|
||||||
|
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
|
||||||
|
//
|
||||||
|
// If a size is zero then default value is used.
|
||||||
|
ReadBufferSize, WriteBufferSize int
|
||||||
|
|
||||||
|
// Timeout is the maximum amount of time a Dial() will wait for a connect
|
||||||
|
// and an handshake to complete.
|
||||||
|
//
|
||||||
|
// The default is no timeout.
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// Protocols is the list of subprotocols that the client wants to speak,
|
||||||
|
// ordered by preference.
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-4.1
|
||||||
|
Protocols []string
|
||||||
|
|
||||||
|
// Extensions is the list of extensions that client wants to speak.
|
||||||
|
//
|
||||||
|
// Note that if server decides to use some of this extensions, Dial() will
|
||||||
|
// return Handshake struct containing a slice of items, which are the
|
||||||
|
// shallow copies of the items from this list. That is, internals of
|
||||||
|
// Extensions items are shared during Dial().
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-4.1
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-9.1
|
||||||
|
Extensions []httphead.Option
|
||||||
|
|
||||||
|
// Header is an optional HandshakeHeader instance that could be used to
|
||||||
|
// write additional headers to the handshake request.
|
||||||
|
//
|
||||||
|
// It used instead of any key-value mappings to avoid allocations in user
|
||||||
|
// land.
|
||||||
|
Header HandshakeHeader
|
||||||
|
|
||||||
|
// OnStatusError is the callback that will be called after receiving non
|
||||||
|
// "101 Continue" HTTP response status. It receives an io.Reader object
|
||||||
|
// representing server response bytes. That is, it gives ability to parse
|
||||||
|
// HTTP response somehow (probably with http.ReadResponse call) and make a
|
||||||
|
// decision of further logic.
|
||||||
|
//
|
||||||
|
// The arguments are only valid until the callback returns.
|
||||||
|
OnStatusError func(status int, reason []byte, resp io.Reader)
|
||||||
|
|
||||||
|
// OnHeader is the callback that will be called after successful parsing of
|
||||||
|
// header, that is not used during WebSocket handshake procedure. That is,
|
||||||
|
// it will be called with non-websocket headers, which could be relevant
|
||||||
|
// for application-level logic.
|
||||||
|
//
|
||||||
|
// The arguments are only valid until the callback returns.
|
||||||
|
//
|
||||||
|
// Returned value could be used to prevent processing response.
|
||||||
|
OnHeader func(key, value []byte) (err error)
|
||||||
|
|
||||||
|
// NetDial is the function that is used to get plain tcp connection.
|
||||||
|
// If it is not nil, then it is used instead of net.Dialer.
|
||||||
|
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
|
// TLSClient is the callback that will be called after successful dial with
|
||||||
|
// received connection and its remote host name. If it is nil, then the
|
||||||
|
// default tls.Client() will be used.
|
||||||
|
// If it is not nil, then TLSConfig field is ignored.
|
||||||
|
TLSClient func(conn net.Conn, hostname string) net.Conn
|
||||||
|
|
||||||
|
// TLSConfig is passed to tls.Client() to start TLS over established
|
||||||
|
// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
|
||||||
|
// non-nil and its ServerName is empty, then for every Dial() it will be
|
||||||
|
// cloned and appropriate ServerName will be set.
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
|
||||||
|
// WrapConn is the optional callback that will be called when connection is
|
||||||
|
// ready for an i/o. That is, it will be called after successful dial and
|
||||||
|
// TLS initialization (for "wss" schemes). It may be helpful for different
|
||||||
|
// user land purposes such as end to end encryption.
|
||||||
|
//
|
||||||
|
// Note that for debugging purposes of an http handshake (e.g. sent request
|
||||||
|
// and received response), there is an wsutil.DebugDialer struct.
|
||||||
|
WrapConn func(conn net.Conn) net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the url host and upgrades connection to WebSocket.
|
||||||
|
//
|
||||||
|
// If server has sent frames right after successful handshake then returned
|
||||||
|
// buffer will be non-nil. In other cases buffer is always nil. For better
|
||||||
|
// memory efficiency received non-nil bufio.Reader should be returned to the
|
||||||
|
// inner pool with PutReader() function after use.
|
||||||
|
//
|
||||||
|
// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
|
||||||
|
// If you want to dial non-ascii host name, take care of its name serialization
|
||||||
|
// avoiding bad request issues. For more info see net/http Request.Write()
|
||||||
|
// implementation, especially cleanHost() function.
|
||||||
|
func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
|
||||||
|
u, err := url.ParseRequestURI(urlstr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare context to dial with. Initially it is the same as original, but
|
||||||
|
// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
|
||||||
|
// we use more shorter context for dial.
|
||||||
|
dialctx := ctx
|
||||||
|
|
||||||
|
var deadline time.Time
|
||||||
|
if t := d.Timeout; t != 0 {
|
||||||
|
deadline = time.Now().Add(t)
|
||||||
|
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
dialctx, cancel = context.WithDeadline(ctx, deadline)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if conn, err = d.dial(dialctx, u); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if ctx == context.Background() {
|
||||||
|
// No need to start I/O interrupter goroutine which is not zero-cost.
|
||||||
|
conn.SetDeadline(deadline)
|
||||||
|
defer conn.SetDeadline(noDeadline)
|
||||||
|
} else {
|
||||||
|
// Context could be canceled or its deadline could be exceeded.
|
||||||
|
// Start the interrupter goroutine to handle context cancelation.
|
||||||
|
done := setupContextDeadliner(ctx, conn)
|
||||||
|
defer func() {
|
||||||
|
// Map Upgrade() error to a possible context expiration error. That
|
||||||
|
// is, even if Upgrade() err is nil, context could be already
|
||||||
|
// expired and connection be "poisoned" by SetDeadline() call.
|
||||||
|
// In that case we must not return ctx.Err() error.
|
||||||
|
done(&err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
br, hs, err = d.Upgrade(conn, u)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
|
||||||
|
// Dialer.NetDial is not provided.
|
||||||
|
netEmptyDialer net.Dialer
|
||||||
|
// tlsEmptyConfig is an empty tls.Config used as default one.
|
||||||
|
tlsEmptyConfig tls.Config
|
||||||
|
)
|
||||||
|
|
||||||
|
func tlsDefaultConfig() *tls.Config {
|
||||||
|
return &tlsEmptyConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func hostport(host string, defaultPort string) (hostname, addr string) {
|
||||||
|
var (
|
||||||
|
colon = strings.LastIndexByte(host, ':')
|
||||||
|
bracket = strings.IndexByte(host, ']')
|
||||||
|
)
|
||||||
|
if colon > bracket {
|
||||||
|
return host[:colon], host
|
||||||
|
}
|
||||||
|
return host, host + defaultPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
|
||||||
|
dial := d.NetDial
|
||||||
|
if dial == nil {
|
||||||
|
dial = netEmptyDialer.DialContext
|
||||||
|
}
|
||||||
|
switch u.Scheme {
|
||||||
|
case "ws":
|
||||||
|
_, addr := hostport(u.Host, ":80")
|
||||||
|
conn, err = dial(ctx, "tcp", addr)
|
||||||
|
case "wss":
|
||||||
|
hostname, addr := hostport(u.Host, ":443")
|
||||||
|
conn, err = dial(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlsClient := d.TLSClient
|
||||||
|
if tlsClient == nil {
|
||||||
|
tlsClient = d.tlsClient
|
||||||
|
}
|
||||||
|
conn = tlsClient(conn, hostname)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
|
||||||
|
}
|
||||||
|
if wrap := d.WrapConn; wrap != nil {
|
||||||
|
conn = wrap(conn)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
|
||||||
|
config := d.TLSConfig
|
||||||
|
if config == nil {
|
||||||
|
config = tlsDefaultConfig()
|
||||||
|
}
|
||||||
|
if config.ServerName == "" {
|
||||||
|
config = tlsCloneConfig(config)
|
||||||
|
config.ServerName = hostname
|
||||||
|
}
|
||||||
|
// Do not make conn.Handshake() here because downstairs we will prepare
|
||||||
|
// i/o on this conn with proper context's timeout handling.
|
||||||
|
return tls.Client(conn, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// This variables are set like in net/net.go.
|
||||||
|
// noDeadline is just zero value for readability.
|
||||||
|
noDeadline = time.Time{}
|
||||||
|
// aLongTimeAgo is a non-zero time, far in the past, used for immediate
|
||||||
|
// cancelation of dials.
|
||||||
|
aLongTimeAgo = time.Unix(42, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Upgrade writes an upgrade request to the given io.ReadWriter conn at given
|
||||||
|
// url u and reads a response from it.
|
||||||
|
//
|
||||||
|
// It is a caller responsibility to manage I/O deadlines on conn.
|
||||||
|
//
|
||||||
|
// It returns handshake info and some bytes which could be written by the peer
|
||||||
|
// right after response and be caught by us during buffered read.
|
||||||
|
func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
|
||||||
|
// headerSeen constants helps to report whether or not some header was seen
|
||||||
|
// during reading request bytes.
|
||||||
|
const (
|
||||||
|
headerSeenUpgrade = 1 << iota
|
||||||
|
headerSeenConnection
|
||||||
|
headerSeenSecAccept
|
||||||
|
|
||||||
|
// headerSeenAll is the value that we expect to receive at the end of
|
||||||
|
// headers read/parse loop.
|
||||||
|
headerSeenAll = 0 |
|
||||||
|
headerSeenUpgrade |
|
||||||
|
headerSeenConnection |
|
||||||
|
headerSeenSecAccept
|
||||||
|
)
|
||||||
|
|
||||||
|
br = pbufio.GetReader(conn,
|
||||||
|
nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
|
||||||
|
)
|
||||||
|
bw := pbufio.GetWriter(conn,
|
||||||
|
nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
pbufio.PutWriter(bw)
|
||||||
|
if br.Buffered() == 0 || err != nil {
|
||||||
|
// Server does not wrote additional bytes to the connection or
|
||||||
|
// error occurred. That is, no reason to return buffer.
|
||||||
|
pbufio.PutReader(br)
|
||||||
|
br = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
nonce := make([]byte, nonceSize)
|
||||||
|
initNonce(nonce)
|
||||||
|
|
||||||
|
httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
|
||||||
|
if err = bw.Flush(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
|
||||||
|
sl, err := readLine(br)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Begin validation of the response.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-4.2.2
|
||||||
|
// Parse request line data like HTTP version, uri and method.
|
||||||
|
resp, err := httpParseResponseLine(sl)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Even if RFC says "1.1 or higher" without mentioning the part of the
|
||||||
|
// version, we apply it only to minor part.
|
||||||
|
if resp.major != 1 || resp.minor < 1 {
|
||||||
|
err = ErrHandshakeBadProtocol
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.status != 101 {
|
||||||
|
err = StatusError(resp.status)
|
||||||
|
if onStatusError := d.OnStatusError; onStatusError != nil {
|
||||||
|
// Invoke callback with multireader of status-line bytes br.
|
||||||
|
onStatusError(resp.status, resp.reason,
|
||||||
|
io.MultiReader(
|
||||||
|
bytes.NewReader(sl),
|
||||||
|
strings.NewReader(crlf),
|
||||||
|
br,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// If response status is 101 then we expect all technical headers to be
|
||||||
|
// valid. If not, then we stop processing response without giving user
|
||||||
|
// ability to read non-technical headers. That is, we do not distinguish
|
||||||
|
// technical errors (such as parsing error) and protocol errors.
|
||||||
|
var headerSeen byte
|
||||||
|
for {
|
||||||
|
line, e := readLine(br)
|
||||||
|
if e != nil {
|
||||||
|
err = e
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(line) == 0 {
|
||||||
|
// Blank line, no more lines to read.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
k, v, ok := httpParseHeaderLine(line)
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedResponse
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch btsToString(k) {
|
||||||
|
case headerUpgradeCanonical:
|
||||||
|
headerSeen |= headerSeenUpgrade
|
||||||
|
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
|
||||||
|
err = ErrHandshakeBadUpgrade
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerConnectionCanonical:
|
||||||
|
headerSeen |= headerSeenConnection
|
||||||
|
// Note that as RFC6455 says:
|
||||||
|
// > A |Connection| header field with value "Upgrade".
|
||||||
|
// That is, in server side, "Connection" header could contain
|
||||||
|
// multiple token. But in response it must contains exactly one.
|
||||||
|
if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
|
||||||
|
err = ErrHandshakeBadConnection
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecAcceptCanonical:
|
||||||
|
headerSeen |= headerSeenSecAccept
|
||||||
|
if !checkAcceptFromNonce(v, nonce) {
|
||||||
|
err = ErrHandshakeBadSecAccept
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecProtocolCanonical:
|
||||||
|
// RFC6455 1.3:
|
||||||
|
// "The server selects one or none of the acceptable protocols
|
||||||
|
// and echoes that value in its handshake to indicate that it has
|
||||||
|
// selected that protocol."
|
||||||
|
for _, want := range d.Protocols {
|
||||||
|
if string(v) == want {
|
||||||
|
hs.Protocol = want
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hs.Protocol == "" {
|
||||||
|
// Server echoed subprotocol that is not present in client
|
||||||
|
// requested protocols.
|
||||||
|
err = ErrHandshakeBadSubProtocol
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecExtensionsCanonical:
|
||||||
|
hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if onHeader := d.OnHeader; onHeader != nil {
|
||||||
|
if e := onHeader(k, v); e != nil {
|
||||||
|
err = e
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err == nil && headerSeen != headerSeenAll {
|
||||||
|
switch {
|
||||||
|
case headerSeen&headerSeenUpgrade == 0:
|
||||||
|
err = ErrHandshakeBadUpgrade
|
||||||
|
case headerSeen&headerSeenConnection == 0:
|
||||||
|
err = ErrHandshakeBadConnection
|
||||||
|
case headerSeen&headerSeenSecAccept == 0:
|
||||||
|
err = ErrHandshakeBadSecAccept
|
||||||
|
default:
|
||||||
|
panic("unknown headers state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutReader returns bufio.Reader instance to the inner reuse pool.
|
||||||
|
// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
|
||||||
|
// contains unprocessed buffered data, that was sent by the server quickly
|
||||||
|
// right after handshake.
|
||||||
|
func PutReader(br *bufio.Reader) {
|
||||||
|
pbufio.PutReader(br)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusError contains an unexpected status-line code from the server.
|
||||||
|
type StatusError int
|
||||||
|
|
||||||
|
func (s StatusError) Error() string {
|
||||||
|
return "unexpected HTTP response status: " + strconv.Itoa(int(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeoutError(err error) bool {
|
||||||
|
t, ok := err.(net.Error)
|
||||||
|
return ok && t.Timeout()
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
|
||||||
|
if len(selected) == 0 {
|
||||||
|
return received, nil
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
index int
|
||||||
|
option httphead.Option
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
index = -1
|
||||||
|
match := func() (ok bool) {
|
||||||
|
for _, want := range wanted {
|
||||||
|
if option.Equal(want) {
|
||||||
|
// Check parsed extension to be present in client
|
||||||
|
// requested extensions. We move matched extension
|
||||||
|
// from client list to avoid allocation.
|
||||||
|
received = append(received, want)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
|
||||||
|
if i != index {
|
||||||
|
// Met next option.
|
||||||
|
index = i
|
||||||
|
if i != 0 && !match() {
|
||||||
|
// Server returned non-requested extension.
|
||||||
|
err = ErrHandshakeBadExtensions
|
||||||
|
return httphead.ControlBreak
|
||||||
|
}
|
||||||
|
option = httphead.Option{Name: name}
|
||||||
|
}
|
||||||
|
if attr != nil {
|
||||||
|
option.Parameters.Set(attr, val)
|
||||||
|
}
|
||||||
|
return httphead.ControlContinue
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedResponse
|
||||||
|
return received, err
|
||||||
|
}
|
||||||
|
if !match() {
|
||||||
|
return received, ErrHandshakeBadExtensions
|
||||||
|
}
|
||||||
|
return received, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupContextDeadliner is a helper function that starts connection I/O
|
||||||
|
// interrupter goroutine.
|
||||||
|
//
|
||||||
|
// Started goroutine calls SetDeadline() with long time ago value when context
|
||||||
|
// become expired to make any I/O operations failed. It returns done function
|
||||||
|
// that stops started goroutine and maps error received from conn I/O methods
|
||||||
|
// to possible context expiration error.
|
||||||
|
//
|
||||||
|
// In concern with possible SetDeadline() call inside interrupter goroutine,
|
||||||
|
// caller passes pointer to its I/O error (even if it is nil) to done(&err).
|
||||||
|
// That is, even if I/O error is nil, context could be already expired and
|
||||||
|
// connection "poisoned" by SetDeadline() call. In that case done(&err) will
|
||||||
|
// store at *err ctx.Err() result. If err is caused not by timeout, it will
|
||||||
|
// leaved untouched.
|
||||||
|
func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
|
||||||
|
var (
|
||||||
|
quit = make(chan struct{})
|
||||||
|
interrupt = make(chan error, 1)
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-quit:
|
||||||
|
interrupt <- nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Cancel i/o immediately.
|
||||||
|
conn.SetDeadline(aLongTimeAgo)
|
||||||
|
interrupt <- ctx.Err()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func(err *error) {
|
||||||
|
close(quit)
|
||||||
|
// If ctx.Err() is non-nil and the original err is net.Error with
|
||||||
|
// Timeout() == true, then it means that I/O was canceled by us by
|
||||||
|
// SetDeadline(aLongTimeAgo) call, or by somebody else previously
|
||||||
|
// by conn.SetDeadline(x).
|
||||||
|
//
|
||||||
|
// Even on race condition when both deadlines are expired
|
||||||
|
// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
|
||||||
|
// be returned.
|
||||||
|
if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
|
||||||
|
*err = ctxErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
// +build !go1.8
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
func tlsCloneConfig(c *tls.Config) *tls.Config {
|
||||||
|
// NOTE: we copying SessionTicketsDisabled and SessionTicketKey here
|
||||||
|
// without calling inner c.initOnceServer somehow because we only could get
|
||||||
|
// here from the ws.Dialer code, which is obviously a client and makes
|
||||||
|
// tls.Client() when it gets new net.Conn.
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: c.Rand,
|
||||||
|
Time: c.Time,
|
||||||
|
Certificates: c.Certificates,
|
||||||
|
NameToCertificate: c.NameToCertificate,
|
||||||
|
GetCertificate: c.GetCertificate,
|
||||||
|
RootCAs: c.RootCAs,
|
||||||
|
NextProtos: c.NextProtos,
|
||||||
|
ServerName: c.ServerName,
|
||||||
|
ClientAuth: c.ClientAuth,
|
||||||
|
ClientCAs: c.ClientCAs,
|
||||||
|
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||||
|
CipherSuites: c.CipherSuites,
|
||||||
|
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: c.SessionTicketKey,
|
||||||
|
ClientSessionCache: c.ClientSessionCache,
|
||||||
|
MinVersion: c.MinVersion,
|
||||||
|
MaxVersion: c.MaxVersion,
|
||||||
|
CurvePreferences: c.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: c.Renegotiation,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
// +build go1.8
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
func tlsCloneConfig(c *tls.Config) *tls.Config {
|
||||||
|
return c.Clone()
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*
|
||||||
|
Package ws implements a client and server for the WebSocket protocol as
|
||||||
|
specified in RFC 6455.
|
||||||
|
|
||||||
|
The main purpose of this package is to provide simple low-level API for
|
||||||
|
efficient work with protocol.
|
||||||
|
|
||||||
|
Overview.
|
||||||
|
|
||||||
|
Upgrade to WebSocket (or WebSocket handshake) can be done in two ways.
|
||||||
|
|
||||||
|
The first way is to use `net/http` server:
|
||||||
|
|
||||||
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||||
|
})
|
||||||
|
|
||||||
|
The second and much more efficient way is so-called "zero-copy upgrade". It
|
||||||
|
avoids redundant allocations and copying of not used headers or other request
|
||||||
|
data. User decides by himself which data should be copied.
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", ":8080")
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
handshake, err := ws.Upgrade(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
For customization details see `ws.Upgrader` documentation.
|
||||||
|
|
||||||
|
After WebSocket handshake you can work with connection in multiple ways.
|
||||||
|
That is, `ws` does not force the only one way of how to work with WebSocket:
|
||||||
|
|
||||||
|
header, err := ws.ReadHeader(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, header.Length)
|
||||||
|
_, err := io.ReadFull(conn, buf)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := ws.NewBinaryFrame([]byte("hello, world!"))
|
||||||
|
if err := ws.WriteFrame(conn, frame); err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
As you can see, it stream friendly:
|
||||||
|
|
||||||
|
const N = 42
|
||||||
|
|
||||||
|
ws.WriteHeader(ws.Header{
|
||||||
|
Fin: true,
|
||||||
|
Length: N,
|
||||||
|
OpCode: ws.OpBinary,
|
||||||
|
})
|
||||||
|
|
||||||
|
io.CopyN(conn, rand.Reader, N)
|
||||||
|
|
||||||
|
Or:
|
||||||
|
|
||||||
|
header, err := ws.ReadHeader(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
io.CopyN(ioutil.Discard, conn, header.Length)
|
||||||
|
|
||||||
|
For more info see the documentation.
|
||||||
|
*/
|
||||||
|
package ws
|
|
@ -0,0 +1,54 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
// RejectOption represents an option used to control the way connection is
|
||||||
|
// rejected.
|
||||||
|
type RejectOption func(*rejectConnectionError)
|
||||||
|
|
||||||
|
// RejectionReason returns an option that makes connection to be rejected with
|
||||||
|
// given reason.
|
||||||
|
func RejectionReason(reason string) RejectOption {
|
||||||
|
return func(err *rejectConnectionError) {
|
||||||
|
err.reason = reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectionStatus returns an option that makes connection to be rejected with
|
||||||
|
// given HTTP status code.
|
||||||
|
func RejectionStatus(code int) RejectOption {
|
||||||
|
return func(err *rejectConnectionError) {
|
||||||
|
err.code = code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectionHeader returns an option that makes connection to be rejected with
|
||||||
|
// given HTTP headers.
|
||||||
|
func RejectionHeader(h HandshakeHeader) RejectOption {
|
||||||
|
return func(err *rejectConnectionError) {
|
||||||
|
err.header = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectConnectionError constructs an error that could be used to control the way
|
||||||
|
// handshake is rejected by Upgrader.
|
||||||
|
func RejectConnectionError(options ...RejectOption) error {
|
||||||
|
err := new(rejectConnectionError)
|
||||||
|
for _, opt := range options {
|
||||||
|
opt(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// rejectConnectionError represents a rejection of upgrade error.
|
||||||
|
//
|
||||||
|
// It can be returned by Upgrader's On* hooks to control the way WebSocket
|
||||||
|
// handshake is rejected.
|
||||||
|
type rejectConnectionError struct {
|
||||||
|
reason string
|
||||||
|
code int
|
||||||
|
header HandshakeHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements error interface.
|
||||||
|
func (r *rejectConnectionError) Error() string {
|
||||||
|
return r.reason
|
||||||
|
}
|
|
@ -0,0 +1,389 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants defined by specification.
|
||||||
|
const (
|
||||||
|
// All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
|
||||||
|
MaxControlFramePayloadSize = 125
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpCode represents operation code.
|
||||||
|
type OpCode byte
|
||||||
|
|
||||||
|
// Operation codes defined by specification.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
const (
|
||||||
|
OpContinuation OpCode = 0x0
|
||||||
|
OpText OpCode = 0x1
|
||||||
|
OpBinary OpCode = 0x2
|
||||||
|
OpClose OpCode = 0x8
|
||||||
|
OpPing OpCode = 0x9
|
||||||
|
OpPong OpCode = 0xa
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsControl checks whether the c is control operation code.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.5
|
||||||
|
func (c OpCode) IsControl() bool {
|
||||||
|
// RFC6455: Control frames are identified by opcodes where
|
||||||
|
// the most significant bit of the opcode is 1.
|
||||||
|
//
|
||||||
|
// Note that OpCode is only 4 bit length.
|
||||||
|
return c&0x8 != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsData checks whether the c is data operation code.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.6
|
||||||
|
func (c OpCode) IsData() bool {
|
||||||
|
// RFC6455: Data frames (e.g., non-control frames) are identified by opcodes
|
||||||
|
// where the most significant bit of the opcode is 0.
|
||||||
|
//
|
||||||
|
// Note that OpCode is only 4 bit length.
|
||||||
|
return c&0x8 == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReserved checks whether the c is reserved operation code.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
func (c OpCode) IsReserved() bool {
|
||||||
|
// RFC6455:
|
||||||
|
// %x3-7 are reserved for further non-control frames
|
||||||
|
// %xB-F are reserved for further control frames
|
||||||
|
return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatusCode represents the encoded reason for closure of websocket connection.
|
||||||
|
//
|
||||||
|
// There are few helper methods on StatusCode that helps to define a range in
|
||||||
|
// which given code is lay in. accordingly to ranges defined in specification.
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-7.4
|
||||||
|
type StatusCode uint16
|
||||||
|
|
||||||
|
// StatusCodeRange describes range of StatusCode values.
|
||||||
|
type StatusCodeRange struct {
|
||||||
|
Min, Max StatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status code ranges defined by specification.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-7.4.2
|
||||||
|
var (
|
||||||
|
StatusRangeNotInUse = StatusCodeRange{0, 999}
|
||||||
|
StatusRangeProtocol = StatusCodeRange{1000, 2999}
|
||||||
|
StatusRangeApplication = StatusCodeRange{3000, 3999}
|
||||||
|
StatusRangePrivate = StatusCodeRange{4000, 4999}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Status codes defined by specification.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-7.4.1
|
||||||
|
const (
|
||||||
|
StatusNormalClosure StatusCode = 1000
|
||||||
|
StatusGoingAway StatusCode = 1001
|
||||||
|
StatusProtocolError StatusCode = 1002
|
||||||
|
StatusUnsupportedData StatusCode = 1003
|
||||||
|
StatusNoMeaningYet StatusCode = 1004
|
||||||
|
StatusInvalidFramePayloadData StatusCode = 1007
|
||||||
|
StatusPolicyViolation StatusCode = 1008
|
||||||
|
StatusMessageTooBig StatusCode = 1009
|
||||||
|
StatusMandatoryExt StatusCode = 1010
|
||||||
|
StatusInternalServerError StatusCode = 1011
|
||||||
|
StatusTLSHandshake StatusCode = 1015
|
||||||
|
|
||||||
|
// StatusAbnormalClosure is a special code designated for use in
|
||||||
|
// applications.
|
||||||
|
StatusAbnormalClosure StatusCode = 1006
|
||||||
|
|
||||||
|
// StatusNoStatusRcvd is a special code designated for use in applications.
|
||||||
|
StatusNoStatusRcvd StatusCode = 1005
|
||||||
|
)
|
||||||
|
|
||||||
|
// In reports whether the code is defined in given range.
|
||||||
|
func (s StatusCode) In(r StatusCodeRange) bool {
|
||||||
|
return r.Min <= s && s <= r.Max
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty reports whether the code is empty.
|
||||||
|
// Empty code has no any meaning neither app level codes nor other.
|
||||||
|
// This method is useful just to check that code is golang default value 0.
|
||||||
|
func (s StatusCode) Empty() bool {
|
||||||
|
return s == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotUsed reports whether the code is predefined in not used range.
|
||||||
|
func (s StatusCode) IsNotUsed() bool {
|
||||||
|
return s.In(StatusRangeNotInUse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsApplicationSpec reports whether the code should be defined by
|
||||||
|
// application, framework or libraries specification.
|
||||||
|
func (s StatusCode) IsApplicationSpec() bool {
|
||||||
|
return s.In(StatusRangeApplication)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrivateSpec reports whether the code should be defined privately.
|
||||||
|
func (s StatusCode) IsPrivateSpec() bool {
|
||||||
|
return s.In(StatusRangePrivate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProtocolSpec reports whether the code should be defined by protocol specification.
|
||||||
|
func (s StatusCode) IsProtocolSpec() bool {
|
||||||
|
return s.In(StatusRangeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProtocolDefined reports whether the code is already defined by protocol specification.
|
||||||
|
func (s StatusCode) IsProtocolDefined() bool {
|
||||||
|
switch s {
|
||||||
|
case StatusNormalClosure,
|
||||||
|
StatusGoingAway,
|
||||||
|
StatusProtocolError,
|
||||||
|
StatusUnsupportedData,
|
||||||
|
StatusInvalidFramePayloadData,
|
||||||
|
StatusPolicyViolation,
|
||||||
|
StatusMessageTooBig,
|
||||||
|
StatusMandatoryExt,
|
||||||
|
StatusInternalServerError,
|
||||||
|
StatusNoStatusRcvd,
|
||||||
|
StatusAbnormalClosure,
|
||||||
|
StatusTLSHandshake:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProtocolReserved reports whether the code is defined by protocol specification
|
||||||
|
// to be reserved only for application usage purpose.
|
||||||
|
func (s StatusCode) IsProtocolReserved() bool {
|
||||||
|
switch s {
|
||||||
|
// [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a
|
||||||
|
// Close control frame by an endpoint.
|
||||||
|
case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compiled control frames for common use cases.
|
||||||
|
// For construct-serialize optimizations.
|
||||||
|
var (
|
||||||
|
CompiledPing = MustCompileFrame(NewPingFrame(nil))
|
||||||
|
CompiledPong = MustCompileFrame(NewPongFrame(nil))
|
||||||
|
CompiledClose = MustCompileFrame(NewCloseFrame(nil))
|
||||||
|
|
||||||
|
CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure)
|
||||||
|
CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway)
|
||||||
|
CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError)
|
||||||
|
CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData)
|
||||||
|
CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet)
|
||||||
|
CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData)
|
||||||
|
CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation)
|
||||||
|
CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig)
|
||||||
|
CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt)
|
||||||
|
CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError)
|
||||||
|
CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Header represents websocket frame header.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
type Header struct {
|
||||||
|
Fin bool
|
||||||
|
Rsv byte
|
||||||
|
OpCode OpCode
|
||||||
|
Masked bool
|
||||||
|
Mask [4]byte
|
||||||
|
Length int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rsv1 reports whether the header has first rsv bit set.
|
||||||
|
func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 }
|
||||||
|
|
||||||
|
// Rsv2 reports whether the header has second rsv bit set.
|
||||||
|
func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 }
|
||||||
|
|
||||||
|
// Rsv3 reports whether the header has third rsv bit set.
|
||||||
|
func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 }
|
||||||
|
|
||||||
|
// Frame represents websocket frame.
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
type Frame struct {
|
||||||
|
Header Header
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFrame creates frame with given operation code,
|
||||||
|
// flag of completeness and payload bytes.
|
||||||
|
func NewFrame(op OpCode, fin bool, p []byte) Frame {
|
||||||
|
return Frame{
|
||||||
|
Header: Header{
|
||||||
|
Fin: fin,
|
||||||
|
OpCode: op,
|
||||||
|
Length: int64(len(p)),
|
||||||
|
},
|
||||||
|
Payload: p,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTextFrame creates text frame with p as payload.
|
||||||
|
// Note that p is not copied.
|
||||||
|
func NewTextFrame(p []byte) Frame {
|
||||||
|
return NewFrame(OpText, true, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBinaryFrame creates binary frame with p as payload.
|
||||||
|
// Note that p is not copied.
|
||||||
|
func NewBinaryFrame(p []byte) Frame {
|
||||||
|
return NewFrame(OpBinary, true, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPingFrame creates ping frame with p as payload.
|
||||||
|
// Note that p is not copied.
|
||||||
|
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
|
||||||
|
// to RFC.
|
||||||
|
func NewPingFrame(p []byte) Frame {
|
||||||
|
return NewFrame(OpPing, true, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPongFrame creates pong frame with p as payload.
|
||||||
|
// Note that p is not copied.
|
||||||
|
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
|
||||||
|
// to RFC.
|
||||||
|
func NewPongFrame(p []byte) Frame {
|
||||||
|
return NewFrame(OpPong, true, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCloseFrame creates close frame with given close body.
|
||||||
|
// Note that p is not copied.
|
||||||
|
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
|
||||||
|
// to RFC.
|
||||||
|
func NewCloseFrame(p []byte) Frame {
|
||||||
|
return NewFrame(OpClose, true, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCloseFrameBody encodes a closure code and a reason into a binary
|
||||||
|
// representation.
|
||||||
|
//
|
||||||
|
// It returns slice which is at most MaxControlFramePayloadSize bytes length.
|
||||||
|
// If the reason is too big it will be cropped to fit the limit defined by the
|
||||||
|
// spec.
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.5
|
||||||
|
func NewCloseFrameBody(code StatusCode, reason string) []byte {
|
||||||
|
n := min(2+len(reason), MaxControlFramePayloadSize)
|
||||||
|
p := make([]byte, n)
|
||||||
|
|
||||||
|
crop := min(MaxControlFramePayloadSize-2, len(reason))
|
||||||
|
PutCloseFrameBody(p, code, reason[:crop])
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutCloseFrameBody encodes code and reason into buf.
|
||||||
|
//
|
||||||
|
// It will panic if the buffer is too small to accommodate a code or a reason.
|
||||||
|
//
|
||||||
|
// PutCloseFrameBody does not check buffer to be RFC compliant, but note that
|
||||||
|
// by RFC it must be at most MaxControlFramePayloadSize.
|
||||||
|
func PutCloseFrameBody(p []byte, code StatusCode, reason string) {
|
||||||
|
_ = p[1+len(reason)]
|
||||||
|
binary.BigEndian.PutUint16(p, uint16(code))
|
||||||
|
copy(p[2:], reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
|
||||||
|
// Note that it copies f payload to prevent collisions.
|
||||||
|
// For less allocations you could use MaskFrameInPlace or construct frame manually.
|
||||||
|
func MaskFrame(f Frame) Frame {
|
||||||
|
return MaskFrameWith(f, NewMask())
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskFrameWith masks frame with given mask and returns frame
|
||||||
|
// with masked payload and Mask header's field set.
|
||||||
|
// Note that it copies f payload to prevent collisions.
|
||||||
|
// For less allocations you could use MaskFrameInPlaceWith or construct frame manually.
|
||||||
|
func MaskFrameWith(f Frame, mask [4]byte) Frame {
|
||||||
|
// TODO(gobwas): check CopyCipher ws copy() Cipher().
|
||||||
|
p := make([]byte, len(f.Payload))
|
||||||
|
copy(p, f.Payload)
|
||||||
|
f.Payload = p
|
||||||
|
return MaskFrameInPlaceWith(f, mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskFrameInPlace masks frame and returns frame with masked payload and Mask
|
||||||
|
// header's field set.
|
||||||
|
// Note that it applies xor cipher to f.Payload without copying, that is, it
|
||||||
|
// modifies f.Payload inplace.
|
||||||
|
func MaskFrameInPlace(f Frame) Frame {
|
||||||
|
return MaskFrameInPlaceWith(f, NewMask())
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskFrameInPlaceWith masks frame with given mask and returns frame
|
||||||
|
// with masked payload and Mask header's field set.
|
||||||
|
// Note that it applies xor cipher to f.Payload without copying, that is, it
|
||||||
|
// modifies f.Payload inplace.
|
||||||
|
func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame {
|
||||||
|
f.Header.Masked = true
|
||||||
|
f.Header.Mask = m
|
||||||
|
Cipher(f.Payload, m, 0)
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMask creates new random mask.
|
||||||
|
func NewMask() (ret [4]byte) {
|
||||||
|
binary.BigEndian.PutUint32(ret[:], rand.Uint32())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompileFrame returns byte representation of given frame.
|
||||||
|
// In terms of memory consumption it is useful to precompile static frames
|
||||||
|
// which are often used.
|
||||||
|
func CompileFrame(f Frame) (bts []byte, err error) {
|
||||||
|
buf := bytes.NewBuffer(make([]byte, 0, 16))
|
||||||
|
err = WriteFrame(buf, f)
|
||||||
|
bts = buf.Bytes()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustCompileFrame is like CompileFrame but panics if frame can not be
|
||||||
|
// encoded.
|
||||||
|
func MustCompileFrame(f Frame) []byte {
|
||||||
|
bts, err := CompileFrame(f)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return bts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rsv creates rsv byte representation.
|
||||||
|
func Rsv(r1, r2, r3 bool) (rsv byte) {
|
||||||
|
if r1 {
|
||||||
|
rsv |= bit5
|
||||||
|
}
|
||||||
|
if r2 {
|
||||||
|
rsv |= bit6
|
||||||
|
}
|
||||||
|
if r3 {
|
||||||
|
rsv |= bit7
|
||||||
|
}
|
||||||
|
return rsv
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCloseFrame(code StatusCode) Frame {
|
||||||
|
return NewCloseFrame(NewCloseFrameBody(code, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure)
|
||||||
|
closeFrameGoingAway = makeCloseFrame(StatusGoingAway)
|
||||||
|
closeFrameProtocolError = makeCloseFrame(StatusProtocolError)
|
||||||
|
closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData)
|
||||||
|
closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet)
|
||||||
|
closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData)
|
||||||
|
closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation)
|
||||||
|
closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig)
|
||||||
|
closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt)
|
||||||
|
closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError)
|
||||||
|
closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake)
|
||||||
|
)
|
|
@ -0,0 +1,468 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/textproto"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gobwas/httphead"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
crlf = "\r\n"
|
||||||
|
colonAndSpace = ": "
|
||||||
|
commaAndSpace = ", "
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
textHeadBadRequest = statusText(http.StatusBadRequest)
|
||||||
|
textHeadInternalServerError = statusText(http.StatusInternalServerError)
|
||||||
|
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
|
||||||
|
|
||||||
|
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
|
||||||
|
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
|
||||||
|
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
|
||||||
|
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
|
||||||
|
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
|
||||||
|
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
|
||||||
|
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
|
||||||
|
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
|
||||||
|
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
headerHost = "Host"
|
||||||
|
headerUpgrade = "Upgrade"
|
||||||
|
headerConnection = "Connection"
|
||||||
|
headerSecVersion = "Sec-WebSocket-Version"
|
||||||
|
headerSecProtocol = "Sec-WebSocket-Protocol"
|
||||||
|
headerSecExtensions = "Sec-WebSocket-Extensions"
|
||||||
|
headerSecKey = "Sec-WebSocket-Key"
|
||||||
|
headerSecAccept = "Sec-WebSocket-Accept"
|
||||||
|
|
||||||
|
headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
|
||||||
|
headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
|
||||||
|
headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
|
||||||
|
headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
|
||||||
|
headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
|
||||||
|
headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
|
||||||
|
headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
|
||||||
|
headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
specHeaderValueUpgrade = []byte("websocket")
|
||||||
|
specHeaderValueConnection = []byte("Upgrade")
|
||||||
|
specHeaderValueConnectionLower = []byte("upgrade")
|
||||||
|
specHeaderValueSecVersion = []byte("13")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
httpVersion1_0 = []byte("HTTP/1.0")
|
||||||
|
httpVersion1_1 = []byte("HTTP/1.1")
|
||||||
|
httpVersionPrefix = []byte("HTTP/")
|
||||||
|
)
|
||||||
|
|
||||||
|
type httpRequestLine struct {
|
||||||
|
method, uri []byte
|
||||||
|
major, minor int
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpResponseLine struct {
|
||||||
|
major, minor int
|
||||||
|
status int
|
||||||
|
reason []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
|
||||||
|
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
|
||||||
|
var proto []byte
|
||||||
|
req.method, req.uri, proto = bsplit3(line, ' ')
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
req.major, req.minor, ok = httpParseVersion(proto)
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedRequest
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
|
||||||
|
var (
|
||||||
|
proto []byte
|
||||||
|
status []byte
|
||||||
|
)
|
||||||
|
proto, status, resp.reason = bsplit3(line, ' ')
|
||||||
|
|
||||||
|
var ok bool
|
||||||
|
resp.major, resp.minor, ok = httpParseVersion(proto)
|
||||||
|
if !ok {
|
||||||
|
return resp, ErrMalformedResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
var convErr error
|
||||||
|
resp.status, convErr = asciiToInt(status)
|
||||||
|
if convErr != nil {
|
||||||
|
return resp, ErrMalformedResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpParseVersion parses major and minor version of HTTP protocol. It returns
|
||||||
|
// parsed values and true if parse is ok.
|
||||||
|
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
|
||||||
|
switch {
|
||||||
|
case bytes.Equal(bts, httpVersion1_0):
|
||||||
|
return 1, 0, true
|
||||||
|
case bytes.Equal(bts, httpVersion1_1):
|
||||||
|
return 1, 1, true
|
||||||
|
case len(bts) < 8:
|
||||||
|
return
|
||||||
|
case !bytes.Equal(bts[:5], httpVersionPrefix):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bts = bts[5:]
|
||||||
|
|
||||||
|
dot := bytes.IndexByte(bts, '.')
|
||||||
|
if dot == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
major, err = asciiToInt(bts[:dot])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
minor, err = asciiToInt(bts[dot+1:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return major, minor, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
|
||||||
|
// values and true if parse is ok.
|
||||||
|
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
|
||||||
|
colon := bytes.IndexByte(line, ':')
|
||||||
|
if colon == -1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
k = btrim(line[:colon])
|
||||||
|
// TODO(gobwas): maybe use just lower here?
|
||||||
|
canonicalizeHeaderKey(k)
|
||||||
|
|
||||||
|
v = btrim(line[colon+1:])
|
||||||
|
|
||||||
|
return k, v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
|
||||||
|
// that key is already canonical. This helps to increase performance.
|
||||||
|
func httpGetHeader(h http.Header, key string) string {
|
||||||
|
if h == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v := h[key]
|
||||||
|
if len(v) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// The request MAY include a header field with the name
|
||||||
|
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
|
||||||
|
// comma-separated subprotocol the client wishes to speak, ordered by
|
||||||
|
// preference. The elements that comprise this value MUST be non-empty strings
|
||||||
|
// with characters in the range U+0021 to U+007E not including separator
|
||||||
|
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
|
||||||
|
// for the value of this header field is 1#token, where the definitions of
|
||||||
|
// constructs and rules are as given in [RFC2616].
|
||||||
|
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
|
||||||
|
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
|
||||||
|
if check(btsToString(v)) {
|
||||||
|
ret = string(v)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
|
||||||
|
var selected []byte
|
||||||
|
ok = httphead.ScanTokens(h, func(v []byte) bool {
|
||||||
|
if check(v) {
|
||||||
|
selected = v
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if ok && selected != nil {
|
||||||
|
return string(selected), true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
||||||
|
return btsSelectExtensions(strToBytes(h), selected, check)
|
||||||
|
}
|
||||||
|
|
||||||
|
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
||||||
|
s := httphead.OptionSelector{
|
||||||
|
Flags: httphead.SelectUnique | httphead.SelectCopy,
|
||||||
|
Check: check,
|
||||||
|
}
|
||||||
|
return s.Select(h, selected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpWriteHeader(bw *bufio.Writer, key, value string) {
|
||||||
|
httpWriteHeaderKey(bw, key)
|
||||||
|
bw.WriteString(value)
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
|
||||||
|
httpWriteHeaderKey(bw, key)
|
||||||
|
bw.Write(value)
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
|
||||||
|
bw.WriteString(key)
|
||||||
|
bw.WriteString(colonAndSpace)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpWriteUpgradeRequest(
|
||||||
|
bw *bufio.Writer,
|
||||||
|
u *url.URL,
|
||||||
|
nonce []byte,
|
||||||
|
protocols []string,
|
||||||
|
extensions []httphead.Option,
|
||||||
|
header HandshakeHeader,
|
||||||
|
) {
|
||||||
|
bw.WriteString("GET ")
|
||||||
|
bw.WriteString(u.RequestURI())
|
||||||
|
bw.WriteString(" HTTP/1.1\r\n")
|
||||||
|
|
||||||
|
httpWriteHeader(bw, headerHost, u.Host)
|
||||||
|
|
||||||
|
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
|
||||||
|
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
|
||||||
|
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
|
||||||
|
|
||||||
|
// NOTE: write nonce bytes as a string to prevent heap allocation –
|
||||||
|
// WriteString() copy given string into its inner buffer, unlike Write()
|
||||||
|
// which may write p directly to the underlying io.Writer – which in turn
|
||||||
|
// will lead to p escape.
|
||||||
|
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
|
||||||
|
|
||||||
|
if len(protocols) > 0 {
|
||||||
|
httpWriteHeaderKey(bw, headerSecProtocol)
|
||||||
|
for i, p := range protocols {
|
||||||
|
if i > 0 {
|
||||||
|
bw.WriteString(commaAndSpace)
|
||||||
|
}
|
||||||
|
bw.WriteString(p)
|
||||||
|
}
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(extensions) > 0 {
|
||||||
|
httpWriteHeaderKey(bw, headerSecExtensions)
|
||||||
|
httphead.WriteOptions(bw, extensions)
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header != nil {
|
||||||
|
header.WriteTo(bw)
|
||||||
|
}
|
||||||
|
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
|
||||||
|
bw.WriteString(textHeadUpgrade)
|
||||||
|
|
||||||
|
httpWriteHeaderKey(bw, headerSecAccept)
|
||||||
|
writeAccept(bw, nonce)
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
|
||||||
|
if hs.Protocol != "" {
|
||||||
|
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
|
||||||
|
}
|
||||||
|
if len(hs.Extensions) > 0 {
|
||||||
|
httpWriteHeaderKey(bw, headerSecExtensions)
|
||||||
|
httphead.WriteOptions(bw, hs.Extensions)
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
if header != nil {
|
||||||
|
header(bw)
|
||||||
|
}
|
||||||
|
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
|
||||||
|
switch code {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
bw.WriteString(textHeadBadRequest)
|
||||||
|
case http.StatusInternalServerError:
|
||||||
|
bw.WriteString(textHeadInternalServerError)
|
||||||
|
case http.StatusUpgradeRequired:
|
||||||
|
bw.WriteString(textHeadUpgradeRequired)
|
||||||
|
default:
|
||||||
|
writeStatusText(bw, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write custom headers.
|
||||||
|
if header != nil {
|
||||||
|
header(bw)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch err {
|
||||||
|
case ErrHandshakeBadProtocol:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadProtocol)
|
||||||
|
case ErrHandshakeBadMethod:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadMethod)
|
||||||
|
case ErrHandshakeBadHost:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadHost)
|
||||||
|
case ErrHandshakeBadUpgrade:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadUpgrade)
|
||||||
|
case ErrHandshakeBadConnection:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadConnection)
|
||||||
|
case ErrHandshakeBadSecAccept:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadSecAccept)
|
||||||
|
case ErrHandshakeBadSecKey:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadSecKey)
|
||||||
|
case ErrHandshakeBadSecVersion:
|
||||||
|
bw.WriteString(textTailErrHandshakeBadSecVersion)
|
||||||
|
case ErrHandshakeUpgradeRequired:
|
||||||
|
bw.WriteString(textTailErrUpgradeRequired)
|
||||||
|
case nil:
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
default:
|
||||||
|
writeErrorText(bw, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeStatusText(bw *bufio.Writer, code int) {
|
||||||
|
bw.WriteString("HTTP/1.1 ")
|
||||||
|
bw.WriteString(strconv.Itoa(code))
|
||||||
|
bw.WriteByte(' ')
|
||||||
|
bw.WriteString(http.StatusText(code))
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
bw.WriteString("Content-Type: text/plain; charset=utf-8")
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeErrorText(bw *bufio.Writer, err error) {
|
||||||
|
body := err.Error()
|
||||||
|
bw.WriteString("Content-Length: ")
|
||||||
|
bw.WriteString(strconv.Itoa(len(body)))
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
bw.WriteString(crlf)
|
||||||
|
bw.WriteString(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpError is like the http.Error with WebSocket context exception.
|
||||||
|
func httpError(w http.ResponseWriter, body string, code int) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||||
|
w.WriteHeader(code)
|
||||||
|
w.Write([]byte(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusText is a non-performant status text generator.
|
||||||
|
// NOTE: Used only to generate constants.
|
||||||
|
func statusText(code int) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
bw := bufio.NewWriter(&buf)
|
||||||
|
writeStatusText(bw, code)
|
||||||
|
bw.Flush()
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorText is a non-performant error text generator.
|
||||||
|
// NOTE: Used only to generate constants.
|
||||||
|
func errorText(err error) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
bw := bufio.NewWriter(&buf)
|
||||||
|
writeErrorText(bw, err)
|
||||||
|
bw.Flush()
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeHeader is the interface that writes both upgrade request or
|
||||||
|
// response headers into a given io.Writer.
|
||||||
|
type HandshakeHeader interface {
|
||||||
|
io.WriterTo
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeHeaderString is an adapter to allow the use of headers represented
|
||||||
|
// by ordinary string as HandshakeHeader.
|
||||||
|
type HandshakeHeaderString string
|
||||||
|
|
||||||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||||
|
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
n, err := io.WriteString(w, string(s))
|
||||||
|
return int64(n), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
|
||||||
|
// by ordinary slice of bytes as HandshakeHeader.
|
||||||
|
type HandshakeHeaderBytes []byte
|
||||||
|
|
||||||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||||
|
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
n, err := w.Write(b)
|
||||||
|
return int64(n), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
|
||||||
|
// ordinary function as HandshakeHeader.
|
||||||
|
type HandshakeHeaderFunc func(io.Writer) (int64, error)
|
||||||
|
|
||||||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||||
|
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
return f(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
|
||||||
|
// HandshakeHeader.
|
||||||
|
type HandshakeHeaderHTTP http.Header
|
||||||
|
|
||||||
|
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||||
|
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
wr := writer{w: w}
|
||||||
|
err := http.Header(h).Write(&wr)
|
||||||
|
return wr.n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type writer struct {
|
||||||
|
n int64
|
||||||
|
w io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) WriteString(s string) (int, error) {
|
||||||
|
n, err := io.WriteString(w.w, s)
|
||||||
|
w.n += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) Write(p []byte) (int, error) {
|
||||||
|
n, err := w.w.Write(p)
|
||||||
|
w.n += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
|
@ -0,0 +1,80 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RFC6455: The value of this header field MUST be a nonce consisting of a
|
||||||
|
// randomly selected 16-byte value that has been base64-encoded (see
|
||||||
|
// Section 4 of [RFC4648]). The nonce MUST be selected randomly for each
|
||||||
|
// connection.
|
||||||
|
nonceKeySize = 16
|
||||||
|
nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
|
||||||
|
|
||||||
|
// RFC6455: The value of this header field is constructed by concatenating
|
||||||
|
// /key/, defined above in step 4 in Section 4.2.2, with the string
|
||||||
|
// "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
|
||||||
|
// concatenated value to obtain a 20-byte value and base64- encoding (see
|
||||||
|
// Section 4 of [RFC4648]) this 20-byte hash.
|
||||||
|
acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
|
||||||
|
)
|
||||||
|
|
||||||
|
// initNonce fills given slice with random base64-encoded nonce bytes.
|
||||||
|
func initNonce(dst []byte) {
|
||||||
|
// NOTE: bts does not escape.
|
||||||
|
bts := make([]byte, nonceKeySize)
|
||||||
|
if _, err := rand.Read(bts); err != nil {
|
||||||
|
panic(fmt.Sprintf("rand read error: %s", err))
|
||||||
|
}
|
||||||
|
base64.StdEncoding.Encode(dst, bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAcceptFromNonce reports whether given accept bytes are valid for given
|
||||||
|
// nonce bytes.
|
||||||
|
func checkAcceptFromNonce(accept, nonce []byte) bool {
|
||||||
|
if len(accept) != acceptSize {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// NOTE: expect does not escape.
|
||||||
|
expect := make([]byte, acceptSize)
|
||||||
|
initAcceptFromNonce(expect, nonce)
|
||||||
|
return bytes.Equal(expect, accept)
|
||||||
|
}
|
||||||
|
|
||||||
|
// initAcceptFromNonce fills given slice with accept bytes generated from given
|
||||||
|
// nonce bytes. Given buffer should be exactly acceptSize bytes.
|
||||||
|
func initAcceptFromNonce(accept, nonce []byte) {
|
||||||
|
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
|
||||||
|
if len(accept) != acceptSize {
|
||||||
|
panic("accept buffer is invalid")
|
||||||
|
}
|
||||||
|
if len(nonce) != nonceSize {
|
||||||
|
panic("nonce is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
p := make([]byte, nonceSize+len(magic))
|
||||||
|
copy(p[:nonceSize], nonce)
|
||||||
|
copy(p[nonceSize:], magic)
|
||||||
|
|
||||||
|
sum := sha1.Sum(p)
|
||||||
|
base64.StdEncoding.Encode(accept, sum[:])
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) {
|
||||||
|
accept := make([]byte, acceptSize)
|
||||||
|
initAcceptFromNonce(accept, nonce)
|
||||||
|
// NOTE: write accept bytes as a string to prevent heap allocation –
|
||||||
|
// WriteString() copy given string into its inner buffer, unlike Write()
|
||||||
|
// which may write p directly to the underlying io.Writer – which in turn
|
||||||
|
// will lead to p escape.
|
||||||
|
return bw.WriteString(btsToString(accept))
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Errors used by frame reader.
|
||||||
|
var (
|
||||||
|
ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0")
|
||||||
|
ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadHeader reads a frame header from r.
|
||||||
|
func ReadHeader(r io.Reader) (h Header, err error) {
|
||||||
|
// Make slice of bytes with capacity 12 that could hold any header.
|
||||||
|
//
|
||||||
|
// The maximum header size is 14, but due to the 2 hop reads,
|
||||||
|
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
|
||||||
|
// So 14 - 2 = 12.
|
||||||
|
bts := make([]byte, 2, MaxHeaderSize-2)
|
||||||
|
|
||||||
|
// Prepare to hold first 2 bytes to choose size of next read.
|
||||||
|
_, err = io.ReadFull(r, bts)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Fin = bts[0]&bit0 != 0
|
||||||
|
h.Rsv = (bts[0] & 0x70) >> 4
|
||||||
|
h.OpCode = OpCode(bts[0] & 0x0f)
|
||||||
|
|
||||||
|
var extra int
|
||||||
|
|
||||||
|
if bts[1]&bit0 != 0 {
|
||||||
|
h.Masked = true
|
||||||
|
extra += 4
|
||||||
|
}
|
||||||
|
|
||||||
|
length := bts[1] & 0x7f
|
||||||
|
switch {
|
||||||
|
case length < 126:
|
||||||
|
h.Length = int64(length)
|
||||||
|
|
||||||
|
case length == 126:
|
||||||
|
extra += 2
|
||||||
|
|
||||||
|
case length == 127:
|
||||||
|
extra += 8
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = ErrHeaderLengthUnexpected
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if extra == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increase len of bts to extra bytes need to read.
|
||||||
|
// Overwrite first 2 bytes that was read before.
|
||||||
|
bts = bts[:extra]
|
||||||
|
_, err = io.ReadFull(r, bts)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case length == 126:
|
||||||
|
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
|
||||||
|
bts = bts[2:]
|
||||||
|
|
||||||
|
case length == 127:
|
||||||
|
if bts[0]&0x80 != 0 {
|
||||||
|
err = ErrHeaderLengthMSB
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
|
||||||
|
bts = bts[8:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.Masked {
|
||||||
|
copy(h.Mask[:], bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFrame reads a frame from r.
|
||||||
|
// It is not designed for high optimized use case cause it makes allocation
|
||||||
|
// for frame.Header.Length size inside to read frame payload into.
|
||||||
|
//
|
||||||
|
// Note that ReadFrame does not unmask payload.
|
||||||
|
func ReadFrame(r io.Reader) (f Frame, err error) {
|
||||||
|
f.Header, err = ReadHeader(r)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.Header.Length > 0 {
|
||||||
|
// int(f.Header.Length) is safe here cause we have
|
||||||
|
// checked it for overflow above in ReadHeader.
|
||||||
|
f.Payload = make([]byte, int(f.Header.Length))
|
||||||
|
_, err = io.ReadFull(r, f.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustReadFrame is like ReadFrame but panics if frame can not be read.
|
||||||
|
func MustReadFrame(r io.Reader) Frame {
|
||||||
|
f, err := ReadFrame(r)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCloseFrameData parses close frame status code and closure reason if any provided.
|
||||||
|
// If there is no status code in the payload
|
||||||
|
// the empty status code is returned (code.Empty()) with empty string as a reason.
|
||||||
|
func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) {
|
||||||
|
if len(payload) < 2 {
|
||||||
|
// We returning empty StatusCode here, preventing the situation
|
||||||
|
// when endpoint really sent code 1005 and we should return ProtocolError on that.
|
||||||
|
//
|
||||||
|
// In other words, we ignoring this rule [RFC6455:7.1.5]:
|
||||||
|
// If this Close control frame contains no status code, _The WebSocket
|
||||||
|
// Connection Close Code_ is considered to be 1005.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code = StatusCode(binary.BigEndian.Uint16(payload))
|
||||||
|
reason = string(payload[2:])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing
|
||||||
|
// that it does not copies payload bytes into reason, but prepares unsafe cast.
|
||||||
|
func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) {
|
||||||
|
if len(payload) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code = StatusCode(binary.BigEndian.Uint16(payload))
|
||||||
|
reason = btsToString(payload[2:])
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,607 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gobwas/httphead"
|
||||||
|
"github.com/gobwas/pool/pbufio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants used by ConnUpgrader.
|
||||||
|
const (
|
||||||
|
DefaultServerReadBufferSize = 4096
|
||||||
|
DefaultServerWriteBufferSize = 512
|
||||||
|
)
|
||||||
|
|
||||||
|
// Errors used by both client and server when preparing WebSocket handshake.
|
||||||
|
var (
|
||||||
|
ErrHandshakeBadProtocol = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusHTTPVersionNotSupported),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadMethod = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusMethodNotAllowed),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadHost = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadUpgrade = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadConnection = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadSecAccept = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadSecKey = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
|
||||||
|
)
|
||||||
|
ErrHandshakeBadSecVersion = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrMalformedResponse is returned by Dialer to indicate that server response
|
||||||
|
// can not be parsed.
|
||||||
|
var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
|
||||||
|
|
||||||
|
// ErrMalformedRequest is returned when HTTP request can not be parsed.
|
||||||
|
var ErrMalformedRequest = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusBadRequest),
|
||||||
|
RejectionReason("malformed HTTP request"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
|
||||||
|
// connection is rejected because given WebSocket version is malformed.
|
||||||
|
//
|
||||||
|
// According to RFC6455:
|
||||||
|
// If this version does not match a version understood by the server, the
|
||||||
|
// server MUST abort the WebSocket handshake described in this section and
|
||||||
|
// instead send an appropriate HTTP error code (such as 426 Upgrade Required)
|
||||||
|
// and a |Sec-WebSocket-Version| header field indicating the version(s) the
|
||||||
|
// server is capable of understanding.
|
||||||
|
var ErrHandshakeUpgradeRequired = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusUpgradeRequired),
|
||||||
|
RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
|
||||||
|
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrNotHijacker is an error returned when http.ResponseWriter does not
|
||||||
|
// implement http.Hijacker interface.
|
||||||
|
var ErrNotHijacker = RejectConnectionError(
|
||||||
|
RejectionStatus(http.StatusInternalServerError),
|
||||||
|
RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
|
||||||
|
// UpgradeHTTP function.
|
||||||
|
var DefaultHTTPUpgrader HTTPUpgrader
|
||||||
|
|
||||||
|
// UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
|
||||||
|
func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
|
||||||
|
return DefaultHTTPUpgrader.Upgrade(r, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
|
||||||
|
// function.
|
||||||
|
var DefaultUpgrader Upgrader
|
||||||
|
|
||||||
|
// Upgrade is like Upgrader{}.Upgrade().
|
||||||
|
func Upgrade(conn io.ReadWriter) (Handshake, error) {
|
||||||
|
return DefaultUpgrader.Upgrade(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPUpgrader contains options for upgrading connection to websocket from
|
||||||
|
// net/http Handler arguments.
|
||||||
|
type HTTPUpgrader struct {
|
||||||
|
// Timeout is the maximum amount of time an Upgrade() will spent while
|
||||||
|
// writing handshake response.
|
||||||
|
//
|
||||||
|
// The default is no timeout.
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// Header is an optional http.Header mapping that could be used to
|
||||||
|
// write additional headers to the handshake response.
|
||||||
|
//
|
||||||
|
// Note that if present, it will be written in any result of handshake.
|
||||||
|
Header http.Header
|
||||||
|
|
||||||
|
// Protocol is the select function that is used to select subprotocol from
|
||||||
|
// list requested by client. If this field is set, then the first matched
|
||||||
|
// protocol is sent to a client as negotiated.
|
||||||
|
Protocol func(string) bool
|
||||||
|
|
||||||
|
// Extension is the select function that is used to select extensions from
|
||||||
|
// list requested by client. If this field is set, then the all matched
|
||||||
|
// extensions are sent to a client as negotiated.
|
||||||
|
Extension func(httphead.Option) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade upgrades http connection to the websocket connection.
|
||||||
|
//
|
||||||
|
// It hijacks net.Conn from w and returns received net.Conn and
|
||||||
|
// bufio.ReadWriter. On successful handshake it returns Handshake struct
|
||||||
|
// describing handshake info.
|
||||||
|
func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
|
||||||
|
// Hijack connection first to get the ability to write rejection errors the
|
||||||
|
// same way as in Upgrader.
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if ok {
|
||||||
|
conn, rw, err = hj.Hijack()
|
||||||
|
} else {
|
||||||
|
err = ErrNotHijacker
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
httpError(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-4.1
|
||||||
|
// The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
|
||||||
|
var nonce string
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
err = ErrHandshakeBadMethod
|
||||||
|
} else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
|
||||||
|
err = ErrHandshakeBadProtocol
|
||||||
|
} else if r.Host == "" {
|
||||||
|
err = ErrHandshakeBadHost
|
||||||
|
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
|
||||||
|
err = ErrHandshakeBadUpgrade
|
||||||
|
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
|
||||||
|
err = ErrHandshakeBadConnection
|
||||||
|
} else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
|
||||||
|
err = ErrHandshakeBadSecKey
|
||||||
|
} else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
|
||||||
|
// According to RFC6455:
|
||||||
|
//
|
||||||
|
// If this version does not match a version understood by the server,
|
||||||
|
// the server MUST abort the WebSocket handshake described in this
|
||||||
|
// section and instead send an appropriate HTTP error code (such as 426
|
||||||
|
// Upgrade Required) and a |Sec-WebSocket-Version| header field
|
||||||
|
// indicating the version(s) the server is capable of understanding.
|
||||||
|
//
|
||||||
|
// So we branching here cause empty or not present version does not
|
||||||
|
// meet the ABNF rules of RFC6455:
|
||||||
|
//
|
||||||
|
// version = DIGIT | (NZDIGIT DIGIT) |
|
||||||
|
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
|
||||||
|
// ; Limited to 0-255 range, with no leading zeros
|
||||||
|
//
|
||||||
|
// That is, if version is really invalid – we sent 426 status, if it
|
||||||
|
// not present or empty – it is 400.
|
||||||
|
if v != "" {
|
||||||
|
err = ErrHandshakeUpgradeRequired
|
||||||
|
} else {
|
||||||
|
err = ErrHandshakeBadSecVersion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if check := u.Protocol; err == nil && check != nil {
|
||||||
|
ps := r.Header[headerSecProtocolCanonical]
|
||||||
|
for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
|
||||||
|
var ok bool
|
||||||
|
hs.Protocol, ok = strSelectProtocol(ps[i], check)
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if check := u.Extension; err == nil && check != nil {
|
||||||
|
xs := r.Header[headerSecExtensionsCanonical]
|
||||||
|
for i := 0; i < len(xs) && err == nil; i++ {
|
||||||
|
var ok bool
|
||||||
|
hs.Extensions, ok = strSelectExtensions(xs[i], hs.Extensions, check)
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear deadlines set by server.
|
||||||
|
conn.SetDeadline(noDeadline)
|
||||||
|
if t := u.Timeout; t != 0 {
|
||||||
|
conn.SetWriteDeadline(time.Now().Add(t))
|
||||||
|
defer conn.SetWriteDeadline(noDeadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
var header handshakeHeader
|
||||||
|
if h := u.Header; h != nil {
|
||||||
|
header[0] = HandshakeHeaderHTTP(h)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
|
||||||
|
err = rw.Writer.Flush()
|
||||||
|
} else {
|
||||||
|
var code int
|
||||||
|
if rej, ok := err.(*rejectConnectionError); ok {
|
||||||
|
code = rej.code
|
||||||
|
header[1] = rej.header
|
||||||
|
}
|
||||||
|
if code == 0 {
|
||||||
|
code = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
|
||||||
|
// Do not store Flush() error to not override already existing one.
|
||||||
|
rw.Writer.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrader contains options for upgrading connection to websocket.
|
||||||
|
type Upgrader struct {
|
||||||
|
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
|
||||||
|
// They used to read and write http data while upgrading to WebSocket.
|
||||||
|
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
|
||||||
|
//
|
||||||
|
// If a size is zero then default value is used.
|
||||||
|
//
|
||||||
|
// Usually it is useful to set read buffer size bigger than write buffer
|
||||||
|
// size because incoming request could contain long header values, such as
|
||||||
|
// Cookie. Response, in other way, could be big only if user write multiple
|
||||||
|
// custom headers. Usually response takes less than 256 bytes.
|
||||||
|
ReadBufferSize, WriteBufferSize int
|
||||||
|
|
||||||
|
// Protocol is a select function that is used to select subprotocol
|
||||||
|
// from list requested by client. If this field is set, then the first matched
|
||||||
|
// protocol is sent to a client as negotiated.
|
||||||
|
//
|
||||||
|
// The argument is only valid until the callback returns.
|
||||||
|
Protocol func([]byte) bool
|
||||||
|
|
||||||
|
// ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
|
||||||
|
// Note that returned bytes must be valid until Upgrade returns.
|
||||||
|
// If ProtocolCustom is set, it used instead of Protocol function.
|
||||||
|
ProtocolCustom func([]byte) (string, bool)
|
||||||
|
|
||||||
|
// Extension is a select function that is used to select extensions
|
||||||
|
// from list requested by client. If this field is set, then the all matched
|
||||||
|
// extensions are sent to a client as negotiated.
|
||||||
|
//
|
||||||
|
// The argument is only valid until the callback returns.
|
||||||
|
//
|
||||||
|
// According to the RFC6455 order of extensions passed by a client is
|
||||||
|
// significant. That is, returning true from this function means that no
|
||||||
|
// other extension with the same name should be checked because server
|
||||||
|
// accepted the most preferable extension right now:
|
||||||
|
// "Note that the order of extensions is significant. Any interactions between
|
||||||
|
// multiple extensions MAY be defined in the documents defining the extensions.
|
||||||
|
// In the absence of such definitions, the interpretation is that the header
|
||||||
|
// fields listed by the client in its request represent a preference of the
|
||||||
|
// header fields it wishes to use, with the first options listed being most
|
||||||
|
// preferable."
|
||||||
|
Extension func(httphead.Option) bool
|
||||||
|
|
||||||
|
// ExtensionCustom allow user to parse Sec-WebSocket-Extensions header manually.
|
||||||
|
// Note that returned options should be valid until Upgrade returns.
|
||||||
|
// If ExtensionCustom is set, it used instead of Extension function.
|
||||||
|
ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
|
||||||
|
|
||||||
|
// Header is an optional HandshakeHeader instance that could be used to
|
||||||
|
// write additional headers to the handshake response.
|
||||||
|
//
|
||||||
|
// It used instead of any key-value mappings to avoid allocations in user
|
||||||
|
// land.
|
||||||
|
//
|
||||||
|
// Note that if present, it will be written in any result of handshake.
|
||||||
|
Header HandshakeHeader
|
||||||
|
|
||||||
|
// OnRequest is a callback that will be called after request line
|
||||||
|
// successful parsing.
|
||||||
|
//
|
||||||
|
// The arguments are only valid until the callback returns.
|
||||||
|
//
|
||||||
|
// If returned error is non-nil then connection is rejected and response is
|
||||||
|
// sent with appropriate HTTP error code and body set to error message.
|
||||||
|
//
|
||||||
|
// RejectConnectionError could be used to get more control on response.
|
||||||
|
OnRequest func(uri []byte) error
|
||||||
|
|
||||||
|
// OnHost is a callback that will be called after "Host" header successful
|
||||||
|
// parsing.
|
||||||
|
//
|
||||||
|
// It is separated from OnHeader callback because the Host header must be
|
||||||
|
// present in each request since HTTP/1.1. Thus Host header is non-optional
|
||||||
|
// and required for every WebSocket handshake.
|
||||||
|
//
|
||||||
|
// The arguments are only valid until the callback returns.
|
||||||
|
//
|
||||||
|
// If returned error is non-nil then connection is rejected and response is
|
||||||
|
// sent with appropriate HTTP error code and body set to error message.
|
||||||
|
//
|
||||||
|
// RejectConnectionError could be used to get more control on response.
|
||||||
|
OnHost func(host []byte) error
|
||||||
|
|
||||||
|
// OnHeader is a callback that will be called after successful parsing of
|
||||||
|
// header, that is not used during WebSocket handshake procedure. That is,
|
||||||
|
// it will be called with non-websocket headers, which could be relevant
|
||||||
|
// for application-level logic.
|
||||||
|
//
|
||||||
|
// The arguments are only valid until the callback returns.
|
||||||
|
//
|
||||||
|
// If returned error is non-nil then connection is rejected and response is
|
||||||
|
// sent with appropriate HTTP error code and body set to error message.
|
||||||
|
//
|
||||||
|
// RejectConnectionError could be used to get more control on response.
|
||||||
|
OnHeader func(key, value []byte) error
|
||||||
|
|
||||||
|
// OnBeforeUpgrade is a callback that will be called before sending
|
||||||
|
// successful upgrade response.
|
||||||
|
//
|
||||||
|
// Setting OnBeforeUpgrade allows user to make final application-level
|
||||||
|
// checks and decide whether this connection is allowed to successfully
|
||||||
|
// upgrade to WebSocket.
|
||||||
|
//
|
||||||
|
// It must return non-nil either HandshakeHeader or error and never both.
|
||||||
|
//
|
||||||
|
// If returned error is non-nil then connection is rejected and response is
|
||||||
|
// sent with appropriate HTTP error code and body set to error message.
|
||||||
|
//
|
||||||
|
// RejectConnectionError could be used to get more control on response.
|
||||||
|
OnBeforeUpgrade func() (header HandshakeHeader, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
|
||||||
|
// as connection with incoming HTTP Upgrade request.
|
||||||
|
//
|
||||||
|
// It is a caller responsibility to manage i/o timeouts on conn.
|
||||||
|
//
|
||||||
|
// Non-nil error means that request for the WebSocket upgrade is invalid or
|
||||||
|
// malformed and usually connection should be closed.
|
||||||
|
// Even when error is non-nil Upgrade will write appropriate response into
|
||||||
|
// connection in compliance with RFC.
|
||||||
|
func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
|
||||||
|
// headerSeen constants helps to report whether or not some header was seen
|
||||||
|
// during reading request bytes.
|
||||||
|
const (
|
||||||
|
headerSeenHost = 1 << iota
|
||||||
|
headerSeenUpgrade
|
||||||
|
headerSeenConnection
|
||||||
|
headerSeenSecVersion
|
||||||
|
headerSeenSecKey
|
||||||
|
|
||||||
|
// headerSeenAll is the value that we expect to receive at the end of
|
||||||
|
// headers read/parse loop.
|
||||||
|
headerSeenAll = 0 |
|
||||||
|
headerSeenHost |
|
||||||
|
headerSeenUpgrade |
|
||||||
|
headerSeenConnection |
|
||||||
|
headerSeenSecVersion |
|
||||||
|
headerSeenSecKey
|
||||||
|
)
|
||||||
|
|
||||||
|
// Prepare I/O buffers.
|
||||||
|
// TODO(gobwas): make it configurable.
|
||||||
|
br := pbufio.GetReader(conn,
|
||||||
|
nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
|
||||||
|
)
|
||||||
|
bw := pbufio.GetWriter(conn,
|
||||||
|
nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
pbufio.PutReader(br)
|
||||||
|
pbufio.PutWriter(bw)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Read HTTP request line like "GET /ws HTTP/1.1".
|
||||||
|
rl, err := readLine(br)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Parse request line data like HTTP version, uri and method.
|
||||||
|
req, err := httpParseRequestLine(rl)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare stack-based handshake header list.
|
||||||
|
header := handshakeHeader{
|
||||||
|
0: u.Header,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and check HTTP request.
|
||||||
|
// As RFC6455 says:
|
||||||
|
// The client's opening handshake consists of the following parts. If the
|
||||||
|
// server, while reading the handshake, finds that the client did not
|
||||||
|
// send a handshake that matches the description below (note that as per
|
||||||
|
// [RFC2616], the order of the header fields is not important), including
|
||||||
|
// but not limited to any violations of the ABNF grammar specified for
|
||||||
|
// the components of the handshake, the server MUST stop processing the
|
||||||
|
// client's handshake and return an HTTP response with an appropriate
|
||||||
|
// error code (such as 400 Bad Request).
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-4.2.1
|
||||||
|
|
||||||
|
// An HTTP/1.1 or higher GET request, including a "Request-URI".
|
||||||
|
//
|
||||||
|
// Even if RFC says "1.1 or higher" without mentioning the part of the
|
||||||
|
// version, we apply it only to minor part.
|
||||||
|
switch {
|
||||||
|
case req.major != 1 || req.minor < 1:
|
||||||
|
// Abort processing the whole request because we do not even know how
|
||||||
|
// to actually parse it.
|
||||||
|
err = ErrHandshakeBadProtocol
|
||||||
|
|
||||||
|
case btsToString(req.method) != http.MethodGet:
|
||||||
|
err = ErrHandshakeBadMethod
|
||||||
|
|
||||||
|
default:
|
||||||
|
if onRequest := u.OnRequest; onRequest != nil {
|
||||||
|
err = onRequest(req.uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Start headers read/parse loop.
|
||||||
|
var (
|
||||||
|
// headerSeen reports which header was seen by setting corresponding
|
||||||
|
// bit on.
|
||||||
|
headerSeen byte
|
||||||
|
|
||||||
|
nonce = make([]byte, nonceSize)
|
||||||
|
)
|
||||||
|
for err == nil {
|
||||||
|
line, e := readLine(br)
|
||||||
|
if e != nil {
|
||||||
|
return hs, e
|
||||||
|
}
|
||||||
|
if len(line) == 0 {
|
||||||
|
// Blank line, no more lines to read.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
k, v, ok := httpParseHeaderLine(line)
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedRequest
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch btsToString(k) {
|
||||||
|
case headerHostCanonical:
|
||||||
|
headerSeen |= headerSeenHost
|
||||||
|
if onHost := u.OnHost; onHost != nil {
|
||||||
|
err = onHost(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerUpgradeCanonical:
|
||||||
|
headerSeen |= headerSeenUpgrade
|
||||||
|
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
|
||||||
|
err = ErrHandshakeBadUpgrade
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerConnectionCanonical:
|
||||||
|
headerSeen |= headerSeenConnection
|
||||||
|
if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
|
||||||
|
err = ErrHandshakeBadConnection
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecVersionCanonical:
|
||||||
|
headerSeen |= headerSeenSecVersion
|
||||||
|
if !bytes.Equal(v, specHeaderValueSecVersion) {
|
||||||
|
err = ErrHandshakeUpgradeRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecKeyCanonical:
|
||||||
|
headerSeen |= headerSeenSecKey
|
||||||
|
if len(v) != nonceSize {
|
||||||
|
err = ErrHandshakeBadSecKey
|
||||||
|
} else {
|
||||||
|
copy(nonce[:], v)
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecProtocolCanonical:
|
||||||
|
if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
|
||||||
|
var ok bool
|
||||||
|
if custom != nil {
|
||||||
|
hs.Protocol, ok = custom(v)
|
||||||
|
} else {
|
||||||
|
hs.Protocol, ok = btsSelectProtocol(v, check)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case headerSecExtensionsCanonical:
|
||||||
|
if custom, check := u.ExtensionCustom, u.Extension; custom != nil || check != nil {
|
||||||
|
var ok bool
|
||||||
|
if custom != nil {
|
||||||
|
hs.Extensions, ok = custom(v, hs.Extensions)
|
||||||
|
} else {
|
||||||
|
hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
err = ErrMalformedRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
if onHeader := u.OnHeader; onHeader != nil {
|
||||||
|
err = onHeader(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case err == nil && headerSeen != headerSeenAll:
|
||||||
|
switch {
|
||||||
|
case headerSeen&headerSeenHost == 0:
|
||||||
|
// As RFC2616 says:
|
||||||
|
// A client MUST include a Host header field in all HTTP/1.1
|
||||||
|
// request messages. If the requested URI does not include an
|
||||||
|
// Internet host name for the service being requested, then the
|
||||||
|
// Host header field MUST be given with an empty value. An
|
||||||
|
// HTTP/1.1 proxy MUST ensure that any request message it
|
||||||
|
// forwards does contain an appropriate Host header field that
|
||||||
|
// identifies the service being requested by the proxy. All
|
||||||
|
// Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
|
||||||
|
// Request) status code to any HTTP/1.1 request message which
|
||||||
|
// lacks a Host header field.
|
||||||
|
err = ErrHandshakeBadHost
|
||||||
|
case headerSeen&headerSeenUpgrade == 0:
|
||||||
|
err = ErrHandshakeBadUpgrade
|
||||||
|
case headerSeen&headerSeenConnection == 0:
|
||||||
|
err = ErrHandshakeBadConnection
|
||||||
|
case headerSeen&headerSeenSecVersion == 0:
|
||||||
|
// In case of empty or not present version we do not send 426 status,
|
||||||
|
// because it does not meet the ABNF rules of RFC6455:
|
||||||
|
//
|
||||||
|
// version = DIGIT | (NZDIGIT DIGIT) |
|
||||||
|
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
|
||||||
|
// ; Limited to 0-255 range, with no leading zeros
|
||||||
|
//
|
||||||
|
// That is, if version is really invalid – we sent 426 status as above, if it
|
||||||
|
// not present – it is 400.
|
||||||
|
err = ErrHandshakeBadSecVersion
|
||||||
|
case headerSeen&headerSeenSecKey == 0:
|
||||||
|
err = ErrHandshakeBadSecKey
|
||||||
|
default:
|
||||||
|
panic("unknown headers state")
|
||||||
|
}
|
||||||
|
|
||||||
|
case err == nil && u.OnBeforeUpgrade != nil:
|
||||||
|
header[1], err = u.OnBeforeUpgrade()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
var code int
|
||||||
|
if rej, ok := err.(*rejectConnectionError); ok {
|
||||||
|
code = rej.code
|
||||||
|
header[1] = rej.header
|
||||||
|
}
|
||||||
|
if code == 0 {
|
||||||
|
code = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
httpWriteResponseError(bw, err, code, header.WriteTo)
|
||||||
|
// Do not store Flush() error to not override already existing one.
|
||||||
|
bw.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
|
||||||
|
err = bw.Flush()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type handshakeHeader [2]HandshakeHeader
|
||||||
|
|
||||||
|
func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
for i := 0; i < len(hs) && err == nil; i++ {
|
||||||
|
if h := hs[i]; h != nil {
|
||||||
|
var m int64
|
||||||
|
m, err = h.WriteTo(w)
|
||||||
|
n += m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
|
@ -0,0 +1,214 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/gobwas/httphead"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SelectFromSlice creates accept function that could be used as Protocol/Extension
|
||||||
|
// select during upgrade.
|
||||||
|
func SelectFromSlice(accept []string) func(string) bool {
|
||||||
|
if len(accept) > 16 {
|
||||||
|
mp := make(map[string]struct{}, len(accept))
|
||||||
|
for _, p := range accept {
|
||||||
|
mp[p] = struct{}{}
|
||||||
|
}
|
||||||
|
return func(p string) bool {
|
||||||
|
_, ok := mp[p]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(p string) bool {
|
||||||
|
for _, ok := range accept {
|
||||||
|
if p == ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectEqual creates accept function that could be used as Protocol/Extension
|
||||||
|
// select during upgrade.
|
||||||
|
func SelectEqual(v string) func(string) bool {
|
||||||
|
return func(p string) bool {
|
||||||
|
return v == p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func strToBytes(str string) (bts []byte) {
|
||||||
|
s := (*reflect.StringHeader)(unsafe.Pointer(&str))
|
||||||
|
b := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
|
||||||
|
b.Data = s.Data
|
||||||
|
b.Len = s.Len
|
||||||
|
b.Cap = s.Len
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func btsToString(bts []byte) (str string) {
|
||||||
|
return *(*string)(unsafe.Pointer(&bts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// asciiToInt converts bytes to int.
|
||||||
|
func asciiToInt(bts []byte) (ret int, err error) {
|
||||||
|
// ASCII numbers all start with the high-order bits 0011.
|
||||||
|
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
|
||||||
|
// bits and interpret them directly as an integer.
|
||||||
|
var n int
|
||||||
|
if n = len(bts); n < 1 {
|
||||||
|
return 0, fmt.Errorf("converting empty bytes to int")
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if bts[i]&0xf0 != 0x30 {
|
||||||
|
return 0, fmt.Errorf("%s is not a numeric character", string(bts[i]))
|
||||||
|
}
|
||||||
|
ret += int(bts[i]&0xf) * pow(10, n-i-1)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pow for integers implementation.
|
||||||
|
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
|
||||||
|
func pow(a, b int) int {
|
||||||
|
p := 1
|
||||||
|
for b > 0 {
|
||||||
|
if b&1 != 0 {
|
||||||
|
p *= a
|
||||||
|
}
|
||||||
|
b >>= 1
|
||||||
|
a *= a
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) {
|
||||||
|
a := bytes.IndexByte(bts, sep)
|
||||||
|
b := bytes.IndexByte(bts[a+1:], sep)
|
||||||
|
if a == -1 || b == -1 {
|
||||||
|
return bts, nil, nil
|
||||||
|
}
|
||||||
|
b += a + 1
|
||||||
|
return bts[:a], bts[a+1 : b], bts[b+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func btrim(bts []byte) []byte {
|
||||||
|
var i, j int
|
||||||
|
for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); {
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
return bts[i:j]
|
||||||
|
}
|
||||||
|
|
||||||
|
func strHasToken(header, token string) (has bool) {
|
||||||
|
return btsHasToken(strToBytes(header), strToBytes(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
func btsHasToken(header, token []byte) (has bool) {
|
||||||
|
httphead.ScanTokens(header, func(v []byte) bool {
|
||||||
|
has = bytes.EqualFold(v, token)
|
||||||
|
return !has
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
toLower = 'a' - 'A' // for use with OR.
|
||||||
|
toUpper = ^byte(toLower) // for use with AND.
|
||||||
|
toLower8 = uint64(toLower) |
|
||||||
|
uint64(toLower)<<8 |
|
||||||
|
uint64(toLower)<<16 |
|
||||||
|
uint64(toLower)<<24 |
|
||||||
|
uint64(toLower)<<32 |
|
||||||
|
uint64(toLower)<<40 |
|
||||||
|
uint64(toLower)<<48 |
|
||||||
|
uint64(toLower)<<56
|
||||||
|
)
|
||||||
|
|
||||||
|
// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except
|
||||||
|
// that it operates with slice of bytes and modifies it inplace without copying.
|
||||||
|
func canonicalizeHeaderKey(k []byte) {
|
||||||
|
upper := true
|
||||||
|
for i, c := range k {
|
||||||
|
if upper && 'a' <= c && c <= 'z' {
|
||||||
|
k[i] &= toUpper
|
||||||
|
} else if !upper && 'A' <= c && c <= 'Z' {
|
||||||
|
k[i] |= toLower
|
||||||
|
}
|
||||||
|
upper = c == '-'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLine reads line from br. It reads until '\n' and returns bytes without
|
||||||
|
// '\n' or '\r\n' at the end.
|
||||||
|
// It returns err if and only if line does not end in '\n'. Note that read
|
||||||
|
// bytes returned in any case of error.
|
||||||
|
//
|
||||||
|
// It is much like the textproto/Reader.ReadLine() except the thing that it
|
||||||
|
// returns raw bytes, instead of string. That is, it avoids copying bytes read
|
||||||
|
// from br.
|
||||||
|
//
|
||||||
|
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
|
||||||
|
// safe with future I/O operations on br.
|
||||||
|
//
|
||||||
|
// We could control I/O operations on br and do not need to make additional
|
||||||
|
// copy for safety.
|
||||||
|
//
|
||||||
|
// NOTE: it may return copied flag to notify that returned buffer is safe to
|
||||||
|
// use.
|
||||||
|
func readLine(br *bufio.Reader) ([]byte, error) {
|
||||||
|
var line []byte
|
||||||
|
for {
|
||||||
|
bts, err := br.ReadSlice('\n')
|
||||||
|
if err == bufio.ErrBufferFull {
|
||||||
|
// Copy bytes because next read will discard them.
|
||||||
|
line = append(line, bts...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid copy of single read.
|
||||||
|
if line == nil {
|
||||||
|
line = bts
|
||||||
|
} else {
|
||||||
|
line = append(line, bts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return line, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size of line is at least 1.
|
||||||
|
// In other case bufio.ReadSlice() returns error.
|
||||||
|
n := len(line)
|
||||||
|
|
||||||
|
// Cut '\n' or '\r\n'.
|
||||||
|
if n > 1 && line[n-2] == '\r' {
|
||||||
|
line = line[:n-2]
|
||||||
|
} else {
|
||||||
|
line = line[:n-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return line, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func nonZero(a, b int) int {
|
||||||
|
if a != 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
|
@ -0,0 +1,104 @@
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Header size length bounds in bytes.
|
||||||
|
const (
|
||||||
|
MaxHeaderSize = 14
|
||||||
|
MinHeaderSize = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bit0 = 0x80
|
||||||
|
bit1 = 0x40
|
||||||
|
bit2 = 0x20
|
||||||
|
bit3 = 0x10
|
||||||
|
bit4 = 0x08
|
||||||
|
bit5 = 0x04
|
||||||
|
bit6 = 0x02
|
||||||
|
bit7 = 0x01
|
||||||
|
|
||||||
|
len7 = int64(125)
|
||||||
|
len16 = int64(^(uint16(0)))
|
||||||
|
len64 = int64(^(uint64(0)) >> 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderSize returns number of bytes that are needed to encode given header.
|
||||||
|
// It returns -1 if header is malformed.
|
||||||
|
func HeaderSize(h Header) (n int) {
|
||||||
|
switch {
|
||||||
|
case h.Length < 126:
|
||||||
|
n = 2
|
||||||
|
case h.Length <= len16:
|
||||||
|
n = 4
|
||||||
|
case h.Length <= len64:
|
||||||
|
n = 10
|
||||||
|
default:
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if h.Masked {
|
||||||
|
n += len(h.Mask)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader writes header binary representation into w.
|
||||||
|
func WriteHeader(w io.Writer, h Header) error {
|
||||||
|
// Make slice of bytes with capacity 14 that could hold any header.
|
||||||
|
bts := make([]byte, MaxHeaderSize)
|
||||||
|
|
||||||
|
if h.Fin {
|
||||||
|
bts[0] |= bit0
|
||||||
|
}
|
||||||
|
bts[0] |= h.Rsv << 4
|
||||||
|
bts[0] |= byte(h.OpCode)
|
||||||
|
|
||||||
|
var n int
|
||||||
|
switch {
|
||||||
|
case h.Length <= len7:
|
||||||
|
bts[1] = byte(h.Length)
|
||||||
|
n = 2
|
||||||
|
|
||||||
|
case h.Length <= len16:
|
||||||
|
bts[1] = 126
|
||||||
|
binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length))
|
||||||
|
n = 4
|
||||||
|
|
||||||
|
case h.Length <= len64:
|
||||||
|
bts[1] = 127
|
||||||
|
binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length))
|
||||||
|
n = 10
|
||||||
|
|
||||||
|
default:
|
||||||
|
return ErrHeaderLengthUnexpected
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.Masked {
|
||||||
|
bts[1] |= bit0
|
||||||
|
n += copy(bts[n:], h.Mask[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := w.Write(bts[:n])
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFrame writes frame binary representation into w.
|
||||||
|
func WriteFrame(w io.Writer, f Frame) error {
|
||||||
|
err := WriteHeader(w, f.Header)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = w.Write(f.Payload)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustWriteFrame is like WriteFrame but panics if frame can not be read.
|
||||||
|
func MustWriteFrame(w io.Writer, f Frame) {
|
||||||
|
if err := WriteFrame(w, f); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gobwas/pool/pbytes"
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CipherReader implements io.Reader that applies xor-cipher to the bytes read
|
||||||
|
// from source.
|
||||||
|
// It could help to unmask WebSocket frame payload on the fly.
|
||||||
|
type CipherReader struct {
|
||||||
|
r io.Reader
|
||||||
|
mask [4]byte
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCipherReader creates xor-cipher reader from r with given mask.
|
||||||
|
func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader {
|
||||||
|
return &CipherReader{r, mask, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets CipherReader to read from r with given mask.
|
||||||
|
func (c *CipherReader) Reset(r io.Reader, mask [4]byte) {
|
||||||
|
c.r = r
|
||||||
|
c.mask = mask
|
||||||
|
c.pos = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.Reader interface. It applies mask given during
|
||||||
|
// initialization to every read byte.
|
||||||
|
func (c *CipherReader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = c.r.Read(p)
|
||||||
|
ws.Cipher(p[:n], c.mask, c.pos)
|
||||||
|
c.pos += n
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CipherWriter implements io.Writer that applies xor-cipher to the bytes
|
||||||
|
// written to the destination writer. It does not modify the original bytes.
|
||||||
|
type CipherWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
mask [4]byte
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCipherWriter creates xor-cipher writer to w with given mask.
|
||||||
|
func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter {
|
||||||
|
return &CipherWriter{w, mask, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset reset CipherWriter to write to w with given mask.
|
||||||
|
func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) {
|
||||||
|
c.w = w
|
||||||
|
c.mask = mask
|
||||||
|
c.pos = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer interface. It applies masking during
|
||||||
|
// initialization to every sent byte. It does not modify original slice.
|
||||||
|
func (c *CipherWriter) Write(p []byte) (n int, err error) {
|
||||||
|
cp := pbytes.GetLen(len(p))
|
||||||
|
defer pbytes.Put(cp)
|
||||||
|
|
||||||
|
copy(cp, p)
|
||||||
|
ws.Cipher(cp, c.mask, c.pos)
|
||||||
|
n, err = c.w.Write(cp)
|
||||||
|
c.pos += n
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
|
||||||
|
// handshake. That is, it gives ability to receive copied HTTP request and
|
||||||
|
// response bytes that made inside Dialer.Dial().
|
||||||
|
//
|
||||||
|
// Note that it must not be used in production applications that requires
|
||||||
|
// Dial() to be efficient.
|
||||||
|
type DebugDialer struct {
|
||||||
|
// Dialer contains WebSocket connection establishment options.
|
||||||
|
Dialer ws.Dialer
|
||||||
|
|
||||||
|
// OnRequest and OnResponse are the callbacks that will be called with the
|
||||||
|
// HTTP request and response respectively.
|
||||||
|
OnRequest, OnResponse func([]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the url host and upgrades connection to WebSocket. It makes
|
||||||
|
// it by calling d.Dialer.Dial().
|
||||||
|
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
|
||||||
|
// Need to copy Dialer to prevent original object mutation.
|
||||||
|
dialer := d.Dialer
|
||||||
|
var (
|
||||||
|
reqBuf bytes.Buffer
|
||||||
|
resBuf bytes.Buffer
|
||||||
|
|
||||||
|
resContentLength int64
|
||||||
|
)
|
||||||
|
userWrap := dialer.WrapConn
|
||||||
|
dialer.WrapConn = func(c net.Conn) net.Conn {
|
||||||
|
if userWrap != nil {
|
||||||
|
c = userWrap(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the pointer to the raw connection.
|
||||||
|
conn = c
|
||||||
|
|
||||||
|
var (
|
||||||
|
r io.Reader = conn
|
||||||
|
w io.Writer = conn
|
||||||
|
)
|
||||||
|
if d.OnResponse != nil {
|
||||||
|
r = &prefetchResponseReader{
|
||||||
|
source: conn,
|
||||||
|
buffer: &resBuf,
|
||||||
|
contentLength: &resContentLength,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.OnRequest != nil {
|
||||||
|
w = io.MultiWriter(conn, &reqBuf)
|
||||||
|
}
|
||||||
|
return rwConn{conn, r, w}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, br, hs, err = dialer.Dial(ctx, urlstr)
|
||||||
|
|
||||||
|
if onRequest := d.OnRequest; onRequest != nil {
|
||||||
|
onRequest(reqBuf.Bytes())
|
||||||
|
}
|
||||||
|
if onResponse := d.OnResponse; onResponse != nil {
|
||||||
|
// We must split response inside buffered bytes from other received
|
||||||
|
// bytes from server.
|
||||||
|
p := resBuf.Bytes()
|
||||||
|
n := bytes.Index(p, headEnd)
|
||||||
|
h := n + len(headEnd) // Head end index.
|
||||||
|
n = h + int(resContentLength) // Body end index.
|
||||||
|
|
||||||
|
onResponse(p[:n])
|
||||||
|
|
||||||
|
if br != nil {
|
||||||
|
// If br is non-nil, then it mean two things. First is that
|
||||||
|
// handshake is OK and server has sent additional bytes – probably
|
||||||
|
// immediate sent frames (or weird but possible response body).
|
||||||
|
// Second, the bad one, is that br buffer's source is now rwConn
|
||||||
|
// instance from above WrapConn call. It is incorrect, so we must
|
||||||
|
// fix it.
|
||||||
|
var r io.Reader = conn
|
||||||
|
if len(p) > h {
|
||||||
|
// Buffer contains more than just HTTP headers bytes.
|
||||||
|
r = io.MultiReader(
|
||||||
|
bytes.NewReader(p[h:]),
|
||||||
|
conn,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
br.Reset(r)
|
||||||
|
// Must make br.Buffered() to be non-zero.
|
||||||
|
br.Peek(len(p[h:]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, br, hs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type rwConn struct {
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
r io.Reader
|
||||||
|
w io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rwc rwConn) Read(p []byte) (int, error) {
|
||||||
|
return rwc.r.Read(p)
|
||||||
|
}
|
||||||
|
func (rwc rwConn) Write(p []byte) (int, error) {
|
||||||
|
return rwc.w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
var headEnd = []byte("\r\n\r\n")
|
||||||
|
|
||||||
|
type prefetchResponseReader struct {
|
||||||
|
source io.Reader // Original connection source.
|
||||||
|
reader io.Reader // Wrapped reader used to read from by clients.
|
||||||
|
buffer *bytes.Buffer
|
||||||
|
|
||||||
|
contentLength *int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
|
||||||
|
if r.reader == nil {
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(
|
||||||
|
io.TeeReader(r.source, r.buffer),
|
||||||
|
), nil)
|
||||||
|
if err == nil {
|
||||||
|
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
bts := r.buffer.Bytes()
|
||||||
|
r.reader = io.MultiReader(
|
||||||
|
bytes.NewReader(bts),
|
||||||
|
r.source,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return r.reader.Read(p)
|
||||||
|
}
|
|
@ -0,0 +1,219 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gobwas/pool/pbytes"
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClosedError returned when peer has closed the connection with appropriate
|
||||||
|
// code and a textual reason.
|
||||||
|
type ClosedError struct {
|
||||||
|
Code ws.StatusCode
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements error interface.
|
||||||
|
func (err ClosedError) Error() string {
|
||||||
|
return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlHandler contains logic of handling control frames.
|
||||||
|
//
|
||||||
|
// The intentional way to use it is to read the next frame header from the
|
||||||
|
// connection, optionally check its validity via ws.CheckHeader() and if it is
|
||||||
|
// not a ws.OpText of ws.OpBinary (or ws.OpContinuation) – pass it to Handle()
|
||||||
|
// method.
|
||||||
|
//
|
||||||
|
// That is, passed header should be checked to get rid of unexpected errors.
|
||||||
|
//
|
||||||
|
// The Handle() method will read out all control frame payload (if any) and
|
||||||
|
// write necessary bytes as a rfc compatible response.
|
||||||
|
type ControlHandler struct {
|
||||||
|
Src io.Reader
|
||||||
|
Dst io.Writer
|
||||||
|
State ws.State
|
||||||
|
|
||||||
|
// DisableSrcCiphering disables unmasking payload data read from Src.
|
||||||
|
// It is useful when wsutil.Reader is used or when frame payload already
|
||||||
|
// pulled and ciphered out from the connection (and introduced by
|
||||||
|
// bytes.Reader, for example).
|
||||||
|
DisableSrcCiphering bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrNotControlFrame is returned by ControlHandler to indicate that given
|
||||||
|
// header could not be handled.
|
||||||
|
var ErrNotControlFrame = errors.New("not a control frame")
|
||||||
|
|
||||||
|
// Handle handles control frames regarding to the c.State and writes responses
|
||||||
|
// to the c.Dst when needed.
|
||||||
|
//
|
||||||
|
// It returns ErrNotControlFrame when given header is not of ws.OpClose,
|
||||||
|
// ws.OpPing or ws.OpPong operation code.
|
||||||
|
func (c ControlHandler) Handle(h ws.Header) error {
|
||||||
|
switch h.OpCode {
|
||||||
|
case ws.OpPing:
|
||||||
|
return c.HandlePing(h)
|
||||||
|
case ws.OpPong:
|
||||||
|
return c.HandlePong(h)
|
||||||
|
case ws.OpClose:
|
||||||
|
return c.HandleClose(h)
|
||||||
|
}
|
||||||
|
return ErrNotControlFrame
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlePing handles ping frame and writes specification compatible response
|
||||||
|
// to the c.Dst.
|
||||||
|
func (c ControlHandler) HandlePing(h ws.Header) error {
|
||||||
|
if h.Length == 0 {
|
||||||
|
// The most common case when ping is empty.
|
||||||
|
// Note that when sending masked frame the mask for empty payload is
|
||||||
|
// just four zero bytes.
|
||||||
|
return ws.WriteHeader(c.Dst, ws.Header{
|
||||||
|
Fin: true,
|
||||||
|
OpCode: ws.OpPong,
|
||||||
|
Masked: c.State.ClientSide(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// In other way reply with Pong frame with copied payload.
|
||||||
|
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
|
||||||
|
Length: h.Length,
|
||||||
|
Masked: c.State.ClientSide(),
|
||||||
|
}))
|
||||||
|
defer pbytes.Put(p)
|
||||||
|
|
||||||
|
// Deal with ciphering i/o:
|
||||||
|
// Masking key is used to mask the "Payload data" defined in the same
|
||||||
|
// section as frame-payload-data, which includes "Extension data" and
|
||||||
|
// "Application data".
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.3
|
||||||
|
//
|
||||||
|
// NOTE: We prefer ControlWriter with preallocated buffer to
|
||||||
|
// ws.WriteHeader because it performs one syscall instead of two.
|
||||||
|
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p)
|
||||||
|
r := c.Src
|
||||||
|
if c.State.ServerSide() && !c.DisableSrcCiphering {
|
||||||
|
r = NewCipherReader(r, h.Mask)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := io.Copy(w, r)
|
||||||
|
if err == nil {
|
||||||
|
err = w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlePong handles pong frame by discarding it.
|
||||||
|
func (c ControlHandler) HandlePong(h ws.Header) error {
|
||||||
|
if h.Length == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := pbytes.GetLen(int(h.Length))
|
||||||
|
defer pbytes.Put(buf)
|
||||||
|
|
||||||
|
// Discard pong message according to the RFC6455:
|
||||||
|
// A Pong frame MAY be sent unsolicited. This serves as a
|
||||||
|
// unidirectional heartbeat. A response to an unsolicited Pong frame
|
||||||
|
// is not expected.
|
||||||
|
_, err := io.CopyBuffer(ioutil.Discard, c.Src, buf)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleClose handles close frame, makes protocol validity checks and writes
|
||||||
|
// specification compatible response to the c.Dst.
|
||||||
|
func (c ControlHandler) HandleClose(h ws.Header) error {
|
||||||
|
if h.Length == 0 {
|
||||||
|
err := ws.WriteHeader(c.Dst, ws.Header{
|
||||||
|
Fin: true,
|
||||||
|
OpCode: ws.OpClose,
|
||||||
|
Masked: c.State.ClientSide(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Due to RFC, we should interpret the code as no status code
|
||||||
|
// received:
|
||||||
|
// If this Close control frame contains no status code, _The WebSocket
|
||||||
|
// Connection Close Code_ is considered to be 1005.
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-7.1.5
|
||||||
|
return ClosedError{
|
||||||
|
Code: ws.StatusNoStatusRcvd,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare bytes both for reading reason and sending response.
|
||||||
|
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
|
||||||
|
Length: h.Length,
|
||||||
|
Masked: c.State.ClientSide(),
|
||||||
|
}))
|
||||||
|
defer pbytes.Put(p)
|
||||||
|
|
||||||
|
// Get the subslice to read the frame payload out.
|
||||||
|
subp := p[:h.Length]
|
||||||
|
|
||||||
|
r := c.Src
|
||||||
|
if c.State.ServerSide() && !c.DisableSrcCiphering {
|
||||||
|
r = NewCipherReader(r, h.Mask)
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(r, subp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
code, reason := ws.ParseCloseFrameData(subp)
|
||||||
|
if err := ws.CheckCloseFrameData(code, reason); err != nil {
|
||||||
|
// Here we could not use the prepared bytes because there is no
|
||||||
|
// guarantee that it may fit our protocol error closure code and a
|
||||||
|
// reason.
|
||||||
|
c.closeWithProtocolError(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deal with ciphering i/o:
|
||||||
|
// Masking key is used to mask the "Payload data" defined in the same
|
||||||
|
// section as frame-payload-data, which includes "Extension data" and
|
||||||
|
// "Application data".
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc6455#section-5.3
|
||||||
|
//
|
||||||
|
// NOTE: We prefer ControlWriter with preallocated buffer to
|
||||||
|
// ws.WriteHeader because it performs one syscall instead of two.
|
||||||
|
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p)
|
||||||
|
|
||||||
|
// RFC6455#5.5.1:
|
||||||
|
// If an endpoint receives a Close frame and did not previously
|
||||||
|
// send a Close frame, the endpoint MUST send a Close frame in
|
||||||
|
// response. (When sending a Close frame in response, the endpoint
|
||||||
|
// typically echoes the status code it received.)
|
||||||
|
_, err := w.Write(p[:2])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = w.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ClosedError{
|
||||||
|
Code: code,
|
||||||
|
Reason: reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ControlHandler) closeWithProtocolError(reason error) error {
|
||||||
|
f := ws.NewCloseFrame(ws.NewCloseFrameBody(
|
||||||
|
ws.StatusProtocolError, reason.Error(),
|
||||||
|
))
|
||||||
|
if c.State.ClientSide() {
|
||||||
|
ws.MaskFrameInPlace(f)
|
||||||
|
}
|
||||||
|
return ws.WriteFrame(c.Dst, f)
|
||||||
|
}
|
|
@ -0,0 +1,279 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message represents a message from peer, that could be presented in one or
|
||||||
|
// more frames. That is, it contains payload of all message fragments and
|
||||||
|
// operation code of initial frame for this message.
|
||||||
|
type Message struct {
|
||||||
|
OpCode ws.OpCode
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMessage is a helper function that reads next message from r. It appends
|
||||||
|
// received message(s) to the third argument and returns the result of it and
|
||||||
|
// an error if some failure happened. That is, it probably could receive more
|
||||||
|
// than one message when peer sending fragmented message in multiple frames and
|
||||||
|
// want to send some control frame between fragments. Then returned slice will
|
||||||
|
// contain those control frames at first, and then result of gluing fragments.
|
||||||
|
//
|
||||||
|
// TODO(gobwas): add DefaultReader with buffer size options.
|
||||||
|
func ReadMessage(r io.Reader, s ws.State, m []Message) ([]Message, error) {
|
||||||
|
rd := Reader{
|
||||||
|
Source: r,
|
||||||
|
State: s,
|
||||||
|
CheckUTF8: true,
|
||||||
|
OnIntermediate: func(hdr ws.Header, src io.Reader) error {
|
||||||
|
bts, err := ioutil.ReadAll(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m = append(m, Message{hdr.OpCode, bts})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h, err := rd.NextFrame()
|
||||||
|
if err != nil {
|
||||||
|
return m, err
|
||||||
|
}
|
||||||
|
var p []byte
|
||||||
|
if h.Fin {
|
||||||
|
// No more frames will be read. Use fixed sized buffer to read payload.
|
||||||
|
p = make([]byte, h.Length)
|
||||||
|
// It is not possible to receive io.EOF here because Reader does not
|
||||||
|
// return EOF if frame payload was successfully fetched.
|
||||||
|
// Thus we consistent here with io.Reader behavior.
|
||||||
|
_, err = io.ReadFull(&rd, p)
|
||||||
|
} else {
|
||||||
|
// Frame is fragmented, thus use ioutil.ReadAll behavior.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
_, err = buf.ReadFrom(&rd)
|
||||||
|
p = buf.Bytes()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return m, err
|
||||||
|
}
|
||||||
|
return append(m, Message{h.OpCode, p}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadClientMessage reads next message from r, considering that caller
|
||||||
|
// represents server side.
|
||||||
|
// It is a shortcut for ReadMessage(r, ws.StateServerSide, m)
|
||||||
|
func ReadClientMessage(r io.Reader, m []Message) ([]Message, error) {
|
||||||
|
return ReadMessage(r, ws.StateServerSide, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadServerMessage reads next message from r, considering that caller
|
||||||
|
// represents client side.
|
||||||
|
// It is a shortcut for ReadMessage(r, ws.StateClientSide, m)
|
||||||
|
func ReadServerMessage(r io.Reader, m []Message) ([]Message, error) {
|
||||||
|
return ReadMessage(r, ws.StateClientSide, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadData is a helper function that reads next data (non-control) message
|
||||||
|
// from rw.
|
||||||
|
// It takes care on handling all control frames. It will write response on
|
||||||
|
// control frames to the write part of rw. It blocks until some data frame
|
||||||
|
// will be received.
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a
|
||||||
|
// given io.ReadWriter.
|
||||||
|
func ReadData(rw io.ReadWriter, s ws.State) ([]byte, ws.OpCode, error) {
|
||||||
|
return readData(rw, s, ws.OpText|ws.OpBinary)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadClientData reads next data message from rw, considering that caller
|
||||||
|
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a
|
||||||
|
// given io.ReadWriter.
|
||||||
|
func ReadClientData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
|
||||||
|
return ReadData(rw, ws.StateServerSide)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadClientText reads next text message from rw, considering that caller
|
||||||
|
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
|
||||||
|
// It discards received binary messages.
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a
|
||||||
|
// given io.ReadWriter.
|
||||||
|
func ReadClientText(rw io.ReadWriter) ([]byte, error) {
|
||||||
|
p, _, err := readData(rw, ws.StateServerSide, ws.OpText)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadClientBinary reads next binary message from rw, considering that caller
|
||||||
|
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
|
||||||
|
// It discards received text messages.
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a given
|
||||||
|
// io.ReadWriter.
|
||||||
|
func ReadClientBinary(rw io.ReadWriter) ([]byte, error) {
|
||||||
|
p, _, err := readData(rw, ws.StateServerSide, ws.OpBinary)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadServerData reads next data message from rw, considering that caller
|
||||||
|
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a
|
||||||
|
// given io.ReadWriter.
|
||||||
|
func ReadServerData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
|
||||||
|
return ReadData(rw, ws.StateClientSide)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadServerText reads next text message from rw, considering that caller
|
||||||
|
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
|
||||||
|
// It discards received binary messages.
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a given
|
||||||
|
// io.ReadWriter.
|
||||||
|
func ReadServerText(rw io.ReadWriter) ([]byte, error) {
|
||||||
|
p, _, err := readData(rw, ws.StateClientSide, ws.OpText)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadServerBinary reads next binary message from rw, considering that caller
|
||||||
|
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
|
||||||
|
// It discards received text messages.
|
||||||
|
//
|
||||||
|
// Note this may handle and write control frames into the writer part of a
|
||||||
|
// given io.ReadWriter.
|
||||||
|
func ReadServerBinary(rw io.ReadWriter) ([]byte, error) {
|
||||||
|
p, _, err := readData(rw, ws.StateClientSide, ws.OpBinary)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMessage is a helper function that writes message to the w. It
|
||||||
|
// constructs single frame with given operation code and payload.
|
||||||
|
// It uses given state to prepare side-dependent things, like cipher
|
||||||
|
// payload bytes from client to server. It will not mutate p bytes if
|
||||||
|
// cipher must be made.
|
||||||
|
//
|
||||||
|
// If you want to write message in fragmented frames, use Writer instead.
|
||||||
|
func WriteMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error {
|
||||||
|
return writeFrame(w, s, op, true, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteServerMessage writes message to w, considering that caller
|
||||||
|
// represents server side.
|
||||||
|
func WriteServerMessage(w io.Writer, op ws.OpCode, p []byte) error {
|
||||||
|
return WriteMessage(w, ws.StateServerSide, op, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteServerText is the same as WriteServerMessage with
|
||||||
|
// ws.OpText.
|
||||||
|
func WriteServerText(w io.Writer, p []byte) error {
|
||||||
|
return WriteServerMessage(w, ws.OpText, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteServerBinary is the same as WriteServerMessage with
|
||||||
|
// ws.OpBinary.
|
||||||
|
func WriteServerBinary(w io.Writer, p []byte) error {
|
||||||
|
return WriteServerMessage(w, ws.OpBinary, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteClientMessage writes message to w, considering that caller
|
||||||
|
// represents client side.
|
||||||
|
func WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error {
|
||||||
|
return WriteMessage(w, ws.StateClientSide, op, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteClientText is the same as WriteClientMessage with
|
||||||
|
// ws.OpText.
|
||||||
|
func WriteClientText(w io.Writer, p []byte) error {
|
||||||
|
return WriteClientMessage(w, ws.OpText, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteClientBinary is the same as WriteClientMessage with
|
||||||
|
// ws.OpBinary.
|
||||||
|
func WriteClientBinary(w io.Writer, p []byte) error {
|
||||||
|
return WriteClientMessage(w, ws.OpBinary, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleClientControlMessage handles control frame from conn and writes
|
||||||
|
// response when needed.
|
||||||
|
//
|
||||||
|
// It considers that caller represents server side.
|
||||||
|
func HandleClientControlMessage(conn io.Writer, msg Message) error {
|
||||||
|
return HandleControlMessage(conn, ws.StateServerSide, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleServerControlMessage handles control frame from conn and writes
|
||||||
|
// response when needed.
|
||||||
|
//
|
||||||
|
// It considers that caller represents client side.
|
||||||
|
func HandleServerControlMessage(conn io.Writer, msg Message) error {
|
||||||
|
return HandleControlMessage(conn, ws.StateClientSide, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleControlMessage handles message which was read by ReadMessage()
|
||||||
|
// functions.
|
||||||
|
//
|
||||||
|
// That is, it is expected, that payload is already unmasked and frame header
|
||||||
|
// were checked by ws.CheckHeader() call.
|
||||||
|
func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error {
|
||||||
|
return (ControlHandler{
|
||||||
|
DisableSrcCiphering: true,
|
||||||
|
Src: bytes.NewReader(msg.Payload),
|
||||||
|
Dst: conn,
|
||||||
|
State: state,
|
||||||
|
}).Handle(ws.Header{
|
||||||
|
Length: int64(len(msg.Payload)),
|
||||||
|
OpCode: msg.OpCode,
|
||||||
|
Fin: true,
|
||||||
|
Masked: state.ServerSide(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ControlFrameHandler returns FrameHandlerFunc for handling control frames.
|
||||||
|
// For more info see ControlHandler docs.
|
||||||
|
func ControlFrameHandler(w io.Writer, state ws.State) FrameHandlerFunc {
|
||||||
|
return func(h ws.Header, r io.Reader) error {
|
||||||
|
return (ControlHandler{
|
||||||
|
DisableSrcCiphering: true,
|
||||||
|
Src: r,
|
||||||
|
Dst: w,
|
||||||
|
State: state,
|
||||||
|
}).Handle(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readData(rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, ws.OpCode, error) {
|
||||||
|
controlHandler := ControlFrameHandler(rw, s)
|
||||||
|
rd := Reader{
|
||||||
|
Source: rw,
|
||||||
|
State: s,
|
||||||
|
CheckUTF8: true,
|
||||||
|
SkipHeaderCheck: false,
|
||||||
|
OnIntermediate: controlHandler,
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
hdr, err := rd.NextFrame()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if hdr.OpCode.IsControl() {
|
||||||
|
if err := controlHandler(hdr, &rd); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hdr.OpCode&want == 0 {
|
||||||
|
if err := rd.Discard(); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bts, err := ioutil.ReadAll(&rd)
|
||||||
|
|
||||||
|
return bts, hdr.OpCode, err
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,257 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrNoFrameAdvance means that Reader's Read() method was called without
|
||||||
|
// preceding NextFrame() call.
|
||||||
|
var ErrNoFrameAdvance = errors.New("no frame advance")
|
||||||
|
|
||||||
|
// FrameHandlerFunc handles parsed frame header and its body represented by
|
||||||
|
// io.Reader.
|
||||||
|
//
|
||||||
|
// Note that reader represents already unmasked body.
|
||||||
|
type FrameHandlerFunc func(ws.Header, io.Reader) error
|
||||||
|
|
||||||
|
// Reader is a wrapper around source io.Reader which represents WebSocket
|
||||||
|
// connection. It contains options for reading messages from source.
|
||||||
|
//
|
||||||
|
// Reader implements io.Reader, which Read() method reads payload of incoming
|
||||||
|
// WebSocket frames. It also takes care on fragmented frames and possibly
|
||||||
|
// intermediate control frames between them.
|
||||||
|
//
|
||||||
|
// Note that Reader's methods are not goroutine safe.
|
||||||
|
type Reader struct {
|
||||||
|
Source io.Reader
|
||||||
|
State ws.State
|
||||||
|
|
||||||
|
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
|
||||||
|
SkipHeaderCheck bool
|
||||||
|
|
||||||
|
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
|
||||||
|
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
|
||||||
|
CheckUTF8 bool
|
||||||
|
|
||||||
|
// TODO(gobwas): add max frame size limit here.
|
||||||
|
|
||||||
|
OnContinuation FrameHandlerFunc
|
||||||
|
OnIntermediate FrameHandlerFunc
|
||||||
|
|
||||||
|
opCode ws.OpCode // Used to store message op code on fragmentation.
|
||||||
|
frame io.Reader // Used to as frame reader.
|
||||||
|
raw io.LimitedReader // Used to discard frames without cipher.
|
||||||
|
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReader creates new frame reader that reads from r keeping given state to
|
||||||
|
// make some protocol validity checks when it needed.
|
||||||
|
func NewReader(r io.Reader, s ws.State) *Reader {
|
||||||
|
return &Reader{
|
||||||
|
Source: r,
|
||||||
|
State: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientSideReader is a helper function that calls NewReader with r and
|
||||||
|
// ws.StateClientSide.
|
||||||
|
func NewClientSideReader(r io.Reader) *Reader {
|
||||||
|
return NewReader(r, ws.StateClientSide)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerSideReader is a helper function that calls NewReader with r and
|
||||||
|
// ws.StateServerSide.
|
||||||
|
func NewServerSideReader(r io.Reader) *Reader {
|
||||||
|
return NewReader(r, ws.StateServerSide)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.Reader. It reads the next message payload into p.
|
||||||
|
// It takes care on fragmented messages.
|
||||||
|
//
|
||||||
|
// The error is io.EOF only if all of message bytes were read.
|
||||||
|
// If an io.EOF happens during reading some but not all the message bytes
|
||||||
|
// Read() returns io.ErrUnexpectedEOF.
|
||||||
|
//
|
||||||
|
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
|
||||||
|
// reading next message bytes.
|
||||||
|
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||||
|
if r.frame == nil {
|
||||||
|
if !r.fragmented() {
|
||||||
|
// Every new Read() must be preceded by NextFrame() call.
|
||||||
|
return 0, ErrNoFrameAdvance
|
||||||
|
}
|
||||||
|
// Read next continuation or intermediate control frame.
|
||||||
|
_, err := r.NextFrame()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if r.frame == nil {
|
||||||
|
// We handled intermediate control and now got nothing to read.
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = r.frame.Read(p)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil && r.raw.N != 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case r.raw.N != 0:
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
|
||||||
|
case r.fragmented():
|
||||||
|
err = nil
|
||||||
|
r.resetFragment()
|
||||||
|
|
||||||
|
case r.CheckUTF8 && !r.utf8.Valid():
|
||||||
|
n = r.utf8.Accepted()
|
||||||
|
err = ErrInvalidUTF8
|
||||||
|
|
||||||
|
default:
|
||||||
|
r.reset()
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discard discards current message unread bytes.
|
||||||
|
// It discards all frames of fragmented message.
|
||||||
|
func (r *Reader) Discard() (err error) {
|
||||||
|
for {
|
||||||
|
_, err = io.Copy(ioutil.Discard, &r.raw)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !r.fragmented() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, err = r.NextFrame(); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.reset()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextFrame prepares r to read next message. It returns received frame header
|
||||||
|
// and non-nil error on failure.
|
||||||
|
//
|
||||||
|
// Note that next NextFrame() call must be done after receiving or discarding
|
||||||
|
// all current message bytes.
|
||||||
|
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
|
||||||
|
hdr, err = ws.ReadHeader(r.Source)
|
||||||
|
if err == io.EOF && r.fragmented() {
|
||||||
|
// If we are in fragmented state EOF means that is was totally
|
||||||
|
// unexpected.
|
||||||
|
//
|
||||||
|
// NOTE: This is necessary to prevent callers such that
|
||||||
|
// ioutil.ReadAll to receive some amount of bytes without an error.
|
||||||
|
// ReadAll() ignores an io.EOF error, thus caller may think that
|
||||||
|
// whole message fetched, but actually only part of it.
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
if err == nil && !r.SkipHeaderCheck {
|
||||||
|
err = ws.CheckHeader(hdr, r.State)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return hdr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save raw reader to use it on discarding frame without ciphering and
|
||||||
|
// other streaming checks.
|
||||||
|
r.raw = io.LimitedReader{r.Source, hdr.Length}
|
||||||
|
|
||||||
|
frame := io.Reader(&r.raw)
|
||||||
|
if hdr.Masked {
|
||||||
|
frame = NewCipherReader(frame, hdr.Mask)
|
||||||
|
}
|
||||||
|
if r.fragmented() {
|
||||||
|
if hdr.OpCode.IsControl() {
|
||||||
|
if cb := r.OnIntermediate; cb != nil {
|
||||||
|
err = cb(hdr, frame)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
// Ensure that src is empty.
|
||||||
|
_, err = io.Copy(ioutil.Discard, &r.raw)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
r.opCode = hdr.OpCode
|
||||||
|
}
|
||||||
|
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
|
||||||
|
r.utf8.Source = frame
|
||||||
|
frame = &r.utf8
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save reader with ciphering and other streaming checks.
|
||||||
|
r.frame = frame
|
||||||
|
|
||||||
|
if hdr.OpCode == ws.OpContinuation {
|
||||||
|
if cb := r.OnContinuation; cb != nil {
|
||||||
|
err = cb(hdr, frame)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Fin {
|
||||||
|
r.State = r.State.Clear(ws.StateFragmented)
|
||||||
|
} else {
|
||||||
|
r.State = r.State.Set(ws.StateFragmented)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) fragmented() bool {
|
||||||
|
return r.State.Fragmented()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) resetFragment() {
|
||||||
|
r.raw = io.LimitedReader{}
|
||||||
|
r.frame = nil
|
||||||
|
// Reset source of the UTF8Reader, but not the state.
|
||||||
|
r.utf8.Source = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reader) reset() {
|
||||||
|
r.raw = io.LimitedReader{}
|
||||||
|
r.frame = nil
|
||||||
|
r.utf8 = UTF8Reader{}
|
||||||
|
r.opCode = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// NextReader prepares next message read from r. It returns header that
|
||||||
|
// describes the message and io.Reader to read message's payload. It returns
|
||||||
|
// non-nil error when it is not possible to read message's initial frame.
|
||||||
|
//
|
||||||
|
// Note that next NextReader() on the same r should be done after reading all
|
||||||
|
// bytes from previously returned io.Reader. For more performant way to discard
|
||||||
|
// message use Reader and its Discard() method.
|
||||||
|
//
|
||||||
|
// Note that it will not handle any "intermediate" frames, that possibly could
|
||||||
|
// be received between text/binary continuation frames. That is, if peer sent
|
||||||
|
// text/binary frame with fin flag "false", then it could send ping frame, and
|
||||||
|
// eventually remaining part of text/binary frame with fin "true" – with
|
||||||
|
// NextReader() the ping frame will be dropped without any notice. To handle
|
||||||
|
// this rare, but possible situation (and if you do not know exactly which
|
||||||
|
// frames peer could send), you could use Reader with OnIntermediate field set.
|
||||||
|
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
|
||||||
|
rd := &Reader{
|
||||||
|
Source: r,
|
||||||
|
State: s,
|
||||||
|
}
|
||||||
|
header, err := rd.NextFrame()
|
||||||
|
if err != nil {
|
||||||
|
return header, nil, err
|
||||||
|
}
|
||||||
|
return header, rd, nil
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DebugUpgrader is a wrapper around ws.Upgrader. It tracks I/O of a
|
||||||
|
// WebSocket handshake.
|
||||||
|
//
|
||||||
|
// Note that it must not be used in production applications that requires
|
||||||
|
// Upgrade() to be efficient.
|
||||||
|
type DebugUpgrader struct {
|
||||||
|
// Upgrader contains upgrade to WebSocket options.
|
||||||
|
Upgrader ws.Upgrader
|
||||||
|
|
||||||
|
// OnRequest and OnResponse are the callbacks that will be called with the
|
||||||
|
// HTTP request and response respectively.
|
||||||
|
OnRequest, OnResponse func([]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade calls Upgrade() on underlying ws.Upgrader and tracks I/O on conn.
|
||||||
|
func (d *DebugUpgrader) Upgrade(conn io.ReadWriter) (hs ws.Handshake, err error) {
|
||||||
|
var (
|
||||||
|
// Take the Reader and Writer parts from conn to be probably replaced
|
||||||
|
// below.
|
||||||
|
r io.Reader = conn
|
||||||
|
w io.Writer = conn
|
||||||
|
)
|
||||||
|
if onRequest := d.OnRequest; onRequest != nil {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
// First, we must read the entire request.
|
||||||
|
req, err := http.ReadRequest(bufio.NewReader(
|
||||||
|
io.TeeReader(conn, &buf),
|
||||||
|
))
|
||||||
|
if err == nil {
|
||||||
|
// Fulfill the buffer with the response body.
|
||||||
|
io.Copy(ioutil.Discard, req.Body)
|
||||||
|
req.Body.Close()
|
||||||
|
}
|
||||||
|
onRequest(buf.Bytes())
|
||||||
|
|
||||||
|
r = io.MultiReader(
|
||||||
|
&buf, conn,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if onResponse := d.OnResponse; onResponse != nil {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
// Intercept the response stream written by the Upgrade().
|
||||||
|
w = io.MultiWriter(
|
||||||
|
conn, &buf,
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
onResponse(buf.Bytes())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.Upgrader.Upgrade(struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{r, w})
|
||||||
|
}
|
|
@ -0,0 +1,140 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidUTF8 is returned by UTF8 reader on invalid utf8 sequence.
|
||||||
|
var ErrInvalidUTF8 = fmt.Errorf("invalid utf8")
|
||||||
|
|
||||||
|
// UTF8Reader implements io.Reader that calculates utf8 validity state after
|
||||||
|
// every read byte from Source.
|
||||||
|
//
|
||||||
|
// Note that in some cases client must call r.Valid() after all bytes are read
|
||||||
|
// to ensure that all of them are valid utf8 sequences. That is, some io helper
|
||||||
|
// functions such io.ReadAtLeast or io.ReadFull could discard the error
|
||||||
|
// information returned by the reader when they receive all of requested bytes.
|
||||||
|
// For example, the last read sequence is invalid and UTF8Reader returns number
|
||||||
|
// of bytes read and an error. But helper function decides to discard received
|
||||||
|
// error due to all requested bytes are completely read from the source.
|
||||||
|
//
|
||||||
|
// Another possible case is when some valid sequence become split by the read
|
||||||
|
// bound. Then UTF8Reader can not make decision about validity of the last
|
||||||
|
// sequence cause it is not fully read yet. And if the read stops, Valid() will
|
||||||
|
// return false, even if Read() by itself dit not.
|
||||||
|
type UTF8Reader struct {
|
||||||
|
Source io.Reader
|
||||||
|
|
||||||
|
accepted int
|
||||||
|
|
||||||
|
state uint32
|
||||||
|
codep uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUTF8Reader creates utf8 reader that reads from r.
|
||||||
|
func NewUTF8Reader(r io.Reader) *UTF8Reader {
|
||||||
|
return &UTF8Reader{
|
||||||
|
Source: r,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets utf8 reader to read from r.
|
||||||
|
func (u *UTF8Reader) Reset(r io.Reader) {
|
||||||
|
u.Source = r
|
||||||
|
u.state = 0
|
||||||
|
u.codep = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.Reader.
|
||||||
|
func (u *UTF8Reader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = u.Source.Read(p)
|
||||||
|
|
||||||
|
accepted := 0
|
||||||
|
s, c := u.state, u.codep
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
c, s = decode(s, c, p[i])
|
||||||
|
if s == utf8Reject {
|
||||||
|
u.state = s
|
||||||
|
return accepted, ErrInvalidUTF8
|
||||||
|
}
|
||||||
|
if s == utf8Accept {
|
||||||
|
accepted = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u.state, u.codep = s, c
|
||||||
|
u.accepted = accepted
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid checks current reader state. It returns true if all read bytes are
|
||||||
|
// valid UTF-8 sequences, and false if not.
|
||||||
|
func (u *UTF8Reader) Valid() bool {
|
||||||
|
return u.state == utf8Accept
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accepted returns number of valid bytes in last Read().
|
||||||
|
func (u *UTF8Reader) Accepted() int {
|
||||||
|
return u.accepted
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
|
||||||
|
//
|
||||||
|
// Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de>
|
||||||
|
//
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to
|
||||||
|
// deal in the Software without restriction, including without limitation the
|
||||||
|
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||||
|
// sell copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
//
|
||||||
|
// The above copyright notice and this permission notice shall be included in
|
||||||
|
// all copies or substantial portions of the Software.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||||
|
// IN THE SOFTWARE.
|
||||||
|
|
||||||
|
const (
|
||||||
|
utf8Accept = 0
|
||||||
|
utf8Reject = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
var utf8d = [...]byte{
|
||||||
|
// The first part of the table maps bytes to character classes that
|
||||||
|
// to reduce the size of the transition table and create bitmasks.
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||||
|
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
||||||
|
8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
|
||||||
|
|
||||||
|
// The second part is a transition table that maps a combination
|
||||||
|
// of a state of the automaton and a character class to a state.
|
||||||
|
0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||||
|
12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12,
|
||||||
|
12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12,
|
||||||
|
12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12,
|
||||||
|
12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
func decode(state, codep uint32, b byte) (uint32, uint32) {
|
||||||
|
t := uint32(utf8d[b])
|
||||||
|
|
||||||
|
if state != utf8Accept {
|
||||||
|
codep = (uint32(b) & 0x3f) | (codep << 6)
|
||||||
|
} else {
|
||||||
|
codep = (0xff >> t) & uint32(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
return codep, uint32(utf8d[256+state+t])
|
||||||
|
}
|
|
@ -0,0 +1,450 @@
|
||||||
|
package wsutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gobwas/pool"
|
||||||
|
"github.com/gobwas/pool/pbytes"
|
||||||
|
"github.com/gobwas/ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultWriteBuffer contains size of Writer's default buffer. It used by
|
||||||
|
// Writer constructor functions.
|
||||||
|
var DefaultWriteBuffer = 4096
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
|
||||||
|
// not empty and write through could not be done. That is, caller should call
|
||||||
|
// Writer.FlushFragment() to make buffer empty.
|
||||||
|
ErrNotEmpty = fmt.Errorf("writer not empty")
|
||||||
|
|
||||||
|
// ErrControlOverflow is returned by ControlWriter.Write() to indicate that
|
||||||
|
// no more data could be written to the underlying io.Writer because
|
||||||
|
// MaxControlFramePayloadSize limit is reached.
|
||||||
|
ErrControlOverflow = fmt.Errorf("control frame payload overflow")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Constants which are represent frame length ranges.
|
||||||
|
const (
|
||||||
|
len7 = int64(125) // 126 and 127 are reserved values
|
||||||
|
len16 = int64(^uint16(0))
|
||||||
|
len64 = int64((^uint64(0)) >> 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ControlWriter is a wrapper around Writer that contains some guards for
|
||||||
|
// buffered writes of control frames.
|
||||||
|
type ControlWriter struct {
|
||||||
|
w *Writer
|
||||||
|
limit int
|
||||||
|
n int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewControlWriter contains ControlWriter with Writer inside whose buffer size
|
||||||
|
// is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
|
||||||
|
func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
|
||||||
|
return &ControlWriter{
|
||||||
|
w: NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
|
||||||
|
limit: ws.MaxControlFramePayloadSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
|
||||||
|
//
|
||||||
|
// Note that it reserves x bytes of buf for header data, where x could be
|
||||||
|
// ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
|
||||||
|
// (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
|
||||||
|
//
|
||||||
|
// It panics if len(buf) <= ws.MinHeaderSize + x.
|
||||||
|
func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
|
||||||
|
max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
|
||||||
|
if len(buf) > max {
|
||||||
|
buf = buf[:max]
|
||||||
|
}
|
||||||
|
|
||||||
|
w := NewWriterBuffer(dest, state, op, buf)
|
||||||
|
|
||||||
|
return &ControlWriter{
|
||||||
|
w: w,
|
||||||
|
limit: len(w.buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer. It writes to the underlying Writer until it
|
||||||
|
// returns error or until ControlWriter write limit will be exceeded.
|
||||||
|
func (c *ControlWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if c.n+len(p) > c.limit {
|
||||||
|
return 0, ErrControlOverflow
|
||||||
|
}
|
||||||
|
return c.w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush flushes all buffered data to the underlying io.Writer.
|
||||||
|
func (c *ControlWriter) Flush() error {
|
||||||
|
return c.w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writer contains logic of buffering output data into a WebSocket fragments.
|
||||||
|
// It is much the same as bufio.Writer, except the thing that it works with
|
||||||
|
// WebSocket frames, not the raw data.
|
||||||
|
//
|
||||||
|
// Writer writes frames with specified OpCode.
|
||||||
|
// It uses ws.State to decide whether the output frames must be masked.
|
||||||
|
//
|
||||||
|
// Note that it does not check control frame size or other RFC rules.
|
||||||
|
// That is, it must be used with special care to write control frames without
|
||||||
|
// violation of RFC. You could use ControlWriter that wraps Writer and contains
|
||||||
|
// some guards for writing control frames.
|
||||||
|
//
|
||||||
|
// If an error occurs writing to a Writer, no more data will be accepted and
|
||||||
|
// all subsequent writes will return the error.
|
||||||
|
// After all data has been written, the client should call the Flush() method
|
||||||
|
// to guarantee all data has been forwarded to the underlying io.Writer.
|
||||||
|
type Writer struct {
|
||||||
|
dest io.Writer
|
||||||
|
|
||||||
|
n int // Buffered bytes counter.
|
||||||
|
raw []byte // Raw representation of buffer, including reserved header bytes.
|
||||||
|
buf []byte // Writeable part of buffer, without reserved header bytes.
|
||||||
|
|
||||||
|
op ws.OpCode
|
||||||
|
state ws.State
|
||||||
|
|
||||||
|
dirty bool
|
||||||
|
fragmented bool
|
||||||
|
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
var writers = pool.New(128, 65536)
|
||||||
|
|
||||||
|
// GetWriter tries to reuse Writer getting it from the pool.
|
||||||
|
//
|
||||||
|
// This function is intended for memory consumption optimizations, because
|
||||||
|
// NewWriter*() functions make allocations for inner buffer.
|
||||||
|
//
|
||||||
|
// Note the it ceils n to the power of two.
|
||||||
|
//
|
||||||
|
// If you have your own bytes buffer pool you could use NewWriterBuffer to use
|
||||||
|
// pooled bytes in writer.
|
||||||
|
func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
|
||||||
|
x, m := writers.Get(n)
|
||||||
|
if x != nil {
|
||||||
|
w := x.(*Writer)
|
||||||
|
w.Reset(dest, state, op)
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
// NOTE: we use m instead of n, because m is an attempt to reuse w of such
|
||||||
|
// size in the future.
|
||||||
|
return NewWriterBufferSize(dest, state, op, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutWriter puts w for future reuse by GetWriter().
|
||||||
|
func PutWriter(w *Writer) {
|
||||||
|
w.Reset(nil, 0, 0)
|
||||||
|
writers.Put(w, w.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
|
||||||
|
func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
|
||||||
|
return NewWriterBufferSize(dest, state, op, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
|
||||||
|
// That is, output frames payload length could be up to n, except the case when
|
||||||
|
// Write() is called on empty Writer with len(p) > n.
|
||||||
|
//
|
||||||
|
// If n <= 0 then the default buffer size is used as Writer's buffer size.
|
||||||
|
func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
|
||||||
|
if n > 0 {
|
||||||
|
n += headerSize(state, n)
|
||||||
|
}
|
||||||
|
return NewWriterBufferSize(dest, state, op, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
|
||||||
|
// If n <= ws.MinHeaderSize then the default buffer size is used.
|
||||||
|
//
|
||||||
|
// Note that Writer will reserve x bytes for header data, where x is in range
|
||||||
|
// [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
|
||||||
|
// will not have payload length equal to n, except the case when Write() is
|
||||||
|
// called on empty Writer with len(p) > n.
|
||||||
|
func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
|
||||||
|
if n <= ws.MinHeaderSize {
|
||||||
|
n = DefaultWriteBuffer
|
||||||
|
}
|
||||||
|
return NewWriterBuffer(dest, state, op, make([]byte, n))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWriterBuffer returns a new Writer with buf as a buffer.
|
||||||
|
//
|
||||||
|
// Note that it reserves x bytes of buf for header data, where x is in range
|
||||||
|
// [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
|
||||||
|
//
|
||||||
|
// You could use ws.HeaderSize() to calculate number of bytes needed to store
|
||||||
|
// header data.
|
||||||
|
//
|
||||||
|
// It panics if len(buf) is too small to fit header and payload data.
|
||||||
|
func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
|
||||||
|
offset := reserve(state, len(buf))
|
||||||
|
if len(buf) <= offset {
|
||||||
|
panic("buffer too small")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Writer{
|
||||||
|
dest: dest,
|
||||||
|
raw: buf,
|
||||||
|
buf: buf[offset:],
|
||||||
|
state: state,
|
||||||
|
op: op,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func reserve(state ws.State, n int) (offset int) {
|
||||||
|
var mask int
|
||||||
|
if state.ClientSide() {
|
||||||
|
mask = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case n <= int(len7)+mask+2:
|
||||||
|
return mask + 2
|
||||||
|
case n <= int(len16)+mask+4:
|
||||||
|
return mask + 4
|
||||||
|
default:
|
||||||
|
return mask + 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerSize returns number of bytes needed to encode header of a frame with
|
||||||
|
// given state and length.
|
||||||
|
func headerSize(s ws.State, n int) int {
|
||||||
|
return ws.HeaderSize(ws.Header{
|
||||||
|
Length: int64(n),
|
||||||
|
Masked: s.ClientSide(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset discards any buffered data, clears error, and resets w to have given
|
||||||
|
// state and write frames with given OpCode to dest.
|
||||||
|
func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
|
||||||
|
w.n = 0
|
||||||
|
w.dirty = false
|
||||||
|
w.fragmented = false
|
||||||
|
w.dest = dest
|
||||||
|
w.state = state
|
||||||
|
w.op = op
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the size of the underlying buffer in bytes.
|
||||||
|
func (w *Writer) Size() int {
|
||||||
|
return len(w.buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Available returns how many bytes are unused in the buffer.
|
||||||
|
func (w *Writer) Available() int {
|
||||||
|
return len(w.buf) - w.n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffered returns the number of bytes that have been written into the current
|
||||||
|
// buffer.
|
||||||
|
func (w *Writer) Buffered() int {
|
||||||
|
return w.n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer.
|
||||||
|
//
|
||||||
|
// Note that even if the Writer was created to have N-sized buffer, Write()
|
||||||
|
// with payload of N bytes will not fit into that buffer. Writer reserves some
|
||||||
|
// space to fit WebSocket header data.
|
||||||
|
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||||
|
// Even empty p may make a sense.
|
||||||
|
w.dirty = true
|
||||||
|
|
||||||
|
var nn int
|
||||||
|
for len(p) > w.Available() && w.err == nil {
|
||||||
|
if w.Buffered() == 0 {
|
||||||
|
// Large write, empty buffer. Write directly from p to avoid copy.
|
||||||
|
// Trade off here is that we make additional Write() to underlying
|
||||||
|
// io.Writer when writing frame header.
|
||||||
|
//
|
||||||
|
// On large buffers additional write is better than copying.
|
||||||
|
nn, _ = w.WriteThrough(p)
|
||||||
|
} else {
|
||||||
|
nn = copy(w.buf[w.n:], p)
|
||||||
|
w.n += nn
|
||||||
|
w.FlushFragment()
|
||||||
|
}
|
||||||
|
n += nn
|
||||||
|
p = p[nn:]
|
||||||
|
}
|
||||||
|
if w.err != nil {
|
||||||
|
return n, w.err
|
||||||
|
}
|
||||||
|
nn = copy(w.buf[w.n:], p)
|
||||||
|
w.n += nn
|
||||||
|
n += nn
|
||||||
|
|
||||||
|
// Even if w.Available() == 0 we will not flush buffer preventively because
|
||||||
|
// this could bring unwanted fragmentation. That is, user could create
|
||||||
|
// buffer with size that fits exactly all further Write() call, and then
|
||||||
|
// call Flush(), excepting that single and not fragmented frame will be
|
||||||
|
// sent. With preemptive flush this case will produce two frames – last one
|
||||||
|
// will be empty and just to set fin = true.
|
||||||
|
|
||||||
|
return n, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteThrough writes data bypassing the buffer.
|
||||||
|
// Note that Writer's buffer must be empty before calling WriteThrough().
|
||||||
|
func (w *Writer) WriteThrough(p []byte) (n int, err error) {
|
||||||
|
if w.err != nil {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
if w.Buffered() != 0 {
|
||||||
|
return 0, ErrNotEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
w.err = writeFrame(w.dest, w.state, w.opCode(), false, p)
|
||||||
|
if w.err == nil {
|
||||||
|
n = len(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.dirty = true
|
||||||
|
w.fragmented = true
|
||||||
|
|
||||||
|
return n, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFrom implements io.ReaderFrom.
|
||||||
|
func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
|
||||||
|
var nn int
|
||||||
|
for err == nil {
|
||||||
|
if w.Available() == 0 {
|
||||||
|
err = w.FlushFragment()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We copy the behavior of bufio.Writer here.
|
||||||
|
// Also, from the docs on io.ReaderFrom:
|
||||||
|
// ReadFrom reads data from r until EOF or error.
|
||||||
|
//
|
||||||
|
// See https://codereview.appspot.com/76400048/#ps1
|
||||||
|
const maxEmptyReads = 100
|
||||||
|
var nr int
|
||||||
|
for nr < maxEmptyReads {
|
||||||
|
nn, err = src.Read(w.buf[w.n:])
|
||||||
|
if nn != 0 || err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nr++
|
||||||
|
}
|
||||||
|
if nr == maxEmptyReads {
|
||||||
|
return n, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
|
||||||
|
w.n += nn
|
||||||
|
n += int64(nn)
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
// NOTE: Do not flush preemptively.
|
||||||
|
// See the Write() sources for more info.
|
||||||
|
err = nil
|
||||||
|
w.dirty = true
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush writes any buffered data to the underlying io.Writer.
|
||||||
|
// It sends the frame with "fin" flag set to true.
|
||||||
|
//
|
||||||
|
// If no Write() or ReadFrom() was made, then Flush() does nothing.
|
||||||
|
func (w *Writer) Flush() error {
|
||||||
|
if (!w.dirty && w.Buffered() == 0) || w.err != nil {
|
||||||
|
return w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.err = w.flushFragment(true)
|
||||||
|
w.n = 0
|
||||||
|
w.dirty = false
|
||||||
|
w.fragmented = false
|
||||||
|
|
||||||
|
return w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlushFragment writes any buffered data to the underlying io.Writer.
|
||||||
|
// It sends the frame with "fin" flag set to false.
|
||||||
|
func (w *Writer) FlushFragment() error {
|
||||||
|
if w.Buffered() == 0 || w.err != nil {
|
||||||
|
return w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.err = w.flushFragment(false)
|
||||||
|
w.n = 0
|
||||||
|
w.fragmented = true
|
||||||
|
|
||||||
|
return w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) flushFragment(fin bool) error {
|
||||||
|
frame := ws.NewFrame(w.opCode(), fin, w.buf[:w.n])
|
||||||
|
if w.state.ClientSide() {
|
||||||
|
frame = ws.MaskFrameInPlace(frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write header to the header segment of the raw buffer.
|
||||||
|
head := len(w.raw) - len(w.buf)
|
||||||
|
offset := head - ws.HeaderSize(frame.Header)
|
||||||
|
buf := bytesWriter{
|
||||||
|
buf: w.raw[offset:head],
|
||||||
|
}
|
||||||
|
if err := ws.WriteHeader(&buf, frame.Header); err != nil {
|
||||||
|
// Must never be reached.
|
||||||
|
panic("dump header error: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := w.dest.Write(w.raw[offset : head+w.n])
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) opCode() ws.OpCode {
|
||||||
|
if w.fragmented {
|
||||||
|
return ws.OpContinuation
|
||||||
|
}
|
||||||
|
return w.op
|
||||||
|
}
|
||||||
|
|
||||||
|
var errNoSpace = fmt.Errorf("not enough buffer space")
|
||||||
|
|
||||||
|
type bytesWriter struct {
|
||||||
|
buf []byte
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *bytesWriter) Write(p []byte) (int, error) {
|
||||||
|
n := copy(w.buf[w.pos:], p)
|
||||||
|
w.pos += n
|
||||||
|
if n != len(p) {
|
||||||
|
return n, errNoSpace
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
|
||||||
|
var frame ws.Frame
|
||||||
|
if s.ClientSide() {
|
||||||
|
// Should copy bytes to prevent corruption of caller data.
|
||||||
|
payload := pbytes.GetLen(len(p))
|
||||||
|
defer pbytes.Put(payload)
|
||||||
|
|
||||||
|
copy(payload, p)
|
||||||
|
|
||||||
|
frame = ws.NewFrame(op, fin, payload)
|
||||||
|
frame = ws.MaskFrameInPlace(frame)
|
||||||
|
} else {
|
||||||
|
frame = ws.NewFrame(op, fin, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ws.WriteFrame(w, frame)
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*
|
||||||
|
Package wsutil provides utilities for working with WebSocket protocol.
|
||||||
|
|
||||||
|
Overview:
|
||||||
|
|
||||||
|
// Read masked text message from peer and check utf8 encoding.
|
||||||
|
header, err := ws.ReadHeader(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare to read payload.
|
||||||
|
r := io.LimitReader(conn, header.Length)
|
||||||
|
r = wsutil.NewCipherReader(r, header.Mask)
|
||||||
|
r = wsutil.NewUTF8Reader(r)
|
||||||
|
|
||||||
|
payload, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
You could get the same behavior using just `wsutil.Reader`:
|
||||||
|
|
||||||
|
r := wsutil.Reader{
|
||||||
|
Source: conn,
|
||||||
|
CheckUTF8: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
Or even simplest:
|
||||||
|
|
||||||
|
payload, err := wsutil.ReadClientText(conn)
|
||||||
|
if err != nil {
|
||||||
|
// handle err
|
||||||
|
}
|
||||||
|
|
||||||
|
Package is also exports tools for buffered writing:
|
||||||
|
|
||||||
|
// Create buffered writer, that will buffer output bytes and send them as
|
||||||
|
// 128-length fragments (with exception on large writes, see the doc).
|
||||||
|
writer := wsutil.NewWriterSize(conn, ws.StateServerSide, ws.OpText, 128)
|
||||||
|
|
||||||
|
_, err := io.CopyN(writer, rand.Reader, 100)
|
||||||
|
if err == nil {
|
||||||
|
err = writer.Flush()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// handle error
|
||||||
|
}
|
||||||
|
|
||||||
|
For more utils and helpers see the documentation.
|
||||||
|
*/
|
||||||
|
package wsutil
|
|
@ -207,6 +207,19 @@ github.com/gliderlabs/ssh
|
||||||
# github.com/go-sql-driver/mysql v1.5.0
|
# github.com/go-sql-driver/mysql v1.5.0
|
||||||
## explicit
|
## explicit
|
||||||
github.com/go-sql-driver/mysql
|
github.com/go-sql-driver/mysql
|
||||||
|
# github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58
|
||||||
|
## explicit
|
||||||
|
github.com/gobwas/httphead
|
||||||
|
# github.com/gobwas/pool v0.2.1
|
||||||
|
## explicit
|
||||||
|
github.com/gobwas/pool
|
||||||
|
github.com/gobwas/pool/internal/pmath
|
||||||
|
github.com/gobwas/pool/pbufio
|
||||||
|
github.com/gobwas/pool/pbytes
|
||||||
|
# github.com/gobwas/ws v1.0.4
|
||||||
|
## explicit
|
||||||
|
github.com/gobwas/ws
|
||||||
|
github.com/gobwas/ws/wsutil
|
||||||
# github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
# github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
||||||
## explicit
|
## explicit
|
||||||
github.com/golang-collections/collections/queue
|
github.com/golang-collections/collections/queue
|
||||||
|
|
Loading…
Reference in New Issue