TUN-3403: Unit test for origin/proxy to test serving HTTP and Websocket
This commit is contained in:
parent
a490443630
commit
6b86f81c4a
|
@ -8,16 +8,14 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
responseMetaHeaderField = "cf-cloudflared-response-meta"
|
||||
responseSourceCloudflared = "cloudflared"
|
||||
responseSourceOrigin = "origin"
|
||||
responseMetaHeaderField = "cf-cloudflared-response-meta"
|
||||
)
|
||||
|
||||
var (
|
||||
canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField)
|
||||
canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(responseMetaHeaderField)
|
||||
responseMetaHeaderCfd = mustInitRespMetaHeader(responseSourceCloudflared)
|
||||
responseMetaHeaderOrigin = mustInitRespMetaHeader(responseSourceOrigin)
|
||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||
)
|
||||
|
||||
type responseMetaHeader struct {
|
||||
|
|
3
go.mod
3
go.mod
|
@ -27,6 +27,9 @@ require (
|
|||
github.com/getsentry/raven-go v0.0.0-20180517221441-ed7bcb39ff10
|
||||
github.com/gliderlabs/ssh v0.0.0-20191009160644-63518b5243e0
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/gobwas/ws v1.0.4
|
||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
||||
github.com/google/go-cmp v0.5.2 // indirect
|
||||
github.com/google/uuid v1.1.2
|
||||
|
|
6
go.sum
6
go.sum
|
@ -233,6 +233,12 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
|
|||
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs=
|
||||
github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
|
||||
|
|
|
@ -18,6 +18,11 @@ import (
|
|||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
UptimeRoute = "/uptime"
|
||||
WSRoute = "/ws"
|
||||
)
|
||||
|
||||
type templateData struct {
|
||||
ServerName string
|
||||
Request *http.Request
|
||||
|
@ -104,8 +109,8 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow
|
|||
}
|
||||
|
||||
muxer := http.NewServeMux()
|
||||
muxer.HandleFunc("/uptime", uptimeHandler(time.Now()))
|
||||
muxer.HandleFunc("/ws", websocketHandler(logger, upgrader))
|
||||
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
|
||||
muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader))
|
||||
muxer.HandleFunc("/", rootHandler(serverName))
|
||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
|
||||
go func() {
|
||||
|
|
|
@ -2,6 +2,7 @@ package origin
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
|
@ -112,12 +113,17 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt
|
|||
|
||||
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
||||
c.setHostHeader(req)
|
||||
|
||||
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
serveCtx, cancel := context.WithCancel(req.Context())
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-serveCtx.Done()
|
||||
conn.Close()
|
||||
}()
|
||||
err = w.WriteRespHeaders(resp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
|
@ -125,7 +131,6 @@ func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request)
|
|||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), w)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHTTPRespWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func newMockHTTPRespWriter() *mockHTTPRespWriter {
|
||||
return &mockHTTPRespWriter{
|
||||
httptest.NewRecorder(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
for header, val := range resp.Header {
|
||||
w.Header()[header] = val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockHTTPRespWriter) WriteErrorResponse(err error) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
}
|
||||
|
||||
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||
}
|
||||
|
||||
type mockWSRespWriter struct {
|
||||
*mockHTTPRespWriter
|
||||
writeNotification chan []byte
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newMockWSRespWriter(httpRespWriter *mockHTTPRespWriter, reader io.Reader) *mockWSRespWriter {
|
||||
return &mockWSRespWriter{
|
||||
httpRespWriter,
|
||||
make(chan []byte),
|
||||
reader,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) Write(data []byte) (int, error) {
|
||||
w.writeNotification <- data
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) respBody() io.ReadWriter {
|
||||
data := <-w.writeNotification
|
||||
return bytes.NewBuffer(data)
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||
return w.reader.Read(data)
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
logger, err := logger.New()
|
||||
require.NoError(t, err)
|
||||
// let runtime pick an available port
|
||||
listener, err := hello.CreateTLSListener("127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
originURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: listener.Addr().String(),
|
||||
}
|
||||
originCA := x509.NewCertPool()
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
require.NoError(t, err)
|
||||
originCA.AddCert(helloCert)
|
||||
clientTLS := &tls.Config{
|
||||
RootCAs: originCA,
|
||||
}
|
||||
proxyConfig := &ProxyConfig{
|
||||
Client: &http.Transport{
|
||||
TLSClientConfig: clientTLS,
|
||||
},
|
||||
URL: originURL,
|
||||
TLSConfig: clientTLS,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
hello.StartHelloWorldServer(logger, listener, ctx.Done())
|
||||
}()
|
||||
|
||||
client := NewClient(proxyConfig, logger)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
|
||||
cancel()
|
||||
}
|
||||
|
||||
func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
respWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Proxy(respWriter, req, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, respWriter.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// WSRoute is a websocket echo handler
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
|
||||
|
||||
readPipe, writePipe := io.Pipe()
|
||||
respWriter := newMockWSRespWriter(newMockHTTPRespWriter(), readPipe)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err = client.Proxy(respWriter, req, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
||||
}()
|
||||
|
||||
msg := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ReadServerText reads next data message from rw, considering that caller represents client side.
|
||||
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, returnedMsg)
|
||||
|
||||
err = wsutil.WriteClientBinary(writePipe, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, returnedMsg)
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2017 Sergey Kamardin
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,63 @@
|
|||
# httphead.[go](https://golang.org)
|
||||
|
||||
[![GoDoc][godoc-image]][godoc-url]
|
||||
|
||||
> Tiny HTTP header value parsing library in go.
|
||||
|
||||
## Overview
|
||||
|
||||
This library contains low-level functions for scanning HTTP RFC2616 compatible header value grammars.
|
||||
|
||||
## Install
|
||||
|
||||
```shell
|
||||
go get github.com/gobwas/httphead
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
The example below shows how multiple-choise HTTP header value could be parsed with this library:
|
||||
|
||||
```go
|
||||
options, ok := httphead.ParseOptions([]byte(`foo;bar=1,baz`), nil)
|
||||
fmt.Println(options, ok)
|
||||
// Output: [{foo map[bar:1]} {baz map[]}] true
|
||||
```
|
||||
|
||||
The low-level example below shows how to optimize keys skipping and selection
|
||||
of some key:
|
||||
|
||||
```go
|
||||
// The right part of full header line like:
|
||||
// X-My-Header: key;foo=bar;baz,key;baz
|
||||
header := []byte(`foo;a=0,foo;a=1,foo;a=2,foo;a=3`)
|
||||
|
||||
// We want to search key "foo" with an "a" parameter that equal to "2".
|
||||
var (
|
||||
foo = []byte(`foo`)
|
||||
a = []byte(`a`)
|
||||
v = []byte(`2`)
|
||||
)
|
||||
var found bool
|
||||
httphead.ScanOptions(header, func(i int, key, param, value []byte) Control {
|
||||
if !bytes.Equal(key, foo) {
|
||||
return ControlSkip
|
||||
}
|
||||
if !bytes.Equal(param, a) {
|
||||
if bytes.Equal(value, v) {
|
||||
// Found it!
|
||||
found = true
|
||||
return ControlBreak
|
||||
}
|
||||
return ControlSkip
|
||||
}
|
||||
return ControlContinue
|
||||
})
|
||||
```
|
||||
|
||||
For more usage examples please see [docs][godoc-url] or package tests.
|
||||
|
||||
[godoc-image]: https://godoc.org/github.com/gobwas/httphead?status.svg
|
||||
[godoc-url]: https://godoc.org/github.com/gobwas/httphead
|
||||
[travis-image]: https://travis-ci.org/gobwas/httphead.svg?branch=master
|
||||
[travis-url]: https://travis-ci.org/gobwas/httphead
|
|
@ -0,0 +1,200 @@
|
|||
package httphead
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// ScanCookie scans cookie pairs from data using DefaultCookieScanner.Scan()
|
||||
// method.
|
||||
func ScanCookie(data []byte, it func(key, value []byte) bool) bool {
|
||||
return DefaultCookieScanner.Scan(data, it)
|
||||
}
|
||||
|
||||
// DefaultCookieScanner is a CookieScanner which is used by ScanCookie().
|
||||
// Note that it is intended to have the same behavior as http.Request.Cookies()
|
||||
// has.
|
||||
var DefaultCookieScanner = CookieScanner{}
|
||||
|
||||
// CookieScanner contains options for scanning cookie pairs.
|
||||
// See https://tools.ietf.org/html/rfc6265#section-4.1.1
|
||||
type CookieScanner struct {
|
||||
// DisableNameValidation disables name validation of a cookie. If false,
|
||||
// only RFC2616 "tokens" are accepted.
|
||||
DisableNameValidation bool
|
||||
|
||||
// DisableValueValidation disables value validation of a cookie. If false,
|
||||
// only RFC6265 "cookie-octet" characters are accepted.
|
||||
//
|
||||
// Note that Strict option also affects validation of a value.
|
||||
//
|
||||
// If Strict is false, then scanner begins to allow space and comma
|
||||
// characters inside the value for better compatibility with non standard
|
||||
// cookies implementations.
|
||||
DisableValueValidation bool
|
||||
|
||||
// BreakOnPairError sets scanner to immediately return after first pair syntax
|
||||
// validation error.
|
||||
// If false, scanner will try to skip invalid pair bytes and go ahead.
|
||||
BreakOnPairError bool
|
||||
|
||||
// Strict enables strict RFC6265 mode scanning. It affects name and value
|
||||
// validation, as also some other rules.
|
||||
// If false, it is intended to bring the same behavior as
|
||||
// http.Request.Cookies().
|
||||
Strict bool
|
||||
}
|
||||
|
||||
// Scan maps data to name and value pairs. Usually data represents value of the
|
||||
// Cookie header.
|
||||
func (c CookieScanner) Scan(data []byte, it func(name, value []byte) bool) bool {
|
||||
lexer := &Scanner{data: data}
|
||||
|
||||
const (
|
||||
statePair = iota
|
||||
stateBefore
|
||||
)
|
||||
|
||||
state := statePair
|
||||
|
||||
for lexer.Buffered() > 0 {
|
||||
switch state {
|
||||
case stateBefore:
|
||||
// Pairs separated by ";" and space, according to the RFC6265:
|
||||
// cookie-pair *( ";" SP cookie-pair )
|
||||
//
|
||||
// Cookie pairs MUST be separated by (";" SP). So our only option
|
||||
// here is to fail as syntax error.
|
||||
a, b := lexer.Peek2()
|
||||
if a != ';' {
|
||||
return false
|
||||
}
|
||||
|
||||
state = statePair
|
||||
|
||||
advance := 1
|
||||
if b == ' ' {
|
||||
advance++
|
||||
} else if c.Strict {
|
||||
return false
|
||||
}
|
||||
|
||||
lexer.Advance(advance)
|
||||
|
||||
case statePair:
|
||||
if !lexer.FetchUntil(';') {
|
||||
return false
|
||||
}
|
||||
|
||||
var value []byte
|
||||
name := lexer.Bytes()
|
||||
if i := bytes.IndexByte(name, '='); i != -1 {
|
||||
value = name[i+1:]
|
||||
name = name[:i]
|
||||
} else if c.Strict {
|
||||
if !c.BreakOnPairError {
|
||||
goto nextPair
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if !c.Strict {
|
||||
trimLeft(name)
|
||||
}
|
||||
if !c.DisableNameValidation && !ValidCookieName(name) {
|
||||
if !c.BreakOnPairError {
|
||||
goto nextPair
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if !c.Strict {
|
||||
value = trimRight(value)
|
||||
}
|
||||
value = stripQuotes(value)
|
||||
if !c.DisableValueValidation && !ValidCookieValue(value, c.Strict) {
|
||||
if !c.BreakOnPairError {
|
||||
goto nextPair
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if !it(name, value) {
|
||||
return true
|
||||
}
|
||||
|
||||
nextPair:
|
||||
state = stateBefore
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidCookieValue reports whether given value is a valid RFC6265
|
||||
// "cookie-octet" bytes.
|
||||
//
|
||||
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
|
||||
// ; US-ASCII characters excluding CTLs,
|
||||
// ; whitespace DQUOTE, comma, semicolon,
|
||||
// ; and backslash
|
||||
//
|
||||
// Note that the false strict parameter disables errors on space 0x20 and comma
|
||||
// 0x2c. This could be useful to bring some compatibility with non-compliant
|
||||
// clients/servers in the real world.
|
||||
// It acts the same as standard library cookie parser if strict is false.
|
||||
func ValidCookieValue(value []byte, strict bool) bool {
|
||||
if len(value) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, c := range value {
|
||||
switch c {
|
||||
case '"', ';', '\\':
|
||||
return false
|
||||
case ',', ' ':
|
||||
if strict {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if c <= 0x20 {
|
||||
return false
|
||||
}
|
||||
if c >= 0x7f {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidCookieName reports wheter given bytes is a valid RFC2616 "token" bytes.
|
||||
func ValidCookieName(name []byte) bool {
|
||||
for _, c := range name {
|
||||
if !OctetTypes[c].IsToken() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stripQuotes(bts []byte) []byte {
|
||||
if last := len(bts) - 1; last > 0 && bts[0] == '"' && bts[last] == '"' {
|
||||
return bts[1:last]
|
||||
}
|
||||
return bts
|
||||
}
|
||||
|
||||
func trimLeft(p []byte) []byte {
|
||||
var i int
|
||||
for i < len(p) && OctetTypes[p[i]].IsSpace() {
|
||||
i++
|
||||
}
|
||||
return p[i:]
|
||||
}
|
||||
|
||||
func trimRight(p []byte) []byte {
|
||||
j := len(p)
|
||||
for j > 0 && OctetTypes[p[j-1]].IsSpace() {
|
||||
j--
|
||||
}
|
||||
return p[:j]
|
||||
}
|
|
@ -0,0 +1,275 @@
|
|||
package httphead
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// Version contains protocol major and minor version.
|
||||
type Version struct {
|
||||
Major int
|
||||
Minor int
|
||||
}
|
||||
|
||||
// RequestLine contains parameters parsed from the first request line.
|
||||
type RequestLine struct {
|
||||
Method []byte
|
||||
URI []byte
|
||||
Version Version
|
||||
}
|
||||
|
||||
// ResponseLine contains parameters parsed from the first response line.
|
||||
type ResponseLine struct {
|
||||
Version Version
|
||||
Status int
|
||||
Reason []byte
|
||||
}
|
||||
|
||||
// SplitRequestLine splits given slice of bytes into three chunks without
|
||||
// parsing.
|
||||
func SplitRequestLine(line []byte) (method, uri, version []byte) {
|
||||
return split3(line, ' ')
|
||||
}
|
||||
|
||||
// ParseRequestLine parses http request line like "GET / HTTP/1.0".
|
||||
func ParseRequestLine(line []byte) (r RequestLine, ok bool) {
|
||||
var i int
|
||||
for i = 0; i < len(line); i++ {
|
||||
c := line[i]
|
||||
if !OctetTypes[c].IsToken() {
|
||||
if i > 0 && c == ' ' {
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if i == len(line) {
|
||||
return
|
||||
}
|
||||
|
||||
var proto []byte
|
||||
r.Method = line[:i]
|
||||
r.URI, proto = split2(line[i+1:], ' ')
|
||||
if len(r.URI) == 0 {
|
||||
return
|
||||
}
|
||||
if major, minor, ok := ParseVersion(proto); ok {
|
||||
r.Version.Major = major
|
||||
r.Version.Minor = minor
|
||||
return r, true
|
||||
}
|
||||
|
||||
return r, false
|
||||
}
|
||||
|
||||
// SplitResponseLine splits given slice of bytes into three chunks without
|
||||
// parsing.
|
||||
func SplitResponseLine(line []byte) (version, status, reason []byte) {
|
||||
return split3(line, ' ')
|
||||
}
|
||||
|
||||
// ParseResponseLine parses first response line into ResponseLine struct.
|
||||
func ParseResponseLine(line []byte) (r ResponseLine, ok bool) {
|
||||
var (
|
||||
proto []byte
|
||||
status []byte
|
||||
)
|
||||
proto, status, r.Reason = split3(line, ' ')
|
||||
if major, minor, ok := ParseVersion(proto); ok {
|
||||
r.Version.Major = major
|
||||
r.Version.Minor = minor
|
||||
} else {
|
||||
return r, false
|
||||
}
|
||||
if n, ok := IntFromASCII(status); ok {
|
||||
r.Status = n
|
||||
} else {
|
||||
return r, false
|
||||
}
|
||||
// TODO(gobwas): parse here r.Reason fot TEXT rule:
|
||||
// TEXT = <any OCTET except CTLs,
|
||||
// but including LWS>
|
||||
return r, true
|
||||
}
|
||||
|
||||
var (
|
||||
httpVersion10 = []byte("HTTP/1.0")
|
||||
httpVersion11 = []byte("HTTP/1.1")
|
||||
httpVersionPrefix = []byte("HTTP/")
|
||||
)
|
||||
|
||||
// ParseVersion parses major and minor version of HTTP protocol.
|
||||
// It returns parsed values and true if parse is ok.
|
||||
func ParseVersion(bts []byte) (major, minor int, ok bool) {
|
||||
switch {
|
||||
case bytes.Equal(bts, httpVersion11):
|
||||
return 1, 1, true
|
||||
case bytes.Equal(bts, httpVersion10):
|
||||
return 1, 0, true
|
||||
case len(bts) < 8:
|
||||
return
|
||||
case !bytes.Equal(bts[:5], httpVersionPrefix):
|
||||
return
|
||||
}
|
||||
|
||||
bts = bts[5:]
|
||||
|
||||
dot := bytes.IndexByte(bts, '.')
|
||||
if dot == -1 {
|
||||
return
|
||||
}
|
||||
major, ok = IntFromASCII(bts[:dot])
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
minor, ok = IntFromASCII(bts[dot+1:])
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
return major, minor, true
|
||||
}
|
||||
|
||||
// ReadLine reads line from br. It reads until '\n' and returns bytes without
|
||||
// '\n' or '\r\n' at the end.
|
||||
// It returns err if and only if line does not end in '\n'. Note that read
|
||||
// bytes returned in any case of error.
|
||||
//
|
||||
// It is much like the textproto/Reader.ReadLine() except the thing that it
|
||||
// returns raw bytes, instead of string. That is, it avoids copying bytes read
|
||||
// from br.
|
||||
//
|
||||
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
|
||||
// safe with future I/O operations on br.
|
||||
//
|
||||
// We could control I/O operations on br and do not need to make additional
|
||||
// copy for safety.
|
||||
func ReadLine(br *bufio.Reader) ([]byte, error) {
|
||||
var line []byte
|
||||
for {
|
||||
bts, err := br.ReadSlice('\n')
|
||||
if err == bufio.ErrBufferFull {
|
||||
// Copy bytes because next read will discard them.
|
||||
line = append(line, bts...)
|
||||
continue
|
||||
}
|
||||
// Avoid copy of single read.
|
||||
if line == nil {
|
||||
line = bts
|
||||
} else {
|
||||
line = append(line, bts...)
|
||||
}
|
||||
if err != nil {
|
||||
return line, err
|
||||
}
|
||||
// Size of line is at least 1.
|
||||
// In other case bufio.ReadSlice() returns error.
|
||||
n := len(line)
|
||||
// Cut '\n' or '\r\n'.
|
||||
if n > 1 && line[n-2] == '\r' {
|
||||
line = line[:n-2]
|
||||
} else {
|
||||
line = line[:n-1]
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ParseHeaderLine parses HTTP header as key-value pair. It returns parsed
|
||||
// values and true if parse is ok.
|
||||
func ParseHeaderLine(line []byte) (k, v []byte, ok bool) {
|
||||
colon := bytes.IndexByte(line, ':')
|
||||
if colon == -1 {
|
||||
return
|
||||
}
|
||||
k = trim(line[:colon])
|
||||
for _, c := range k {
|
||||
if !OctetTypes[c].IsToken() {
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
v = trim(line[colon+1:])
|
||||
return k, v, true
|
||||
}
|
||||
|
||||
// IntFromASCII converts ascii encoded decimal numeric value from HTTP entities
|
||||
// to an integer.
|
||||
func IntFromASCII(bts []byte) (ret int, ok bool) {
|
||||
// ASCII numbers all start with the high-order bits 0011.
|
||||
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
|
||||
// bits and interpret them directly as an integer.
|
||||
var n int
|
||||
if n = len(bts); n < 1 {
|
||||
return 0, false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
if bts[i]&0xf0 != 0x30 {
|
||||
return 0, false
|
||||
}
|
||||
ret += int(bts[i]&0xf) * pow(10, n-i-1)
|
||||
}
|
||||
return ret, true
|
||||
}
|
||||
|
||||
const (
|
||||
toLower = 'a' - 'A' // for use with OR.
|
||||
toUpper = ^byte(toLower) // for use with AND.
|
||||
)
|
||||
|
||||
// CanonicalizeHeaderKey is like standard textproto/CanonicalMIMEHeaderKey,
|
||||
// except that it operates with slice of bytes and modifies it inplace without
|
||||
// copying.
|
||||
func CanonicalizeHeaderKey(k []byte) {
|
||||
upper := true
|
||||
for i, c := range k {
|
||||
if upper && 'a' <= c && c <= 'z' {
|
||||
k[i] &= toUpper
|
||||
} else if !upper && 'A' <= c && c <= 'Z' {
|
||||
k[i] |= toLower
|
||||
}
|
||||
upper = c == '-'
|
||||
}
|
||||
}
|
||||
|
||||
// pow for integers implementation.
|
||||
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
|
||||
func pow(a, b int) int {
|
||||
p := 1
|
||||
for b > 0 {
|
||||
if b&1 != 0 {
|
||||
p *= a
|
||||
}
|
||||
b >>= 1
|
||||
a *= a
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func split3(p []byte, sep byte) (p1, p2, p3 []byte) {
|
||||
a := bytes.IndexByte(p, sep)
|
||||
b := bytes.IndexByte(p[a+1:], sep)
|
||||
if a == -1 || b == -1 {
|
||||
return p, nil, nil
|
||||
}
|
||||
b += a + 1
|
||||
return p[:a], p[a+1 : b], p[b+1:]
|
||||
}
|
||||
|
||||
func split2(p []byte, sep byte) (p1, p2 []byte) {
|
||||
i := bytes.IndexByte(p, sep)
|
||||
if i == -1 {
|
||||
return p, nil
|
||||
}
|
||||
return p[:i], p[i+1:]
|
||||
}
|
||||
|
||||
func trim(p []byte) []byte {
|
||||
var i, j int
|
||||
for i = 0; i < len(p) && (p[i] == ' ' || p[i] == '\t'); {
|
||||
i++
|
||||
}
|
||||
for j = len(p); j > i && (p[j-1] == ' ' || p[j-1] == '\t'); {
|
||||
j--
|
||||
}
|
||||
return p[i:j]
|
||||
}
|
|
@ -0,0 +1,331 @@
|
|||
// Package httphead contains utils for parsing HTTP and HTTP-grammar compatible
|
||||
// text protocols headers.
|
||||
//
|
||||
// That is, this package first aim is to bring ability to easily parse
|
||||
// constructions, described here https://tools.ietf.org/html/rfc2616#section-2
|
||||
package httphead
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ScanTokens parses data in this form:
|
||||
//
|
||||
// list = 1#token
|
||||
//
|
||||
// It returns false if data is malformed.
|
||||
func ScanTokens(data []byte, it func([]byte) bool) bool {
|
||||
lexer := &Scanner{data: data}
|
||||
|
||||
var ok bool
|
||||
for lexer.Next() {
|
||||
switch lexer.Type() {
|
||||
case ItemToken:
|
||||
ok = true
|
||||
if !it(lexer.Bytes()) {
|
||||
return true
|
||||
}
|
||||
case ItemSeparator:
|
||||
if !isComma(lexer.Bytes()) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return ok && !lexer.err
|
||||
}
|
||||
|
||||
// ParseOptions parses all header options and appends it to given slice of
|
||||
// Option. It returns flag of successful (wellformed input) parsing.
|
||||
//
|
||||
// Note that appended options are all consist of subslices of data. That is,
|
||||
// mutation of data will mutate appended options.
|
||||
func ParseOptions(data []byte, options []Option) ([]Option, bool) {
|
||||
var i int
|
||||
index := -1
|
||||
return options, ScanOptions(data, func(idx int, name, attr, val []byte) Control {
|
||||
if idx != index {
|
||||
index = idx
|
||||
i = len(options)
|
||||
options = append(options, Option{Name: name})
|
||||
}
|
||||
if attr != nil {
|
||||
options[i].Parameters.Set(attr, val)
|
||||
}
|
||||
return ControlContinue
|
||||
})
|
||||
}
|
||||
|
||||
// SelectFlag encodes way of options selection.
|
||||
type SelectFlag byte
|
||||
|
||||
// String represetns flag as string.
|
||||
func (f SelectFlag) String() string {
|
||||
var flags [2]string
|
||||
var n int
|
||||
if f&SelectCopy != 0 {
|
||||
flags[n] = "copy"
|
||||
n++
|
||||
}
|
||||
if f&SelectUnique != 0 {
|
||||
flags[n] = "unique"
|
||||
n++
|
||||
}
|
||||
return "[" + strings.Join(flags[:n], "|") + "]"
|
||||
}
|
||||
|
||||
const (
|
||||
// SelectCopy causes selector to copy selected option before appending it
|
||||
// to resulting slice.
|
||||
// If SelectCopy flag is not passed to selector, then appended options will
|
||||
// contain sub-slices of the initial data.
|
||||
SelectCopy SelectFlag = 1 << iota
|
||||
|
||||
// SelectUnique causes selector to append only not yet existing option to
|
||||
// resulting slice. Unique is checked by comparing option names.
|
||||
SelectUnique
|
||||
)
|
||||
|
||||
// OptionSelector contains configuration for selecting Options from header value.
|
||||
type OptionSelector struct {
|
||||
// Check is a filter function that applied to every Option that possibly
|
||||
// could be selected.
|
||||
// If Check is nil all options will be selected.
|
||||
Check func(Option) bool
|
||||
|
||||
// Flags contains flags for options selection.
|
||||
Flags SelectFlag
|
||||
|
||||
// Alloc used to allocate slice of bytes when selector is configured with
|
||||
// SelectCopy flag. It will be called with number of bytes needed for copy
|
||||
// of single Option.
|
||||
// If Alloc is nil make is used.
|
||||
Alloc func(n int) []byte
|
||||
}
|
||||
|
||||
// Select parses header data and appends it to given slice of Option.
|
||||
// It also returns flag of successful (wellformed input) parsing.
|
||||
func (s OptionSelector) Select(data []byte, options []Option) ([]Option, bool) {
|
||||
var current Option
|
||||
var has bool
|
||||
index := -1
|
||||
|
||||
alloc := s.Alloc
|
||||
if alloc == nil {
|
||||
alloc = defaultAlloc
|
||||
}
|
||||
check := s.Check
|
||||
if check == nil {
|
||||
check = defaultCheck
|
||||
}
|
||||
|
||||
ok := ScanOptions(data, func(idx int, name, attr, val []byte) Control {
|
||||
if idx != index {
|
||||
if has && check(current) {
|
||||
if s.Flags&SelectCopy != 0 {
|
||||
current = current.Copy(alloc(current.Size()))
|
||||
}
|
||||
options = append(options, current)
|
||||
has = false
|
||||
}
|
||||
if s.Flags&SelectUnique != 0 {
|
||||
for i := len(options) - 1; i >= 0; i-- {
|
||||
if bytes.Equal(options[i].Name, name) {
|
||||
return ControlSkip
|
||||
}
|
||||
}
|
||||
}
|
||||
index = idx
|
||||
current = Option{Name: name}
|
||||
has = true
|
||||
}
|
||||
if attr != nil {
|
||||
current.Parameters.Set(attr, val)
|
||||
}
|
||||
|
||||
return ControlContinue
|
||||
})
|
||||
if has && check(current) {
|
||||
if s.Flags&SelectCopy != 0 {
|
||||
current = current.Copy(alloc(current.Size()))
|
||||
}
|
||||
options = append(options, current)
|
||||
}
|
||||
|
||||
return options, ok
|
||||
}
|
||||
|
||||
func defaultAlloc(n int) []byte { return make([]byte, n) }
|
||||
func defaultCheck(Option) bool { return true }
|
||||
|
||||
// Control represents operation that scanner should perform.
|
||||
type Control byte
|
||||
|
||||
const (
|
||||
// ControlContinue causes scanner to continue scan tokens.
|
||||
ControlContinue Control = iota
|
||||
// ControlBreak causes scanner to stop scan tokens.
|
||||
ControlBreak
|
||||
// ControlSkip causes scanner to skip current entity.
|
||||
ControlSkip
|
||||
)
|
||||
|
||||
// ScanOptions parses data in this form:
|
||||
//
|
||||
// values = 1#value
|
||||
// value = token *( ";" param )
|
||||
// param = token [ "=" (token | quoted-string) ]
|
||||
//
|
||||
// It calls given callback with the index of the option, option itself and its
|
||||
// parameter (attribute and its value, both could be nil). Index is useful when
|
||||
// header contains multiple choises for the same named option.
|
||||
//
|
||||
// Given callback should return one of the defined Control* values.
|
||||
// ControlSkip means that passed key is not in caller's interest. That is, all
|
||||
// parameters of that key will be skipped.
|
||||
// ControlBreak means that no more keys and parameters should be parsed. That
|
||||
// is, it must break parsing immediately.
|
||||
// ControlContinue means that caller want to receive next parameter and its
|
||||
// value or the next key.
|
||||
//
|
||||
// It returns false if data is malformed.
|
||||
func ScanOptions(data []byte, it func(index int, option, attribute, value []byte) Control) bool {
|
||||
lexer := &Scanner{data: data}
|
||||
|
||||
var ok bool
|
||||
var state int
|
||||
const (
|
||||
stateKey = iota
|
||||
stateParamBeforeName
|
||||
stateParamName
|
||||
stateParamBeforeValue
|
||||
stateParamValue
|
||||
)
|
||||
|
||||
var (
|
||||
index int
|
||||
key, param, value []byte
|
||||
mustCall bool
|
||||
)
|
||||
for lexer.Next() {
|
||||
var (
|
||||
call bool
|
||||
growIndex int
|
||||
)
|
||||
|
||||
t := lexer.Type()
|
||||
v := lexer.Bytes()
|
||||
|
||||
switch t {
|
||||
case ItemToken:
|
||||
switch state {
|
||||
case stateKey, stateParamBeforeName:
|
||||
key = v
|
||||
state = stateParamBeforeName
|
||||
mustCall = true
|
||||
case stateParamName:
|
||||
param = v
|
||||
state = stateParamBeforeValue
|
||||
mustCall = true
|
||||
case stateParamValue:
|
||||
value = v
|
||||
state = stateParamBeforeName
|
||||
call = true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
case ItemString:
|
||||
if state != stateParamValue {
|
||||
return false
|
||||
}
|
||||
value = v
|
||||
state = stateParamBeforeName
|
||||
call = true
|
||||
|
||||
case ItemSeparator:
|
||||
switch {
|
||||
case isComma(v) && state == stateKey:
|
||||
// Nothing to do.
|
||||
|
||||
case isComma(v) && state == stateParamBeforeName:
|
||||
state = stateKey
|
||||
// Make call only if we have not called this key yet.
|
||||
call = mustCall
|
||||
if !call {
|
||||
// If we have already called callback with the key
|
||||
// that just ended.
|
||||
index++
|
||||
} else {
|
||||
// Else grow the index after calling callback.
|
||||
growIndex = 1
|
||||
}
|
||||
|
||||
case isComma(v) && state == stateParamBeforeValue:
|
||||
state = stateKey
|
||||
growIndex = 1
|
||||
call = true
|
||||
|
||||
case isSemicolon(v) && state == stateParamBeforeName:
|
||||
state = stateParamName
|
||||
|
||||
case isSemicolon(v) && state == stateParamBeforeValue:
|
||||
state = stateParamName
|
||||
call = true
|
||||
|
||||
case isEquality(v) && state == stateParamBeforeValue:
|
||||
state = stateParamValue
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if call {
|
||||
switch it(index, key, param, value) {
|
||||
case ControlBreak:
|
||||
// User want to stop to parsing parameters.
|
||||
return true
|
||||
|
||||
case ControlSkip:
|
||||
// User want to skip current param.
|
||||
state = stateKey
|
||||
lexer.SkipEscaped(',')
|
||||
|
||||
case ControlContinue:
|
||||
// User is interested in rest of parameters.
|
||||
// Nothing to do.
|
||||
|
||||
default:
|
||||
panic("unexpected control value")
|
||||
}
|
||||
ok = true
|
||||
param = nil
|
||||
value = nil
|
||||
mustCall = false
|
||||
index += growIndex
|
||||
}
|
||||
}
|
||||
if mustCall {
|
||||
ok = true
|
||||
it(index, key, param, value)
|
||||
}
|
||||
|
||||
return ok && !lexer.err
|
||||
}
|
||||
|
||||
func isComma(b []byte) bool {
|
||||
return len(b) == 1 && b[0] == ','
|
||||
}
|
||||
func isSemicolon(b []byte) bool {
|
||||
return len(b) == 1 && b[0] == ';'
|
||||
}
|
||||
func isEquality(b []byte) bool {
|
||||
return len(b) == 1 && b[0] == '='
|
||||
}
|
|
@ -0,0 +1,360 @@
|
|||
package httphead
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
// ItemType encodes type of the lexing token.
|
||||
type ItemType int
|
||||
|
||||
const (
|
||||
// ItemUndef reports that token is undefined.
|
||||
ItemUndef ItemType = iota
|
||||
// ItemToken reports that token is RFC2616 token.
|
||||
ItemToken
|
||||
// ItemSeparator reports that token is RFC2616 separator.
|
||||
ItemSeparator
|
||||
// ItemString reports that token is RFC2616 quouted string.
|
||||
ItemString
|
||||
// ItemComment reports that token is RFC2616 comment.
|
||||
ItemComment
|
||||
// ItemOctet reports that token is octet slice.
|
||||
ItemOctet
|
||||
)
|
||||
|
||||
// Scanner represents header tokens scanner.
|
||||
// See https://tools.ietf.org/html/rfc2616#section-2
|
||||
type Scanner struct {
|
||||
data []byte
|
||||
pos int
|
||||
|
||||
itemType ItemType
|
||||
itemBytes []byte
|
||||
|
||||
err bool
|
||||
}
|
||||
|
||||
// NewScanner creates new RFC2616 data scanner.
|
||||
func NewScanner(data []byte) *Scanner {
|
||||
return &Scanner{data: data}
|
||||
}
|
||||
|
||||
// Next scans for next token. It returns true on successful scanning, and false
|
||||
// on error or EOF.
|
||||
func (l *Scanner) Next() bool {
|
||||
c, ok := l.nextChar()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch c {
|
||||
case '"': // quoted-string;
|
||||
return l.fetchQuotedString()
|
||||
|
||||
case '(': // comment;
|
||||
return l.fetchComment()
|
||||
|
||||
case '\\', ')': // unexpected chars;
|
||||
l.err = true
|
||||
return false
|
||||
|
||||
default:
|
||||
return l.fetchToken()
|
||||
}
|
||||
}
|
||||
|
||||
// FetchUntil fetches ItemOctet from current scanner position to first
|
||||
// occurence of the c or to the end of the underlying data.
|
||||
func (l *Scanner) FetchUntil(c byte) bool {
|
||||
l.resetItem()
|
||||
if l.pos == len(l.data) {
|
||||
return false
|
||||
}
|
||||
return l.fetchOctet(c)
|
||||
}
|
||||
|
||||
// Peek reads byte at current position without advancing it. On end of data it
|
||||
// returns 0.
|
||||
func (l *Scanner) Peek() byte {
|
||||
if l.pos == len(l.data) {
|
||||
return 0
|
||||
}
|
||||
return l.data[l.pos]
|
||||
}
|
||||
|
||||
// Peek2 reads two first bytes at current position without advancing it.
|
||||
// If there not enough data it returs 0.
|
||||
func (l *Scanner) Peek2() (a, b byte) {
|
||||
if l.pos == len(l.data) {
|
||||
return 0, 0
|
||||
}
|
||||
if l.pos+1 == len(l.data) {
|
||||
return l.data[l.pos], 0
|
||||
}
|
||||
return l.data[l.pos], l.data[l.pos+1]
|
||||
}
|
||||
|
||||
// Buffered reporst how many bytes there are left to scan.
|
||||
func (l *Scanner) Buffered() int {
|
||||
return len(l.data) - l.pos
|
||||
}
|
||||
|
||||
// Advance moves current position index at n bytes. It returns true on
|
||||
// successful move.
|
||||
func (l *Scanner) Advance(n int) bool {
|
||||
l.pos += n
|
||||
if l.pos > len(l.data) {
|
||||
l.pos = len(l.data)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip skips all bytes until first occurence of c.
|
||||
func (l *Scanner) Skip(c byte) {
|
||||
if l.err {
|
||||
return
|
||||
}
|
||||
// Reset scanner state.
|
||||
l.resetItem()
|
||||
|
||||
if i := bytes.IndexByte(l.data[l.pos:], c); i == -1 {
|
||||
// Reached the end of data.
|
||||
l.pos = len(l.data)
|
||||
} else {
|
||||
l.pos += i + 1
|
||||
}
|
||||
}
|
||||
|
||||
// SkipEscaped skips all bytes until first occurence of non-escaped c.
|
||||
func (l *Scanner) SkipEscaped(c byte) {
|
||||
if l.err {
|
||||
return
|
||||
}
|
||||
// Reset scanner state.
|
||||
l.resetItem()
|
||||
|
||||
if i := ScanUntil(l.data[l.pos:], c); i == -1 {
|
||||
// Reached the end of data.
|
||||
l.pos = len(l.data)
|
||||
} else {
|
||||
l.pos += i + 1
|
||||
}
|
||||
}
|
||||
|
||||
// Type reports current token type.
|
||||
func (l *Scanner) Type() ItemType {
|
||||
return l.itemType
|
||||
}
|
||||
|
||||
// Bytes returns current token bytes.
|
||||
func (l *Scanner) Bytes() []byte {
|
||||
return l.itemBytes
|
||||
}
|
||||
|
||||
func (l *Scanner) nextChar() (byte, bool) {
|
||||
// Reset scanner state.
|
||||
l.resetItem()
|
||||
|
||||
if l.err {
|
||||
return 0, false
|
||||
}
|
||||
l.pos += SkipSpace(l.data[l.pos:])
|
||||
if l.pos == len(l.data) {
|
||||
return 0, false
|
||||
}
|
||||
return l.data[l.pos], true
|
||||
}
|
||||
|
||||
func (l *Scanner) resetItem() {
|
||||
l.itemType = ItemUndef
|
||||
l.itemBytes = nil
|
||||
}
|
||||
|
||||
func (l *Scanner) fetchOctet(c byte) bool {
|
||||
i := l.pos
|
||||
if j := bytes.IndexByte(l.data[l.pos:], c); j == -1 {
|
||||
// Reached the end of data.
|
||||
l.pos = len(l.data)
|
||||
} else {
|
||||
l.pos += j
|
||||
}
|
||||
|
||||
l.itemType = ItemOctet
|
||||
l.itemBytes = l.data[i:l.pos]
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *Scanner) fetchToken() bool {
|
||||
n, t := ScanToken(l.data[l.pos:])
|
||||
if n == -1 {
|
||||
l.err = true
|
||||
return false
|
||||
}
|
||||
|
||||
l.itemType = t
|
||||
l.itemBytes = l.data[l.pos : l.pos+n]
|
||||
l.pos += n
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *Scanner) fetchQuotedString() (ok bool) {
|
||||
l.pos++
|
||||
|
||||
n := ScanUntil(l.data[l.pos:], '"')
|
||||
if n == -1 {
|
||||
l.err = true
|
||||
return false
|
||||
}
|
||||
|
||||
l.itemType = ItemString
|
||||
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
|
||||
l.pos += n + 1
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *Scanner) fetchComment() (ok bool) {
|
||||
l.pos++
|
||||
|
||||
n := ScanPairGreedy(l.data[l.pos:], '(', ')')
|
||||
if n == -1 {
|
||||
l.err = true
|
||||
return false
|
||||
}
|
||||
|
||||
l.itemType = ItemComment
|
||||
l.itemBytes = RemoveByte(l.data[l.pos:l.pos+n], '\\')
|
||||
l.pos += n + 1
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ScanUntil scans for first non-escaped character c in given data.
|
||||
// It returns index of matched c and -1 if c is not found.
|
||||
func ScanUntil(data []byte, c byte) (n int) {
|
||||
for {
|
||||
i := bytes.IndexByte(data[n:], c)
|
||||
if i == -1 {
|
||||
return -1
|
||||
}
|
||||
n += i
|
||||
if n == 0 || data[n-1] != '\\' {
|
||||
break
|
||||
}
|
||||
n++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ScanPairGreedy scans for complete pair of opening and closing chars in greedy manner.
|
||||
// Note that first opening byte must not be present in data.
|
||||
func ScanPairGreedy(data []byte, open, close byte) (n int) {
|
||||
var m int
|
||||
opened := 1
|
||||
for {
|
||||
i := bytes.IndexByte(data[n:], close)
|
||||
if i == -1 {
|
||||
return -1
|
||||
}
|
||||
n += i
|
||||
// If found index is not escaped then it is the end.
|
||||
if n == 0 || data[n-1] != '\\' {
|
||||
opened--
|
||||
}
|
||||
|
||||
for m < i {
|
||||
j := bytes.IndexByte(data[m:i], open)
|
||||
if j == -1 {
|
||||
break
|
||||
}
|
||||
m += j + 1
|
||||
opened++
|
||||
}
|
||||
|
||||
if opened == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
n++
|
||||
m = n
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RemoveByte returns data without c. If c is not present in data it returns
|
||||
// the same slice. If not, it copies data without c.
|
||||
func RemoveByte(data []byte, c byte) []byte {
|
||||
j := bytes.IndexByte(data, c)
|
||||
if j == -1 {
|
||||
return data
|
||||
}
|
||||
|
||||
n := len(data) - 1
|
||||
|
||||
// If character is present, than allocate slice with n-1 capacity. That is,
|
||||
// resulting bytes could be at most n-1 length.
|
||||
result := make([]byte, n)
|
||||
k := copy(result, data[:j])
|
||||
|
||||
for i := j + 1; i < n; {
|
||||
j = bytes.IndexByte(data[i:], c)
|
||||
if j != -1 {
|
||||
k += copy(result[k:], data[i:i+j])
|
||||
i = i + j + 1
|
||||
} else {
|
||||
k += copy(result[k:], data[i:])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return result[:k]
|
||||
}
|
||||
|
||||
// SkipSpace skips spaces and lws-sequences from p.
|
||||
// It returns number ob bytes skipped.
|
||||
func SkipSpace(p []byte) (n int) {
|
||||
for len(p) > 0 {
|
||||
switch {
|
||||
case len(p) >= 3 &&
|
||||
p[0] == '\r' &&
|
||||
p[1] == '\n' &&
|
||||
OctetTypes[p[2]].IsSpace():
|
||||
p = p[3:]
|
||||
n += 3
|
||||
case OctetTypes[p[0]].IsSpace():
|
||||
p = p[1:]
|
||||
n++
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ScanToken scan for next token in p. It returns length of the token and its
|
||||
// type. It do not trim p.
|
||||
func ScanToken(p []byte) (n int, t ItemType) {
|
||||
if len(p) == 0 {
|
||||
return 0, ItemUndef
|
||||
}
|
||||
|
||||
c := p[0]
|
||||
switch {
|
||||
case OctetTypes[c].IsSeparator():
|
||||
return 1, ItemSeparator
|
||||
|
||||
case OctetTypes[c].IsToken():
|
||||
for n = 1; n < len(p); n++ {
|
||||
c := p[n]
|
||||
if !OctetTypes[c].IsToken() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return n, ItemToken
|
||||
|
||||
default:
|
||||
return -1, ItemUndef
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package httphead
|
||||
|
||||
// OctetType desribes character type.
|
||||
//
|
||||
// From the "Basic Rules" chapter of RFC2616
|
||||
// See https://tools.ietf.org/html/rfc2616#section-2.2
|
||||
//
|
||||
// OCTET = <any 8-bit sequence of data>
|
||||
// CHAR = <any US-ASCII character (octets 0 - 127)>
|
||||
// UPALPHA = <any US-ASCII uppercase letter "A".."Z">
|
||||
// LOALPHA = <any US-ASCII lowercase letter "a".."z">
|
||||
// ALPHA = UPALPHA | LOALPHA
|
||||
// DIGIT = <any US-ASCII digit "0".."9">
|
||||
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
|
||||
// CR = <US-ASCII CR, carriage return (13)>
|
||||
// LF = <US-ASCII LF, linefeed (10)>
|
||||
// SP = <US-ASCII SP, space (32)>
|
||||
// HT = <US-ASCII HT, horizontal-tab (9)>
|
||||
// <"> = <US-ASCII double-quote mark (34)>
|
||||
// CRLF = CR LF
|
||||
// LWS = [CRLF] 1*( SP | HT )
|
||||
//
|
||||
// Many HTTP/1.1 header field values consist of words separated by LWS
|
||||
// or special characters. These special characters MUST be in a quoted
|
||||
// string to be used within a parameter value (as defined in section
|
||||
// 3.6).
|
||||
//
|
||||
// token = 1*<any CHAR except CTLs or separators>
|
||||
// separators = "(" | ")" | "<" | ">" | "@"
|
||||
// | "," | ";" | ":" | "\" | <">
|
||||
// | "/" | "[" | "]" | "?" | "="
|
||||
// | "{" | "}" | SP | HT
|
||||
type OctetType byte
|
||||
|
||||
// IsChar reports whether octet is CHAR.
|
||||
func (t OctetType) IsChar() bool { return t&octetChar != 0 }
|
||||
|
||||
// IsControl reports whether octet is CTL.
|
||||
func (t OctetType) IsControl() bool { return t&octetControl != 0 }
|
||||
|
||||
// IsSeparator reports whether octet is separator.
|
||||
func (t OctetType) IsSeparator() bool { return t&octetSeparator != 0 }
|
||||
|
||||
// IsSpace reports whether octet is space (SP or HT).
|
||||
func (t OctetType) IsSpace() bool { return t&octetSpace != 0 }
|
||||
|
||||
// IsToken reports whether octet is token.
|
||||
func (t OctetType) IsToken() bool { return t&octetToken != 0 }
|
||||
|
||||
const (
|
||||
octetChar OctetType = 1 << iota
|
||||
octetControl
|
||||
octetSpace
|
||||
octetSeparator
|
||||
octetToken
|
||||
)
|
||||
|
||||
// OctetTypes is a table of octets.
|
||||
var OctetTypes [256]OctetType
|
||||
|
||||
func init() {
|
||||
for c := 32; c < 256; c++ {
|
||||
var t OctetType
|
||||
if c <= 127 {
|
||||
t |= octetChar
|
||||
}
|
||||
if 0 <= c && c <= 31 || c == 127 {
|
||||
t |= octetControl
|
||||
}
|
||||
switch c {
|
||||
case '(', ')', '<', '>', '@', ',', ';', ':', '"', '/', '[', ']', '?', '=', '{', '}', '\\':
|
||||
t |= octetSeparator
|
||||
case ' ', '\t':
|
||||
t |= octetSpace | octetSeparator
|
||||
}
|
||||
|
||||
if t.IsChar() && !t.IsControl() && !t.IsSeparator() && !t.IsSpace() {
|
||||
t |= octetToken
|
||||
}
|
||||
|
||||
OctetTypes[c] = t
|
||||
}
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
package httphead
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Option represents a header option.
|
||||
type Option struct {
|
||||
Name []byte
|
||||
Parameters Parameters
|
||||
}
|
||||
|
||||
// Size returns number of bytes need to be allocated for use in opt.Copy.
|
||||
func (opt Option) Size() int {
|
||||
return len(opt.Name) + opt.Parameters.bytes
|
||||
}
|
||||
|
||||
// Copy copies all underlying []byte slices into p and returns new Option.
|
||||
// Note that p must be at least of opt.Size() length.
|
||||
func (opt Option) Copy(p []byte) Option {
|
||||
n := copy(p, opt.Name)
|
||||
opt.Name = p[:n]
|
||||
opt.Parameters, p = opt.Parameters.Copy(p[n:])
|
||||
return opt
|
||||
}
|
||||
|
||||
// Clone is a shorthand for making slice of opt.Size() sequenced with Copy()
|
||||
// call.
|
||||
func (opt Option) Clone() Option {
|
||||
return opt.Copy(make([]byte, opt.Size()))
|
||||
}
|
||||
|
||||
// String represents option as a string.
|
||||
func (opt Option) String() string {
|
||||
return "{" + string(opt.Name) + " " + opt.Parameters.String() + "}"
|
||||
}
|
||||
|
||||
// NewOption creates named option with given parameters.
|
||||
func NewOption(name string, params map[string]string) Option {
|
||||
p := Parameters{}
|
||||
for k, v := range params {
|
||||
p.Set([]byte(k), []byte(v))
|
||||
}
|
||||
return Option{
|
||||
Name: []byte(name),
|
||||
Parameters: p,
|
||||
}
|
||||
}
|
||||
|
||||
// Equal reports whether option is equal to b.
|
||||
func (opt Option) Equal(b Option) bool {
|
||||
if bytes.Equal(opt.Name, b.Name) {
|
||||
return opt.Parameters.Equal(b.Parameters)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Parameters represents option's parameters.
|
||||
type Parameters struct {
|
||||
pos int
|
||||
bytes int
|
||||
arr [8]pair
|
||||
dyn []pair
|
||||
}
|
||||
|
||||
// Equal reports whether a equal to b.
|
||||
func (p Parameters) Equal(b Parameters) bool {
|
||||
switch {
|
||||
case p.dyn == nil && b.dyn == nil:
|
||||
case p.dyn != nil && b.dyn != nil:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
ad, bd := p.data(), b.data()
|
||||
if len(ad) != len(bd) {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Sort(pairs(ad))
|
||||
sort.Sort(pairs(bd))
|
||||
|
||||
for i := 0; i < len(ad); i++ {
|
||||
av, bv := ad[i], bd[i]
|
||||
if !bytes.Equal(av.key, bv.key) || !bytes.Equal(av.value, bv.value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Size returns number of bytes that needed to copy p.
|
||||
func (p *Parameters) Size() int {
|
||||
return p.bytes
|
||||
}
|
||||
|
||||
// Copy copies all underlying []byte slices into dst and returns new
|
||||
// Parameters.
|
||||
// Note that dst must be at least of p.Size() length.
|
||||
func (p *Parameters) Copy(dst []byte) (Parameters, []byte) {
|
||||
ret := Parameters{
|
||||
pos: p.pos,
|
||||
bytes: p.bytes,
|
||||
}
|
||||
if p.dyn != nil {
|
||||
ret.dyn = make([]pair, len(p.dyn))
|
||||
for i, v := range p.dyn {
|
||||
ret.dyn[i], dst = v.copy(dst)
|
||||
}
|
||||
} else {
|
||||
for i, p := range p.arr {
|
||||
ret.arr[i], dst = p.copy(dst)
|
||||
}
|
||||
}
|
||||
return ret, dst
|
||||
}
|
||||
|
||||
// Get returns value by key and flag about existence such value.
|
||||
func (p *Parameters) Get(key string) (value []byte, ok bool) {
|
||||
for _, v := range p.data() {
|
||||
if string(v.key) == key {
|
||||
return v.value, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Set sets value by key.
|
||||
func (p *Parameters) Set(key, value []byte) {
|
||||
p.bytes += len(key) + len(value)
|
||||
|
||||
if p.pos < len(p.arr) {
|
||||
p.arr[p.pos] = pair{key, value}
|
||||
p.pos++
|
||||
return
|
||||
}
|
||||
|
||||
if p.dyn == nil {
|
||||
p.dyn = make([]pair, len(p.arr), len(p.arr)+1)
|
||||
copy(p.dyn, p.arr[:])
|
||||
}
|
||||
p.dyn = append(p.dyn, pair{key, value})
|
||||
}
|
||||
|
||||
// ForEach iterates over parameters key-value pairs and calls cb for each one.
|
||||
func (p *Parameters) ForEach(cb func(k, v []byte) bool) {
|
||||
for _, v := range p.data() {
|
||||
if !cb(v.key, v.value) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String represents parameters as a string.
|
||||
func (p *Parameters) String() (ret string) {
|
||||
ret = "["
|
||||
for i, v := range p.data() {
|
||||
if i > 0 {
|
||||
ret += " "
|
||||
}
|
||||
ret += string(v.key) + ":" + string(v.value)
|
||||
}
|
||||
return ret + "]"
|
||||
}
|
||||
|
||||
func (p *Parameters) data() []pair {
|
||||
if p.dyn != nil {
|
||||
return p.dyn
|
||||
}
|
||||
return p.arr[:p.pos]
|
||||
}
|
||||
|
||||
type pair struct {
|
||||
key, value []byte
|
||||
}
|
||||
|
||||
func (p pair) copy(dst []byte) (pair, []byte) {
|
||||
n := copy(dst, p.key)
|
||||
p.key = dst[:n]
|
||||
m := n + copy(dst[n:], p.value)
|
||||
p.value = dst[n:m]
|
||||
|
||||
dst = dst[m:]
|
||||
|
||||
return p, dst
|
||||
}
|
||||
|
||||
type pairs []pair
|
||||
|
||||
func (p pairs) Len() int { return len(p) }
|
||||
func (p pairs) Less(a, b int) bool { return bytes.Compare(p[a].key, p[b].key) == -1 }
|
||||
func (p pairs) Swap(a, b int) { p[a], p[b] = p[b], p[a] }
|
|
@ -0,0 +1,101 @@
|
|||
package httphead
|
||||
|
||||
import "io"
|
||||
|
||||
var (
|
||||
comma = []byte{','}
|
||||
equality = []byte{'='}
|
||||
semicolon = []byte{';'}
|
||||
quote = []byte{'"'}
|
||||
escape = []byte{'\\'}
|
||||
)
|
||||
|
||||
// WriteOptions write options list to the dest.
|
||||
// It uses the same form as {Scan,Parse}Options functions:
|
||||
// values = 1#value
|
||||
// value = token *( ";" param )
|
||||
// param = token [ "=" (token | quoted-string) ]
|
||||
//
|
||||
// It wraps valuse into the quoted-string sequence if it contains any
|
||||
// non-token characters.
|
||||
func WriteOptions(dest io.Writer, options []Option) (n int, err error) {
|
||||
w := writer{w: dest}
|
||||
for i, opt := range options {
|
||||
if i > 0 {
|
||||
w.write(comma)
|
||||
}
|
||||
|
||||
writeTokenSanitized(&w, opt.Name)
|
||||
|
||||
for _, p := range opt.Parameters.data() {
|
||||
w.write(semicolon)
|
||||
writeTokenSanitized(&w, p.key)
|
||||
if len(p.value) != 0 {
|
||||
w.write(equality)
|
||||
writeTokenSanitized(&w, p.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return w.result()
|
||||
}
|
||||
|
||||
// writeTokenSanitized writes token as is or as quouted string if it contains
|
||||
// non-token characters.
|
||||
//
|
||||
// Note that is is not expects LWS sequnces be in s, cause LWS is used only as
|
||||
// header field continuation:
|
||||
// "A CRLF is allowed in the definition of TEXT only as part of a header field
|
||||
// continuation. It is expected that the folding LWS will be replaced with a
|
||||
// single SP before interpretation of the TEXT value."
|
||||
// See https://tools.ietf.org/html/rfc2616#section-2
|
||||
//
|
||||
// That is we sanitizing s for writing, so there could not be any header field
|
||||
// continuation.
|
||||
// That is any CRLF will be escaped as any other control characters not allowd in TEXT.
|
||||
func writeTokenSanitized(bw *writer, bts []byte) {
|
||||
var qt bool
|
||||
var pos int
|
||||
for i := 0; i < len(bts); i++ {
|
||||
c := bts[i]
|
||||
if !OctetTypes[c].IsToken() && !qt {
|
||||
qt = true
|
||||
bw.write(quote)
|
||||
}
|
||||
if OctetTypes[c].IsControl() || c == '"' {
|
||||
if !qt {
|
||||
qt = true
|
||||
bw.write(quote)
|
||||
}
|
||||
bw.write(bts[pos:i])
|
||||
bw.write(escape)
|
||||
bw.write(bts[i : i+1])
|
||||
pos = i + 1
|
||||
}
|
||||
}
|
||||
if !qt {
|
||||
bw.write(bts)
|
||||
} else {
|
||||
bw.write(bts[pos:])
|
||||
bw.write(quote)
|
||||
}
|
||||
}
|
||||
|
||||
type writer struct {
|
||||
w io.Writer
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
func (w *writer) write(p []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
var n int
|
||||
n, w.err = w.w.Write(p)
|
||||
w.n += n
|
||||
return
|
||||
}
|
||||
|
||||
func (w *writer) result() (int, error) {
|
||||
return w.n, w.err
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2017-2019 Sergey Kamardin <gobwas@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,107 @@
|
|||
# pool
|
||||
|
||||
[![GoDoc][godoc-image]][godoc-url]
|
||||
|
||||
> Tiny memory reuse helpers for Go.
|
||||
|
||||
## generic
|
||||
|
||||
Without use of subpackages, `pool` allows to reuse any struct distinguishable
|
||||
by size in generic way:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "github.com/gobwas/pool"
|
||||
|
||||
func main() {
|
||||
x, n := pool.Get(100) // Returns object with size 128 or nil.
|
||||
if x == nil {
|
||||
// Create x somehow with knowledge that n is 128.
|
||||
}
|
||||
defer pool.Put(x, n)
|
||||
|
||||
// Work with x.
|
||||
}
|
||||
```
|
||||
|
||||
Pool allows you to pass specific options for constructing custom pool:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "github.com/gobwas/pool"
|
||||
|
||||
func main() {
|
||||
p := pool.Custom(
|
||||
pool.WithLogSizeMapping(), // Will ceil size n passed to Get(n) to nearest power of two.
|
||||
pool.WithLogSizeRange(64, 512), // Will reuse objects in logarithmic range [64, 512].
|
||||
pool.WithSize(65536), // Will reuse object with size 65536.
|
||||
)
|
||||
x, n := p.Get(1000) // Returns nil and 1000 because mapped size 1000 => 1024 is not reusing by the pool.
|
||||
defer pool.Put(x, n) // Will not reuse x.
|
||||
|
||||
// Work with x.
|
||||
}
|
||||
```
|
||||
|
||||
Note that there are few non-generic pooling implementations inside subpackages.
|
||||
|
||||
## pbytes
|
||||
|
||||
Subpackage `pbytes` is intended for `[]byte` reuse.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "github.com/gobwas/pool/pbytes"
|
||||
|
||||
func main() {
|
||||
bts := pbytes.GetCap(100) // Returns make([]byte, 0, 128).
|
||||
defer pbytes.Put(bts)
|
||||
|
||||
// Work with bts.
|
||||
}
|
||||
```
|
||||
|
||||
You can also create your own range for pooling:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "github.com/gobwas/pool/pbytes"
|
||||
|
||||
func main() {
|
||||
// Reuse only slices whose capacity is 128, 256, 512 or 1024.
|
||||
pool := pbytes.New(128, 1024)
|
||||
|
||||
bts := pool.GetCap(100) // Returns make([]byte, 0, 128).
|
||||
defer pool.Put(bts)
|
||||
|
||||
// Work with bts.
|
||||
}
|
||||
```
|
||||
|
||||
## pbufio
|
||||
|
||||
Subpackage `pbufio` is intended for `*bufio.{Reader, Writer}` reuse.
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import "github.com/gobwas/pool/pbufio"
|
||||
|
||||
func main() {
|
||||
bw := pbufio.GetWriter(os.Stdout, 100) // Returns bufio.NewWriterSize(128).
|
||||
defer pbufio.PutWriter(bw)
|
||||
|
||||
// Work with bw.
|
||||
}
|
||||
```
|
||||
|
||||
Like with `pbytes`, you can also create pool with custom reuse bounds.
|
||||
|
||||
|
||||
|
||||
[godoc-image]: https://godoc.org/github.com/gobwas/pool?status.svg
|
||||
[godoc-url]: https://godoc.org/github.com/gobwas/pool
|
|
@ -0,0 +1,87 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/gobwas/pool/internal/pmath"
|
||||
)
|
||||
|
||||
var DefaultPool = New(128, 65536)
|
||||
|
||||
// Get pulls object whose generic size is at least of given size. It also
|
||||
// returns a real size of x for further pass to Put(). It returns -1 as real
|
||||
// size for nil x. Size >-1 does not mean that x is non-nil, so checks must be
|
||||
// done.
|
||||
//
|
||||
// Note that size could be ceiled to the next power of two.
|
||||
//
|
||||
// Get is a wrapper around DefaultPool.Get().
|
||||
func Get(size int) (interface{}, int) { return DefaultPool.Get(size) }
|
||||
|
||||
// Put takes x and its size for future reuse.
|
||||
// Put is a wrapper around DefaultPool.Put().
|
||||
func Put(x interface{}, size int) { DefaultPool.Put(x, size) }
|
||||
|
||||
// Pool contains logic of reusing objects distinguishable by size in generic
|
||||
// way.
|
||||
type Pool struct {
|
||||
pool map[int]*sync.Pool
|
||||
size func(int) int
|
||||
}
|
||||
|
||||
// New creates new Pool that reuses objects which size is in logarithmic range
|
||||
// [min, max].
|
||||
//
|
||||
// Note that it is a shortcut for Custom() constructor with Options provided by
|
||||
// WithLogSizeMapping() and WithLogSizeRange(min, max) calls.
|
||||
func New(min, max int) *Pool {
|
||||
return Custom(
|
||||
WithLogSizeMapping(),
|
||||
WithLogSizeRange(min, max),
|
||||
)
|
||||
}
|
||||
|
||||
// Custom creates new Pool with given options.
|
||||
func Custom(opts ...Option) *Pool {
|
||||
p := &Pool{
|
||||
pool: make(map[int]*sync.Pool),
|
||||
size: pmath.Identity,
|
||||
}
|
||||
|
||||
c := (*poolConfig)(p)
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Get pulls object whose generic size is at least of given size.
|
||||
// It also returns a real size of x for further pass to Put() even if x is nil.
|
||||
// Note that size could be ceiled to the next power of two.
|
||||
func (p *Pool) Get(size int) (interface{}, int) {
|
||||
n := p.size(size)
|
||||
if pool := p.pool[n]; pool != nil {
|
||||
return pool.Get(), n
|
||||
}
|
||||
return nil, size
|
||||
}
|
||||
|
||||
// Put takes x and its size for future reuse.
|
||||
func (p *Pool) Put(x interface{}, size int) {
|
||||
if pool := p.pool[size]; pool != nil {
|
||||
pool.Put(x)
|
||||
}
|
||||
}
|
||||
|
||||
type poolConfig Pool
|
||||
|
||||
// AddSize adds size n to the map.
|
||||
func (p *poolConfig) AddSize(n int) {
|
||||
p.pool[n] = new(sync.Pool)
|
||||
}
|
||||
|
||||
// SetSizeMapping sets up incoming size mapping function.
|
||||
func (p *poolConfig) SetSizeMapping(size func(int) int) {
|
||||
p.size = size
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package pmath
|
||||
|
||||
const (
|
||||
bitsize = 32 << (^uint(0) >> 63)
|
||||
maxint = int(1<<(bitsize-1) - 1)
|
||||
maxintHeadBit = 1 << (bitsize - 2)
|
||||
)
|
||||
|
||||
// LogarithmicRange iterates from ceiled to power of two min to max,
|
||||
// calling cb on each iteration.
|
||||
func LogarithmicRange(min, max int, cb func(int)) {
|
||||
if min == 0 {
|
||||
min = 1
|
||||
}
|
||||
for n := CeilToPowerOfTwo(min); n <= max; n <<= 1 {
|
||||
cb(n)
|
||||
}
|
||||
}
|
||||
|
||||
// IsPowerOfTwo reports whether given integer is a power of two.
|
||||
func IsPowerOfTwo(n int) bool {
|
||||
return n&(n-1) == 0
|
||||
}
|
||||
|
||||
// Identity is identity.
|
||||
func Identity(n int) int {
|
||||
return n
|
||||
}
|
||||
|
||||
// CeilToPowerOfTwo returns the least power of two integer value greater than
|
||||
// or equal to n.
|
||||
func CeilToPowerOfTwo(n int) int {
|
||||
if n&maxintHeadBit != 0 && n > maxintHeadBit {
|
||||
panic("argument is too large")
|
||||
}
|
||||
if n <= 2 {
|
||||
return n
|
||||
}
|
||||
n--
|
||||
n = fillBits(n)
|
||||
n++
|
||||
return n
|
||||
}
|
||||
|
||||
// FloorToPowerOfTwo returns the greatest power of two integer value less than
|
||||
// or equal to n.
|
||||
func FloorToPowerOfTwo(n int) int {
|
||||
if n <= 2 {
|
||||
return n
|
||||
}
|
||||
n = fillBits(n)
|
||||
n >>= 1
|
||||
n++
|
||||
return n
|
||||
}
|
||||
|
||||
func fillBits(n int) int {
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
return n
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package pool
|
||||
|
||||
import "github.com/gobwas/pool/internal/pmath"
|
||||
|
||||
// Option configures pool.
|
||||
type Option func(Config)
|
||||
|
||||
// Config describes generic pool configuration.
|
||||
type Config interface {
|
||||
AddSize(n int)
|
||||
SetSizeMapping(func(int) int)
|
||||
}
|
||||
|
||||
// WithSizeLogRange returns an Option that will add logarithmic range of
|
||||
// pooling sizes containing [min, max] values.
|
||||
func WithLogSizeRange(min, max int) Option {
|
||||
return func(c Config) {
|
||||
pmath.LogarithmicRange(min, max, func(n int) {
|
||||
c.AddSize(n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithSize returns an Option that will add given pooling size to the pool.
|
||||
func WithSize(n int) Option {
|
||||
return func(c Config) {
|
||||
c.AddSize(n)
|
||||
}
|
||||
}
|
||||
|
||||
func WithSizeMapping(sz func(int) int) Option {
|
||||
return func(c Config) {
|
||||
c.SetSizeMapping(sz)
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogSizeMapping() Option {
|
||||
return WithSizeMapping(pmath.CeilToPowerOfTwo)
|
||||
}
|
||||
|
||||
func WithIdentitySizeMapping() Option {
|
||||
return WithSizeMapping(pmath.Identity)
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
// Package pbufio contains tools for pooling bufio.Reader and bufio.Writers.
|
||||
package pbufio
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
|
||||
"github.com/gobwas/pool"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultWriterPool = NewWriterPool(256, 65536)
|
||||
DefaultReaderPool = NewReaderPool(256, 65536)
|
||||
)
|
||||
|
||||
// GetWriter returns bufio.Writer whose buffer has at least size bytes.
|
||||
// Note that size could be ceiled to the next power of two.
|
||||
// GetWriter is a wrapper around DefaultWriterPool.Get().
|
||||
func GetWriter(w io.Writer, size int) *bufio.Writer { return DefaultWriterPool.Get(w, size) }
|
||||
|
||||
// PutWriter takes bufio.Writer for future reuse.
|
||||
// It does not reuse bufio.Writer which underlying buffer size is not power of
|
||||
// PutWriter is a wrapper around DefaultWriterPool.Put().
|
||||
func PutWriter(bw *bufio.Writer) { DefaultWriterPool.Put(bw) }
|
||||
|
||||
// GetReader returns bufio.Reader whose buffer has at least size bytes. It returns
|
||||
// its capacity for further pass to Put().
|
||||
// Note that size could be ceiled to the next power of two.
|
||||
// GetReader is a wrapper around DefaultReaderPool.Get().
|
||||
func GetReader(w io.Reader, size int) *bufio.Reader { return DefaultReaderPool.Get(w, size) }
|
||||
|
||||
// PutReader takes bufio.Reader and its size for future reuse.
|
||||
// It does not reuse bufio.Reader if size is not power of two or is out of pool
|
||||
// min/max range.
|
||||
// PutReader is a wrapper around DefaultReaderPool.Put().
|
||||
func PutReader(bw *bufio.Reader) { DefaultReaderPool.Put(bw) }
|
||||
|
||||
// WriterPool contains logic of *bufio.Writer reuse with various size.
|
||||
type WriterPool struct {
|
||||
pool *pool.Pool
|
||||
}
|
||||
|
||||
// NewWriterPool creates new WriterPool that reuses writers which size is in
|
||||
// logarithmic range [min, max].
|
||||
func NewWriterPool(min, max int) *WriterPool {
|
||||
return &WriterPool{pool.New(min, max)}
|
||||
}
|
||||
|
||||
// CustomWriterPool creates new WriterPool with given options.
|
||||
func CustomWriterPool(opts ...pool.Option) *WriterPool {
|
||||
return &WriterPool{pool.Custom(opts...)}
|
||||
}
|
||||
|
||||
// Get returns bufio.Writer whose buffer has at least size bytes.
|
||||
func (wp *WriterPool) Get(w io.Writer, size int) *bufio.Writer {
|
||||
v, n := wp.pool.Get(size)
|
||||
if v != nil {
|
||||
bw := v.(*bufio.Writer)
|
||||
bw.Reset(w)
|
||||
return bw
|
||||
}
|
||||
return bufio.NewWriterSize(w, n)
|
||||
}
|
||||
|
||||
// Put takes ownership of bufio.Writer for further reuse.
|
||||
func (wp *WriterPool) Put(bw *bufio.Writer) {
|
||||
// Should reset even if we do Reset() inside Get().
|
||||
// This is done to prevent locking underlying io.Writer from GC.
|
||||
bw.Reset(nil)
|
||||
wp.pool.Put(bw, writerSize(bw))
|
||||
}
|
||||
|
||||
// ReaderPool contains logic of *bufio.Reader reuse with various size.
|
||||
type ReaderPool struct {
|
||||
pool *pool.Pool
|
||||
}
|
||||
|
||||
// NewReaderPool creates new ReaderPool that reuses writers which size is in
|
||||
// logarithmic range [min, max].
|
||||
func NewReaderPool(min, max int) *ReaderPool {
|
||||
return &ReaderPool{pool.New(min, max)}
|
||||
}
|
||||
|
||||
// CustomReaderPool creates new ReaderPool with given options.
|
||||
func CustomReaderPool(opts ...pool.Option) *ReaderPool {
|
||||
return &ReaderPool{pool.Custom(opts...)}
|
||||
}
|
||||
|
||||
// Get returns bufio.Reader whose buffer has at least size bytes.
|
||||
func (rp *ReaderPool) Get(r io.Reader, size int) *bufio.Reader {
|
||||
v, n := rp.pool.Get(size)
|
||||
if v != nil {
|
||||
br := v.(*bufio.Reader)
|
||||
br.Reset(r)
|
||||
return br
|
||||
}
|
||||
return bufio.NewReaderSize(r, n)
|
||||
}
|
||||
|
||||
// Put takes ownership of bufio.Reader for further reuse.
|
||||
func (rp *ReaderPool) Put(br *bufio.Reader) {
|
||||
// Should reset even if we do Reset() inside Get().
|
||||
// This is done to prevent locking underlying io.Reader from GC.
|
||||
br.Reset(nil)
|
||||
rp.pool.Put(br, readerSize(br))
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// +build go1.10
|
||||
|
||||
package pbufio
|
||||
|
||||
import "bufio"
|
||||
|
||||
func writerSize(bw *bufio.Writer) int {
|
||||
return bw.Size()
|
||||
}
|
||||
|
||||
func readerSize(br *bufio.Reader) int {
|
||||
return br.Size()
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
// +build !go1.10
|
||||
|
||||
package pbufio
|
||||
|
||||
import "bufio"
|
||||
|
||||
func writerSize(bw *bufio.Writer) int {
|
||||
return bw.Available() + bw.Buffered()
|
||||
}
|
||||
|
||||
// readerSize returns buffer size of the given buffered reader.
|
||||
// NOTE: current workaround implementation resets underlying io.Reader.
|
||||
func readerSize(br *bufio.Reader) int {
|
||||
br.Reset(sizeReader)
|
||||
br.ReadByte()
|
||||
n := br.Buffered() + 1
|
||||
br.Reset(nil)
|
||||
return n
|
||||
}
|
||||
|
||||
var sizeReader optimisticReader
|
||||
|
||||
type optimisticReader struct{}
|
||||
|
||||
func (optimisticReader) Read(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
// Package pbytes contains tools for pooling byte pool.
|
||||
// Note that by default it reuse slices with capacity from 128 to 65536 bytes.
|
||||
package pbytes
|
||||
|
||||
// DefaultPool is used by pacakge level functions.
|
||||
var DefaultPool = New(128, 65536)
|
||||
|
||||
// Get returns probably reused slice of bytes with at least capacity of c and
|
||||
// exactly len of n.
|
||||
// Get is a wrapper around DefaultPool.Get().
|
||||
func Get(n, c int) []byte { return DefaultPool.Get(n, c) }
|
||||
|
||||
// GetCap returns probably reused slice of bytes with at least capacity of n.
|
||||
// GetCap is a wrapper around DefaultPool.GetCap().
|
||||
func GetCap(c int) []byte { return DefaultPool.GetCap(c) }
|
||||
|
||||
// GetLen returns probably reused slice of bytes with at least capacity of n
|
||||
// and exactly len of n.
|
||||
// GetLen is a wrapper around DefaultPool.GetLen().
|
||||
func GetLen(n int) []byte { return DefaultPool.GetLen(n) }
|
||||
|
||||
// Put returns given slice to reuse pool.
|
||||
// Put is a wrapper around DefaultPool.Put().
|
||||
func Put(p []byte) { DefaultPool.Put(p) }
|
|
@ -0,0 +1,59 @@
|
|||
// +build !pool_sanitize
|
||||
|
||||
package pbytes
|
||||
|
||||
import "github.com/gobwas/pool"
|
||||
|
||||
// Pool contains logic of reusing byte slices of various size.
|
||||
type Pool struct {
|
||||
pool *pool.Pool
|
||||
}
|
||||
|
||||
// New creates new Pool that reuses slices which size is in logarithmic range
|
||||
// [min, max].
|
||||
//
|
||||
// Note that it is a shortcut for Custom() constructor with Options provided by
|
||||
// pool.WithLogSizeMapping() and pool.WithLogSizeRange(min, max) calls.
|
||||
func New(min, max int) *Pool {
|
||||
return &Pool{pool.New(min, max)}
|
||||
}
|
||||
|
||||
// New creates new Pool with given options.
|
||||
func Custom(opts ...pool.Option) *Pool {
|
||||
return &Pool{pool.Custom(opts...)}
|
||||
}
|
||||
|
||||
// Get returns probably reused slice of bytes with at least capacity of c and
|
||||
// exactly len of n.
|
||||
func (p *Pool) Get(n, c int) []byte {
|
||||
if n > c {
|
||||
panic("requested length is greater than capacity")
|
||||
}
|
||||
|
||||
v, x := p.pool.Get(c)
|
||||
if v != nil {
|
||||
bts := v.([]byte)
|
||||
bts = bts[:n]
|
||||
return bts
|
||||
}
|
||||
|
||||
return make([]byte, n, x)
|
||||
}
|
||||
|
||||
// Put returns given slice to reuse pool.
|
||||
// It does not reuse bytes whose size is not power of two or is out of pool
|
||||
// min/max range.
|
||||
func (p *Pool) Put(bts []byte) {
|
||||
p.pool.Put(bts, cap(bts))
|
||||
}
|
||||
|
||||
// GetCap returns probably reused slice of bytes with at least capacity of n.
|
||||
func (p *Pool) GetCap(c int) []byte {
|
||||
return p.Get(0, c)
|
||||
}
|
||||
|
||||
// GetLen returns probably reused slice of bytes with at least capacity of n
|
||||
// and exactly len of n.
|
||||
func (p *Pool) GetLen(n int) []byte {
|
||||
return p.Get(n, n)
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
// +build pool_sanitize
|
||||
|
||||
package pbytes
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const magic = uint64(0x777742)
|
||||
|
||||
type guard struct {
|
||||
magic uint64
|
||||
size int
|
||||
owners int32
|
||||
}
|
||||
|
||||
const guardSize = int(unsafe.Sizeof(guard{}))
|
||||
|
||||
type Pool struct {
|
||||
min, max int
|
||||
}
|
||||
|
||||
func New(min, max int) *Pool {
|
||||
return &Pool{min, max}
|
||||
}
|
||||
|
||||
// Get returns probably reused slice of bytes with at least capacity of c and
|
||||
// exactly len of n.
|
||||
func (p *Pool) Get(n, c int) []byte {
|
||||
if n > c {
|
||||
panic("requested length is greater than capacity")
|
||||
}
|
||||
|
||||
pageSize := syscall.Getpagesize()
|
||||
pages := (c+guardSize)/pageSize + 1
|
||||
size := pages * pageSize
|
||||
|
||||
bts := alloc(size)
|
||||
|
||||
g := (*guard)(unsafe.Pointer(&bts[0]))
|
||||
*g = guard{
|
||||
magic: magic,
|
||||
size: size,
|
||||
owners: 1,
|
||||
}
|
||||
|
||||
return bts[guardSize : guardSize+n]
|
||||
}
|
||||
|
||||
func (p *Pool) GetCap(c int) []byte { return p.Get(0, c) }
|
||||
func (p *Pool) GetLen(n int) []byte { return Get(n, n) }
|
||||
|
||||
// Put returns given slice to reuse pool.
|
||||
func (p *Pool) Put(bts []byte) {
|
||||
hdr := *(*reflect.SliceHeader)(unsafe.Pointer(&bts))
|
||||
ptr := hdr.Data - uintptr(guardSize)
|
||||
|
||||
g := (*guard)(unsafe.Pointer(ptr))
|
||||
if g.magic != magic {
|
||||
panic("unknown slice returned to the pool")
|
||||
}
|
||||
if n := atomic.AddInt32(&g.owners, -1); n < 0 {
|
||||
panic("multiple Put() detected")
|
||||
}
|
||||
|
||||
// Disable read and write on bytes memory pages. This will cause panic on
|
||||
// incorrect access to returned slice.
|
||||
mprotect(ptr, false, false, g.size)
|
||||
|
||||
runtime.SetFinalizer(&bts, func(b *[]byte) {
|
||||
mprotect(ptr, true, true, g.size)
|
||||
free(*(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
|
||||
Data: ptr,
|
||||
Len: g.size,
|
||||
Cap: g.size,
|
||||
})))
|
||||
})
|
||||
}
|
||||
|
||||
func alloc(n int) []byte {
|
||||
b, err := unix.Mmap(-1, 0, n, unix.PROT_READ|unix.PROT_WRITE|unix.PROT_EXEC, unix.MAP_SHARED|unix.MAP_ANONYMOUS)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func free(b []byte) {
|
||||
if err := unix.Munmap(b); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func mprotect(ptr uintptr, r, w bool, size int) {
|
||||
// Need to avoid "EINVAL addr is not a valid pointer,
|
||||
// or not a multiple of PAGESIZE."
|
||||
start := ptr & ^(uintptr(syscall.Getpagesize() - 1))
|
||||
|
||||
prot := uintptr(syscall.PROT_EXEC)
|
||||
switch {
|
||||
case r && w:
|
||||
prot |= syscall.PROT_READ | syscall.PROT_WRITE
|
||||
case r:
|
||||
prot |= syscall.PROT_READ
|
||||
case w:
|
||||
prot |= syscall.PROT_WRITE
|
||||
}
|
||||
|
||||
_, _, err := syscall.Syscall(syscall.SYS_MPROTECT,
|
||||
start, uintptr(size), prot,
|
||||
)
|
||||
if err != 0 {
|
||||
panic(err.Error())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
// Package pool contains helpers for pooling structures distinguishable by
|
||||
// size.
|
||||
//
|
||||
// Quick example:
|
||||
//
|
||||
// import "github.com/gobwas/pool"
|
||||
//
|
||||
// func main() {
|
||||
// // Reuse objects in logarithmic range from 0 to 64 (0,1,2,4,6,8,16,32,64).
|
||||
// p := pool.New(0, 64)
|
||||
//
|
||||
// buf, n := p.Get(10) // Returns buffer with 16 capacity.
|
||||
// if buf == nil {
|
||||
// buf = bytes.NewBuffer(make([]byte, n))
|
||||
// }
|
||||
// defer p.Put(buf, n)
|
||||
//
|
||||
// // Work with buf.
|
||||
// }
|
||||
//
|
||||
// There are non-generic implementations for pooling:
|
||||
// - pool/pbytes for []byte reuse;
|
||||
// - pool/pbufio for *bufio.Reader and *bufio.Writer reuse;
|
||||
//
|
||||
package pool
|
|
@ -0,0 +1,5 @@
|
|||
bin/
|
||||
reports/
|
||||
cpu.out
|
||||
mem.out
|
||||
ws.test
|
|
@ -0,0 +1,25 @@
|
|||
sudo: required
|
||||
|
||||
language: go
|
||||
|
||||
services:
|
||||
- docker
|
||||
|
||||
os:
|
||||
- linux
|
||||
- windows
|
||||
|
||||
go:
|
||||
- 1.8.x
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- 1.x
|
||||
|
||||
install:
|
||||
- go get github.com/gobwas/pool
|
||||
- go get github.com/gobwas/httphead
|
||||
|
||||
script:
|
||||
- if [ "$TRAVIS_OS_NAME" = "windows" ]; then go test ./...; fi
|
||||
- if [ "$TRAVIS_OS_NAME" = "linux" ]; then make test autobahn; fi
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2017-2018 Sergey Kamardin <gobwas@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,47 @@
|
|||
BENCH ?=.
|
||||
BENCH_BASE?=master
|
||||
|
||||
clean:
|
||||
rm -f bin/reporter
|
||||
rm -fr autobahn/report/*
|
||||
|
||||
bin/reporter:
|
||||
go build -o bin/reporter ./autobahn
|
||||
|
||||
bin/gocovmerge:
|
||||
go build -o bin/gocovmerge github.com/wadey/gocovmerge
|
||||
|
||||
.PHONY: autobahn
|
||||
autobahn: clean bin/reporter
|
||||
./autobahn/script/test.sh --build
|
||||
bin/reporter $(PWD)/autobahn/report/index.json
|
||||
|
||||
test:
|
||||
go test -coverprofile=ws.coverage .
|
||||
go test -coverprofile=wsutil.coverage ./wsutil
|
||||
|
||||
cover: bin/gocovmerge test autobahn
|
||||
bin/gocovmerge ws.coverage wsutil.coverage autobahn/report/server.coverage > total.coverage
|
||||
|
||||
benchcmp: BENCH_BRANCH=$(shell git rev-parse --abbrev-ref HEAD)
|
||||
benchcmp: BENCH_OLD:=$(shell mktemp -t old.XXXX)
|
||||
benchcmp: BENCH_NEW:=$(shell mktemp -t new.XXXX)
|
||||
benchcmp:
|
||||
if [ ! -z "$(shell git status -s)" ]; then\
|
||||
echo "could not compare with $(BENCH_BASE) – found unstaged changes";\
|
||||
exit 1;\
|
||||
fi;\
|
||||
if [ "$(BENCH_BRANCH)" == "$(BENCH_BASE)" ]; then\
|
||||
echo "comparing the same branches";\
|
||||
exit 1;\
|
||||
fi;\
|
||||
echo "benchmarking $(BENCH_BRANCH)...";\
|
||||
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_NEW);\
|
||||
echo "benchmarking $(BENCH_BASE)...";\
|
||||
git checkout -q $(BENCH_BASE);\
|
||||
go test -run=none -bench=$(BENCH) -benchmem > $(BENCH_OLD);\
|
||||
git checkout -q $(BENCH_BRANCH);\
|
||||
echo "\nresults:";\
|
||||
echo "========\n";\
|
||||
benchcmp $(BENCH_OLD) $(BENCH_NEW);\
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
# ws
|
||||
|
||||
[![GoDoc][godoc-image]][godoc-url]
|
||||
[![Travis][travis-image]][travis-url]
|
||||
|
||||
> [RFC6455][rfc-url] WebSocket implementation in Go.
|
||||
|
||||
# Features
|
||||
|
||||
- Zero-copy upgrade
|
||||
- No intermediate allocations during I/O
|
||||
- Low-level API which allows to build your own logic of packet handling and
|
||||
buffers reuse
|
||||
- High-level wrappers and helpers around API in `wsutil` package, which allow
|
||||
to start fast without digging the protocol internals
|
||||
|
||||
# Documentation
|
||||
|
||||
[GoDoc][godoc-url].
|
||||
|
||||
# Why
|
||||
|
||||
Existing WebSocket implementations do not allow users to reuse I/O buffers
|
||||
between connections in clear way. This library aims to export efficient
|
||||
low-level interface for working with the protocol without forcing only one way
|
||||
it could be used.
|
||||
|
||||
By the way, if you want get the higher-level tools, you can use `wsutil`
|
||||
package.
|
||||
|
||||
# Status
|
||||
|
||||
Library is tagged as `v1*` so its API must not be broken during some
|
||||
improvements or refactoring.
|
||||
|
||||
This implementation of RFC6455 passes [Autobahn Test
|
||||
Suite](https://github.com/crossbario/autobahn-testsuite) and currently has
|
||||
about 78% coverage.
|
||||
|
||||
# Examples
|
||||
|
||||
Example applications using `ws` are developed in separate repository
|
||||
[ws-examples](https://github.com/gobwas/ws-examples).
|
||||
|
||||
# Usage
|
||||
|
||||
The higher-level example of WebSocket echo server:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
)
|
||||
|
||||
func main() {
|
||||
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
msg, op, err := wsutil.ReadClientData(conn)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
err = wsutil.WriteServerMessage(conn, op, msg)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
}
|
||||
}()
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
Lower-level, but still high-level example:
|
||||
|
||||
|
||||
```go
|
||||
import (
|
||||
"net/http"
|
||||
"io"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
)
|
||||
|
||||
func main() {
|
||||
http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
|
||||
var (
|
||||
state = ws.StateServerSide
|
||||
reader = wsutil.NewReader(conn, state)
|
||||
writer = wsutil.NewWriter(conn, state, ws.OpText)
|
||||
)
|
||||
for {
|
||||
header, err := reader.NextFrame()
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Reset writer to write frame with right operation code.
|
||||
writer.Reset(conn, state, header.OpCode)
|
||||
|
||||
if _, err = io.Copy(writer, reader); err != nil {
|
||||
// handle error
|
||||
}
|
||||
if err = writer.Flush(); err != nil {
|
||||
// handle error
|
||||
}
|
||||
}
|
||||
}()
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
We can apply the same pattern to read and write structured responses through a JSON encoder and decoder.:
|
||||
|
||||
```go
|
||||
...
|
||||
var (
|
||||
r = wsutil.NewReader(conn, ws.StateServerSide)
|
||||
w = wsutil.NewWriter(conn, ws.StateServerSide, ws.OpText)
|
||||
decoder = json.NewDecoder(r)
|
||||
encoder = json.NewEncoder(w)
|
||||
)
|
||||
for {
|
||||
hdr, err = r.NextFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hdr.OpCode == ws.OpClose {
|
||||
return io.EOF
|
||||
}
|
||||
var req Request
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
var resp Response
|
||||
if err := encoder.Encode(&resp); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = w.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
The lower-level example without `wsutil`:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"io"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ln, err := net.Listen("tcp", "localhost:8080")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
_, err = ws.Upgrade(conn)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
header, err := ws.ReadHeader(conn)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
payload := make([]byte, header.Length)
|
||||
_, err = io.ReadFull(conn, payload)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
if header.Masked {
|
||||
ws.Cipher(payload, header.Mask, 0)
|
||||
}
|
||||
|
||||
// Reset the Masked flag, server frames must not be masked as
|
||||
// RFC6455 says.
|
||||
header.Masked = false
|
||||
|
||||
if err := ws.WriteHeader(conn, header); err != nil {
|
||||
// handle error
|
||||
}
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
if header.OpCode == ws.OpClose {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
# Zero-copy upgrade
|
||||
|
||||
Zero-copy upgrade helps to avoid unnecessary allocations and copying while
|
||||
handling HTTP Upgrade request.
|
||||
|
||||
Processing of all non-websocket headers is made in place with use of registered
|
||||
user callbacks whose arguments are only valid until callback returns.
|
||||
|
||||
The simple example looks like this:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"log"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ln, err := net.Listen("tcp", "localhost:8080")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
u := ws.Upgrader{
|
||||
OnHeader: func(key, value []byte) (err error) {
|
||||
log.Printf("non-websocket header: %q=%q", key, value)
|
||||
return
|
||||
},
|
||||
}
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
_, err = u.Upgrade(conn)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Usage of `ws.Upgrader` here brings ability to control incoming connections on
|
||||
tcp level and simply not to accept them by some logic.
|
||||
|
||||
Zero-copy upgrade is for high-load services which have to control many
|
||||
resources such as connections buffers.
|
||||
|
||||
The real life example could be like this:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/gobwas/httphead"
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ln, err := net.Listen("tcp", "localhost:8080")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
// Prepare handshake header writer from http.Header mapping.
|
||||
header := ws.HandshakeHeaderHTTP(http.Header{
|
||||
"X-Go-Version": []string{runtime.Version()},
|
||||
})
|
||||
|
||||
u := ws.Upgrader{
|
||||
OnHost: func(host []byte) error {
|
||||
if string(host) == "github.com" {
|
||||
return nil
|
||||
}
|
||||
return ws.RejectConnectionError(
|
||||
ws.RejectionStatus(403),
|
||||
ws.RejectionHeader(ws.HandshakeHeaderString(
|
||||
"X-Want-Host: github.com\r\n",
|
||||
)),
|
||||
)
|
||||
},
|
||||
OnHeader: func(key, value []byte) error {
|
||||
if string(key) != "Cookie" {
|
||||
return nil
|
||||
}
|
||||
ok := httphead.ScanCookie(value, func(key, value []byte) bool {
|
||||
// Check session here or do some other stuff with cookies.
|
||||
// Maybe copy some values for future use.
|
||||
return true
|
||||
})
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
return ws.RejectConnectionError(
|
||||
ws.RejectionReason("bad cookie"),
|
||||
ws.RejectionStatus(400),
|
||||
)
|
||||
},
|
||||
OnBeforeUpgrade: func() (ws.HandshakeHeader, error) {
|
||||
return header, nil
|
||||
},
|
||||
}
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = u.Upgrade(conn)
|
||||
if err != nil {
|
||||
log.Printf("upgrade error: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
[rfc-url]: https://tools.ietf.org/html/rfc6455
|
||||
[godoc-image]: https://godoc.org/github.com/gobwas/ws?status.svg
|
||||
[godoc-url]: https://godoc.org/github.com/gobwas/ws
|
||||
[travis-image]: https://travis-ci.org/gobwas/ws.svg?branch=master
|
||||
[travis-url]: https://travis-ci.org/gobwas/ws
|
|
@ -0,0 +1,145 @@
|
|||
package ws
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// State represents state of websocket endpoint.
|
||||
// It used by some functions to be more strict when checking compatibility with RFC6455.
|
||||
type State uint8
|
||||
|
||||
const (
|
||||
// StateServerSide means that endpoint (caller) is a server.
|
||||
StateServerSide State = 0x1 << iota
|
||||
// StateClientSide means that endpoint (caller) is a client.
|
||||
StateClientSide
|
||||
// StateExtended means that extension was negotiated during handshake.
|
||||
StateExtended
|
||||
// StateFragmented means that endpoint (caller) has received fragmented
|
||||
// frame and waits for continuation parts.
|
||||
StateFragmented
|
||||
)
|
||||
|
||||
// Is checks whether the s has v enabled.
|
||||
func (s State) Is(v State) bool {
|
||||
return uint8(s)&uint8(v) != 0
|
||||
}
|
||||
|
||||
// Set enables v state on s.
|
||||
func (s State) Set(v State) State {
|
||||
return s | v
|
||||
}
|
||||
|
||||
// Clear disables v state on s.
|
||||
func (s State) Clear(v State) State {
|
||||
return s & (^v)
|
||||
}
|
||||
|
||||
// ServerSide reports whether states represents server side.
|
||||
func (s State) ServerSide() bool { return s.Is(StateServerSide) }
|
||||
|
||||
// ClientSide reports whether state represents client side.
|
||||
func (s State) ClientSide() bool { return s.Is(StateClientSide) }
|
||||
|
||||
// Extended reports whether state is extended.
|
||||
func (s State) Extended() bool { return s.Is(StateExtended) }
|
||||
|
||||
// Fragmented reports whether state is fragmented.
|
||||
func (s State) Fragmented() bool { return s.Is(StateFragmented) }
|
||||
|
||||
// ProtocolError describes error during checking/parsing websocket frames or
|
||||
// headers.
|
||||
type ProtocolError string
|
||||
|
||||
// Error implements error interface.
|
||||
func (p ProtocolError) Error() string { return string(p) }
|
||||
|
||||
// Errors used by the protocol checkers.
|
||||
var (
|
||||
ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code")
|
||||
ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded")
|
||||
ErrProtocolControlNotFinal = ProtocolError("control frame is not final")
|
||||
ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated")
|
||||
ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked")
|
||||
ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked")
|
||||
ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame")
|
||||
ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame")
|
||||
ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use")
|
||||
ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
|
||||
ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet")
|
||||
ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec")
|
||||
ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason")
|
||||
)
|
||||
|
||||
// CheckHeader checks h to contain valid header data for given state s.
|
||||
//
|
||||
// Note that zero state (0) means that state is clean,
|
||||
// neither server or client side, nor fragmented, nor extended.
|
||||
func CheckHeader(h Header, s State) error {
|
||||
if h.OpCode.IsReserved() {
|
||||
return ErrProtocolOpCodeReserved
|
||||
}
|
||||
if h.OpCode.IsControl() {
|
||||
if h.Length > MaxControlFramePayloadSize {
|
||||
return ErrProtocolControlPayloadOverflow
|
||||
}
|
||||
if !h.Fin {
|
||||
return ErrProtocolControlNotFinal
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
// [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
|
||||
// non-zero values. If a nonzero value is received and none of the
|
||||
// negotiated extensions defines the meaning of such a nonzero value, the
|
||||
// receiving endpoint MUST _Fail the WebSocket Connection_.
|
||||
case h.Rsv != 0 && !s.Extended():
|
||||
return ErrProtocolNonZeroRsv
|
||||
|
||||
// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
|
||||
// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
|
||||
// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
|
||||
// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
|
||||
// status code 1002 (protocol error) as defined in Section 7.4.1.
|
||||
case s.ServerSide() && !h.Masked:
|
||||
return ErrProtocolMaskRequired
|
||||
case s.ClientSide() && h.Masked:
|
||||
return ErrProtocolMaskUnexpected
|
||||
|
||||
// [RFC6455]: See detailed explanation in 5.4 section.
|
||||
case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
|
||||
return ErrProtocolContinuationExpected
|
||||
case !s.Fragmented() && h.OpCode == OpContinuation:
|
||||
return ErrProtocolContinuationUnexpected
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// CheckCloseFrameData checks received close information
|
||||
// to be valid RFC6455 compatible close info.
|
||||
//
|
||||
// Note that code.Empty() or code.IsAppLevel() will raise error.
|
||||
//
|
||||
// If endpoint sends close frame without status code (with frame.Length = 0),
|
||||
// application should not check its payload.
|
||||
func CheckCloseFrameData(code StatusCode, reason string) error {
|
||||
switch {
|
||||
case code.IsNotUsed():
|
||||
return ErrProtocolStatusCodeNotInUse
|
||||
|
||||
case code.IsProtocolReserved():
|
||||
return ErrProtocolStatusCodeApplicationLevel
|
||||
|
||||
case code == StatusNoMeaningYet:
|
||||
return ErrProtocolStatusCodeNoMeaning
|
||||
|
||||
case code.IsProtocolSpec() && !code.IsProtocolDefined():
|
||||
return ErrProtocolStatusCodeUnknown
|
||||
|
||||
case !utf8.ValidString(reason):
|
||||
return ErrProtocolInvalidUTF8
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// Cipher applies XOR cipher to the payload using mask.
|
||||
// Offset is used to cipher chunked data (e.g. in io.Reader implementations).
|
||||
//
|
||||
// To convert masked data into unmasked data, or vice versa, the following
|
||||
// algorithm is applied. The same algorithm applies regardless of the
|
||||
// direction of the translation, e.g., the same steps are applied to
|
||||
// mask the data as to unmask the data.
|
||||
func Cipher(payload []byte, mask [4]byte, offset int) {
|
||||
n := len(payload)
|
||||
if n < 8 {
|
||||
for i := 0; i < n; i++ {
|
||||
payload[i] ^= mask[(offset+i)%4]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate position in mask due to previously processed bytes number.
|
||||
mpos := offset % 4
|
||||
// Count number of bytes will processed one by one from the beginning of payload.
|
||||
ln := remain[mpos]
|
||||
// Count number of bytes will processed one by one from the end of payload.
|
||||
// This is done to process payload by 8 bytes in each iteration of main loop.
|
||||
rn := (n - ln) % 8
|
||||
|
||||
for i := 0; i < ln; i++ {
|
||||
payload[i] ^= mask[(mpos+i)%4]
|
||||
}
|
||||
for i := n - rn; i < n; i++ {
|
||||
payload[i] ^= mask[(mpos+i)%4]
|
||||
}
|
||||
|
||||
// NOTE: we use here binary.LittleEndian regardless of what is real
|
||||
// endianess on machine is. To do so, we have to use binary.LittleEndian in
|
||||
// the masking loop below as well.
|
||||
var (
|
||||
m = binary.LittleEndian.Uint32(mask[:])
|
||||
m2 = uint64(m)<<32 | uint64(m)
|
||||
)
|
||||
// Skip already processed right part.
|
||||
// Get number of uint64 parts remaining to process.
|
||||
n = (n - ln - rn) >> 3
|
||||
for i := 0; i < n; i++ {
|
||||
var (
|
||||
j = ln + (i << 3)
|
||||
chunk = payload[j : j+8]
|
||||
)
|
||||
p := binary.LittleEndian.Uint64(chunk)
|
||||
p = p ^ m2
|
||||
binary.LittleEndian.PutUint64(chunk, p)
|
||||
}
|
||||
}
|
||||
|
||||
// remain maps position in masking key [0,4) to number
|
||||
// of bytes that need to be processed manually inside Cipher().
|
||||
var remain = [4]int{0, 3, 2, 1}
|
|
@ -0,0 +1,556 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/httphead"
|
||||
"github.com/gobwas/pool/pbufio"
|
||||
)
|
||||
|
||||
// Constants used by Dialer.
|
||||
const (
|
||||
DefaultClientReadBufferSize = 4096
|
||||
DefaultClientWriteBufferSize = 4096
|
||||
)
|
||||
|
||||
// Handshake represents handshake result.
|
||||
type Handshake struct {
|
||||
// Protocol is the subprotocol selected during handshake.
|
||||
Protocol string
|
||||
|
||||
// Extensions is the list of negotiated extensions.
|
||||
Extensions []httphead.Option
|
||||
}
|
||||
|
||||
// Errors used by the websocket client.
|
||||
var (
|
||||
ErrHandshakeBadStatus = fmt.Errorf("unexpected http status")
|
||||
ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol)
|
||||
ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol)
|
||||
)
|
||||
|
||||
// DefaultDialer is dialer that holds no options and is used by Dial function.
|
||||
var DefaultDialer Dialer
|
||||
|
||||
// Dial is like Dialer{}.Dial().
|
||||
func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) {
|
||||
return DefaultDialer.Dial(ctx, urlstr)
|
||||
}
|
||||
|
||||
// Dialer contains options for establishing websocket connection to an url.
|
||||
type Dialer struct {
|
||||
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
|
||||
// They used to read and write http data while upgrading to WebSocket.
|
||||
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
|
||||
//
|
||||
// If a size is zero then default value is used.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
||||
// Timeout is the maximum amount of time a Dial() will wait for a connect
|
||||
// and an handshake to complete.
|
||||
//
|
||||
// The default is no timeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// Protocols is the list of subprotocols that the client wants to speak,
|
||||
// ordered by preference.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-4.1
|
||||
Protocols []string
|
||||
|
||||
// Extensions is the list of extensions that client wants to speak.
|
||||
//
|
||||
// Note that if server decides to use some of this extensions, Dial() will
|
||||
// return Handshake struct containing a slice of items, which are the
|
||||
// shallow copies of the items from this list. That is, internals of
|
||||
// Extensions items are shared during Dial().
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-4.1
|
||||
// See https://tools.ietf.org/html/rfc6455#section-9.1
|
||||
Extensions []httphead.Option
|
||||
|
||||
// Header is an optional HandshakeHeader instance that could be used to
|
||||
// write additional headers to the handshake request.
|
||||
//
|
||||
// It used instead of any key-value mappings to avoid allocations in user
|
||||
// land.
|
||||
Header HandshakeHeader
|
||||
|
||||
// OnStatusError is the callback that will be called after receiving non
|
||||
// "101 Continue" HTTP response status. It receives an io.Reader object
|
||||
// representing server response bytes. That is, it gives ability to parse
|
||||
// HTTP response somehow (probably with http.ReadResponse call) and make a
|
||||
// decision of further logic.
|
||||
//
|
||||
// The arguments are only valid until the callback returns.
|
||||
OnStatusError func(status int, reason []byte, resp io.Reader)
|
||||
|
||||
// OnHeader is the callback that will be called after successful parsing of
|
||||
// header, that is not used during WebSocket handshake procedure. That is,
|
||||
// it will be called with non-websocket headers, which could be relevant
|
||||
// for application-level logic.
|
||||
//
|
||||
// The arguments are only valid until the callback returns.
|
||||
//
|
||||
// Returned value could be used to prevent processing response.
|
||||
OnHeader func(key, value []byte) (err error)
|
||||
|
||||
// NetDial is the function that is used to get plain tcp connection.
|
||||
// If it is not nil, then it is used instead of net.Dialer.
|
||||
NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// TLSClient is the callback that will be called after successful dial with
|
||||
// received connection and its remote host name. If it is nil, then the
|
||||
// default tls.Client() will be used.
|
||||
// If it is not nil, then TLSConfig field is ignored.
|
||||
TLSClient func(conn net.Conn, hostname string) net.Conn
|
||||
|
||||
// TLSConfig is passed to tls.Client() to start TLS over established
|
||||
// connection. If TLSClient is not nil, then it is ignored. If TLSConfig is
|
||||
// non-nil and its ServerName is empty, then for every Dial() it will be
|
||||
// cloned and appropriate ServerName will be set.
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// WrapConn is the optional callback that will be called when connection is
|
||||
// ready for an i/o. That is, it will be called after successful dial and
|
||||
// TLS initialization (for "wss" schemes). It may be helpful for different
|
||||
// user land purposes such as end to end encryption.
|
||||
//
|
||||
// Note that for debugging purposes of an http handshake (e.g. sent request
|
||||
// and received response), there is an wsutil.DebugDialer struct.
|
||||
WrapConn func(conn net.Conn) net.Conn
|
||||
}
|
||||
|
||||
// Dial connects to the url host and upgrades connection to WebSocket.
|
||||
//
|
||||
// If server has sent frames right after successful handshake then returned
|
||||
// buffer will be non-nil. In other cases buffer is always nil. For better
|
||||
// memory efficiency received non-nil bufio.Reader should be returned to the
|
||||
// inner pool with PutReader() function after use.
|
||||
//
|
||||
// Note that Dialer does not implement IDNA (RFC5895) logic as net/http does.
|
||||
// If you want to dial non-ascii host name, take care of its name serialization
|
||||
// avoiding bad request issues. For more info see net/http Request.Write()
|
||||
// implementation, especially cleanHost() function.
|
||||
func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) {
|
||||
u, err := url.ParseRequestURI(urlstr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare context to dial with. Initially it is the same as original, but
|
||||
// if d.Timeout is non-zero and points to time that is before ctx.Deadline,
|
||||
// we use more shorter context for dial.
|
||||
dialctx := ctx
|
||||
|
||||
var deadline time.Time
|
||||
if t := d.Timeout; t != 0 {
|
||||
deadline = time.Now().Add(t)
|
||||
if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
|
||||
var cancel context.CancelFunc
|
||||
dialctx, cancel = context.WithDeadline(ctx, deadline)
|
||||
defer cancel()
|
||||
}
|
||||
}
|
||||
if conn, err = d.dial(dialctx, u); err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
if ctx == context.Background() {
|
||||
// No need to start I/O interrupter goroutine which is not zero-cost.
|
||||
conn.SetDeadline(deadline)
|
||||
defer conn.SetDeadline(noDeadline)
|
||||
} else {
|
||||
// Context could be canceled or its deadline could be exceeded.
|
||||
// Start the interrupter goroutine to handle context cancelation.
|
||||
done := setupContextDeadliner(ctx, conn)
|
||||
defer func() {
|
||||
// Map Upgrade() error to a possible context expiration error. That
|
||||
// is, even if Upgrade() err is nil, context could be already
|
||||
// expired and connection be "poisoned" by SetDeadline() call.
|
||||
// In that case we must not return ctx.Err() error.
|
||||
done(&err)
|
||||
}()
|
||||
}
|
||||
|
||||
br, hs, err = d.Upgrade(conn, u)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
// netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if
|
||||
// Dialer.NetDial is not provided.
|
||||
netEmptyDialer net.Dialer
|
||||
// tlsEmptyConfig is an empty tls.Config used as default one.
|
||||
tlsEmptyConfig tls.Config
|
||||
)
|
||||
|
||||
func tlsDefaultConfig() *tls.Config {
|
||||
return &tlsEmptyConfig
|
||||
}
|
||||
|
||||
func hostport(host string, defaultPort string) (hostname, addr string) {
|
||||
var (
|
||||
colon = strings.LastIndexByte(host, ':')
|
||||
bracket = strings.IndexByte(host, ']')
|
||||
)
|
||||
if colon > bracket {
|
||||
return host[:colon], host
|
||||
}
|
||||
return host, host + defaultPort
|
||||
}
|
||||
|
||||
func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) {
|
||||
dial := d.NetDial
|
||||
if dial == nil {
|
||||
dial = netEmptyDialer.DialContext
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "ws":
|
||||
_, addr := hostport(u.Host, ":80")
|
||||
conn, err = dial(ctx, "tcp", addr)
|
||||
case "wss":
|
||||
hostname, addr := hostport(u.Host, ":443")
|
||||
conn, err = dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tlsClient := d.TLSClient
|
||||
if tlsClient == nil {
|
||||
tlsClient = d.tlsClient
|
||||
}
|
||||
conn = tlsClient(conn, hostname)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme)
|
||||
}
|
||||
if wrap := d.WrapConn; wrap != nil {
|
||||
conn = wrap(conn)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn {
|
||||
config := d.TLSConfig
|
||||
if config == nil {
|
||||
config = tlsDefaultConfig()
|
||||
}
|
||||
if config.ServerName == "" {
|
||||
config = tlsCloneConfig(config)
|
||||
config.ServerName = hostname
|
||||
}
|
||||
// Do not make conn.Handshake() here because downstairs we will prepare
|
||||
// i/o on this conn with proper context's timeout handling.
|
||||
return tls.Client(conn, config)
|
||||
}
|
||||
|
||||
var (
|
||||
// This variables are set like in net/net.go.
|
||||
// noDeadline is just zero value for readability.
|
||||
noDeadline = time.Time{}
|
||||
// aLongTimeAgo is a non-zero time, far in the past, used for immediate
|
||||
// cancelation of dials.
|
||||
aLongTimeAgo = time.Unix(42, 0)
|
||||
)
|
||||
|
||||
// Upgrade writes an upgrade request to the given io.ReadWriter conn at given
|
||||
// url u and reads a response from it.
|
||||
//
|
||||
// It is a caller responsibility to manage I/O deadlines on conn.
|
||||
//
|
||||
// It returns handshake info and some bytes which could be written by the peer
|
||||
// right after response and be caught by us during buffered read.
|
||||
func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) {
|
||||
// headerSeen constants helps to report whether or not some header was seen
|
||||
// during reading request bytes.
|
||||
const (
|
||||
headerSeenUpgrade = 1 << iota
|
||||
headerSeenConnection
|
||||
headerSeenSecAccept
|
||||
|
||||
// headerSeenAll is the value that we expect to receive at the end of
|
||||
// headers read/parse loop.
|
||||
headerSeenAll = 0 |
|
||||
headerSeenUpgrade |
|
||||
headerSeenConnection |
|
||||
headerSeenSecAccept
|
||||
)
|
||||
|
||||
br = pbufio.GetReader(conn,
|
||||
nonZero(d.ReadBufferSize, DefaultClientReadBufferSize),
|
||||
)
|
||||
bw := pbufio.GetWriter(conn,
|
||||
nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize),
|
||||
)
|
||||
defer func() {
|
||||
pbufio.PutWriter(bw)
|
||||
if br.Buffered() == 0 || err != nil {
|
||||
// Server does not wrote additional bytes to the connection or
|
||||
// error occurred. That is, no reason to return buffer.
|
||||
pbufio.PutReader(br)
|
||||
br = nil
|
||||
}
|
||||
}()
|
||||
|
||||
nonce := make([]byte, nonceSize)
|
||||
initNonce(nonce)
|
||||
|
||||
httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header)
|
||||
if err = bw.Flush(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read HTTP status line like "HTTP/1.1 101 Switching Protocols".
|
||||
sl, err := readLine(br)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Begin validation of the response.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-4.2.2
|
||||
// Parse request line data like HTTP version, uri and method.
|
||||
resp, err := httpParseResponseLine(sl)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Even if RFC says "1.1 or higher" without mentioning the part of the
|
||||
// version, we apply it only to minor part.
|
||||
if resp.major != 1 || resp.minor < 1 {
|
||||
err = ErrHandshakeBadProtocol
|
||||
return
|
||||
}
|
||||
if resp.status != 101 {
|
||||
err = StatusError(resp.status)
|
||||
if onStatusError := d.OnStatusError; onStatusError != nil {
|
||||
// Invoke callback with multireader of status-line bytes br.
|
||||
onStatusError(resp.status, resp.reason,
|
||||
io.MultiReader(
|
||||
bytes.NewReader(sl),
|
||||
strings.NewReader(crlf),
|
||||
br,
|
||||
),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
// If response status is 101 then we expect all technical headers to be
|
||||
// valid. If not, then we stop processing response without giving user
|
||||
// ability to read non-technical headers. That is, we do not distinguish
|
||||
// technical errors (such as parsing error) and protocol errors.
|
||||
var headerSeen byte
|
||||
for {
|
||||
line, e := readLine(br)
|
||||
if e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
if len(line) == 0 {
|
||||
// Blank line, no more lines to read.
|
||||
break
|
||||
}
|
||||
|
||||
k, v, ok := httpParseHeaderLine(line)
|
||||
if !ok {
|
||||
err = ErrMalformedResponse
|
||||
return
|
||||
}
|
||||
|
||||
switch btsToString(k) {
|
||||
case headerUpgradeCanonical:
|
||||
headerSeen |= headerSeenUpgrade
|
||||
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
|
||||
err = ErrHandshakeBadUpgrade
|
||||
return
|
||||
}
|
||||
|
||||
case headerConnectionCanonical:
|
||||
headerSeen |= headerSeenConnection
|
||||
// Note that as RFC6455 says:
|
||||
// > A |Connection| header field with value "Upgrade".
|
||||
// That is, in server side, "Connection" header could contain
|
||||
// multiple token. But in response it must contains exactly one.
|
||||
if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) {
|
||||
err = ErrHandshakeBadConnection
|
||||
return
|
||||
}
|
||||
|
||||
case headerSecAcceptCanonical:
|
||||
headerSeen |= headerSeenSecAccept
|
||||
if !checkAcceptFromNonce(v, nonce) {
|
||||
err = ErrHandshakeBadSecAccept
|
||||
return
|
||||
}
|
||||
|
||||
case headerSecProtocolCanonical:
|
||||
// RFC6455 1.3:
|
||||
// "The server selects one or none of the acceptable protocols
|
||||
// and echoes that value in its handshake to indicate that it has
|
||||
// selected that protocol."
|
||||
for _, want := range d.Protocols {
|
||||
if string(v) == want {
|
||||
hs.Protocol = want
|
||||
break
|
||||
}
|
||||
}
|
||||
if hs.Protocol == "" {
|
||||
// Server echoed subprotocol that is not present in client
|
||||
// requested protocols.
|
||||
err = ErrHandshakeBadSubProtocol
|
||||
return
|
||||
}
|
||||
|
||||
case headerSecExtensionsCanonical:
|
||||
hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
if onHeader := d.OnHeader; onHeader != nil {
|
||||
if e := onHeader(k, v); e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == nil && headerSeen != headerSeenAll {
|
||||
switch {
|
||||
case headerSeen&headerSeenUpgrade == 0:
|
||||
err = ErrHandshakeBadUpgrade
|
||||
case headerSeen&headerSeenConnection == 0:
|
||||
err = ErrHandshakeBadConnection
|
||||
case headerSeen&headerSeenSecAccept == 0:
|
||||
err = ErrHandshakeBadSecAccept
|
||||
default:
|
||||
panic("unknown headers state")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PutReader returns bufio.Reader instance to the inner reuse pool.
|
||||
// It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which
|
||||
// contains unprocessed buffered data, that was sent by the server quickly
|
||||
// right after handshake.
|
||||
func PutReader(br *bufio.Reader) {
|
||||
pbufio.PutReader(br)
|
||||
}
|
||||
|
||||
// StatusError contains an unexpected status-line code from the server.
|
||||
type StatusError int
|
||||
|
||||
func (s StatusError) Error() string {
|
||||
return "unexpected HTTP response status: " + strconv.Itoa(int(s))
|
||||
}
|
||||
|
||||
func isTimeoutError(err error) bool {
|
||||
t, ok := err.(net.Error)
|
||||
return ok && t.Timeout()
|
||||
}
|
||||
|
||||
func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) {
|
||||
if len(selected) == 0 {
|
||||
return received, nil
|
||||
}
|
||||
var (
|
||||
index int
|
||||
option httphead.Option
|
||||
err error
|
||||
)
|
||||
index = -1
|
||||
match := func() (ok bool) {
|
||||
for _, want := range wanted {
|
||||
if option.Equal(want) {
|
||||
// Check parsed extension to be present in client
|
||||
// requested extensions. We move matched extension
|
||||
// from client list to avoid allocation.
|
||||
received = append(received, want)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control {
|
||||
if i != index {
|
||||
// Met next option.
|
||||
index = i
|
||||
if i != 0 && !match() {
|
||||
// Server returned non-requested extension.
|
||||
err = ErrHandshakeBadExtensions
|
||||
return httphead.ControlBreak
|
||||
}
|
||||
option = httphead.Option{Name: name}
|
||||
}
|
||||
if attr != nil {
|
||||
option.Parameters.Set(attr, val)
|
||||
}
|
||||
return httphead.ControlContinue
|
||||
})
|
||||
if !ok {
|
||||
err = ErrMalformedResponse
|
||||
return received, err
|
||||
}
|
||||
if !match() {
|
||||
return received, ErrHandshakeBadExtensions
|
||||
}
|
||||
return received, err
|
||||
}
|
||||
|
||||
// setupContextDeadliner is a helper function that starts connection I/O
|
||||
// interrupter goroutine.
|
||||
//
|
||||
// Started goroutine calls SetDeadline() with long time ago value when context
|
||||
// become expired to make any I/O operations failed. It returns done function
|
||||
// that stops started goroutine and maps error received from conn I/O methods
|
||||
// to possible context expiration error.
|
||||
//
|
||||
// In concern with possible SetDeadline() call inside interrupter goroutine,
|
||||
// caller passes pointer to its I/O error (even if it is nil) to done(&err).
|
||||
// That is, even if I/O error is nil, context could be already expired and
|
||||
// connection "poisoned" by SetDeadline() call. In that case done(&err) will
|
||||
// store at *err ctx.Err() result. If err is caused not by timeout, it will
|
||||
// leaved untouched.
|
||||
func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) {
|
||||
var (
|
||||
quit = make(chan struct{})
|
||||
interrupt = make(chan error, 1)
|
||||
)
|
||||
go func() {
|
||||
select {
|
||||
case <-quit:
|
||||
interrupt <- nil
|
||||
case <-ctx.Done():
|
||||
// Cancel i/o immediately.
|
||||
conn.SetDeadline(aLongTimeAgo)
|
||||
interrupt <- ctx.Err()
|
||||
}
|
||||
}()
|
||||
return func(err *error) {
|
||||
close(quit)
|
||||
// If ctx.Err() is non-nil and the original err is net.Error with
|
||||
// Timeout() == true, then it means that I/O was canceled by us by
|
||||
// SetDeadline(aLongTimeAgo) call, or by somebody else previously
|
||||
// by conn.SetDeadline(x).
|
||||
//
|
||||
// Even on race condition when both deadlines are expired
|
||||
// (SetDeadline() made not by us and context's), we prefer ctx.Err() to
|
||||
// be returned.
|
||||
if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) {
|
||||
*err = ctxErr
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
// +build !go1.8
|
||||
|
||||
package ws
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func tlsCloneConfig(c *tls.Config) *tls.Config {
|
||||
// NOTE: we copying SessionTicketsDisabled and SessionTicketKey here
|
||||
// without calling inner c.initOnceServer somehow because we only could get
|
||||
// here from the ws.Dialer code, which is obviously a client and makes
|
||||
// tls.Client() when it gets new net.Conn.
|
||||
return &tls.Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
Certificates: c.Certificates,
|
||||
NameToCertificate: c.NameToCertificate,
|
||||
GetCertificate: c.GetCertificate,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
ServerName: c.ServerName,
|
||||
ClientAuth: c.ClientAuth,
|
||||
ClientCAs: c.ClientCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
CipherSuites: c.CipherSuites,
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
ClientSessionCache: c.ClientSessionCache,
|
||||
MinVersion: c.MinVersion,
|
||||
MaxVersion: c.MaxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
Renegotiation: c.Renegotiation,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
// +build go1.8
|
||||
|
||||
package ws
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func tlsCloneConfig(c *tls.Config) *tls.Config {
|
||||
return c.Clone()
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
Package ws implements a client and server for the WebSocket protocol as
|
||||
specified in RFC 6455.
|
||||
|
||||
The main purpose of this package is to provide simple low-level API for
|
||||
efficient work with protocol.
|
||||
|
||||
Overview.
|
||||
|
||||
Upgrade to WebSocket (or WebSocket handshake) can be done in two ways.
|
||||
|
||||
The first way is to use `net/http` server:
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
})
|
||||
|
||||
The second and much more efficient way is so-called "zero-copy upgrade". It
|
||||
avoids redundant allocations and copying of not used headers or other request
|
||||
data. User decides by himself which data should be copied.
|
||||
|
||||
ln, err := net.Listen("tcp", ":8080")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
handshake, err := ws.Upgrade(conn)
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
For customization details see `ws.Upgrader` documentation.
|
||||
|
||||
After WebSocket handshake you can work with connection in multiple ways.
|
||||
That is, `ws` does not force the only one way of how to work with WebSocket:
|
||||
|
||||
header, err := ws.ReadHeader(conn)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
buf := make([]byte, header.Length)
|
||||
_, err := io.ReadFull(conn, buf)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
resp := ws.NewBinaryFrame([]byte("hello, world!"))
|
||||
if err := ws.WriteFrame(conn, frame); err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
As you can see, it stream friendly:
|
||||
|
||||
const N = 42
|
||||
|
||||
ws.WriteHeader(ws.Header{
|
||||
Fin: true,
|
||||
Length: N,
|
||||
OpCode: ws.OpBinary,
|
||||
})
|
||||
|
||||
io.CopyN(conn, rand.Reader, N)
|
||||
|
||||
Or:
|
||||
|
||||
header, err := ws.ReadHeader(conn)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
io.CopyN(ioutil.Discard, conn, header.Length)
|
||||
|
||||
For more info see the documentation.
|
||||
*/
|
||||
package ws
|
|
@ -0,0 +1,54 @@
|
|||
package ws
|
||||
|
||||
// RejectOption represents an option used to control the way connection is
|
||||
// rejected.
|
||||
type RejectOption func(*rejectConnectionError)
|
||||
|
||||
// RejectionReason returns an option that makes connection to be rejected with
|
||||
// given reason.
|
||||
func RejectionReason(reason string) RejectOption {
|
||||
return func(err *rejectConnectionError) {
|
||||
err.reason = reason
|
||||
}
|
||||
}
|
||||
|
||||
// RejectionStatus returns an option that makes connection to be rejected with
|
||||
// given HTTP status code.
|
||||
func RejectionStatus(code int) RejectOption {
|
||||
return func(err *rejectConnectionError) {
|
||||
err.code = code
|
||||
}
|
||||
}
|
||||
|
||||
// RejectionHeader returns an option that makes connection to be rejected with
|
||||
// given HTTP headers.
|
||||
func RejectionHeader(h HandshakeHeader) RejectOption {
|
||||
return func(err *rejectConnectionError) {
|
||||
err.header = h
|
||||
}
|
||||
}
|
||||
|
||||
// RejectConnectionError constructs an error that could be used to control the way
|
||||
// handshake is rejected by Upgrader.
|
||||
func RejectConnectionError(options ...RejectOption) error {
|
||||
err := new(rejectConnectionError)
|
||||
for _, opt := range options {
|
||||
opt(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// rejectConnectionError represents a rejection of upgrade error.
|
||||
//
|
||||
// It can be returned by Upgrader's On* hooks to control the way WebSocket
|
||||
// handshake is rejected.
|
||||
type rejectConnectionError struct {
|
||||
reason string
|
||||
code int
|
||||
header HandshakeHeader
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (r *rejectConnectionError) Error() string {
|
||||
return r.reason
|
||||
}
|
|
@ -0,0 +1,389 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
// Constants defined by specification.
|
||||
const (
|
||||
// All control frames MUST have a payload length of 125 bytes or less and MUST NOT be fragmented.
|
||||
MaxControlFramePayloadSize = 125
|
||||
)
|
||||
|
||||
// OpCode represents operation code.
|
||||
type OpCode byte
|
||||
|
||||
// Operation codes defined by specification.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
const (
|
||||
OpContinuation OpCode = 0x0
|
||||
OpText OpCode = 0x1
|
||||
OpBinary OpCode = 0x2
|
||||
OpClose OpCode = 0x8
|
||||
OpPing OpCode = 0x9
|
||||
OpPong OpCode = 0xa
|
||||
)
|
||||
|
||||
// IsControl checks whether the c is control operation code.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.5
|
||||
func (c OpCode) IsControl() bool {
|
||||
// RFC6455: Control frames are identified by opcodes where
|
||||
// the most significant bit of the opcode is 1.
|
||||
//
|
||||
// Note that OpCode is only 4 bit length.
|
||||
return c&0x8 != 0
|
||||
}
|
||||
|
||||
// IsData checks whether the c is data operation code.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.6
|
||||
func (c OpCode) IsData() bool {
|
||||
// RFC6455: Data frames (e.g., non-control frames) are identified by opcodes
|
||||
// where the most significant bit of the opcode is 0.
|
||||
//
|
||||
// Note that OpCode is only 4 bit length.
|
||||
return c&0x8 == 0
|
||||
}
|
||||
|
||||
// IsReserved checks whether the c is reserved operation code.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
func (c OpCode) IsReserved() bool {
|
||||
// RFC6455:
|
||||
// %x3-7 are reserved for further non-control frames
|
||||
// %xB-F are reserved for further control frames
|
||||
return (0x3 <= c && c <= 0x7) || (0xb <= c && c <= 0xf)
|
||||
}
|
||||
|
||||
// StatusCode represents the encoded reason for closure of websocket connection.
|
||||
//
|
||||
// There are few helper methods on StatusCode that helps to define a range in
|
||||
// which given code is lay in. accordingly to ranges defined in specification.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-7.4
|
||||
type StatusCode uint16
|
||||
|
||||
// StatusCodeRange describes range of StatusCode values.
|
||||
type StatusCodeRange struct {
|
||||
Min, Max StatusCode
|
||||
}
|
||||
|
||||
// Status code ranges defined by specification.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-7.4.2
|
||||
var (
|
||||
StatusRangeNotInUse = StatusCodeRange{0, 999}
|
||||
StatusRangeProtocol = StatusCodeRange{1000, 2999}
|
||||
StatusRangeApplication = StatusCodeRange{3000, 3999}
|
||||
StatusRangePrivate = StatusCodeRange{4000, 4999}
|
||||
)
|
||||
|
||||
// Status codes defined by specification.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-7.4.1
|
||||
const (
|
||||
StatusNormalClosure StatusCode = 1000
|
||||
StatusGoingAway StatusCode = 1001
|
||||
StatusProtocolError StatusCode = 1002
|
||||
StatusUnsupportedData StatusCode = 1003
|
||||
StatusNoMeaningYet StatusCode = 1004
|
||||
StatusInvalidFramePayloadData StatusCode = 1007
|
||||
StatusPolicyViolation StatusCode = 1008
|
||||
StatusMessageTooBig StatusCode = 1009
|
||||
StatusMandatoryExt StatusCode = 1010
|
||||
StatusInternalServerError StatusCode = 1011
|
||||
StatusTLSHandshake StatusCode = 1015
|
||||
|
||||
// StatusAbnormalClosure is a special code designated for use in
|
||||
// applications.
|
||||
StatusAbnormalClosure StatusCode = 1006
|
||||
|
||||
// StatusNoStatusRcvd is a special code designated for use in applications.
|
||||
StatusNoStatusRcvd StatusCode = 1005
|
||||
)
|
||||
|
||||
// In reports whether the code is defined in given range.
|
||||
func (s StatusCode) In(r StatusCodeRange) bool {
|
||||
return r.Min <= s && s <= r.Max
|
||||
}
|
||||
|
||||
// Empty reports whether the code is empty.
|
||||
// Empty code has no any meaning neither app level codes nor other.
|
||||
// This method is useful just to check that code is golang default value 0.
|
||||
func (s StatusCode) Empty() bool {
|
||||
return s == 0
|
||||
}
|
||||
|
||||
// IsNotUsed reports whether the code is predefined in not used range.
|
||||
func (s StatusCode) IsNotUsed() bool {
|
||||
return s.In(StatusRangeNotInUse)
|
||||
}
|
||||
|
||||
// IsApplicationSpec reports whether the code should be defined by
|
||||
// application, framework or libraries specification.
|
||||
func (s StatusCode) IsApplicationSpec() bool {
|
||||
return s.In(StatusRangeApplication)
|
||||
}
|
||||
|
||||
// IsPrivateSpec reports whether the code should be defined privately.
|
||||
func (s StatusCode) IsPrivateSpec() bool {
|
||||
return s.In(StatusRangePrivate)
|
||||
}
|
||||
|
||||
// IsProtocolSpec reports whether the code should be defined by protocol specification.
|
||||
func (s StatusCode) IsProtocolSpec() bool {
|
||||
return s.In(StatusRangeProtocol)
|
||||
}
|
||||
|
||||
// IsProtocolDefined reports whether the code is already defined by protocol specification.
|
||||
func (s StatusCode) IsProtocolDefined() bool {
|
||||
switch s {
|
||||
case StatusNormalClosure,
|
||||
StatusGoingAway,
|
||||
StatusProtocolError,
|
||||
StatusUnsupportedData,
|
||||
StatusInvalidFramePayloadData,
|
||||
StatusPolicyViolation,
|
||||
StatusMessageTooBig,
|
||||
StatusMandatoryExt,
|
||||
StatusInternalServerError,
|
||||
StatusNoStatusRcvd,
|
||||
StatusAbnormalClosure,
|
||||
StatusTLSHandshake:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsProtocolReserved reports whether the code is defined by protocol specification
|
||||
// to be reserved only for application usage purpose.
|
||||
func (s StatusCode) IsProtocolReserved() bool {
|
||||
switch s {
|
||||
// [RFC6455]: {1005,1006,1015} is a reserved value and MUST NOT be set as a status code in a
|
||||
// Close control frame by an endpoint.
|
||||
case StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Compiled control frames for common use cases.
|
||||
// For construct-serialize optimizations.
|
||||
var (
|
||||
CompiledPing = MustCompileFrame(NewPingFrame(nil))
|
||||
CompiledPong = MustCompileFrame(NewPongFrame(nil))
|
||||
CompiledClose = MustCompileFrame(NewCloseFrame(nil))
|
||||
|
||||
CompiledCloseNormalClosure = MustCompileFrame(closeFrameNormalClosure)
|
||||
CompiledCloseGoingAway = MustCompileFrame(closeFrameGoingAway)
|
||||
CompiledCloseProtocolError = MustCompileFrame(closeFrameProtocolError)
|
||||
CompiledCloseUnsupportedData = MustCompileFrame(closeFrameUnsupportedData)
|
||||
CompiledCloseNoMeaningYet = MustCompileFrame(closeFrameNoMeaningYet)
|
||||
CompiledCloseInvalidFramePayloadData = MustCompileFrame(closeFrameInvalidFramePayloadData)
|
||||
CompiledClosePolicyViolation = MustCompileFrame(closeFramePolicyViolation)
|
||||
CompiledCloseMessageTooBig = MustCompileFrame(closeFrameMessageTooBig)
|
||||
CompiledCloseMandatoryExt = MustCompileFrame(closeFrameMandatoryExt)
|
||||
CompiledCloseInternalServerError = MustCompileFrame(closeFrameInternalServerError)
|
||||
CompiledCloseTLSHandshake = MustCompileFrame(closeFrameTLSHandshake)
|
||||
)
|
||||
|
||||
// Header represents websocket frame header.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
type Header struct {
|
||||
Fin bool
|
||||
Rsv byte
|
||||
OpCode OpCode
|
||||
Masked bool
|
||||
Mask [4]byte
|
||||
Length int64
|
||||
}
|
||||
|
||||
// Rsv1 reports whether the header has first rsv bit set.
|
||||
func (h Header) Rsv1() bool { return h.Rsv&bit5 != 0 }
|
||||
|
||||
// Rsv2 reports whether the header has second rsv bit set.
|
||||
func (h Header) Rsv2() bool { return h.Rsv&bit6 != 0 }
|
||||
|
||||
// Rsv3 reports whether the header has third rsv bit set.
|
||||
func (h Header) Rsv3() bool { return h.Rsv&bit7 != 0 }
|
||||
|
||||
// Frame represents websocket frame.
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.2
|
||||
type Frame struct {
|
||||
Header Header
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// NewFrame creates frame with given operation code,
|
||||
// flag of completeness and payload bytes.
|
||||
func NewFrame(op OpCode, fin bool, p []byte) Frame {
|
||||
return Frame{
|
||||
Header: Header{
|
||||
Fin: fin,
|
||||
OpCode: op,
|
||||
Length: int64(len(p)),
|
||||
},
|
||||
Payload: p,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTextFrame creates text frame with p as payload.
|
||||
// Note that p is not copied.
|
||||
func NewTextFrame(p []byte) Frame {
|
||||
return NewFrame(OpText, true, p)
|
||||
}
|
||||
|
||||
// NewBinaryFrame creates binary frame with p as payload.
|
||||
// Note that p is not copied.
|
||||
func NewBinaryFrame(p []byte) Frame {
|
||||
return NewFrame(OpBinary, true, p)
|
||||
}
|
||||
|
||||
// NewPingFrame creates ping frame with p as payload.
|
||||
// Note that p is not copied.
|
||||
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
|
||||
// to RFC.
|
||||
func NewPingFrame(p []byte) Frame {
|
||||
return NewFrame(OpPing, true, p)
|
||||
}
|
||||
|
||||
// NewPongFrame creates pong frame with p as payload.
|
||||
// Note that p is not copied.
|
||||
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
|
||||
// to RFC.
|
||||
func NewPongFrame(p []byte) Frame {
|
||||
return NewFrame(OpPong, true, p)
|
||||
}
|
||||
|
||||
// NewCloseFrame creates close frame with given close body.
|
||||
// Note that p is not copied.
|
||||
// Note that p must have length of MaxControlFramePayloadSize bytes or less due
|
||||
// to RFC.
|
||||
func NewCloseFrame(p []byte) Frame {
|
||||
return NewFrame(OpClose, true, p)
|
||||
}
|
||||
|
||||
// NewCloseFrameBody encodes a closure code and a reason into a binary
|
||||
// representation.
|
||||
//
|
||||
// It returns slice which is at most MaxControlFramePayloadSize bytes length.
|
||||
// If the reason is too big it will be cropped to fit the limit defined by the
|
||||
// spec.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.5
|
||||
func NewCloseFrameBody(code StatusCode, reason string) []byte {
|
||||
n := min(2+len(reason), MaxControlFramePayloadSize)
|
||||
p := make([]byte, n)
|
||||
|
||||
crop := min(MaxControlFramePayloadSize-2, len(reason))
|
||||
PutCloseFrameBody(p, code, reason[:crop])
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// PutCloseFrameBody encodes code and reason into buf.
|
||||
//
|
||||
// It will panic if the buffer is too small to accommodate a code or a reason.
|
||||
//
|
||||
// PutCloseFrameBody does not check buffer to be RFC compliant, but note that
|
||||
// by RFC it must be at most MaxControlFramePayloadSize.
|
||||
func PutCloseFrameBody(p []byte, code StatusCode, reason string) {
|
||||
_ = p[1+len(reason)]
|
||||
binary.BigEndian.PutUint16(p, uint16(code))
|
||||
copy(p[2:], reason)
|
||||
}
|
||||
|
||||
// MaskFrame masks frame and returns frame with masked payload and Mask header's field set.
|
||||
// Note that it copies f payload to prevent collisions.
|
||||
// For less allocations you could use MaskFrameInPlace or construct frame manually.
|
||||
func MaskFrame(f Frame) Frame {
|
||||
return MaskFrameWith(f, NewMask())
|
||||
}
|
||||
|
||||
// MaskFrameWith masks frame with given mask and returns frame
|
||||
// with masked payload and Mask header's field set.
|
||||
// Note that it copies f payload to prevent collisions.
|
||||
// For less allocations you could use MaskFrameInPlaceWith or construct frame manually.
|
||||
func MaskFrameWith(f Frame, mask [4]byte) Frame {
|
||||
// TODO(gobwas): check CopyCipher ws copy() Cipher().
|
||||
p := make([]byte, len(f.Payload))
|
||||
copy(p, f.Payload)
|
||||
f.Payload = p
|
||||
return MaskFrameInPlaceWith(f, mask)
|
||||
}
|
||||
|
||||
// MaskFrameInPlace masks frame and returns frame with masked payload and Mask
|
||||
// header's field set.
|
||||
// Note that it applies xor cipher to f.Payload without copying, that is, it
|
||||
// modifies f.Payload inplace.
|
||||
func MaskFrameInPlace(f Frame) Frame {
|
||||
return MaskFrameInPlaceWith(f, NewMask())
|
||||
}
|
||||
|
||||
// MaskFrameInPlaceWith masks frame with given mask and returns frame
|
||||
// with masked payload and Mask header's field set.
|
||||
// Note that it applies xor cipher to f.Payload without copying, that is, it
|
||||
// modifies f.Payload inplace.
|
||||
func MaskFrameInPlaceWith(f Frame, m [4]byte) Frame {
|
||||
f.Header.Masked = true
|
||||
f.Header.Mask = m
|
||||
Cipher(f.Payload, m, 0)
|
||||
return f
|
||||
}
|
||||
|
||||
// NewMask creates new random mask.
|
||||
func NewMask() (ret [4]byte) {
|
||||
binary.BigEndian.PutUint32(ret[:], rand.Uint32())
|
||||
return
|
||||
}
|
||||
|
||||
// CompileFrame returns byte representation of given frame.
|
||||
// In terms of memory consumption it is useful to precompile static frames
|
||||
// which are often used.
|
||||
func CompileFrame(f Frame) (bts []byte, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 16))
|
||||
err = WriteFrame(buf, f)
|
||||
bts = buf.Bytes()
|
||||
return
|
||||
}
|
||||
|
||||
// MustCompileFrame is like CompileFrame but panics if frame can not be
|
||||
// encoded.
|
||||
func MustCompileFrame(f Frame) []byte {
|
||||
bts, err := CompileFrame(f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bts
|
||||
}
|
||||
|
||||
// Rsv creates rsv byte representation.
|
||||
func Rsv(r1, r2, r3 bool) (rsv byte) {
|
||||
if r1 {
|
||||
rsv |= bit5
|
||||
}
|
||||
if r2 {
|
||||
rsv |= bit6
|
||||
}
|
||||
if r3 {
|
||||
rsv |= bit7
|
||||
}
|
||||
return rsv
|
||||
}
|
||||
|
||||
func makeCloseFrame(code StatusCode) Frame {
|
||||
return NewCloseFrame(NewCloseFrameBody(code, ""))
|
||||
}
|
||||
|
||||
var (
|
||||
closeFrameNormalClosure = makeCloseFrame(StatusNormalClosure)
|
||||
closeFrameGoingAway = makeCloseFrame(StatusGoingAway)
|
||||
closeFrameProtocolError = makeCloseFrame(StatusProtocolError)
|
||||
closeFrameUnsupportedData = makeCloseFrame(StatusUnsupportedData)
|
||||
closeFrameNoMeaningYet = makeCloseFrame(StatusNoMeaningYet)
|
||||
closeFrameInvalidFramePayloadData = makeCloseFrame(StatusInvalidFramePayloadData)
|
||||
closeFramePolicyViolation = makeCloseFrame(StatusPolicyViolation)
|
||||
closeFrameMessageTooBig = makeCloseFrame(StatusMessageTooBig)
|
||||
closeFrameMandatoryExt = makeCloseFrame(StatusMandatoryExt)
|
||||
closeFrameInternalServerError = makeCloseFrame(StatusInternalServerError)
|
||||
closeFrameTLSHandshake = makeCloseFrame(StatusTLSHandshake)
|
||||
)
|
|
@ -0,0 +1,468 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/gobwas/httphead"
|
||||
)
|
||||
|
||||
const (
|
||||
crlf = "\r\n"
|
||||
colonAndSpace = ": "
|
||||
commaAndSpace = ", "
|
||||
)
|
||||
|
||||
const (
|
||||
textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||||
)
|
||||
|
||||
var (
|
||||
textHeadBadRequest = statusText(http.StatusBadRequest)
|
||||
textHeadInternalServerError = statusText(http.StatusInternalServerError)
|
||||
textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
|
||||
|
||||
textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
|
||||
textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
|
||||
textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
|
||||
textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
|
||||
textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
|
||||
textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
|
||||
textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
|
||||
textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
|
||||
textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
|
||||
)
|
||||
|
||||
var (
|
||||
headerHost = "Host"
|
||||
headerUpgrade = "Upgrade"
|
||||
headerConnection = "Connection"
|
||||
headerSecVersion = "Sec-WebSocket-Version"
|
||||
headerSecProtocol = "Sec-WebSocket-Protocol"
|
||||
headerSecExtensions = "Sec-WebSocket-Extensions"
|
||||
headerSecKey = "Sec-WebSocket-Key"
|
||||
headerSecAccept = "Sec-WebSocket-Accept"
|
||||
|
||||
headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
|
||||
headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
|
||||
headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
|
||||
headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
|
||||
headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
|
||||
headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
|
||||
headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
|
||||
headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
|
||||
)
|
||||
|
||||
var (
|
||||
specHeaderValueUpgrade = []byte("websocket")
|
||||
specHeaderValueConnection = []byte("Upgrade")
|
||||
specHeaderValueConnectionLower = []byte("upgrade")
|
||||
specHeaderValueSecVersion = []byte("13")
|
||||
)
|
||||
|
||||
var (
|
||||
httpVersion1_0 = []byte("HTTP/1.0")
|
||||
httpVersion1_1 = []byte("HTTP/1.1")
|
||||
httpVersionPrefix = []byte("HTTP/")
|
||||
)
|
||||
|
||||
type httpRequestLine struct {
|
||||
method, uri []byte
|
||||
major, minor int
|
||||
}
|
||||
|
||||
type httpResponseLine struct {
|
||||
major, minor int
|
||||
status int
|
||||
reason []byte
|
||||
}
|
||||
|
||||
// httpParseRequestLine parses http request line like "GET / HTTP/1.0".
|
||||
func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
|
||||
var proto []byte
|
||||
req.method, req.uri, proto = bsplit3(line, ' ')
|
||||
|
||||
var ok bool
|
||||
req.major, req.minor, ok = httpParseVersion(proto)
|
||||
if !ok {
|
||||
err = ErrMalformedRequest
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
|
||||
var (
|
||||
proto []byte
|
||||
status []byte
|
||||
)
|
||||
proto, status, resp.reason = bsplit3(line, ' ')
|
||||
|
||||
var ok bool
|
||||
resp.major, resp.minor, ok = httpParseVersion(proto)
|
||||
if !ok {
|
||||
return resp, ErrMalformedResponse
|
||||
}
|
||||
|
||||
var convErr error
|
||||
resp.status, convErr = asciiToInt(status)
|
||||
if convErr != nil {
|
||||
return resp, ErrMalformedResponse
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// httpParseVersion parses major and minor version of HTTP protocol. It returns
|
||||
// parsed values and true if parse is ok.
|
||||
func httpParseVersion(bts []byte) (major, minor int, ok bool) {
|
||||
switch {
|
||||
case bytes.Equal(bts, httpVersion1_0):
|
||||
return 1, 0, true
|
||||
case bytes.Equal(bts, httpVersion1_1):
|
||||
return 1, 1, true
|
||||
case len(bts) < 8:
|
||||
return
|
||||
case !bytes.Equal(bts[:5], httpVersionPrefix):
|
||||
return
|
||||
}
|
||||
|
||||
bts = bts[5:]
|
||||
|
||||
dot := bytes.IndexByte(bts, '.')
|
||||
if dot == -1 {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
major, err = asciiToInt(bts[:dot])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
minor, err = asciiToInt(bts[dot+1:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return major, minor, true
|
||||
}
|
||||
|
||||
// httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
|
||||
// values and true if parse is ok.
|
||||
func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
|
||||
colon := bytes.IndexByte(line, ':')
|
||||
if colon == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
k = btrim(line[:colon])
|
||||
// TODO(gobwas): maybe use just lower here?
|
||||
canonicalizeHeaderKey(k)
|
||||
|
||||
v = btrim(line[colon+1:])
|
||||
|
||||
return k, v, true
|
||||
}
|
||||
|
||||
// httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
|
||||
// that key is already canonical. This helps to increase performance.
|
||||
func httpGetHeader(h http.Header, key string) string {
|
||||
if h == nil {
|
||||
return ""
|
||||
}
|
||||
v := h[key]
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
}
|
||||
return v[0]
|
||||
}
|
||||
|
||||
// The request MAY include a header field with the name
|
||||
// |Sec-WebSocket-Protocol|. If present, this value indicates one or more
|
||||
// comma-separated subprotocol the client wishes to speak, ordered by
|
||||
// preference. The elements that comprise this value MUST be non-empty strings
|
||||
// with characters in the range U+0021 to U+007E not including separator
|
||||
// characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
|
||||
// for the value of this header field is 1#token, where the definitions of
|
||||
// constructs and rules are as given in [RFC2616].
|
||||
func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
|
||||
ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
|
||||
if check(btsToString(v)) {
|
||||
ret = string(v)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
|
||||
var selected []byte
|
||||
ok = httphead.ScanTokens(h, func(v []byte) bool {
|
||||
if check(v) {
|
||||
selected = v
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if ok && selected != nil {
|
||||
return string(selected), true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func strSelectExtensions(h string, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
||||
return btsSelectExtensions(strToBytes(h), selected, check)
|
||||
}
|
||||
|
||||
func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
|
||||
s := httphead.OptionSelector{
|
||||
Flags: httphead.SelectUnique | httphead.SelectCopy,
|
||||
Check: check,
|
||||
}
|
||||
return s.Select(h, selected)
|
||||
}
|
||||
|
||||
func httpWriteHeader(bw *bufio.Writer, key, value string) {
|
||||
httpWriteHeaderKey(bw, key)
|
||||
bw.WriteString(value)
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
|
||||
httpWriteHeaderKey(bw, key)
|
||||
bw.Write(value)
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
func httpWriteHeaderKey(bw *bufio.Writer, key string) {
|
||||
bw.WriteString(key)
|
||||
bw.WriteString(colonAndSpace)
|
||||
}
|
||||
|
||||
func httpWriteUpgradeRequest(
|
||||
bw *bufio.Writer,
|
||||
u *url.URL,
|
||||
nonce []byte,
|
||||
protocols []string,
|
||||
extensions []httphead.Option,
|
||||
header HandshakeHeader,
|
||||
) {
|
||||
bw.WriteString("GET ")
|
||||
bw.WriteString(u.RequestURI())
|
||||
bw.WriteString(" HTTP/1.1\r\n")
|
||||
|
||||
httpWriteHeader(bw, headerHost, u.Host)
|
||||
|
||||
httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
|
||||
httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
|
||||
httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
|
||||
|
||||
// NOTE: write nonce bytes as a string to prevent heap allocation –
|
||||
// WriteString() copy given string into its inner buffer, unlike Write()
|
||||
// which may write p directly to the underlying io.Writer – which in turn
|
||||
// will lead to p escape.
|
||||
httpWriteHeader(bw, headerSecKey, btsToString(nonce))
|
||||
|
||||
if len(protocols) > 0 {
|
||||
httpWriteHeaderKey(bw, headerSecProtocol)
|
||||
for i, p := range protocols {
|
||||
if i > 0 {
|
||||
bw.WriteString(commaAndSpace)
|
||||
}
|
||||
bw.WriteString(p)
|
||||
}
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
if len(extensions) > 0 {
|
||||
httpWriteHeaderKey(bw, headerSecExtensions)
|
||||
httphead.WriteOptions(bw, extensions)
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
if header != nil {
|
||||
header.WriteTo(bw)
|
||||
}
|
||||
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
|
||||
bw.WriteString(textHeadUpgrade)
|
||||
|
||||
httpWriteHeaderKey(bw, headerSecAccept)
|
||||
writeAccept(bw, nonce)
|
||||
bw.WriteString(crlf)
|
||||
|
||||
if hs.Protocol != "" {
|
||||
httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
|
||||
}
|
||||
if len(hs.Extensions) > 0 {
|
||||
httpWriteHeaderKey(bw, headerSecExtensions)
|
||||
httphead.WriteOptions(bw, hs.Extensions)
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
if header != nil {
|
||||
header(bw)
|
||||
}
|
||||
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
bw.WriteString(textHeadBadRequest)
|
||||
case http.StatusInternalServerError:
|
||||
bw.WriteString(textHeadInternalServerError)
|
||||
case http.StatusUpgradeRequired:
|
||||
bw.WriteString(textHeadUpgradeRequired)
|
||||
default:
|
||||
writeStatusText(bw, code)
|
||||
}
|
||||
|
||||
// Write custom headers.
|
||||
if header != nil {
|
||||
header(bw)
|
||||
}
|
||||
|
||||
switch err {
|
||||
case ErrHandshakeBadProtocol:
|
||||
bw.WriteString(textTailErrHandshakeBadProtocol)
|
||||
case ErrHandshakeBadMethod:
|
||||
bw.WriteString(textTailErrHandshakeBadMethod)
|
||||
case ErrHandshakeBadHost:
|
||||
bw.WriteString(textTailErrHandshakeBadHost)
|
||||
case ErrHandshakeBadUpgrade:
|
||||
bw.WriteString(textTailErrHandshakeBadUpgrade)
|
||||
case ErrHandshakeBadConnection:
|
||||
bw.WriteString(textTailErrHandshakeBadConnection)
|
||||
case ErrHandshakeBadSecAccept:
|
||||
bw.WriteString(textTailErrHandshakeBadSecAccept)
|
||||
case ErrHandshakeBadSecKey:
|
||||
bw.WriteString(textTailErrHandshakeBadSecKey)
|
||||
case ErrHandshakeBadSecVersion:
|
||||
bw.WriteString(textTailErrHandshakeBadSecVersion)
|
||||
case ErrHandshakeUpgradeRequired:
|
||||
bw.WriteString(textTailErrUpgradeRequired)
|
||||
case nil:
|
||||
bw.WriteString(crlf)
|
||||
default:
|
||||
writeErrorText(bw, err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeStatusText(bw *bufio.Writer, code int) {
|
||||
bw.WriteString("HTTP/1.1 ")
|
||||
bw.WriteString(strconv.Itoa(code))
|
||||
bw.WriteByte(' ')
|
||||
bw.WriteString(http.StatusText(code))
|
||||
bw.WriteString(crlf)
|
||||
bw.WriteString("Content-Type: text/plain; charset=utf-8")
|
||||
bw.WriteString(crlf)
|
||||
}
|
||||
|
||||
func writeErrorText(bw *bufio.Writer, err error) {
|
||||
body := err.Error()
|
||||
bw.WriteString("Content-Length: ")
|
||||
bw.WriteString(strconv.Itoa(len(body)))
|
||||
bw.WriteString(crlf)
|
||||
bw.WriteString(crlf)
|
||||
bw.WriteString(body)
|
||||
}
|
||||
|
||||
// httpError is like the http.Error with WebSocket context exception.
|
||||
func httpError(w http.ResponseWriter, body string, code int) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
w.WriteHeader(code)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
|
||||
// statusText is a non-performant status text generator.
|
||||
// NOTE: Used only to generate constants.
|
||||
func statusText(code int) string {
|
||||
var buf bytes.Buffer
|
||||
bw := bufio.NewWriter(&buf)
|
||||
writeStatusText(bw, code)
|
||||
bw.Flush()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// errorText is a non-performant error text generator.
|
||||
// NOTE: Used only to generate constants.
|
||||
func errorText(err error) string {
|
||||
var buf bytes.Buffer
|
||||
bw := bufio.NewWriter(&buf)
|
||||
writeErrorText(bw, err)
|
||||
bw.Flush()
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// HandshakeHeader is the interface that writes both upgrade request or
|
||||
// response headers into a given io.Writer.
|
||||
type HandshakeHeader interface {
|
||||
io.WriterTo
|
||||
}
|
||||
|
||||
// HandshakeHeaderString is an adapter to allow the use of headers represented
|
||||
// by ordinary string as HandshakeHeader.
|
||||
type HandshakeHeaderString string
|
||||
|
||||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||
func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := io.WriteString(w, string(s))
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// HandshakeHeaderBytes is an adapter to allow the use of headers represented
|
||||
// by ordinary slice of bytes as HandshakeHeader.
|
||||
type HandshakeHeaderBytes []byte
|
||||
|
||||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||
func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := w.Write(b)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// HandshakeHeaderFunc is an adapter to allow the use of headers represented by
|
||||
// ordinary function as HandshakeHeader.
|
||||
type HandshakeHeaderFunc func(io.Writer) (int64, error)
|
||||
|
||||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||
func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
|
||||
return f(w)
|
||||
}
|
||||
|
||||
// HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
|
||||
// HandshakeHeader.
|
||||
type HandshakeHeaderHTTP http.Header
|
||||
|
||||
// WriteTo implements HandshakeHeader (and io.WriterTo) interface.
|
||||
func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
|
||||
wr := writer{w: w}
|
||||
err := http.Header(h).Write(&wr)
|
||||
return wr.n, err
|
||||
}
|
||||
|
||||
type writer struct {
|
||||
n int64
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (w *writer) WriteString(s string) (int, error) {
|
||||
n, err := io.WriteString(w.w, s)
|
||||
w.n += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *writer) Write(p []byte) (int, error) {
|
||||
n, err := w.w.Write(p)
|
||||
w.n += int64(n)
|
||||
return n, err
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
const (
|
||||
// RFC6455: The value of this header field MUST be a nonce consisting of a
|
||||
// randomly selected 16-byte value that has been base64-encoded (see
|
||||
// Section 4 of [RFC4648]). The nonce MUST be selected randomly for each
|
||||
// connection.
|
||||
nonceKeySize = 16
|
||||
nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
|
||||
|
||||
// RFC6455: The value of this header field is constructed by concatenating
|
||||
// /key/, defined above in step 4 in Section 4.2.2, with the string
|
||||
// "258EAFA5- E914-47DA-95CA-C5AB0DC85B11", taking the SHA-1 hash of this
|
||||
// concatenated value to obtain a 20-byte value and base64- encoding (see
|
||||
// Section 4 of [RFC4648]) this 20-byte hash.
|
||||
acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
|
||||
)
|
||||
|
||||
// initNonce fills given slice with random base64-encoded nonce bytes.
|
||||
func initNonce(dst []byte) {
|
||||
// NOTE: bts does not escape.
|
||||
bts := make([]byte, nonceKeySize)
|
||||
if _, err := rand.Read(bts); err != nil {
|
||||
panic(fmt.Sprintf("rand read error: %s", err))
|
||||
}
|
||||
base64.StdEncoding.Encode(dst, bts)
|
||||
}
|
||||
|
||||
// checkAcceptFromNonce reports whether given accept bytes are valid for given
|
||||
// nonce bytes.
|
||||
func checkAcceptFromNonce(accept, nonce []byte) bool {
|
||||
if len(accept) != acceptSize {
|
||||
return false
|
||||
}
|
||||
// NOTE: expect does not escape.
|
||||
expect := make([]byte, acceptSize)
|
||||
initAcceptFromNonce(expect, nonce)
|
||||
return bytes.Equal(expect, accept)
|
||||
}
|
||||
|
||||
// initAcceptFromNonce fills given slice with accept bytes generated from given
|
||||
// nonce bytes. Given buffer should be exactly acceptSize bytes.
|
||||
func initAcceptFromNonce(accept, nonce []byte) {
|
||||
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
if len(accept) != acceptSize {
|
||||
panic("accept buffer is invalid")
|
||||
}
|
||||
if len(nonce) != nonceSize {
|
||||
panic("nonce is invalid")
|
||||
}
|
||||
|
||||
p := make([]byte, nonceSize+len(magic))
|
||||
copy(p[:nonceSize], nonce)
|
||||
copy(p[nonceSize:], magic)
|
||||
|
||||
sum := sha1.Sum(p)
|
||||
base64.StdEncoding.Encode(accept, sum[:])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func writeAccept(bw *bufio.Writer, nonce []byte) (int, error) {
|
||||
accept := make([]byte, acceptSize)
|
||||
initAcceptFromNonce(accept, nonce)
|
||||
// NOTE: write accept bytes as a string to prevent heap allocation –
|
||||
// WriteString() copy given string into its inner buffer, unlike Write()
|
||||
// which may write p directly to the underlying io.Writer – which in turn
|
||||
// will lead to p escape.
|
||||
return bw.WriteString(btsToString(accept))
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Errors used by frame reader.
|
||||
var (
|
||||
ErrHeaderLengthMSB = fmt.Errorf("header error: the most significant bit must be 0")
|
||||
ErrHeaderLengthUnexpected = fmt.Errorf("header error: unexpected payload length bits")
|
||||
)
|
||||
|
||||
// ReadHeader reads a frame header from r.
|
||||
func ReadHeader(r io.Reader) (h Header, err error) {
|
||||
// Make slice of bytes with capacity 12 that could hold any header.
|
||||
//
|
||||
// The maximum header size is 14, but due to the 2 hop reads,
|
||||
// after first hop that reads first 2 constant bytes, we could reuse 2 bytes.
|
||||
// So 14 - 2 = 12.
|
||||
bts := make([]byte, 2, MaxHeaderSize-2)
|
||||
|
||||
// Prepare to hold first 2 bytes to choose size of next read.
|
||||
_, err = io.ReadFull(r, bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.Fin = bts[0]&bit0 != 0
|
||||
h.Rsv = (bts[0] & 0x70) >> 4
|
||||
h.OpCode = OpCode(bts[0] & 0x0f)
|
||||
|
||||
var extra int
|
||||
|
||||
if bts[1]&bit0 != 0 {
|
||||
h.Masked = true
|
||||
extra += 4
|
||||
}
|
||||
|
||||
length := bts[1] & 0x7f
|
||||
switch {
|
||||
case length < 126:
|
||||
h.Length = int64(length)
|
||||
|
||||
case length == 126:
|
||||
extra += 2
|
||||
|
||||
case length == 127:
|
||||
extra += 8
|
||||
|
||||
default:
|
||||
err = ErrHeaderLengthUnexpected
|
||||
return
|
||||
}
|
||||
|
||||
if extra == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Increase len of bts to extra bytes need to read.
|
||||
// Overwrite first 2 bytes that was read before.
|
||||
bts = bts[:extra]
|
||||
_, err = io.ReadFull(r, bts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case length == 126:
|
||||
h.Length = int64(binary.BigEndian.Uint16(bts[:2]))
|
||||
bts = bts[2:]
|
||||
|
||||
case length == 127:
|
||||
if bts[0]&0x80 != 0 {
|
||||
err = ErrHeaderLengthMSB
|
||||
return
|
||||
}
|
||||
h.Length = int64(binary.BigEndian.Uint64(bts[:8]))
|
||||
bts = bts[8:]
|
||||
}
|
||||
|
||||
if h.Masked {
|
||||
copy(h.Mask[:], bts)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ReadFrame reads a frame from r.
|
||||
// It is not designed for high optimized use case cause it makes allocation
|
||||
// for frame.Header.Length size inside to read frame payload into.
|
||||
//
|
||||
// Note that ReadFrame does not unmask payload.
|
||||
func ReadFrame(r io.Reader) (f Frame, err error) {
|
||||
f.Header, err = ReadHeader(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if f.Header.Length > 0 {
|
||||
// int(f.Header.Length) is safe here cause we have
|
||||
// checked it for overflow above in ReadHeader.
|
||||
f.Payload = make([]byte, int(f.Header.Length))
|
||||
_, err = io.ReadFull(r, f.Payload)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MustReadFrame is like ReadFrame but panics if frame can not be read.
|
||||
func MustReadFrame(r io.Reader) Frame {
|
||||
f, err := ReadFrame(r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// ParseCloseFrameData parses close frame status code and closure reason if any provided.
|
||||
// If there is no status code in the payload
|
||||
// the empty status code is returned (code.Empty()) with empty string as a reason.
|
||||
func ParseCloseFrameData(payload []byte) (code StatusCode, reason string) {
|
||||
if len(payload) < 2 {
|
||||
// We returning empty StatusCode here, preventing the situation
|
||||
// when endpoint really sent code 1005 and we should return ProtocolError on that.
|
||||
//
|
||||
// In other words, we ignoring this rule [RFC6455:7.1.5]:
|
||||
// If this Close control frame contains no status code, _The WebSocket
|
||||
// Connection Close Code_ is considered to be 1005.
|
||||
return
|
||||
}
|
||||
code = StatusCode(binary.BigEndian.Uint16(payload))
|
||||
reason = string(payload[2:])
|
||||
return
|
||||
}
|
||||
|
||||
// ParseCloseFrameDataUnsafe is like ParseCloseFrameData except the thing
|
||||
// that it does not copies payload bytes into reason, but prepares unsafe cast.
|
||||
func ParseCloseFrameDataUnsafe(payload []byte) (code StatusCode, reason string) {
|
||||
if len(payload) < 2 {
|
||||
return
|
||||
}
|
||||
code = StatusCode(binary.BigEndian.Uint16(payload))
|
||||
reason = btsToString(payload[2:])
|
||||
return
|
||||
}
|
|
@ -0,0 +1,607 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/httphead"
|
||||
"github.com/gobwas/pool/pbufio"
|
||||
)
|
||||
|
||||
// Constants used by ConnUpgrader.
|
||||
const (
|
||||
DefaultServerReadBufferSize = 4096
|
||||
DefaultServerWriteBufferSize = 512
|
||||
)
|
||||
|
||||
// Errors used by both client and server when preparing WebSocket handshake.
|
||||
var (
|
||||
ErrHandshakeBadProtocol = RejectConnectionError(
|
||||
RejectionStatus(http.StatusHTTPVersionNotSupported),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad HTTP protocol version")),
|
||||
)
|
||||
ErrHandshakeBadMethod = RejectConnectionError(
|
||||
RejectionStatus(http.StatusMethodNotAllowed),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad HTTP request method")),
|
||||
)
|
||||
ErrHandshakeBadHost = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerHost)),
|
||||
)
|
||||
ErrHandshakeBadUpgrade = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerUpgrade)),
|
||||
)
|
||||
ErrHandshakeBadConnection = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerConnection)),
|
||||
)
|
||||
ErrHandshakeBadSecAccept = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecAccept)),
|
||||
)
|
||||
ErrHandshakeBadSecKey = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecKey)),
|
||||
)
|
||||
ErrHandshakeBadSecVersion = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
|
||||
)
|
||||
)
|
||||
|
||||
// ErrMalformedResponse is returned by Dialer to indicate that server response
|
||||
// can not be parsed.
|
||||
var ErrMalformedResponse = fmt.Errorf("malformed HTTP response")
|
||||
|
||||
// ErrMalformedRequest is returned when HTTP request can not be parsed.
|
||||
var ErrMalformedRequest = RejectConnectionError(
|
||||
RejectionStatus(http.StatusBadRequest),
|
||||
RejectionReason("malformed HTTP request"),
|
||||
)
|
||||
|
||||
// ErrHandshakeUpgradeRequired is returned by Upgrader to indicate that
|
||||
// connection is rejected because given WebSocket version is malformed.
|
||||
//
|
||||
// According to RFC6455:
|
||||
// If this version does not match a version understood by the server, the
|
||||
// server MUST abort the WebSocket handshake described in this section and
|
||||
// instead send an appropriate HTTP error code (such as 426 Upgrade Required)
|
||||
// and a |Sec-WebSocket-Version| header field indicating the version(s) the
|
||||
// server is capable of understanding.
|
||||
var ErrHandshakeUpgradeRequired = RejectConnectionError(
|
||||
RejectionStatus(http.StatusUpgradeRequired),
|
||||
RejectionHeader(HandshakeHeaderString(headerSecVersion+": 13\r\n")),
|
||||
RejectionReason(fmt.Sprintf("handshake error: bad %q header", headerSecVersion)),
|
||||
)
|
||||
|
||||
// ErrNotHijacker is an error returned when http.ResponseWriter does not
|
||||
// implement http.Hijacker interface.
|
||||
var ErrNotHijacker = RejectConnectionError(
|
||||
RejectionStatus(http.StatusInternalServerError),
|
||||
RejectionReason("given http.ResponseWriter is not a http.Hijacker"),
|
||||
)
|
||||
|
||||
// DefaultHTTPUpgrader is an HTTPUpgrader that holds no options and is used by
|
||||
// UpgradeHTTP function.
|
||||
var DefaultHTTPUpgrader HTTPUpgrader
|
||||
|
||||
// UpgradeHTTP is like HTTPUpgrader{}.Upgrade().
|
||||
func UpgradeHTTP(r *http.Request, w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, Handshake, error) {
|
||||
return DefaultHTTPUpgrader.Upgrade(r, w)
|
||||
}
|
||||
|
||||
// DefaultUpgrader is an Upgrader that holds no options and is used by Upgrade
|
||||
// function.
|
||||
var DefaultUpgrader Upgrader
|
||||
|
||||
// Upgrade is like Upgrader{}.Upgrade().
|
||||
func Upgrade(conn io.ReadWriter) (Handshake, error) {
|
||||
return DefaultUpgrader.Upgrade(conn)
|
||||
}
|
||||
|
||||
// HTTPUpgrader contains options for upgrading connection to websocket from
|
||||
// net/http Handler arguments.
|
||||
type HTTPUpgrader struct {
|
||||
// Timeout is the maximum amount of time an Upgrade() will spent while
|
||||
// writing handshake response.
|
||||
//
|
||||
// The default is no timeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// Header is an optional http.Header mapping that could be used to
|
||||
// write additional headers to the handshake response.
|
||||
//
|
||||
// Note that if present, it will be written in any result of handshake.
|
||||
Header http.Header
|
||||
|
||||
// Protocol is the select function that is used to select subprotocol from
|
||||
// list requested by client. If this field is set, then the first matched
|
||||
// protocol is sent to a client as negotiated.
|
||||
Protocol func(string) bool
|
||||
|
||||
// Extension is the select function that is used to select extensions from
|
||||
// list requested by client. If this field is set, then the all matched
|
||||
// extensions are sent to a client as negotiated.
|
||||
Extension func(httphead.Option) bool
|
||||
}
|
||||
|
||||
// Upgrade upgrades http connection to the websocket connection.
|
||||
//
|
||||
// It hijacks net.Conn from w and returns received net.Conn and
|
||||
// bufio.ReadWriter. On successful handshake it returns Handshake struct
|
||||
// describing handshake info.
|
||||
func (u HTTPUpgrader) Upgrade(r *http.Request, w http.ResponseWriter) (conn net.Conn, rw *bufio.ReadWriter, hs Handshake, err error) {
|
||||
// Hijack connection first to get the ability to write rejection errors the
|
||||
// same way as in Upgrader.
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if ok {
|
||||
conn, rw, err = hj.Hijack()
|
||||
} else {
|
||||
err = ErrNotHijacker
|
||||
}
|
||||
if err != nil {
|
||||
httpError(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// See https://tools.ietf.org/html/rfc6455#section-4.1
|
||||
// The method of the request MUST be GET, and the HTTP version MUST be at least 1.1.
|
||||
var nonce string
|
||||
if r.Method != http.MethodGet {
|
||||
err = ErrHandshakeBadMethod
|
||||
} else if r.ProtoMajor < 1 || (r.ProtoMajor == 1 && r.ProtoMinor < 1) {
|
||||
err = ErrHandshakeBadProtocol
|
||||
} else if r.Host == "" {
|
||||
err = ErrHandshakeBadHost
|
||||
} else if u := httpGetHeader(r.Header, headerUpgradeCanonical); u != "websocket" && !strings.EqualFold(u, "websocket") {
|
||||
err = ErrHandshakeBadUpgrade
|
||||
} else if c := httpGetHeader(r.Header, headerConnectionCanonical); c != "Upgrade" && !strHasToken(c, "upgrade") {
|
||||
err = ErrHandshakeBadConnection
|
||||
} else if nonce = httpGetHeader(r.Header, headerSecKeyCanonical); len(nonce) != nonceSize {
|
||||
err = ErrHandshakeBadSecKey
|
||||
} else if v := httpGetHeader(r.Header, headerSecVersionCanonical); v != "13" {
|
||||
// According to RFC6455:
|
||||
//
|
||||
// If this version does not match a version understood by the server,
|
||||
// the server MUST abort the WebSocket handshake described in this
|
||||
// section and instead send an appropriate HTTP error code (such as 426
|
||||
// Upgrade Required) and a |Sec-WebSocket-Version| header field
|
||||
// indicating the version(s) the server is capable of understanding.
|
||||
//
|
||||
// So we branching here cause empty or not present version does not
|
||||
// meet the ABNF rules of RFC6455:
|
||||
//
|
||||
// version = DIGIT | (NZDIGIT DIGIT) |
|
||||
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
|
||||
// ; Limited to 0-255 range, with no leading zeros
|
||||
//
|
||||
// That is, if version is really invalid – we sent 426 status, if it
|
||||
// not present or empty – it is 400.
|
||||
if v != "" {
|
||||
err = ErrHandshakeUpgradeRequired
|
||||
} else {
|
||||
err = ErrHandshakeBadSecVersion
|
||||
}
|
||||
}
|
||||
if check := u.Protocol; err == nil && check != nil {
|
||||
ps := r.Header[headerSecProtocolCanonical]
|
||||
for i := 0; i < len(ps) && err == nil && hs.Protocol == ""; i++ {
|
||||
var ok bool
|
||||
hs.Protocol, ok = strSelectProtocol(ps[i], check)
|
||||
if !ok {
|
||||
err = ErrMalformedRequest
|
||||
}
|
||||
}
|
||||
}
|
||||
if check := u.Extension; err == nil && check != nil {
|
||||
xs := r.Header[headerSecExtensionsCanonical]
|
||||
for i := 0; i < len(xs) && err == nil; i++ {
|
||||
var ok bool
|
||||
hs.Extensions, ok = strSelectExtensions(xs[i], hs.Extensions, check)
|
||||
if !ok {
|
||||
err = ErrMalformedRequest
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear deadlines set by server.
|
||||
conn.SetDeadline(noDeadline)
|
||||
if t := u.Timeout; t != 0 {
|
||||
conn.SetWriteDeadline(time.Now().Add(t))
|
||||
defer conn.SetWriteDeadline(noDeadline)
|
||||
}
|
||||
|
||||
var header handshakeHeader
|
||||
if h := u.Header; h != nil {
|
||||
header[0] = HandshakeHeaderHTTP(h)
|
||||
}
|
||||
if err == nil {
|
||||
httpWriteResponseUpgrade(rw.Writer, strToBytes(nonce), hs, header.WriteTo)
|
||||
err = rw.Writer.Flush()
|
||||
} else {
|
||||
var code int
|
||||
if rej, ok := err.(*rejectConnectionError); ok {
|
||||
code = rej.code
|
||||
header[1] = rej.header
|
||||
}
|
||||
if code == 0 {
|
||||
code = http.StatusInternalServerError
|
||||
}
|
||||
httpWriteResponseError(rw.Writer, err, code, header.WriteTo)
|
||||
// Do not store Flush() error to not override already existing one.
|
||||
rw.Writer.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Upgrader contains options for upgrading connection to websocket.
|
||||
type Upgrader struct {
|
||||
// ReadBufferSize and WriteBufferSize is an I/O buffer sizes.
|
||||
// They used to read and write http data while upgrading to WebSocket.
|
||||
// Allocated buffers are pooled with sync.Pool to avoid extra allocations.
|
||||
//
|
||||
// If a size is zero then default value is used.
|
||||
//
|
||||
// Usually it is useful to set read buffer size bigger than write buffer
|
||||
// size because incoming request could contain long header values, such as
|
||||
// Cookie. Response, in other way, could be big only if user write multiple
|
||||
// custom headers. Usually response takes less than 256 bytes.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
||||
// Protocol is a select function that is used to select subprotocol
|
||||
// from list requested by client. If this field is set, then the first matched
|
||||
// protocol is sent to a client as negotiated.
|
||||
//
|
||||
// The argument is only valid until the callback returns.
|
||||
Protocol func([]byte) bool
|
||||
|
||||
// ProtocolCustrom allow user to parse Sec-WebSocket-Protocol header manually.
|
||||
// Note that returned bytes must be valid until Upgrade returns.
|
||||
// If ProtocolCustom is set, it used instead of Protocol function.
|
||||
ProtocolCustom func([]byte) (string, bool)
|
||||
|
||||
// Extension is a select function that is used to select extensions
|
||||
// from list requested by client. If this field is set, then the all matched
|
||||
// extensions are sent to a client as negotiated.
|
||||
//
|
||||
// The argument is only valid until the callback returns.
|
||||
//
|
||||
// According to the RFC6455 order of extensions passed by a client is
|
||||
// significant. That is, returning true from this function means that no
|
||||
// other extension with the same name should be checked because server
|
||||
// accepted the most preferable extension right now:
|
||||
// "Note that the order of extensions is significant. Any interactions between
|
||||
// multiple extensions MAY be defined in the documents defining the extensions.
|
||||
// In the absence of such definitions, the interpretation is that the header
|
||||
// fields listed by the client in its request represent a preference of the
|
||||
// header fields it wishes to use, with the first options listed being most
|
||||
// preferable."
|
||||
Extension func(httphead.Option) bool
|
||||
|
||||
// ExtensionCustom allow user to parse Sec-WebSocket-Extensions header manually.
|
||||
// Note that returned options should be valid until Upgrade returns.
|
||||
// If ExtensionCustom is set, it used instead of Extension function.
|
||||
ExtensionCustom func([]byte, []httphead.Option) ([]httphead.Option, bool)
|
||||
|
||||
// Header is an optional HandshakeHeader instance that could be used to
|
||||
// write additional headers to the handshake response.
|
||||
//
|
||||
// It used instead of any key-value mappings to avoid allocations in user
|
||||
// land.
|
||||
//
|
||||
// Note that if present, it will be written in any result of handshake.
|
||||
Header HandshakeHeader
|
||||
|
||||
// OnRequest is a callback that will be called after request line
|
||||
// successful parsing.
|
||||
//
|
||||
// The arguments are only valid until the callback returns.
|
||||
//
|
||||
// If returned error is non-nil then connection is rejected and response is
|
||||
// sent with appropriate HTTP error code and body set to error message.
|
||||
//
|
||||
// RejectConnectionError could be used to get more control on response.
|
||||
OnRequest func(uri []byte) error
|
||||
|
||||
// OnHost is a callback that will be called after "Host" header successful
|
||||
// parsing.
|
||||
//
|
||||
// It is separated from OnHeader callback because the Host header must be
|
||||
// present in each request since HTTP/1.1. Thus Host header is non-optional
|
||||
// and required for every WebSocket handshake.
|
||||
//
|
||||
// The arguments are only valid until the callback returns.
|
||||
//
|
||||
// If returned error is non-nil then connection is rejected and response is
|
||||
// sent with appropriate HTTP error code and body set to error message.
|
||||
//
|
||||
// RejectConnectionError could be used to get more control on response.
|
||||
OnHost func(host []byte) error
|
||||
|
||||
// OnHeader is a callback that will be called after successful parsing of
|
||||
// header, that is not used during WebSocket handshake procedure. That is,
|
||||
// it will be called with non-websocket headers, which could be relevant
|
||||
// for application-level logic.
|
||||
//
|
||||
// The arguments are only valid until the callback returns.
|
||||
//
|
||||
// If returned error is non-nil then connection is rejected and response is
|
||||
// sent with appropriate HTTP error code and body set to error message.
|
||||
//
|
||||
// RejectConnectionError could be used to get more control on response.
|
||||
OnHeader func(key, value []byte) error
|
||||
|
||||
// OnBeforeUpgrade is a callback that will be called before sending
|
||||
// successful upgrade response.
|
||||
//
|
||||
// Setting OnBeforeUpgrade allows user to make final application-level
|
||||
// checks and decide whether this connection is allowed to successfully
|
||||
// upgrade to WebSocket.
|
||||
//
|
||||
// It must return non-nil either HandshakeHeader or error and never both.
|
||||
//
|
||||
// If returned error is non-nil then connection is rejected and response is
|
||||
// sent with appropriate HTTP error code and body set to error message.
|
||||
//
|
||||
// RejectConnectionError could be used to get more control on response.
|
||||
OnBeforeUpgrade func() (header HandshakeHeader, err error)
|
||||
}
|
||||
|
||||
// Upgrade zero-copy upgrades connection to WebSocket. It interprets given conn
|
||||
// as connection with incoming HTTP Upgrade request.
|
||||
//
|
||||
// It is a caller responsibility to manage i/o timeouts on conn.
|
||||
//
|
||||
// Non-nil error means that request for the WebSocket upgrade is invalid or
|
||||
// malformed and usually connection should be closed.
|
||||
// Even when error is non-nil Upgrade will write appropriate response into
|
||||
// connection in compliance with RFC.
|
||||
func (u Upgrader) Upgrade(conn io.ReadWriter) (hs Handshake, err error) {
|
||||
// headerSeen constants helps to report whether or not some header was seen
|
||||
// during reading request bytes.
|
||||
const (
|
||||
headerSeenHost = 1 << iota
|
||||
headerSeenUpgrade
|
||||
headerSeenConnection
|
||||
headerSeenSecVersion
|
||||
headerSeenSecKey
|
||||
|
||||
// headerSeenAll is the value that we expect to receive at the end of
|
||||
// headers read/parse loop.
|
||||
headerSeenAll = 0 |
|
||||
headerSeenHost |
|
||||
headerSeenUpgrade |
|
||||
headerSeenConnection |
|
||||
headerSeenSecVersion |
|
||||
headerSeenSecKey
|
||||
)
|
||||
|
||||
// Prepare I/O buffers.
|
||||
// TODO(gobwas): make it configurable.
|
||||
br := pbufio.GetReader(conn,
|
||||
nonZero(u.ReadBufferSize, DefaultServerReadBufferSize),
|
||||
)
|
||||
bw := pbufio.GetWriter(conn,
|
||||
nonZero(u.WriteBufferSize, DefaultServerWriteBufferSize),
|
||||
)
|
||||
defer func() {
|
||||
pbufio.PutReader(br)
|
||||
pbufio.PutWriter(bw)
|
||||
}()
|
||||
|
||||
// Read HTTP request line like "GET /ws HTTP/1.1".
|
||||
rl, err := readLine(br)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Parse request line data like HTTP version, uri and method.
|
||||
req, err := httpParseRequestLine(rl)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare stack-based handshake header list.
|
||||
header := handshakeHeader{
|
||||
0: u.Header,
|
||||
}
|
||||
|
||||
// Parse and check HTTP request.
|
||||
// As RFC6455 says:
|
||||
// The client's opening handshake consists of the following parts. If the
|
||||
// server, while reading the handshake, finds that the client did not
|
||||
// send a handshake that matches the description below (note that as per
|
||||
// [RFC2616], the order of the header fields is not important), including
|
||||
// but not limited to any violations of the ABNF grammar specified for
|
||||
// the components of the handshake, the server MUST stop processing the
|
||||
// client's handshake and return an HTTP response with an appropriate
|
||||
// error code (such as 400 Bad Request).
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-4.2.1
|
||||
|
||||
// An HTTP/1.1 or higher GET request, including a "Request-URI".
|
||||
//
|
||||
// Even if RFC says "1.1 or higher" without mentioning the part of the
|
||||
// version, we apply it only to minor part.
|
||||
switch {
|
||||
case req.major != 1 || req.minor < 1:
|
||||
// Abort processing the whole request because we do not even know how
|
||||
// to actually parse it.
|
||||
err = ErrHandshakeBadProtocol
|
||||
|
||||
case btsToString(req.method) != http.MethodGet:
|
||||
err = ErrHandshakeBadMethod
|
||||
|
||||
default:
|
||||
if onRequest := u.OnRequest; onRequest != nil {
|
||||
err = onRequest(req.uri)
|
||||
}
|
||||
}
|
||||
// Start headers read/parse loop.
|
||||
var (
|
||||
// headerSeen reports which header was seen by setting corresponding
|
||||
// bit on.
|
||||
headerSeen byte
|
||||
|
||||
nonce = make([]byte, nonceSize)
|
||||
)
|
||||
for err == nil {
|
||||
line, e := readLine(br)
|
||||
if e != nil {
|
||||
return hs, e
|
||||
}
|
||||
if len(line) == 0 {
|
||||
// Blank line, no more lines to read.
|
||||
break
|
||||
}
|
||||
|
||||
k, v, ok := httpParseHeaderLine(line)
|
||||
if !ok {
|
||||
err = ErrMalformedRequest
|
||||
break
|
||||
}
|
||||
|
||||
switch btsToString(k) {
|
||||
case headerHostCanonical:
|
||||
headerSeen |= headerSeenHost
|
||||
if onHost := u.OnHost; onHost != nil {
|
||||
err = onHost(v)
|
||||
}
|
||||
|
||||
case headerUpgradeCanonical:
|
||||
headerSeen |= headerSeenUpgrade
|
||||
if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) {
|
||||
err = ErrHandshakeBadUpgrade
|
||||
}
|
||||
|
||||
case headerConnectionCanonical:
|
||||
headerSeen |= headerSeenConnection
|
||||
if !bytes.Equal(v, specHeaderValueConnection) && !btsHasToken(v, specHeaderValueConnectionLower) {
|
||||
err = ErrHandshakeBadConnection
|
||||
}
|
||||
|
||||
case headerSecVersionCanonical:
|
||||
headerSeen |= headerSeenSecVersion
|
||||
if !bytes.Equal(v, specHeaderValueSecVersion) {
|
||||
err = ErrHandshakeUpgradeRequired
|
||||
}
|
||||
|
||||
case headerSecKeyCanonical:
|
||||
headerSeen |= headerSeenSecKey
|
||||
if len(v) != nonceSize {
|
||||
err = ErrHandshakeBadSecKey
|
||||
} else {
|
||||
copy(nonce[:], v)
|
||||
}
|
||||
|
||||
case headerSecProtocolCanonical:
|
||||
if custom, check := u.ProtocolCustom, u.Protocol; hs.Protocol == "" && (custom != nil || check != nil) {
|
||||
var ok bool
|
||||
if custom != nil {
|
||||
hs.Protocol, ok = custom(v)
|
||||
} else {
|
||||
hs.Protocol, ok = btsSelectProtocol(v, check)
|
||||
}
|
||||
if !ok {
|
||||
err = ErrMalformedRequest
|
||||
}
|
||||
}
|
||||
|
||||
case headerSecExtensionsCanonical:
|
||||
if custom, check := u.ExtensionCustom, u.Extension; custom != nil || check != nil {
|
||||
var ok bool
|
||||
if custom != nil {
|
||||
hs.Extensions, ok = custom(v, hs.Extensions)
|
||||
} else {
|
||||
hs.Extensions, ok = btsSelectExtensions(v, hs.Extensions, check)
|
||||
}
|
||||
if !ok {
|
||||
err = ErrMalformedRequest
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
if onHeader := u.OnHeader; onHeader != nil {
|
||||
err = onHeader(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case err == nil && headerSeen != headerSeenAll:
|
||||
switch {
|
||||
case headerSeen&headerSeenHost == 0:
|
||||
// As RFC2616 says:
|
||||
// A client MUST include a Host header field in all HTTP/1.1
|
||||
// request messages. If the requested URI does not include an
|
||||
// Internet host name for the service being requested, then the
|
||||
// Host header field MUST be given with an empty value. An
|
||||
// HTTP/1.1 proxy MUST ensure that any request message it
|
||||
// forwards does contain an appropriate Host header field that
|
||||
// identifies the service being requested by the proxy. All
|
||||
// Internet-based HTTP/1.1 servers MUST respond with a 400 (Bad
|
||||
// Request) status code to any HTTP/1.1 request message which
|
||||
// lacks a Host header field.
|
||||
err = ErrHandshakeBadHost
|
||||
case headerSeen&headerSeenUpgrade == 0:
|
||||
err = ErrHandshakeBadUpgrade
|
||||
case headerSeen&headerSeenConnection == 0:
|
||||
err = ErrHandshakeBadConnection
|
||||
case headerSeen&headerSeenSecVersion == 0:
|
||||
// In case of empty or not present version we do not send 426 status,
|
||||
// because it does not meet the ABNF rules of RFC6455:
|
||||
//
|
||||
// version = DIGIT | (NZDIGIT DIGIT) |
|
||||
// ("1" DIGIT DIGIT) | ("2" DIGIT DIGIT)
|
||||
// ; Limited to 0-255 range, with no leading zeros
|
||||
//
|
||||
// That is, if version is really invalid – we sent 426 status as above, if it
|
||||
// not present – it is 400.
|
||||
err = ErrHandshakeBadSecVersion
|
||||
case headerSeen&headerSeenSecKey == 0:
|
||||
err = ErrHandshakeBadSecKey
|
||||
default:
|
||||
panic("unknown headers state")
|
||||
}
|
||||
|
||||
case err == nil && u.OnBeforeUpgrade != nil:
|
||||
header[1], err = u.OnBeforeUpgrade()
|
||||
}
|
||||
if err != nil {
|
||||
var code int
|
||||
if rej, ok := err.(*rejectConnectionError); ok {
|
||||
code = rej.code
|
||||
header[1] = rej.header
|
||||
}
|
||||
if code == 0 {
|
||||
code = http.StatusInternalServerError
|
||||
}
|
||||
httpWriteResponseError(bw, err, code, header.WriteTo)
|
||||
// Do not store Flush() error to not override already existing one.
|
||||
bw.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
httpWriteResponseUpgrade(bw, nonce, hs, header.WriteTo)
|
||||
err = bw.Flush()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type handshakeHeader [2]HandshakeHeader
|
||||
|
||||
func (hs handshakeHeader) WriteTo(w io.Writer) (n int64, err error) {
|
||||
for i := 0; i < len(hs) && err == nil; i++ {
|
||||
if h := hs[i]; h != nil {
|
||||
var m int64
|
||||
m, err = h.WriteTo(w)
|
||||
n += m
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gobwas/httphead"
|
||||
)
|
||||
|
||||
// SelectFromSlice creates accept function that could be used as Protocol/Extension
|
||||
// select during upgrade.
|
||||
func SelectFromSlice(accept []string) func(string) bool {
|
||||
if len(accept) > 16 {
|
||||
mp := make(map[string]struct{}, len(accept))
|
||||
for _, p := range accept {
|
||||
mp[p] = struct{}{}
|
||||
}
|
||||
return func(p string) bool {
|
||||
_, ok := mp[p]
|
||||
return ok
|
||||
}
|
||||
}
|
||||
return func(p string) bool {
|
||||
for _, ok := range accept {
|
||||
if p == ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SelectEqual creates accept function that could be used as Protocol/Extension
|
||||
// select during upgrade.
|
||||
func SelectEqual(v string) func(string) bool {
|
||||
return func(p string) bool {
|
||||
return v == p
|
||||
}
|
||||
}
|
||||
|
||||
func strToBytes(str string) (bts []byte) {
|
||||
s := (*reflect.StringHeader)(unsafe.Pointer(&str))
|
||||
b := (*reflect.SliceHeader)(unsafe.Pointer(&bts))
|
||||
b.Data = s.Data
|
||||
b.Len = s.Len
|
||||
b.Cap = s.Len
|
||||
return
|
||||
}
|
||||
|
||||
func btsToString(bts []byte) (str string) {
|
||||
return *(*string)(unsafe.Pointer(&bts))
|
||||
}
|
||||
|
||||
// asciiToInt converts bytes to int.
|
||||
func asciiToInt(bts []byte) (ret int, err error) {
|
||||
// ASCII numbers all start with the high-order bits 0011.
|
||||
// If you see that, and the next bits are 0-9 (0000 - 1001) you can grab those
|
||||
// bits and interpret them directly as an integer.
|
||||
var n int
|
||||
if n = len(bts); n < 1 {
|
||||
return 0, fmt.Errorf("converting empty bytes to int")
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
if bts[i]&0xf0 != 0x30 {
|
||||
return 0, fmt.Errorf("%s is not a numeric character", string(bts[i]))
|
||||
}
|
||||
ret += int(bts[i]&0xf) * pow(10, n-i-1)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// pow for integers implementation.
|
||||
// See Donald Knuth, The Art of Computer Programming, Volume 2, Section 4.6.3
|
||||
func pow(a, b int) int {
|
||||
p := 1
|
||||
for b > 0 {
|
||||
if b&1 != 0 {
|
||||
p *= a
|
||||
}
|
||||
b >>= 1
|
||||
a *= a
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func bsplit3(bts []byte, sep byte) (b1, b2, b3 []byte) {
|
||||
a := bytes.IndexByte(bts, sep)
|
||||
b := bytes.IndexByte(bts[a+1:], sep)
|
||||
if a == -1 || b == -1 {
|
||||
return bts, nil, nil
|
||||
}
|
||||
b += a + 1
|
||||
return bts[:a], bts[a+1 : b], bts[b+1:]
|
||||
}
|
||||
|
||||
func btrim(bts []byte) []byte {
|
||||
var i, j int
|
||||
for i = 0; i < len(bts) && (bts[i] == ' ' || bts[i] == '\t'); {
|
||||
i++
|
||||
}
|
||||
for j = len(bts); j > i && (bts[j-1] == ' ' || bts[j-1] == '\t'); {
|
||||
j--
|
||||
}
|
||||
return bts[i:j]
|
||||
}
|
||||
|
||||
func strHasToken(header, token string) (has bool) {
|
||||
return btsHasToken(strToBytes(header), strToBytes(token))
|
||||
}
|
||||
|
||||
func btsHasToken(header, token []byte) (has bool) {
|
||||
httphead.ScanTokens(header, func(v []byte) bool {
|
||||
has = bytes.EqualFold(v, token)
|
||||
return !has
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
toLower = 'a' - 'A' // for use with OR.
|
||||
toUpper = ^byte(toLower) // for use with AND.
|
||||
toLower8 = uint64(toLower) |
|
||||
uint64(toLower)<<8 |
|
||||
uint64(toLower)<<16 |
|
||||
uint64(toLower)<<24 |
|
||||
uint64(toLower)<<32 |
|
||||
uint64(toLower)<<40 |
|
||||
uint64(toLower)<<48 |
|
||||
uint64(toLower)<<56
|
||||
)
|
||||
|
||||
// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except
|
||||
// that it operates with slice of bytes and modifies it inplace without copying.
|
||||
func canonicalizeHeaderKey(k []byte) {
|
||||
upper := true
|
||||
for i, c := range k {
|
||||
if upper && 'a' <= c && c <= 'z' {
|
||||
k[i] &= toUpper
|
||||
} else if !upper && 'A' <= c && c <= 'Z' {
|
||||
k[i] |= toLower
|
||||
}
|
||||
upper = c == '-'
|
||||
}
|
||||
}
|
||||
|
||||
// readLine reads line from br. It reads until '\n' and returns bytes without
|
||||
// '\n' or '\r\n' at the end.
|
||||
// It returns err if and only if line does not end in '\n'. Note that read
|
||||
// bytes returned in any case of error.
|
||||
//
|
||||
// It is much like the textproto/Reader.ReadLine() except the thing that it
|
||||
// returns raw bytes, instead of string. That is, it avoids copying bytes read
|
||||
// from br.
|
||||
//
|
||||
// textproto/Reader.ReadLineBytes() is also makes copy of resulting bytes to be
|
||||
// safe with future I/O operations on br.
|
||||
//
|
||||
// We could control I/O operations on br and do not need to make additional
|
||||
// copy for safety.
|
||||
//
|
||||
// NOTE: it may return copied flag to notify that returned buffer is safe to
|
||||
// use.
|
||||
func readLine(br *bufio.Reader) ([]byte, error) {
|
||||
var line []byte
|
||||
for {
|
||||
bts, err := br.ReadSlice('\n')
|
||||
if err == bufio.ErrBufferFull {
|
||||
// Copy bytes because next read will discard them.
|
||||
line = append(line, bts...)
|
||||
continue
|
||||
}
|
||||
|
||||
// Avoid copy of single read.
|
||||
if line == nil {
|
||||
line = bts
|
||||
} else {
|
||||
line = append(line, bts...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return line, err
|
||||
}
|
||||
|
||||
// Size of line is at least 1.
|
||||
// In other case bufio.ReadSlice() returns error.
|
||||
n := len(line)
|
||||
|
||||
// Cut '\n' or '\r\n'.
|
||||
if n > 1 && line[n-2] == '\r' {
|
||||
line = line[:n-2]
|
||||
} else {
|
||||
line = line[:n-1]
|
||||
}
|
||||
|
||||
return line, nil
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func nonZero(a, b int) int {
|
||||
if a != 0 {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
package ws
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Header size length bounds in bytes.
|
||||
const (
|
||||
MaxHeaderSize = 14
|
||||
MinHeaderSize = 2
|
||||
)
|
||||
|
||||
const (
|
||||
bit0 = 0x80
|
||||
bit1 = 0x40
|
||||
bit2 = 0x20
|
||||
bit3 = 0x10
|
||||
bit4 = 0x08
|
||||
bit5 = 0x04
|
||||
bit6 = 0x02
|
||||
bit7 = 0x01
|
||||
|
||||
len7 = int64(125)
|
||||
len16 = int64(^(uint16(0)))
|
||||
len64 = int64(^(uint64(0)) >> 1)
|
||||
)
|
||||
|
||||
// HeaderSize returns number of bytes that are needed to encode given header.
|
||||
// It returns -1 if header is malformed.
|
||||
func HeaderSize(h Header) (n int) {
|
||||
switch {
|
||||
case h.Length < 126:
|
||||
n = 2
|
||||
case h.Length <= len16:
|
||||
n = 4
|
||||
case h.Length <= len64:
|
||||
n = 10
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
if h.Masked {
|
||||
n += len(h.Mask)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// WriteHeader writes header binary representation into w.
|
||||
func WriteHeader(w io.Writer, h Header) error {
|
||||
// Make slice of bytes with capacity 14 that could hold any header.
|
||||
bts := make([]byte, MaxHeaderSize)
|
||||
|
||||
if h.Fin {
|
||||
bts[0] |= bit0
|
||||
}
|
||||
bts[0] |= h.Rsv << 4
|
||||
bts[0] |= byte(h.OpCode)
|
||||
|
||||
var n int
|
||||
switch {
|
||||
case h.Length <= len7:
|
||||
bts[1] = byte(h.Length)
|
||||
n = 2
|
||||
|
||||
case h.Length <= len16:
|
||||
bts[1] = 126
|
||||
binary.BigEndian.PutUint16(bts[2:4], uint16(h.Length))
|
||||
n = 4
|
||||
|
||||
case h.Length <= len64:
|
||||
bts[1] = 127
|
||||
binary.BigEndian.PutUint64(bts[2:10], uint64(h.Length))
|
||||
n = 10
|
||||
|
||||
default:
|
||||
return ErrHeaderLengthUnexpected
|
||||
}
|
||||
|
||||
if h.Masked {
|
||||
bts[1] |= bit0
|
||||
n += copy(bts[n:], h.Mask[:])
|
||||
}
|
||||
|
||||
_, err := w.Write(bts[:n])
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteFrame writes frame binary representation into w.
|
||||
func WriteFrame(w io.Writer, f Frame) error {
|
||||
err := WriteHeader(w, f.Header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(f.Payload)
|
||||
return err
|
||||
}
|
||||
|
||||
// MustWriteFrame is like WriteFrame but panics if frame can not be read.
|
||||
func MustWriteFrame(w io.Writer, f Frame) {
|
||||
if err := WriteFrame(w, f); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/gobwas/pool/pbytes"
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// CipherReader implements io.Reader that applies xor-cipher to the bytes read
|
||||
// from source.
|
||||
// It could help to unmask WebSocket frame payload on the fly.
|
||||
type CipherReader struct {
|
||||
r io.Reader
|
||||
mask [4]byte
|
||||
pos int
|
||||
}
|
||||
|
||||
// NewCipherReader creates xor-cipher reader from r with given mask.
|
||||
func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader {
|
||||
return &CipherReader{r, mask, 0}
|
||||
}
|
||||
|
||||
// Reset resets CipherReader to read from r with given mask.
|
||||
func (c *CipherReader) Reset(r io.Reader, mask [4]byte) {
|
||||
c.r = r
|
||||
c.mask = mask
|
||||
c.pos = 0
|
||||
}
|
||||
|
||||
// Read implements io.Reader interface. It applies mask given during
|
||||
// initialization to every read byte.
|
||||
func (c *CipherReader) Read(p []byte) (n int, err error) {
|
||||
n, err = c.r.Read(p)
|
||||
ws.Cipher(p[:n], c.mask, c.pos)
|
||||
c.pos += n
|
||||
return
|
||||
}
|
||||
|
||||
// CipherWriter implements io.Writer that applies xor-cipher to the bytes
|
||||
// written to the destination writer. It does not modify the original bytes.
|
||||
type CipherWriter struct {
|
||||
w io.Writer
|
||||
mask [4]byte
|
||||
pos int
|
||||
}
|
||||
|
||||
// NewCipherWriter creates xor-cipher writer to w with given mask.
|
||||
func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter {
|
||||
return &CipherWriter{w, mask, 0}
|
||||
}
|
||||
|
||||
// Reset reset CipherWriter to write to w with given mask.
|
||||
func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) {
|
||||
c.w = w
|
||||
c.mask = mask
|
||||
c.pos = 0
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface. It applies masking during
|
||||
// initialization to every sent byte. It does not modify original slice.
|
||||
func (c *CipherWriter) Write(p []byte) (n int, err error) {
|
||||
cp := pbytes.GetLen(len(p))
|
||||
defer pbytes.Put(cp)
|
||||
|
||||
copy(cp, p)
|
||||
ws.Cipher(cp, c.mask, c.pos)
|
||||
n, err = c.w.Write(cp)
|
||||
c.pos += n
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// DebugDialer is a wrapper around ws.Dialer. It tracks i/o of WebSocket
|
||||
// handshake. That is, it gives ability to receive copied HTTP request and
|
||||
// response bytes that made inside Dialer.Dial().
|
||||
//
|
||||
// Note that it must not be used in production applications that requires
|
||||
// Dial() to be efficient.
|
||||
type DebugDialer struct {
|
||||
// Dialer contains WebSocket connection establishment options.
|
||||
Dialer ws.Dialer
|
||||
|
||||
// OnRequest and OnResponse are the callbacks that will be called with the
|
||||
// HTTP request and response respectively.
|
||||
OnRequest, OnResponse func([]byte)
|
||||
}
|
||||
|
||||
// Dial connects to the url host and upgrades connection to WebSocket. It makes
|
||||
// it by calling d.Dialer.Dial().
|
||||
func (d *DebugDialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs ws.Handshake, err error) {
|
||||
// Need to copy Dialer to prevent original object mutation.
|
||||
dialer := d.Dialer
|
||||
var (
|
||||
reqBuf bytes.Buffer
|
||||
resBuf bytes.Buffer
|
||||
|
||||
resContentLength int64
|
||||
)
|
||||
userWrap := dialer.WrapConn
|
||||
dialer.WrapConn = func(c net.Conn) net.Conn {
|
||||
if userWrap != nil {
|
||||
c = userWrap(c)
|
||||
}
|
||||
|
||||
// Save the pointer to the raw connection.
|
||||
conn = c
|
||||
|
||||
var (
|
||||
r io.Reader = conn
|
||||
w io.Writer = conn
|
||||
)
|
||||
if d.OnResponse != nil {
|
||||
r = &prefetchResponseReader{
|
||||
source: conn,
|
||||
buffer: &resBuf,
|
||||
contentLength: &resContentLength,
|
||||
}
|
||||
}
|
||||
if d.OnRequest != nil {
|
||||
w = io.MultiWriter(conn, &reqBuf)
|
||||
}
|
||||
return rwConn{conn, r, w}
|
||||
}
|
||||
|
||||
_, br, hs, err = dialer.Dial(ctx, urlstr)
|
||||
|
||||
if onRequest := d.OnRequest; onRequest != nil {
|
||||
onRequest(reqBuf.Bytes())
|
||||
}
|
||||
if onResponse := d.OnResponse; onResponse != nil {
|
||||
// We must split response inside buffered bytes from other received
|
||||
// bytes from server.
|
||||
p := resBuf.Bytes()
|
||||
n := bytes.Index(p, headEnd)
|
||||
h := n + len(headEnd) // Head end index.
|
||||
n = h + int(resContentLength) // Body end index.
|
||||
|
||||
onResponse(p[:n])
|
||||
|
||||
if br != nil {
|
||||
// If br is non-nil, then it mean two things. First is that
|
||||
// handshake is OK and server has sent additional bytes – probably
|
||||
// immediate sent frames (or weird but possible response body).
|
||||
// Second, the bad one, is that br buffer's source is now rwConn
|
||||
// instance from above WrapConn call. It is incorrect, so we must
|
||||
// fix it.
|
||||
var r io.Reader = conn
|
||||
if len(p) > h {
|
||||
// Buffer contains more than just HTTP headers bytes.
|
||||
r = io.MultiReader(
|
||||
bytes.NewReader(p[h:]),
|
||||
conn,
|
||||
)
|
||||
}
|
||||
br.Reset(r)
|
||||
// Must make br.Buffered() to be non-zero.
|
||||
br.Peek(len(p[h:]))
|
||||
}
|
||||
}
|
||||
|
||||
return conn, br, hs, err
|
||||
}
|
||||
|
||||
type rwConn struct {
|
||||
net.Conn
|
||||
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (rwc rwConn) Read(p []byte) (int, error) {
|
||||
return rwc.r.Read(p)
|
||||
}
|
||||
func (rwc rwConn) Write(p []byte) (int, error) {
|
||||
return rwc.w.Write(p)
|
||||
}
|
||||
|
||||
var headEnd = []byte("\r\n\r\n")
|
||||
|
||||
type prefetchResponseReader struct {
|
||||
source io.Reader // Original connection source.
|
||||
reader io.Reader // Wrapped reader used to read from by clients.
|
||||
buffer *bytes.Buffer
|
||||
|
||||
contentLength *int64
|
||||
}
|
||||
|
||||
func (r *prefetchResponseReader) Read(p []byte) (int, error) {
|
||||
if r.reader == nil {
|
||||
resp, err := http.ReadResponse(bufio.NewReader(
|
||||
io.TeeReader(r.source, r.buffer),
|
||||
), nil)
|
||||
if err == nil {
|
||||
*r.contentLength, _ = io.Copy(ioutil.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
}
|
||||
bts := r.buffer.Bytes()
|
||||
r.reader = io.MultiReader(
|
||||
bytes.NewReader(bts),
|
||||
r.source,
|
||||
)
|
||||
}
|
||||
return r.reader.Read(p)
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
|
||||
"github.com/gobwas/pool/pbytes"
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// ClosedError returned when peer has closed the connection with appropriate
|
||||
// code and a textual reason.
|
||||
type ClosedError struct {
|
||||
Code ws.StatusCode
|
||||
Reason string
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (err ClosedError) Error() string {
|
||||
return "ws closed: " + strconv.FormatUint(uint64(err.Code), 10) + " " + err.Reason
|
||||
}
|
||||
|
||||
// ControlHandler contains logic of handling control frames.
|
||||
//
|
||||
// The intentional way to use it is to read the next frame header from the
|
||||
// connection, optionally check its validity via ws.CheckHeader() and if it is
|
||||
// not a ws.OpText of ws.OpBinary (or ws.OpContinuation) – pass it to Handle()
|
||||
// method.
|
||||
//
|
||||
// That is, passed header should be checked to get rid of unexpected errors.
|
||||
//
|
||||
// The Handle() method will read out all control frame payload (if any) and
|
||||
// write necessary bytes as a rfc compatible response.
|
||||
type ControlHandler struct {
|
||||
Src io.Reader
|
||||
Dst io.Writer
|
||||
State ws.State
|
||||
|
||||
// DisableSrcCiphering disables unmasking payload data read from Src.
|
||||
// It is useful when wsutil.Reader is used or when frame payload already
|
||||
// pulled and ciphered out from the connection (and introduced by
|
||||
// bytes.Reader, for example).
|
||||
DisableSrcCiphering bool
|
||||
}
|
||||
|
||||
// ErrNotControlFrame is returned by ControlHandler to indicate that given
|
||||
// header could not be handled.
|
||||
var ErrNotControlFrame = errors.New("not a control frame")
|
||||
|
||||
// Handle handles control frames regarding to the c.State and writes responses
|
||||
// to the c.Dst when needed.
|
||||
//
|
||||
// It returns ErrNotControlFrame when given header is not of ws.OpClose,
|
||||
// ws.OpPing or ws.OpPong operation code.
|
||||
func (c ControlHandler) Handle(h ws.Header) error {
|
||||
switch h.OpCode {
|
||||
case ws.OpPing:
|
||||
return c.HandlePing(h)
|
||||
case ws.OpPong:
|
||||
return c.HandlePong(h)
|
||||
case ws.OpClose:
|
||||
return c.HandleClose(h)
|
||||
}
|
||||
return ErrNotControlFrame
|
||||
}
|
||||
|
||||
// HandlePing handles ping frame and writes specification compatible response
|
||||
// to the c.Dst.
|
||||
func (c ControlHandler) HandlePing(h ws.Header) error {
|
||||
if h.Length == 0 {
|
||||
// The most common case when ping is empty.
|
||||
// Note that when sending masked frame the mask for empty payload is
|
||||
// just four zero bytes.
|
||||
return ws.WriteHeader(c.Dst, ws.Header{
|
||||
Fin: true,
|
||||
OpCode: ws.OpPong,
|
||||
Masked: c.State.ClientSide(),
|
||||
})
|
||||
}
|
||||
|
||||
// In other way reply with Pong frame with copied payload.
|
||||
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
|
||||
Length: h.Length,
|
||||
Masked: c.State.ClientSide(),
|
||||
}))
|
||||
defer pbytes.Put(p)
|
||||
|
||||
// Deal with ciphering i/o:
|
||||
// Masking key is used to mask the "Payload data" defined in the same
|
||||
// section as frame-payload-data, which includes "Extension data" and
|
||||
// "Application data".
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
//
|
||||
// NOTE: We prefer ControlWriter with preallocated buffer to
|
||||
// ws.WriteHeader because it performs one syscall instead of two.
|
||||
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpPong, p)
|
||||
r := c.Src
|
||||
if c.State.ServerSide() && !c.DisableSrcCiphering {
|
||||
r = NewCipherReader(r, h.Mask)
|
||||
}
|
||||
|
||||
_, err := io.Copy(w, r)
|
||||
if err == nil {
|
||||
err = w.Flush()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// HandlePong handles pong frame by discarding it.
|
||||
func (c ControlHandler) HandlePong(h ws.Header) error {
|
||||
if h.Length == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
buf := pbytes.GetLen(int(h.Length))
|
||||
defer pbytes.Put(buf)
|
||||
|
||||
// Discard pong message according to the RFC6455:
|
||||
// A Pong frame MAY be sent unsolicited. This serves as a
|
||||
// unidirectional heartbeat. A response to an unsolicited Pong frame
|
||||
// is not expected.
|
||||
_, err := io.CopyBuffer(ioutil.Discard, c.Src, buf)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleClose handles close frame, makes protocol validity checks and writes
|
||||
// specification compatible response to the c.Dst.
|
||||
func (c ControlHandler) HandleClose(h ws.Header) error {
|
||||
if h.Length == 0 {
|
||||
err := ws.WriteHeader(c.Dst, ws.Header{
|
||||
Fin: true,
|
||||
OpCode: ws.OpClose,
|
||||
Masked: c.State.ClientSide(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Due to RFC, we should interpret the code as no status code
|
||||
// received:
|
||||
// If this Close control frame contains no status code, _The WebSocket
|
||||
// Connection Close Code_ is considered to be 1005.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-7.1.5
|
||||
return ClosedError{
|
||||
Code: ws.StatusNoStatusRcvd,
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare bytes both for reading reason and sending response.
|
||||
p := pbytes.GetLen(int(h.Length) + ws.HeaderSize(ws.Header{
|
||||
Length: h.Length,
|
||||
Masked: c.State.ClientSide(),
|
||||
}))
|
||||
defer pbytes.Put(p)
|
||||
|
||||
// Get the subslice to read the frame payload out.
|
||||
subp := p[:h.Length]
|
||||
|
||||
r := c.Src
|
||||
if c.State.ServerSide() && !c.DisableSrcCiphering {
|
||||
r = NewCipherReader(r, h.Mask)
|
||||
}
|
||||
if _, err := io.ReadFull(r, subp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code, reason := ws.ParseCloseFrameData(subp)
|
||||
if err := ws.CheckCloseFrameData(code, reason); err != nil {
|
||||
// Here we could not use the prepared bytes because there is no
|
||||
// guarantee that it may fit our protocol error closure code and a
|
||||
// reason.
|
||||
c.closeWithProtocolError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Deal with ciphering i/o:
|
||||
// Masking key is used to mask the "Payload data" defined in the same
|
||||
// section as frame-payload-data, which includes "Extension data" and
|
||||
// "Application data".
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
//
|
||||
// NOTE: We prefer ControlWriter with preallocated buffer to
|
||||
// ws.WriteHeader because it performs one syscall instead of two.
|
||||
w := NewControlWriterBuffer(c.Dst, c.State, ws.OpClose, p)
|
||||
|
||||
// RFC6455#5.5.1:
|
||||
// If an endpoint receives a Close frame and did not previously
|
||||
// send a Close frame, the endpoint MUST send a Close frame in
|
||||
// response. (When sending a Close frame in response, the endpoint
|
||||
// typically echoes the status code it received.)
|
||||
_, err := w.Write(p[:2])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = w.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ClosedError{
|
||||
Code: code,
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
func (c ControlHandler) closeWithProtocolError(reason error) error {
|
||||
f := ws.NewCloseFrame(ws.NewCloseFrameBody(
|
||||
ws.StatusProtocolError, reason.Error(),
|
||||
))
|
||||
if c.State.ClientSide() {
|
||||
ws.MaskFrameInPlace(f)
|
||||
}
|
||||
return ws.WriteFrame(c.Dst, f)
|
||||
}
|
|
@ -0,0 +1,279 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// Message represents a message from peer, that could be presented in one or
|
||||
// more frames. That is, it contains payload of all message fragments and
|
||||
// operation code of initial frame for this message.
|
||||
type Message struct {
|
||||
OpCode ws.OpCode
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
// ReadMessage is a helper function that reads next message from r. It appends
|
||||
// received message(s) to the third argument and returns the result of it and
|
||||
// an error if some failure happened. That is, it probably could receive more
|
||||
// than one message when peer sending fragmented message in multiple frames and
|
||||
// want to send some control frame between fragments. Then returned slice will
|
||||
// contain those control frames at first, and then result of gluing fragments.
|
||||
//
|
||||
// TODO(gobwas): add DefaultReader with buffer size options.
|
||||
func ReadMessage(r io.Reader, s ws.State, m []Message) ([]Message, error) {
|
||||
rd := Reader{
|
||||
Source: r,
|
||||
State: s,
|
||||
CheckUTF8: true,
|
||||
OnIntermediate: func(hdr ws.Header, src io.Reader) error {
|
||||
bts, err := ioutil.ReadAll(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m = append(m, Message{hdr.OpCode, bts})
|
||||
return nil
|
||||
},
|
||||
}
|
||||
h, err := rd.NextFrame()
|
||||
if err != nil {
|
||||
return m, err
|
||||
}
|
||||
var p []byte
|
||||
if h.Fin {
|
||||
// No more frames will be read. Use fixed sized buffer to read payload.
|
||||
p = make([]byte, h.Length)
|
||||
// It is not possible to receive io.EOF here because Reader does not
|
||||
// return EOF if frame payload was successfully fetched.
|
||||
// Thus we consistent here with io.Reader behavior.
|
||||
_, err = io.ReadFull(&rd, p)
|
||||
} else {
|
||||
// Frame is fragmented, thus use ioutil.ReadAll behavior.
|
||||
var buf bytes.Buffer
|
||||
_, err = buf.ReadFrom(&rd)
|
||||
p = buf.Bytes()
|
||||
}
|
||||
if err != nil {
|
||||
return m, err
|
||||
}
|
||||
return append(m, Message{h.OpCode, p}), nil
|
||||
}
|
||||
|
||||
// ReadClientMessage reads next message from r, considering that caller
|
||||
// represents server side.
|
||||
// It is a shortcut for ReadMessage(r, ws.StateServerSide, m)
|
||||
func ReadClientMessage(r io.Reader, m []Message) ([]Message, error) {
|
||||
return ReadMessage(r, ws.StateServerSide, m)
|
||||
}
|
||||
|
||||
// ReadServerMessage reads next message from r, considering that caller
|
||||
// represents client side.
|
||||
// It is a shortcut for ReadMessage(r, ws.StateClientSide, m)
|
||||
func ReadServerMessage(r io.Reader, m []Message) ([]Message, error) {
|
||||
return ReadMessage(r, ws.StateClientSide, m)
|
||||
}
|
||||
|
||||
// ReadData is a helper function that reads next data (non-control) message
|
||||
// from rw.
|
||||
// It takes care on handling all control frames. It will write response on
|
||||
// control frames to the write part of rw. It blocks until some data frame
|
||||
// will be received.
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a
|
||||
// given io.ReadWriter.
|
||||
func ReadData(rw io.ReadWriter, s ws.State) ([]byte, ws.OpCode, error) {
|
||||
return readData(rw, s, ws.OpText|ws.OpBinary)
|
||||
}
|
||||
|
||||
// ReadClientData reads next data message from rw, considering that caller
|
||||
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a
|
||||
// given io.ReadWriter.
|
||||
func ReadClientData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
|
||||
return ReadData(rw, ws.StateServerSide)
|
||||
}
|
||||
|
||||
// ReadClientText reads next text message from rw, considering that caller
|
||||
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
|
||||
// It discards received binary messages.
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a
|
||||
// given io.ReadWriter.
|
||||
func ReadClientText(rw io.ReadWriter) ([]byte, error) {
|
||||
p, _, err := readData(rw, ws.StateServerSide, ws.OpText)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// ReadClientBinary reads next binary message from rw, considering that caller
|
||||
// represents server side. It is a shortcut for ReadData(rw, ws.StateServerSide).
|
||||
// It discards received text messages.
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a given
|
||||
// io.ReadWriter.
|
||||
func ReadClientBinary(rw io.ReadWriter) ([]byte, error) {
|
||||
p, _, err := readData(rw, ws.StateServerSide, ws.OpBinary)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// ReadServerData reads next data message from rw, considering that caller
|
||||
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a
|
||||
// given io.ReadWriter.
|
||||
func ReadServerData(rw io.ReadWriter) ([]byte, ws.OpCode, error) {
|
||||
return ReadData(rw, ws.StateClientSide)
|
||||
}
|
||||
|
||||
// ReadServerText reads next text message from rw, considering that caller
|
||||
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
|
||||
// It discards received binary messages.
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a given
|
||||
// io.ReadWriter.
|
||||
func ReadServerText(rw io.ReadWriter) ([]byte, error) {
|
||||
p, _, err := readData(rw, ws.StateClientSide, ws.OpText)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// ReadServerBinary reads next binary message from rw, considering that caller
|
||||
// represents client side. It is a shortcut for ReadData(rw, ws.StateClientSide).
|
||||
// It discards received text messages.
|
||||
//
|
||||
// Note this may handle and write control frames into the writer part of a
|
||||
// given io.ReadWriter.
|
||||
func ReadServerBinary(rw io.ReadWriter) ([]byte, error) {
|
||||
p, _, err := readData(rw, ws.StateClientSide, ws.OpBinary)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// WriteMessage is a helper function that writes message to the w. It
|
||||
// constructs single frame with given operation code and payload.
|
||||
// It uses given state to prepare side-dependent things, like cipher
|
||||
// payload bytes from client to server. It will not mutate p bytes if
|
||||
// cipher must be made.
|
||||
//
|
||||
// If you want to write message in fragmented frames, use Writer instead.
|
||||
func WriteMessage(w io.Writer, s ws.State, op ws.OpCode, p []byte) error {
|
||||
return writeFrame(w, s, op, true, p)
|
||||
}
|
||||
|
||||
// WriteServerMessage writes message to w, considering that caller
|
||||
// represents server side.
|
||||
func WriteServerMessage(w io.Writer, op ws.OpCode, p []byte) error {
|
||||
return WriteMessage(w, ws.StateServerSide, op, p)
|
||||
}
|
||||
|
||||
// WriteServerText is the same as WriteServerMessage with
|
||||
// ws.OpText.
|
||||
func WriteServerText(w io.Writer, p []byte) error {
|
||||
return WriteServerMessage(w, ws.OpText, p)
|
||||
}
|
||||
|
||||
// WriteServerBinary is the same as WriteServerMessage with
|
||||
// ws.OpBinary.
|
||||
func WriteServerBinary(w io.Writer, p []byte) error {
|
||||
return WriteServerMessage(w, ws.OpBinary, p)
|
||||
}
|
||||
|
||||
// WriteClientMessage writes message to w, considering that caller
|
||||
// represents client side.
|
||||
func WriteClientMessage(w io.Writer, op ws.OpCode, p []byte) error {
|
||||
return WriteMessage(w, ws.StateClientSide, op, p)
|
||||
}
|
||||
|
||||
// WriteClientText is the same as WriteClientMessage with
|
||||
// ws.OpText.
|
||||
func WriteClientText(w io.Writer, p []byte) error {
|
||||
return WriteClientMessage(w, ws.OpText, p)
|
||||
}
|
||||
|
||||
// WriteClientBinary is the same as WriteClientMessage with
|
||||
// ws.OpBinary.
|
||||
func WriteClientBinary(w io.Writer, p []byte) error {
|
||||
return WriteClientMessage(w, ws.OpBinary, p)
|
||||
}
|
||||
|
||||
// HandleClientControlMessage handles control frame from conn and writes
|
||||
// response when needed.
|
||||
//
|
||||
// It considers that caller represents server side.
|
||||
func HandleClientControlMessage(conn io.Writer, msg Message) error {
|
||||
return HandleControlMessage(conn, ws.StateServerSide, msg)
|
||||
}
|
||||
|
||||
// HandleServerControlMessage handles control frame from conn and writes
|
||||
// response when needed.
|
||||
//
|
||||
// It considers that caller represents client side.
|
||||
func HandleServerControlMessage(conn io.Writer, msg Message) error {
|
||||
return HandleControlMessage(conn, ws.StateClientSide, msg)
|
||||
}
|
||||
|
||||
// HandleControlMessage handles message which was read by ReadMessage()
|
||||
// functions.
|
||||
//
|
||||
// That is, it is expected, that payload is already unmasked and frame header
|
||||
// were checked by ws.CheckHeader() call.
|
||||
func HandleControlMessage(conn io.Writer, state ws.State, msg Message) error {
|
||||
return (ControlHandler{
|
||||
DisableSrcCiphering: true,
|
||||
Src: bytes.NewReader(msg.Payload),
|
||||
Dst: conn,
|
||||
State: state,
|
||||
}).Handle(ws.Header{
|
||||
Length: int64(len(msg.Payload)),
|
||||
OpCode: msg.OpCode,
|
||||
Fin: true,
|
||||
Masked: state.ServerSide(),
|
||||
})
|
||||
}
|
||||
|
||||
// ControlFrameHandler returns FrameHandlerFunc for handling control frames.
|
||||
// For more info see ControlHandler docs.
|
||||
func ControlFrameHandler(w io.Writer, state ws.State) FrameHandlerFunc {
|
||||
return func(h ws.Header, r io.Reader) error {
|
||||
return (ControlHandler{
|
||||
DisableSrcCiphering: true,
|
||||
Src: r,
|
||||
Dst: w,
|
||||
State: state,
|
||||
}).Handle(h)
|
||||
}
|
||||
}
|
||||
|
||||
func readData(rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, ws.OpCode, error) {
|
||||
controlHandler := ControlFrameHandler(rw, s)
|
||||
rd := Reader{
|
||||
Source: rw,
|
||||
State: s,
|
||||
CheckUTF8: true,
|
||||
SkipHeaderCheck: false,
|
||||
OnIntermediate: controlHandler,
|
||||
}
|
||||
for {
|
||||
hdr, err := rd.NextFrame()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if hdr.OpCode.IsControl() {
|
||||
if err := controlHandler(hdr, &rd); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hdr.OpCode&want == 0 {
|
||||
if err := rd.Discard(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
bts, err := ioutil.ReadAll(&rd)
|
||||
|
||||
return bts, hdr.OpCode, err
|
||||
}
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// ErrNoFrameAdvance means that Reader's Read() method was called without
|
||||
// preceding NextFrame() call.
|
||||
var ErrNoFrameAdvance = errors.New("no frame advance")
|
||||
|
||||
// FrameHandlerFunc handles parsed frame header and its body represented by
|
||||
// io.Reader.
|
||||
//
|
||||
// Note that reader represents already unmasked body.
|
||||
type FrameHandlerFunc func(ws.Header, io.Reader) error
|
||||
|
||||
// Reader is a wrapper around source io.Reader which represents WebSocket
|
||||
// connection. It contains options for reading messages from source.
|
||||
//
|
||||
// Reader implements io.Reader, which Read() method reads payload of incoming
|
||||
// WebSocket frames. It also takes care on fragmented frames and possibly
|
||||
// intermediate control frames between them.
|
||||
//
|
||||
// Note that Reader's methods are not goroutine safe.
|
||||
type Reader struct {
|
||||
Source io.Reader
|
||||
State ws.State
|
||||
|
||||
// SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
|
||||
SkipHeaderCheck bool
|
||||
|
||||
// CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
|
||||
// bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
|
||||
CheckUTF8 bool
|
||||
|
||||
// TODO(gobwas): add max frame size limit here.
|
||||
|
||||
OnContinuation FrameHandlerFunc
|
||||
OnIntermediate FrameHandlerFunc
|
||||
|
||||
opCode ws.OpCode // Used to store message op code on fragmentation.
|
||||
frame io.Reader // Used to as frame reader.
|
||||
raw io.LimitedReader // Used to discard frames without cipher.
|
||||
utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
|
||||
}
|
||||
|
||||
// NewReader creates new frame reader that reads from r keeping given state to
|
||||
// make some protocol validity checks when it needed.
|
||||
func NewReader(r io.Reader, s ws.State) *Reader {
|
||||
return &Reader{
|
||||
Source: r,
|
||||
State: s,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientSideReader is a helper function that calls NewReader with r and
|
||||
// ws.StateClientSide.
|
||||
func NewClientSideReader(r io.Reader) *Reader {
|
||||
return NewReader(r, ws.StateClientSide)
|
||||
}
|
||||
|
||||
// NewServerSideReader is a helper function that calls NewReader with r and
|
||||
// ws.StateServerSide.
|
||||
func NewServerSideReader(r io.Reader) *Reader {
|
||||
return NewReader(r, ws.StateServerSide)
|
||||
}
|
||||
|
||||
// Read implements io.Reader. It reads the next message payload into p.
|
||||
// It takes care on fragmented messages.
|
||||
//
|
||||
// The error is io.EOF only if all of message bytes were read.
|
||||
// If an io.EOF happens during reading some but not all the message bytes
|
||||
// Read() returns io.ErrUnexpectedEOF.
|
||||
//
|
||||
// The error is ErrNoFrameAdvance if no NextFrame() call was made before
|
||||
// reading next message bytes.
|
||||
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||
if r.frame == nil {
|
||||
if !r.fragmented() {
|
||||
// Every new Read() must be preceded by NextFrame() call.
|
||||
return 0, ErrNoFrameAdvance
|
||||
}
|
||||
// Read next continuation or intermediate control frame.
|
||||
_, err := r.NextFrame()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if r.frame == nil {
|
||||
// We handled intermediate control and now got nothing to read.
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, err = r.frame.Read(p)
|
||||
if err != nil && err != io.EOF {
|
||||
return
|
||||
}
|
||||
if err == nil && r.raw.N != 0 {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.raw.N != 0:
|
||||
err = io.ErrUnexpectedEOF
|
||||
|
||||
case r.fragmented():
|
||||
err = nil
|
||||
r.resetFragment()
|
||||
|
||||
case r.CheckUTF8 && !r.utf8.Valid():
|
||||
n = r.utf8.Accepted()
|
||||
err = ErrInvalidUTF8
|
||||
|
||||
default:
|
||||
r.reset()
|
||||
err = io.EOF
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Discard discards current message unread bytes.
|
||||
// It discards all frames of fragmented message.
|
||||
func (r *Reader) Discard() (err error) {
|
||||
for {
|
||||
_, err = io.Copy(ioutil.Discard, &r.raw)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if !r.fragmented() {
|
||||
break
|
||||
}
|
||||
if _, err = r.NextFrame(); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
r.reset()
|
||||
return err
|
||||
}
|
||||
|
||||
// NextFrame prepares r to read next message. It returns received frame header
|
||||
// and non-nil error on failure.
|
||||
//
|
||||
// Note that next NextFrame() call must be done after receiving or discarding
|
||||
// all current message bytes.
|
||||
func (r *Reader) NextFrame() (hdr ws.Header, err error) {
|
||||
hdr, err = ws.ReadHeader(r.Source)
|
||||
if err == io.EOF && r.fragmented() {
|
||||
// If we are in fragmented state EOF means that is was totally
|
||||
// unexpected.
|
||||
//
|
||||
// NOTE: This is necessary to prevent callers such that
|
||||
// ioutil.ReadAll to receive some amount of bytes without an error.
|
||||
// ReadAll() ignores an io.EOF error, thus caller may think that
|
||||
// whole message fetched, but actually only part of it.
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err == nil && !r.SkipHeaderCheck {
|
||||
err = ws.CheckHeader(hdr, r.State)
|
||||
}
|
||||
if err != nil {
|
||||
return hdr, err
|
||||
}
|
||||
|
||||
// Save raw reader to use it on discarding frame without ciphering and
|
||||
// other streaming checks.
|
||||
r.raw = io.LimitedReader{r.Source, hdr.Length}
|
||||
|
||||
frame := io.Reader(&r.raw)
|
||||
if hdr.Masked {
|
||||
frame = NewCipherReader(frame, hdr.Mask)
|
||||
}
|
||||
if r.fragmented() {
|
||||
if hdr.OpCode.IsControl() {
|
||||
if cb := r.OnIntermediate; cb != nil {
|
||||
err = cb(hdr, frame)
|
||||
}
|
||||
if err == nil {
|
||||
// Ensure that src is empty.
|
||||
_, err = io.Copy(ioutil.Discard, &r.raw)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else {
|
||||
r.opCode = hdr.OpCode
|
||||
}
|
||||
if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
|
||||
r.utf8.Source = frame
|
||||
frame = &r.utf8
|
||||
}
|
||||
|
||||
// Save reader with ciphering and other streaming checks.
|
||||
r.frame = frame
|
||||
|
||||
if hdr.OpCode == ws.OpContinuation {
|
||||
if cb := r.OnContinuation; cb != nil {
|
||||
err = cb(hdr, frame)
|
||||
}
|
||||
}
|
||||
|
||||
if hdr.Fin {
|
||||
r.State = r.State.Clear(ws.StateFragmented)
|
||||
} else {
|
||||
r.State = r.State.Set(ws.StateFragmented)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Reader) fragmented() bool {
|
||||
return r.State.Fragmented()
|
||||
}
|
||||
|
||||
func (r *Reader) resetFragment() {
|
||||
r.raw = io.LimitedReader{}
|
||||
r.frame = nil
|
||||
// Reset source of the UTF8Reader, but not the state.
|
||||
r.utf8.Source = nil
|
||||
}
|
||||
|
||||
func (r *Reader) reset() {
|
||||
r.raw = io.LimitedReader{}
|
||||
r.frame = nil
|
||||
r.utf8 = UTF8Reader{}
|
||||
r.opCode = 0
|
||||
}
|
||||
|
||||
// NextReader prepares next message read from r. It returns header that
|
||||
// describes the message and io.Reader to read message's payload. It returns
|
||||
// non-nil error when it is not possible to read message's initial frame.
|
||||
//
|
||||
// Note that next NextReader() on the same r should be done after reading all
|
||||
// bytes from previously returned io.Reader. For more performant way to discard
|
||||
// message use Reader and its Discard() method.
|
||||
//
|
||||
// Note that it will not handle any "intermediate" frames, that possibly could
|
||||
// be received between text/binary continuation frames. That is, if peer sent
|
||||
// text/binary frame with fin flag "false", then it could send ping frame, and
|
||||
// eventually remaining part of text/binary frame with fin "true" – with
|
||||
// NextReader() the ping frame will be dropped without any notice. To handle
|
||||
// this rare, but possible situation (and if you do not know exactly which
|
||||
// frames peer could send), you could use Reader with OnIntermediate field set.
|
||||
func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
|
||||
rd := &Reader{
|
||||
Source: r,
|
||||
State: s,
|
||||
}
|
||||
header, err := rd.NextFrame()
|
||||
if err != nil {
|
||||
return header, nil, err
|
||||
}
|
||||
return header, rd, nil
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// DebugUpgrader is a wrapper around ws.Upgrader. It tracks I/O of a
|
||||
// WebSocket handshake.
|
||||
//
|
||||
// Note that it must not be used in production applications that requires
|
||||
// Upgrade() to be efficient.
|
||||
type DebugUpgrader struct {
|
||||
// Upgrader contains upgrade to WebSocket options.
|
||||
Upgrader ws.Upgrader
|
||||
|
||||
// OnRequest and OnResponse are the callbacks that will be called with the
|
||||
// HTTP request and response respectively.
|
||||
OnRequest, OnResponse func([]byte)
|
||||
}
|
||||
|
||||
// Upgrade calls Upgrade() on underlying ws.Upgrader and tracks I/O on conn.
|
||||
func (d *DebugUpgrader) Upgrade(conn io.ReadWriter) (hs ws.Handshake, err error) {
|
||||
var (
|
||||
// Take the Reader and Writer parts from conn to be probably replaced
|
||||
// below.
|
||||
r io.Reader = conn
|
||||
w io.Writer = conn
|
||||
)
|
||||
if onRequest := d.OnRequest; onRequest != nil {
|
||||
var buf bytes.Buffer
|
||||
// First, we must read the entire request.
|
||||
req, err := http.ReadRequest(bufio.NewReader(
|
||||
io.TeeReader(conn, &buf),
|
||||
))
|
||||
if err == nil {
|
||||
// Fulfill the buffer with the response body.
|
||||
io.Copy(ioutil.Discard, req.Body)
|
||||
req.Body.Close()
|
||||
}
|
||||
onRequest(buf.Bytes())
|
||||
|
||||
r = io.MultiReader(
|
||||
&buf, conn,
|
||||
)
|
||||
}
|
||||
|
||||
if onResponse := d.OnResponse; onResponse != nil {
|
||||
var buf bytes.Buffer
|
||||
// Intercept the response stream written by the Upgrade().
|
||||
w = io.MultiWriter(
|
||||
conn, &buf,
|
||||
)
|
||||
defer func() {
|
||||
onResponse(buf.Bytes())
|
||||
}()
|
||||
}
|
||||
|
||||
return d.Upgrader.Upgrade(struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{r, w})
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrInvalidUTF8 is returned by UTF8 reader on invalid utf8 sequence.
|
||||
var ErrInvalidUTF8 = fmt.Errorf("invalid utf8")
|
||||
|
||||
// UTF8Reader implements io.Reader that calculates utf8 validity state after
|
||||
// every read byte from Source.
|
||||
//
|
||||
// Note that in some cases client must call r.Valid() after all bytes are read
|
||||
// to ensure that all of them are valid utf8 sequences. That is, some io helper
|
||||
// functions such io.ReadAtLeast or io.ReadFull could discard the error
|
||||
// information returned by the reader when they receive all of requested bytes.
|
||||
// For example, the last read sequence is invalid and UTF8Reader returns number
|
||||
// of bytes read and an error. But helper function decides to discard received
|
||||
// error due to all requested bytes are completely read from the source.
|
||||
//
|
||||
// Another possible case is when some valid sequence become split by the read
|
||||
// bound. Then UTF8Reader can not make decision about validity of the last
|
||||
// sequence cause it is not fully read yet. And if the read stops, Valid() will
|
||||
// return false, even if Read() by itself dit not.
|
||||
type UTF8Reader struct {
|
||||
Source io.Reader
|
||||
|
||||
accepted int
|
||||
|
||||
state uint32
|
||||
codep uint32
|
||||
}
|
||||
|
||||
// NewUTF8Reader creates utf8 reader that reads from r.
|
||||
func NewUTF8Reader(r io.Reader) *UTF8Reader {
|
||||
return &UTF8Reader{
|
||||
Source: r,
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets utf8 reader to read from r.
|
||||
func (u *UTF8Reader) Reset(r io.Reader) {
|
||||
u.Source = r
|
||||
u.state = 0
|
||||
u.codep = 0
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (u *UTF8Reader) Read(p []byte) (n int, err error) {
|
||||
n, err = u.Source.Read(p)
|
||||
|
||||
accepted := 0
|
||||
s, c := u.state, u.codep
|
||||
for i := 0; i < n; i++ {
|
||||
c, s = decode(s, c, p[i])
|
||||
if s == utf8Reject {
|
||||
u.state = s
|
||||
return accepted, ErrInvalidUTF8
|
||||
}
|
||||
if s == utf8Accept {
|
||||
accepted = i + 1
|
||||
}
|
||||
}
|
||||
u.state, u.codep = s, c
|
||||
u.accepted = accepted
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Valid checks current reader state. It returns true if all read bytes are
|
||||
// valid UTF-8 sequences, and false if not.
|
||||
func (u *UTF8Reader) Valid() bool {
|
||||
return u.state == utf8Accept
|
||||
}
|
||||
|
||||
// Accepted returns number of valid bytes in last Read().
|
||||
func (u *UTF8Reader) Accepted() int {
|
||||
return u.accepted
|
||||
}
|
||||
|
||||
// Below is port of UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
|
||||
//
|
||||
// Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de>
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to
|
||||
// deal in the Software without restriction, including without limitation the
|
||||
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||
// sell copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
// IN THE SOFTWARE.
|
||||
|
||||
const (
|
||||
utf8Accept = 0
|
||||
utf8Reject = 12
|
||||
)
|
||||
|
||||
var utf8d = [...]byte{
|
||||
// The first part of the table maps bytes to character classes that
|
||||
// to reduce the size of the transition table and create bitmasks.
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
|
||||
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
||||
8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
|
||||
|
||||
// The second part is a transition table that maps a combination
|
||||
// of a state of the automaton and a character class to a state.
|
||||
0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||
12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12,
|
||||
12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12,
|
||||
12, 12, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12,
|
||||
12, 36, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
|
||||
}
|
||||
|
||||
func decode(state, codep uint32, b byte) (uint32, uint32) {
|
||||
t := uint32(utf8d[b])
|
||||
|
||||
if state != utf8Accept {
|
||||
codep = (uint32(b) & 0x3f) | (codep << 6)
|
||||
} else {
|
||||
codep = (0xff >> t) & uint32(b)
|
||||
}
|
||||
|
||||
return codep, uint32(utf8d[256+state+t])
|
||||
}
|
|
@ -0,0 +1,450 @@
|
|||
package wsutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gobwas/pool"
|
||||
"github.com/gobwas/pool/pbytes"
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
// DefaultWriteBuffer contains size of Writer's default buffer. It used by
|
||||
// Writer constructor functions.
|
||||
var DefaultWriteBuffer = 4096
|
||||
|
||||
var (
|
||||
// ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
|
||||
// not empty and write through could not be done. That is, caller should call
|
||||
// Writer.FlushFragment() to make buffer empty.
|
||||
ErrNotEmpty = fmt.Errorf("writer not empty")
|
||||
|
||||
// ErrControlOverflow is returned by ControlWriter.Write() to indicate that
|
||||
// no more data could be written to the underlying io.Writer because
|
||||
// MaxControlFramePayloadSize limit is reached.
|
||||
ErrControlOverflow = fmt.Errorf("control frame payload overflow")
|
||||
)
|
||||
|
||||
// Constants which are represent frame length ranges.
|
||||
const (
|
||||
len7 = int64(125) // 126 and 127 are reserved values
|
||||
len16 = int64(^uint16(0))
|
||||
len64 = int64((^uint64(0)) >> 1)
|
||||
)
|
||||
|
||||
// ControlWriter is a wrapper around Writer that contains some guards for
|
||||
// buffered writes of control frames.
|
||||
type ControlWriter struct {
|
||||
w *Writer
|
||||
limit int
|
||||
n int
|
||||
}
|
||||
|
||||
// NewControlWriter contains ControlWriter with Writer inside whose buffer size
|
||||
// is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
|
||||
func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
|
||||
return &ControlWriter{
|
||||
w: NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
|
||||
limit: ws.MaxControlFramePayloadSize,
|
||||
}
|
||||
}
|
||||
|
||||
// NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
|
||||
//
|
||||
// Note that it reserves x bytes of buf for header data, where x could be
|
||||
// ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
|
||||
// (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
|
||||
//
|
||||
// It panics if len(buf) <= ws.MinHeaderSize + x.
|
||||
func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
|
||||
max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
|
||||
if len(buf) > max {
|
||||
buf = buf[:max]
|
||||
}
|
||||
|
||||
w := NewWriterBuffer(dest, state, op, buf)
|
||||
|
||||
return &ControlWriter{
|
||||
w: w,
|
||||
limit: len(w.buf),
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer. It writes to the underlying Writer until it
|
||||
// returns error or until ControlWriter write limit will be exceeded.
|
||||
func (c *ControlWriter) Write(p []byte) (n int, err error) {
|
||||
if c.n+len(p) > c.limit {
|
||||
return 0, ErrControlOverflow
|
||||
}
|
||||
return c.w.Write(p)
|
||||
}
|
||||
|
||||
// Flush flushes all buffered data to the underlying io.Writer.
|
||||
func (c *ControlWriter) Flush() error {
|
||||
return c.w.Flush()
|
||||
}
|
||||
|
||||
// Writer contains logic of buffering output data into a WebSocket fragments.
|
||||
// It is much the same as bufio.Writer, except the thing that it works with
|
||||
// WebSocket frames, not the raw data.
|
||||
//
|
||||
// Writer writes frames with specified OpCode.
|
||||
// It uses ws.State to decide whether the output frames must be masked.
|
||||
//
|
||||
// Note that it does not check control frame size or other RFC rules.
|
||||
// That is, it must be used with special care to write control frames without
|
||||
// violation of RFC. You could use ControlWriter that wraps Writer and contains
|
||||
// some guards for writing control frames.
|
||||
//
|
||||
// If an error occurs writing to a Writer, no more data will be accepted and
|
||||
// all subsequent writes will return the error.
|
||||
// After all data has been written, the client should call the Flush() method
|
||||
// to guarantee all data has been forwarded to the underlying io.Writer.
|
||||
type Writer struct {
|
||||
dest io.Writer
|
||||
|
||||
n int // Buffered bytes counter.
|
||||
raw []byte // Raw representation of buffer, including reserved header bytes.
|
||||
buf []byte // Writeable part of buffer, without reserved header bytes.
|
||||
|
||||
op ws.OpCode
|
||||
state ws.State
|
||||
|
||||
dirty bool
|
||||
fragmented bool
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
var writers = pool.New(128, 65536)
|
||||
|
||||
// GetWriter tries to reuse Writer getting it from the pool.
|
||||
//
|
||||
// This function is intended for memory consumption optimizations, because
|
||||
// NewWriter*() functions make allocations for inner buffer.
|
||||
//
|
||||
// Note the it ceils n to the power of two.
|
||||
//
|
||||
// If you have your own bytes buffer pool you could use NewWriterBuffer to use
|
||||
// pooled bytes in writer.
|
||||
func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
|
||||
x, m := writers.Get(n)
|
||||
if x != nil {
|
||||
w := x.(*Writer)
|
||||
w.Reset(dest, state, op)
|
||||
return w
|
||||
}
|
||||
// NOTE: we use m instead of n, because m is an attempt to reuse w of such
|
||||
// size in the future.
|
||||
return NewWriterBufferSize(dest, state, op, m)
|
||||
}
|
||||
|
||||
// PutWriter puts w for future reuse by GetWriter().
|
||||
func PutWriter(w *Writer) {
|
||||
w.Reset(nil, 0, 0)
|
||||
writers.Put(w, w.Size())
|
||||
}
|
||||
|
||||
// NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
|
||||
func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
|
||||
return NewWriterBufferSize(dest, state, op, 0)
|
||||
}
|
||||
|
||||
// NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
|
||||
// That is, output frames payload length could be up to n, except the case when
|
||||
// Write() is called on empty Writer with len(p) > n.
|
||||
//
|
||||
// If n <= 0 then the default buffer size is used as Writer's buffer size.
|
||||
func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
|
||||
if n > 0 {
|
||||
n += headerSize(state, n)
|
||||
}
|
||||
return NewWriterBufferSize(dest, state, op, n)
|
||||
}
|
||||
|
||||
// NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
|
||||
// If n <= ws.MinHeaderSize then the default buffer size is used.
|
||||
//
|
||||
// Note that Writer will reserve x bytes for header data, where x is in range
|
||||
// [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
|
||||
// will not have payload length equal to n, except the case when Write() is
|
||||
// called on empty Writer with len(p) > n.
|
||||
func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
|
||||
if n <= ws.MinHeaderSize {
|
||||
n = DefaultWriteBuffer
|
||||
}
|
||||
return NewWriterBuffer(dest, state, op, make([]byte, n))
|
||||
}
|
||||
|
||||
// NewWriterBuffer returns a new Writer with buf as a buffer.
|
||||
//
|
||||
// Note that it reserves x bytes of buf for header data, where x is in range
|
||||
// [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
|
||||
//
|
||||
// You could use ws.HeaderSize() to calculate number of bytes needed to store
|
||||
// header data.
|
||||
//
|
||||
// It panics if len(buf) is too small to fit header and payload data.
|
||||
func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
|
||||
offset := reserve(state, len(buf))
|
||||
if len(buf) <= offset {
|
||||
panic("buffer too small")
|
||||
}
|
||||
|
||||
return &Writer{
|
||||
dest: dest,
|
||||
raw: buf,
|
||||
buf: buf[offset:],
|
||||
state: state,
|
||||
op: op,
|
||||
}
|
||||
}
|
||||
|
||||
func reserve(state ws.State, n int) (offset int) {
|
||||
var mask int
|
||||
if state.ClientSide() {
|
||||
mask = 4
|
||||
}
|
||||
|
||||
switch {
|
||||
case n <= int(len7)+mask+2:
|
||||
return mask + 2
|
||||
case n <= int(len16)+mask+4:
|
||||
return mask + 4
|
||||
default:
|
||||
return mask + 10
|
||||
}
|
||||
}
|
||||
|
||||
// headerSize returns number of bytes needed to encode header of a frame with
|
||||
// given state and length.
|
||||
func headerSize(s ws.State, n int) int {
|
||||
return ws.HeaderSize(ws.Header{
|
||||
Length: int64(n),
|
||||
Masked: s.ClientSide(),
|
||||
})
|
||||
}
|
||||
|
||||
// Reset discards any buffered data, clears error, and resets w to have given
|
||||
// state and write frames with given OpCode to dest.
|
||||
func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
|
||||
w.n = 0
|
||||
w.dirty = false
|
||||
w.fragmented = false
|
||||
w.dest = dest
|
||||
w.state = state
|
||||
w.op = op
|
||||
}
|
||||
|
||||
// Size returns the size of the underlying buffer in bytes.
|
||||
func (w *Writer) Size() int {
|
||||
return len(w.buf)
|
||||
}
|
||||
|
||||
// Available returns how many bytes are unused in the buffer.
|
||||
func (w *Writer) Available() int {
|
||||
return len(w.buf) - w.n
|
||||
}
|
||||
|
||||
// Buffered returns the number of bytes that have been written into the current
|
||||
// buffer.
|
||||
func (w *Writer) Buffered() int {
|
||||
return w.n
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
//
|
||||
// Note that even if the Writer was created to have N-sized buffer, Write()
|
||||
// with payload of N bytes will not fit into that buffer. Writer reserves some
|
||||
// space to fit WebSocket header data.
|
||||
func (w *Writer) Write(p []byte) (n int, err error) {
|
||||
// Even empty p may make a sense.
|
||||
w.dirty = true
|
||||
|
||||
var nn int
|
||||
for len(p) > w.Available() && w.err == nil {
|
||||
if w.Buffered() == 0 {
|
||||
// Large write, empty buffer. Write directly from p to avoid copy.
|
||||
// Trade off here is that we make additional Write() to underlying
|
||||
// io.Writer when writing frame header.
|
||||
//
|
||||
// On large buffers additional write is better than copying.
|
||||
nn, _ = w.WriteThrough(p)
|
||||
} else {
|
||||
nn = copy(w.buf[w.n:], p)
|
||||
w.n += nn
|
||||
w.FlushFragment()
|
||||
}
|
||||
n += nn
|
||||
p = p[nn:]
|
||||
}
|
||||
if w.err != nil {
|
||||
return n, w.err
|
||||
}
|
||||
nn = copy(w.buf[w.n:], p)
|
||||
w.n += nn
|
||||
n += nn
|
||||
|
||||
// Even if w.Available() == 0 we will not flush buffer preventively because
|
||||
// this could bring unwanted fragmentation. That is, user could create
|
||||
// buffer with size that fits exactly all further Write() call, and then
|
||||
// call Flush(), excepting that single and not fragmented frame will be
|
||||
// sent. With preemptive flush this case will produce two frames – last one
|
||||
// will be empty and just to set fin = true.
|
||||
|
||||
return n, w.err
|
||||
}
|
||||
|
||||
// WriteThrough writes data bypassing the buffer.
|
||||
// Note that Writer's buffer must be empty before calling WriteThrough().
|
||||
func (w *Writer) WriteThrough(p []byte) (n int, err error) {
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
if w.Buffered() != 0 {
|
||||
return 0, ErrNotEmpty
|
||||
}
|
||||
|
||||
w.err = writeFrame(w.dest, w.state, w.opCode(), false, p)
|
||||
if w.err == nil {
|
||||
n = len(p)
|
||||
}
|
||||
|
||||
w.dirty = true
|
||||
w.fragmented = true
|
||||
|
||||
return n, w.err
|
||||
}
|
||||
|
||||
// ReadFrom implements io.ReaderFrom.
|
||||
func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
|
||||
var nn int
|
||||
for err == nil {
|
||||
if w.Available() == 0 {
|
||||
err = w.FlushFragment()
|
||||
continue
|
||||
}
|
||||
|
||||
// We copy the behavior of bufio.Writer here.
|
||||
// Also, from the docs on io.ReaderFrom:
|
||||
// ReadFrom reads data from r until EOF or error.
|
||||
//
|
||||
// See https://codereview.appspot.com/76400048/#ps1
|
||||
const maxEmptyReads = 100
|
||||
var nr int
|
||||
for nr < maxEmptyReads {
|
||||
nn, err = src.Read(w.buf[w.n:])
|
||||
if nn != 0 || err != nil {
|
||||
break
|
||||
}
|
||||
nr++
|
||||
}
|
||||
if nr == maxEmptyReads {
|
||||
return n, io.ErrNoProgress
|
||||
}
|
||||
|
||||
w.n += nn
|
||||
n += int64(nn)
|
||||
}
|
||||
if err == io.EOF {
|
||||
// NOTE: Do not flush preemptively.
|
||||
// See the Write() sources for more info.
|
||||
err = nil
|
||||
w.dirty = true
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Flush writes any buffered data to the underlying io.Writer.
|
||||
// It sends the frame with "fin" flag set to true.
|
||||
//
|
||||
// If no Write() or ReadFrom() was made, then Flush() does nothing.
|
||||
func (w *Writer) Flush() error {
|
||||
if (!w.dirty && w.Buffered() == 0) || w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
|
||||
w.err = w.flushFragment(true)
|
||||
w.n = 0
|
||||
w.dirty = false
|
||||
w.fragmented = false
|
||||
|
||||
return w.err
|
||||
}
|
||||
|
||||
// FlushFragment writes any buffered data to the underlying io.Writer.
|
||||
// It sends the frame with "fin" flag set to false.
|
||||
func (w *Writer) FlushFragment() error {
|
||||
if w.Buffered() == 0 || w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
|
||||
w.err = w.flushFragment(false)
|
||||
w.n = 0
|
||||
w.fragmented = true
|
||||
|
||||
return w.err
|
||||
}
|
||||
|
||||
func (w *Writer) flushFragment(fin bool) error {
|
||||
frame := ws.NewFrame(w.opCode(), fin, w.buf[:w.n])
|
||||
if w.state.ClientSide() {
|
||||
frame = ws.MaskFrameInPlace(frame)
|
||||
}
|
||||
|
||||
// Write header to the header segment of the raw buffer.
|
||||
head := len(w.raw) - len(w.buf)
|
||||
offset := head - ws.HeaderSize(frame.Header)
|
||||
buf := bytesWriter{
|
||||
buf: w.raw[offset:head],
|
||||
}
|
||||
if err := ws.WriteHeader(&buf, frame.Header); err != nil {
|
||||
// Must never be reached.
|
||||
panic("dump header error: " + err.Error())
|
||||
}
|
||||
|
||||
_, err := w.dest.Write(w.raw[offset : head+w.n])
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Writer) opCode() ws.OpCode {
|
||||
if w.fragmented {
|
||||
return ws.OpContinuation
|
||||
}
|
||||
return w.op
|
||||
}
|
||||
|
||||
var errNoSpace = fmt.Errorf("not enough buffer space")
|
||||
|
||||
type bytesWriter struct {
|
||||
buf []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (w *bytesWriter) Write(p []byte) (int, error) {
|
||||
n := copy(w.buf[w.pos:], p)
|
||||
w.pos += n
|
||||
if n != len(p) {
|
||||
return n, errNoSpace
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
|
||||
var frame ws.Frame
|
||||
if s.ClientSide() {
|
||||
// Should copy bytes to prevent corruption of caller data.
|
||||
payload := pbytes.GetLen(len(p))
|
||||
defer pbytes.Put(payload)
|
||||
|
||||
copy(payload, p)
|
||||
|
||||
frame = ws.NewFrame(op, fin, payload)
|
||||
frame = ws.MaskFrameInPlace(frame)
|
||||
} else {
|
||||
frame = ws.NewFrame(op, fin, p)
|
||||
}
|
||||
|
||||
return ws.WriteFrame(w, frame)
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
Package wsutil provides utilities for working with WebSocket protocol.
|
||||
|
||||
Overview:
|
||||
|
||||
// Read masked text message from peer and check utf8 encoding.
|
||||
header, err := ws.ReadHeader(conn)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
// Prepare to read payload.
|
||||
r := io.LimitReader(conn, header.Length)
|
||||
r = wsutil.NewCipherReader(r, header.Mask)
|
||||
r = wsutil.NewUTF8Reader(r)
|
||||
|
||||
payload, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
You could get the same behavior using just `wsutil.Reader`:
|
||||
|
||||
r := wsutil.Reader{
|
||||
Source: conn,
|
||||
CheckUTF8: true,
|
||||
}
|
||||
|
||||
payload, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
Or even simplest:
|
||||
|
||||
payload, err := wsutil.ReadClientText(conn)
|
||||
if err != nil {
|
||||
// handle err
|
||||
}
|
||||
|
||||
Package is also exports tools for buffered writing:
|
||||
|
||||
// Create buffered writer, that will buffer output bytes and send them as
|
||||
// 128-length fragments (with exception on large writes, see the doc).
|
||||
writer := wsutil.NewWriterSize(conn, ws.StateServerSide, ws.OpText, 128)
|
||||
|
||||
_, err := io.CopyN(writer, rand.Reader, 100)
|
||||
if err == nil {
|
||||
err = writer.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
|
||||
For more utils and helpers see the documentation.
|
||||
*/
|
||||
package wsutil
|
|
@ -207,6 +207,19 @@ github.com/gliderlabs/ssh
|
|||
# github.com/go-sql-driver/mysql v1.5.0
|
||||
## explicit
|
||||
github.com/go-sql-driver/mysql
|
||||
# github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58
|
||||
## explicit
|
||||
github.com/gobwas/httphead
|
||||
# github.com/gobwas/pool v0.2.1
|
||||
## explicit
|
||||
github.com/gobwas/pool
|
||||
github.com/gobwas/pool/internal/pmath
|
||||
github.com/gobwas/pool/pbufio
|
||||
github.com/gobwas/pool/pbytes
|
||||
# github.com/gobwas/ws v1.0.4
|
||||
## explicit
|
||||
github.com/gobwas/ws
|
||||
github.com/gobwas/ws/wsutil
|
||||
# github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
||||
## explicit
|
||||
github.com/golang-collections/collections/queue
|
||||
|
|
Loading…
Reference in New Issue