TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).
This commit is contained in:
parent
b25d38dd72
commit
3ad99b241c
19
CHANGES.md
19
CHANGES.md
|
@ -1,5 +1,24 @@
|
|||
**Experimental**: This is a new format for release notes. The format and availability is subject to change.
|
||||
|
||||
## UNRELEASED
|
||||
|
||||
### Backward Incompatible Changes
|
||||
|
||||
- none
|
||||
|
||||
### New Features
|
||||
|
||||
- none
|
||||
|
||||
### Improvements
|
||||
|
||||
- none
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- Fixed proxying of websocket requests to avoid possibility of losing initial frames that were sent in the same TCP
|
||||
packet as response headers [#345](https://github.com/cloudflare/cloudflared/issues/345).
|
||||
|
||||
## 2021.3.6
|
||||
|
||||
### Bug Fixes
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -60,7 +61,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
|
|||
TLSClientConfig: options.TLSClientConfig,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
wsConn, resp, err := cfwebsocket.ClientConnect(req, dialer)
|
||||
wsConn, resp, err := clientConnect(req, dialer)
|
||||
defer closeRespBody(resp)
|
||||
|
||||
if err != nil && IsAccessResponse(resp) {
|
||||
|
@ -87,6 +88,63 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
|
|||
return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
|
||||
}
|
||||
|
||||
var stripWebsocketHeaders = []string{
|
||||
"Upgrade",
|
||||
"Connection",
|
||||
"Sec-Websocket-Key",
|
||||
"Sec-Websocket-Version",
|
||||
"Sec-Websocket-Extensions",
|
||||
}
|
||||
|
||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
|
||||
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
|
||||
func websocketHeaders(req *http.Request) http.Header {
|
||||
wsHeaders := make(http.Header)
|
||||
for key, val := range req.Header {
|
||||
wsHeaders[key] = val
|
||||
}
|
||||
// Assume the header keys are in canonical format.
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
wsHeaders.Del(header)
|
||||
}
|
||||
wsHeaders.Set("Host", req.Host) // See TUN-1097
|
||||
return wsHeaders
|
||||
}
|
||||
|
||||
// clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||
// the connection. The response body may not contain the entire response and does
|
||||
// not need to be closed by the application.
|
||||
func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
|
||||
req.URL.Scheme = changeRequestScheme(req.URL)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
if dialler == nil {
|
||||
dialler = &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
}
|
||||
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
||||
if err != nil {
|
||||
return nil, response, err
|
||||
}
|
||||
return conn, response, nil
|
||||
}
|
||||
|
||||
// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
|
||||
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
|
||||
func changeRequestScheme(reqURL *url.URL) string {
|
||||
switch reqURL.Scheme {
|
||||
case "https":
|
||||
return "wss"
|
||||
case "http":
|
||||
return "ws"
|
||||
case "":
|
||||
return "ws"
|
||||
default:
|
||||
return reqURL.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
// createAccessAuthenticatedStream will try load a token from storage and make
|
||||
// a connection with the token set on the request. If it still get redirect,
|
||||
// this probably means the token in storage is invalid (expired/revoked). If that
|
||||
|
@ -126,7 +184,7 @@ func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*w
|
|||
dump, err := httputil.DumpRequest(req, false)
|
||||
log.Debug().Msgf("Access Websocket request: %s", string(dump))
|
||||
|
||||
conn, resp, err := cfwebsocket.ClientConnect(req, nil)
|
||||
conn, resp, err := clientConnect(req, nil)
|
||||
|
||||
if resp != nil {
|
||||
r, err := httputil.DumpResponse(resp, true)
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
package carrier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
||||
certPool := x509.NewCertPool()
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
assert.NoError(t, err)
|
||||
certPool.AddCert(helloCert)
|
||||
assert.NotNil(t, certPool)
|
||||
return &tls.Config{RootCAs: certPool}
|
||||
}
|
||||
|
||||
func TestWebsocketHeaders(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
assert.Empty(t, wsHeaders[header])
|
||||
}
|
||||
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
listener, err := hello.CreateTLSListener("localhost:1111")
|
||||
assert.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
|
||||
}()
|
||||
|
||||
req := testRequest(t, "https://localhost:1111/ws", nil)
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
assert.NotNil(t, tlsConfig)
|
||||
d := gws.Dialer{TLSClientConfig: tlsConfig}
|
||||
conn, resp, err := clientConnect(req, &d)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
messageSize := rand.Int()%2048 + 1
|
||||
clientMessage := make([]byte, messageSize)
|
||||
// rand.Read always returns len(clientMessage) and a nil error
|
||||
rand.Read(clientMessage)
|
||||
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, websocket.BinaryFrame, messageType)
|
||||
assert.Equal(t, clientMessage, message)
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
close(shutdownC)
|
||||
<-errC
|
||||
}
|
||||
|
||||
func TestWebsocketWrapper(t *testing.T) {
|
||||
listener, err := hello.CreateTLSListener("localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
serverErrorChan := make(chan error)
|
||||
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
|
||||
defer func() { <-serverErrorChan }()
|
||||
defer cancelHelloSvr()
|
||||
go func() {
|
||||
log := zerolog.Nop()
|
||||
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
|
||||
}()
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
|
||||
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
|
||||
req := testRequest(t, testAddr, nil)
|
||||
conn, resp, err := clientConnect(req, &d)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||
|
||||
// Websocket now connected to test server so lets check our wrapper
|
||||
wrapper := cfwebsocket.GorillaConn{Conn: conn}
|
||||
buf := make([]byte, 100)
|
||||
wrapper.Write([]byte("abc"))
|
||||
n, err := wrapper.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 3)
|
||||
require.Equal(t, "abc", string(buf[:n]))
|
||||
|
||||
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
|
||||
wrapper.Write([]byte("abc"))
|
||||
buf = buf[:1]
|
||||
n, err = wrapper.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 1)
|
||||
require.Equal(t, "a", string(buf[:n]))
|
||||
buf = buf[:cap(buf)]
|
||||
n, err = wrapper.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 2)
|
||||
require.Equal(t, "bc", string(buf[:n]))
|
||||
}
|
|
@ -2,12 +2,9 @@ package ingress
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/ipaccess"
|
||||
|
@ -58,35 +55,6 @@ func (wc *tcpOverWSConnection) Close() {
|
|||
wc.conn.Close()
|
||||
}
|
||||
|
||||
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
|
||||
type wsConnection struct {
|
||||
wsConn *gws.Conn
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Close() {
|
||||
wsc.resp.Body.Close()
|
||||
wsc.wsConn.Close()
|
||||
}
|
||||
|
||||
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
}
|
||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &wsConnection{
|
||||
wsConn,
|
||||
resp,
|
||||
}, resp, nil
|
||||
}
|
||||
|
||||
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
|
||||
// The connection to the origin happens inside the SOCKS code as the client specifies the origin
|
||||
// details in the packet.
|
||||
|
@ -100,3 +68,16 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
|
|||
|
||||
func (sp *socksProxyOverWSConnection) Close() {
|
||||
}
|
||||
|
||||
// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin
|
||||
type wsProxyConnection struct {
|
||||
rwc io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
websocket.Stream(tunnelConn, conn.rwc, log)
|
||||
}
|
||||
|
||||
func (conn *wsProxyConnection) Close() {
|
||||
conn.rwc.Close()
|
||||
}
|
||||
|
|
|
@ -3,13 +3,13 @@ package ingress
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -193,18 +193,26 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
|||
func TestStreamWSConnection(t *testing.T) {
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
origin := echoWSOrigin(t)
|
||||
origin := echoWSOrigin(t, true)
|
||||
defer origin.Close()
|
||||
|
||||
var svc httpService
|
||||
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
|
||||
NoTLSVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
|
||||
conn, resp, err := svc.newWebsocketProxyConnection(req)
|
||||
|
||||
clientTLSConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
||||
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
||||
|
@ -213,13 +221,37 @@ func TestStreamWSConnection(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
connClosed := make(chan struct{})
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case <-connClosed:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
eyeballConn.Close()
|
||||
edgeConn.Close()
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
echoWSEyeball(t, eyeballConn)
|
||||
fmt.Println("closing pipe")
|
||||
edgeConn.Close()
|
||||
return eyeballConn.Close()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
defer conn.Close()
|
||||
conn.Stream(ctx, edgeConn, testLogger)
|
||||
close(connClosed)
|
||||
return nil
|
||||
})
|
||||
|
||||
wsConn.Stream(ctx, edgeConn, testLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
|
@ -241,17 +273,23 @@ func (wse *wsEyeball) Write(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
||||
defer func() {
|
||||
assert.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
readMsg, err := wsutil.ReadServerBinary(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testResponse, readMsg)
|
||||
|
||||
require.NoError(t, conn.Close())
|
||||
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
|
||||
return
|
||||
}
|
||||
|
||||
func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||
readMsg, err := wsutil.ReadServerBinary(conn)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, testResponse, readMsg)
|
||||
}
|
||||
|
||||
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
|
||||
var upgrader = gorillaWS.Upgrader{
|
||||
ReadBufferSize: 10,
|
||||
WriteBufferSize: 10,
|
||||
|
@ -268,12 +306,17 @@ func echoWSOrigin(t *testing.T) *httptest.Server {
|
|||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
sawMessage := false
|
||||
for {
|
||||
messageType, p, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if expectMessages && !sawMessage {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.Equal(t, testMessage, p)
|
||||
assert.Equal(t, testMessage, p)
|
||||
sawMessage = true
|
||||
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,8 +2,10 @@ package ingress
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
|
@ -13,6 +15,7 @@ import (
|
|||
|
||||
var (
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
|
@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
|
||||
switch req.URL.Scheme {
|
||||
case "ws":
|
||||
req.URL.Scheme = "http"
|
||||
case "wss":
|
||||
req.URL.Scheme = "https"
|
||||
}
|
||||
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
|
||||
return o.newWebsocketProxyConnection(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
|
||||
req.ContentLength = 0
|
||||
req.Body = nil
|
||||
|
||||
resp, err := o.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
toClose := resp.Body
|
||||
defer func() {
|
||||
if toClose != nil {
|
||||
_ = toClose.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
|
||||
}
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
|
||||
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
|
||||
}
|
||||
|
||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
return nil, nil, errUnsupportedConnectionType
|
||||
}
|
||||
conn := wsProxyConnection{
|
||||
rwc: rwc,
|
||||
}
|
||||
// clear to prevent defer from closing
|
||||
toClose = nil
|
||||
|
||||
return &conn, resp, nil
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
|
|
|
@ -33,7 +33,7 @@ func assertEstablishConnectionResponse(t *testing.T,
|
|||
}
|
||||
|
||||
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
||||
origin := echoWSOrigin(t)
|
||||
origin := echoWSOrigin(t, false)
|
||||
defer origin.Close()
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
|
@ -71,11 +71,11 @@ func TestHelloWorldEstablishConnection(t *testing.T) {
|
|||
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
||||
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
// Accept key when Sec-Websocket-Key is not specified
|
||||
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
|
@ -19,7 +18,6 @@ import (
|
|||
"github.com/cloudflare/cloudflared/ipaccess"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
// originService is something a tunnel can proxy traffic to.
|
||||
|
@ -50,16 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
|
|||
return nil
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
NetDial: o.transport.Dial,
|
||||
NetDialContext: o.transport.DialContext,
|
||||
TLSClientConfig: o.transport.TLSClientConfig,
|
||||
}
|
||||
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
type httpService struct {
|
||||
url *url.URL
|
||||
hostHeader string
|
||||
|
@ -171,8 +159,8 @@ func (o *socksProxyOverWSService) String() string {
|
|||
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||
// Users only use this for testing and experimenting with cloudflared.
|
||||
type helloWorld struct {
|
||||
httpService
|
||||
server net.Listener
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (o *helloWorld) String() string {
|
||||
|
@ -187,11 +175,10 @@ func (o *helloWorld) start(
|
|||
errC chan error,
|
||||
cfg OriginRequestConfig,
|
||||
) error {
|
||||
transport, err := newHTTPTransport(o, cfg, log)
|
||||
if err != nil {
|
||||
if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
|
||||
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Cannot start Hello World Server")
|
||||
|
@ -202,6 +189,12 @@ func (o *helloWorld) start(
|
|||
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
|
||||
}()
|
||||
o.server = helloListener
|
||||
|
||||
o.httpService.url = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: o.server.Addr().String(),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
lbProbe: lbProbe,
|
||||
rule: ingress.ServiceWarpRouting,
|
||||
}
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
|
||||
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
|
||||
return err
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
return fmt.Errorf("Not a connection-oriented service")
|
||||
}
|
||||
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
|
||||
p.logRequestError(err, cfRay, ruleNum)
|
||||
return err
|
||||
}
|
||||
|
@ -152,7 +152,6 @@ func (p *proxy) proxyStreamRequest(
|
|||
serveCtx context.Context,
|
||||
w connection.ResponseWriter,
|
||||
req *http.Request,
|
||||
sourceConnectionType connection.Type,
|
||||
connectionProxy ingress.StreamBasedOriginProxy,
|
||||
fields logFields,
|
||||
) error {
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/cli/v2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
|
@ -79,6 +80,11 @@ func (w *mockWSRespWriter) respBody() io.ReadWriter {
|
|||
return bytes.NewBuffer(data)
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) Close() error {
|
||||
close(w.writeNotification)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||
return w.reader.Read(data)
|
||||
}
|
||||
|
@ -125,14 +131,14 @@ func TestProxySingleOrigin(t *testing.T) {
|
|||
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
|
||||
t.Run("testProxySSE", testProxySSE(t, proxy))
|
||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||
t.Run("testProxySSE", testProxySSE(proxy))
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||
func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||
|
@ -145,23 +151,43 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T
|
|||
}
|
||||
}
|
||||
|
||||
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||
func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// WSRoute is a websocket echo handler
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
const testTimeout = 5 * time.Second * 1000
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
|
||||
responseWriter := newMockWSRespWriter(readPipe)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
responseWriter := newMockWSRespWriter(nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
finished := make(chan struct{})
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
t.Errorf("Test timed out")
|
||||
readPipe.Close()
|
||||
writePipe.Close()
|
||||
responseWriter.Close()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
msg := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, msg)
|
||||
|
@ -179,12 +205,16 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, msg, returnedMsg)
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
_ = readPipe.Close()
|
||||
_ = writePipe.Close()
|
||||
_ = responseWriter.Close()
|
||||
|
||||
close(finished)
|
||||
errGroup.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||
func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
var (
|
||||
pushCount = 50
|
||||
|
|
|
@ -3,116 +3,47 @@ package websocket
|
|||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var stripWebsocketHeaders = []string{
|
||||
"Upgrade",
|
||||
"Connection",
|
||||
"Sec-Websocket-Key",
|
||||
"Sec-Websocket-Version",
|
||||
"Sec-Websocket-Extensions",
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return websocket.IsWebSocketUpgrade(req)
|
||||
}
|
||||
|
||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||
// the connection. The response body may not contain the entire response and does
|
||||
// not need to be closed by the application.
|
||||
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
|
||||
req.URL.Scheme = ChangeRequestScheme(req.URL)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
if dialler == nil {
|
||||
dialler = &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
}
|
||||
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
||||
if err != nil {
|
||||
return nil, response, err
|
||||
}
|
||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||
return conn, response, nil
|
||||
}
|
||||
|
||||
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
||||
func NewResponseHeader(req *http.Request) http.Header {
|
||||
header := http.Header{}
|
||||
header.Add("Connection", "Upgrade")
|
||||
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
|
||||
header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key")))
|
||||
header.Add("Upgrade", "websocket")
|
||||
return header
|
||||
}
|
||||
|
||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
|
||||
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
|
||||
func websocketHeaders(req *http.Request) http.Header {
|
||||
wsHeaders := make(http.Header)
|
||||
for key, val := range req.Header {
|
||||
wsHeaders[key] = val
|
||||
}
|
||||
// Assume the header keys are in canonical format.
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
wsHeaders.Del(header)
|
||||
}
|
||||
wsHeaders.Set("Host", req.Host) // See TUN-1097
|
||||
return wsHeaders
|
||||
}
|
||||
|
||||
// sha1Base64 sha1 and then base64 encodes str.
|
||||
func sha1Base64(str string) string {
|
||||
hasher := sha1.New()
|
||||
_, _ = io.WriteString(hasher, str)
|
||||
hash := hasher.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
|
||||
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
|
||||
func generateAcceptKey(req *http.Request) string {
|
||||
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
}
|
||||
|
||||
// ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
|
||||
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
|
||||
func ChangeRequestScheme(reqURL *url.URL) string {
|
||||
switch reqURL.Scheme {
|
||||
case "https":
|
||||
return "wss"
|
||||
case "http":
|
||||
return "ws"
|
||||
case "":
|
||||
return "ws"
|
||||
default:
|
||||
return reqURL.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
// Stream copies copy data to & from provided io.ReadWriters.
|
||||
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
||||
func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
|
||||
proxyDone := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(conn, backendConn)
|
||||
_, err := copyData(tunnelConn, originConn, "origin->tunnel")
|
||||
if err != nil {
|
||||
log.Debug().Msgf("conn to backendConn copy: %v", err)
|
||||
log.Debug().Msgf("origin to tunnel copy: %v", err)
|
||||
}
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(backendConn, conn)
|
||||
_, err := copyData(originConn, tunnelConn, "tunnel->origin")
|
||||
if err != nil {
|
||||
log.Debug().Msgf("backendConn to conn copy: %v", err)
|
||||
log.Debug().Msgf("tunnel to origin copy: %v", err)
|
||||
}
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
@ -120,3 +51,60 @@ func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
|||
// If one side is done, we are done.
|
||||
<-proxyDone
|
||||
}
|
||||
|
||||
// when set to true, enables logging of content copied to/from origin and tunnel
|
||||
const debugCopy = false
|
||||
|
||||
func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
|
||||
if debugCopy {
|
||||
// copyBuffer is based on stdio Copy implementation but shows copied data
|
||||
copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
|
||||
var buf []byte
|
||||
size := 32 * 1024
|
||||
buf = make([]byte, size)
|
||||
for {
|
||||
t := time.Now()
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw < 0 || nr < nw {
|
||||
nw = 0
|
||||
if ew == nil {
|
||||
ew = errors.New("invalid write")
|
||||
}
|
||||
}
|
||||
written += int64(nw)
|
||||
if ew != nil {
|
||||
err = ew
|
||||
break
|
||||
}
|
||||
if nr != nw {
|
||||
err = io.ErrShortWrite
|
||||
break
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er != io.EOF {
|
||||
err = er
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
return copyBuffer(dst, src, dir)
|
||||
} else {
|
||||
return io.Copy(dst, src)
|
||||
}
|
||||
}
|
||||
|
||||
// from RFC-6455
|
||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
|
||||
func generateAcceptKey(challengeKey string) string {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(challengeKey))
|
||||
h.Write(keyGUID)
|
||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
|
|
@ -1,24 +1,9 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -28,126 +13,6 @@ const (
|
|||
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
)
|
||||
|
||||
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||
req, err := http.NewRequest("GET", url, stream)
|
||||
if err != nil {
|
||||
t.Fatalf("testRequestHeader error")
|
||||
}
|
||||
|
||||
req.Header.Add("Connection", "Upgrade")
|
||||
req.Header.Add("Upgrade", "WebSocket")
|
||||
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
|
||||
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
|
||||
req.Header.Add("Sec-Websocket-Version", "13")
|
||||
req.Header.Add("User-Agent", "curl/7.59.0")
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
||||
certPool := x509.NewCertPool()
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
assert.NoError(t, err)
|
||||
certPool.AddCert(helloCert)
|
||||
assert.NotNil(t, certPool)
|
||||
return &tls.Config{RootCAs: certPool}
|
||||
}
|
||||
|
||||
func TestWebsocketHeaders(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
assert.Empty(t, wsHeaders[header])
|
||||
}
|
||||
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestGenerateAcceptKey(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
listener, err := hello.CreateTLSListener("localhost:1111")
|
||||
assert.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
|
||||
}()
|
||||
|
||||
req := testRequest(t, "https://localhost:1111/ws", nil)
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
assert.NotNil(t, tlsConfig)
|
||||
d := gws.Dialer{TLSClientConfig: tlsConfig}
|
||||
conn, resp, err := ClientConnect(req, &d)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
messageSize := rand.Int()%2048 + 1
|
||||
clientMessage := make([]byte, messageSize)
|
||||
// rand.Read always returns len(clientMessage) and a nil error
|
||||
rand.Read(clientMessage)
|
||||
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, websocket.BinaryFrame, messageType)
|
||||
assert.Equal(t, clientMessage, message)
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
close(shutdownC)
|
||||
<-errC
|
||||
}
|
||||
|
||||
func TestWebsocketWrapper(t *testing.T) {
|
||||
|
||||
listener, err := hello.CreateTLSListener("localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
serverErrorChan := make(chan error)
|
||||
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
|
||||
defer func() { <-serverErrorChan }()
|
||||
defer cancelHelloSvr()
|
||||
go func() {
|
||||
log := zerolog.Nop()
|
||||
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
|
||||
}()
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
|
||||
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
|
||||
req := testRequest(t, testAddr, nil)
|
||||
conn, resp, err := ClientConnect(req, &d)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
|
||||
// Websocket now connected to test server so lets check our wrapper
|
||||
wrapper := GorillaConn{Conn: conn}
|
||||
buf := make([]byte, 100)
|
||||
wrapper.Write([]byte("abc"))
|
||||
n, err := wrapper.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 3)
|
||||
require.Equal(t, "abc", string(buf[:n]))
|
||||
|
||||
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
|
||||
wrapper.Write([]byte("abc"))
|
||||
buf = buf[:1]
|
||||
n, err = wrapper.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 1)
|
||||
require.Equal(t, "a", string(buf[:n]))
|
||||
buf = buf[:cap(buf)]
|
||||
n, err = wrapper.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 2)
|
||||
require.Equal(t, "bc", string(buf[:n]))
|
||||
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(testSecWebsocketKey))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue