TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).

This commit is contained in:
Igor Postelnik 2021-04-02 01:10:43 -05:00
parent b25d38dd72
commit 3ad99b241c
12 changed files with 455 additions and 315 deletions

View File

@ -1,5 +1,24 @@
**Experimental**: This is a new format for release notes. The format and availability is subject to change.
## UNRELEASED
### Backward Incompatible Changes
- none
### New Features
- none
### Improvements
- none
### Bug Fixes
- Fixed proxying of websocket requests to avoid possibility of losing initial frames that were sent in the same TCP
packet as response headers [#345](https://github.com/cloudflare/cloudflared/issues/345).
## 2021.3.6
### Bug Fixes

View File

@ -4,6 +4,7 @@ import (
"io"
"net/http"
"net/http/httputil"
"net/url"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
@ -60,7 +61,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
TLSClientConfig: options.TLSClientConfig,
Proxy: http.ProxyFromEnvironment,
}
wsConn, resp, err := cfwebsocket.ClientConnect(req, dialer)
wsConn, resp, err := clientConnect(req, dialer)
defer closeRespBody(resp)
if err != nil && IsAccessResponse(resp) {
@ -87,6 +88,63 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
}
var stripWebsocketHeaders = []string{
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Extensions",
}
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
func websocketHeaders(req *http.Request) http.Header {
wsHeaders := make(http.Header)
for key, val := range req.Header {
wsHeaders[key] = val
}
// Assume the header keys are in canonical format.
for _, header := range stripWebsocketHeaders {
wsHeaders.Del(header)
}
wsHeaders.Set("Host", req.Host) // See TUN-1097
return wsHeaders
}
// clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does
// not need to be closed by the application.
func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = changeRequestScheme(req.URL)
wsHeaders := websocketHeaders(req)
if dialler == nil {
dialler = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
if err != nil {
return nil, response, err
}
return conn, response, nil
}
// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
func changeRequestScheme(reqURL *url.URL) string {
switch reqURL.Scheme {
case "https":
return "wss"
case "http":
return "ws"
case "":
return "ws"
default:
return reqURL.Scheme
}
}
// createAccessAuthenticatedStream will try load a token from storage and make
// a connection with the token set on the request. If it still get redirect,
// this probably means the token in storage is invalid (expired/revoked). If that
@ -126,7 +184,7 @@ func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*w
dump, err := httputil.DumpRequest(req, false)
log.Debug().Msgf("Access Websocket request: %s", string(dump))
conn, resp, err := cfwebsocket.ClientConnect(req, nil)
conn, resp, err := clientConnect(req, nil)
if resp != nil {
r, err := httputil.DumpResponse(resp, true)

123
carrier/websocket_test.go Normal file
View File

@ -0,0 +1,123 @@
package carrier
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"math/rand"
"testing"
"time"
gws "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig"
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
)
func websocketClientTLSConfig(t *testing.T) *tls.Config {
certPool := x509.NewCertPool()
helloCert, err := tlsconfig.GetHelloCertificateX509()
assert.NoError(t, err)
certPool.AddCert(helloCert)
assert.NotNil(t, certPool)
return &tls.Config{RootCAs: certPool}
}
func TestWebsocketHeaders(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
wsHeaders := websocketHeaders(req)
for _, header := range stripWebsocketHeaders {
assert.Empty(t, wsHeaders[header])
}
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
}
func TestServe(t *testing.T) {
log := zerolog.Nop()
shutdownC := make(chan struct{})
errC := make(chan error)
listener, err := hello.CreateTLSListener("localhost:1111")
assert.NoError(t, err)
defer listener.Close()
go func() {
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
}()
req := testRequest(t, "https://localhost:1111/ws", nil)
tlsConfig := websocketClientTLSConfig(t)
assert.NotNil(t, tlsConfig)
d := gws.Dialer{TLSClientConfig: tlsConfig}
conn, resp, err := clientConnect(req, &d)
assert.NoError(t, err)
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
for i := 0; i < 1000; i++ {
messageSize := rand.Int()%2048 + 1
clientMessage := make([]byte, messageSize)
// rand.Read always returns len(clientMessage) and a nil error
rand.Read(clientMessage)
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
assert.NoError(t, err)
messageType, message, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, websocket.BinaryFrame, messageType)
assert.Equal(t, clientMessage, message)
}
_ = conn.Close()
close(shutdownC)
<-errC
}
func TestWebsocketWrapper(t *testing.T) {
listener, err := hello.CreateTLSListener("localhost:0")
require.NoError(t, err)
serverErrorChan := make(chan error)
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
defer func() { <-serverErrorChan }()
defer cancelHelloSvr()
go func() {
log := zerolog.Nop()
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
}()
tlsConfig := websocketClientTLSConfig(t)
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
req := testRequest(t, testAddr, nil)
conn, resp, err := clientConnect(req, &d)
require.NoError(t, err)
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
// Websocket now connected to test server so lets check our wrapper
wrapper := cfwebsocket.GorillaConn{Conn: conn}
buf := make([]byte, 100)
wrapper.Write([]byte("abc"))
n, err := wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 3)
require.Equal(t, "abc", string(buf[:n]))
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
wrapper.Write([]byte("abc"))
buf = buf[:1]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 1)
require.Equal(t, "a", string(buf[:n]))
buf = buf[:cap(buf)]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 2)
require.Equal(t, "bc", string(buf[:n]))
}

View File

@ -2,12 +2,9 @@ package ingress
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
gws "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ipaccess"
@ -58,35 +55,6 @@ func (wc *tcpOverWSConnection) Close() {
wc.conn.Close()
}
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
type wsConnection struct {
wsConn *gws.Conn
resp *http.Response
}
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
}
func (wsc *wsConnection) Close() {
wsc.resp.Body.Close()
wsc.wsConn.Close()
}
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: clientTLSConfig,
}
wsConn, resp, err := websocket.ClientConnect(r, d)
if err != nil {
return nil, nil, err
}
return &wsConnection{
wsConn,
resp,
}, resp, nil
}
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
// The connection to the origin happens inside the SOCKS code as the client specifies the origin
// details in the packet.
@ -100,3 +68,16 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
func (sp *socksProxyOverWSConnection) Close() {
}
// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin
type wsProxyConnection struct {
rwc io.ReadWriteCloser
}
func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
websocket.Stream(tunnelConn, conn.rwc, log)
}
func (conn *wsProxyConnection) Close() {
conn.rwc.Close()
}

View File

@ -3,13 +3,13 @@ package ingress
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
@ -193,18 +193,26 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t)
origin := echoWSOrigin(t, true)
defer origin.Close()
var svc httpService
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
NoTLSVerify: true,
})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
conn, resp, err := svc.newWebsocketProxyConnection(req)
clientTLSConfig := &tls.Config{
InsecureSkipVerify: true,
}
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
@ -213,13 +221,37 @@ func TestStreamWSConnection(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
connClosed := make(chan struct{})
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
select {
case <-connClosed:
case <-ctx.Done():
}
if ctx.Err() == context.DeadlineExceeded {
eyeballConn.Close()
edgeConn.Close()
conn.Close()
}
return ctx.Err()
})
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
fmt.Println("closing pipe")
edgeConn.Close()
return eyeballConn.Close()
})
errGroup.Go(func() error {
defer conn.Close()
conn.Stream(ctx, edgeConn, testLogger)
close(connClosed)
return nil
})
wsConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
@ -241,17 +273,23 @@ func (wse *wsEyeball) Write(p []byte) (int, error) {
}
func echoWSEyeball(t *testing.T, conn net.Conn) {
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
defer func() {
assert.NoError(t, conn.Close())
}()
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
return
}
readMsg, err := wsutil.ReadServerBinary(conn)
require.NoError(t, err)
if !assert.NoError(t, err) {
return
}
require.Equal(t, testResponse, readMsg)
require.NoError(t, conn.Close())
assert.Equal(t, testResponse, readMsg)
}
func echoWSOrigin(t *testing.T) *httptest.Server {
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
var upgrader = gorillaWS.Upgrader{
ReadBufferSize: 10,
WriteBufferSize: 10,
@ -268,12 +306,17 @@ func echoWSOrigin(t *testing.T) *httptest.Server {
require.NoError(t, err)
defer conn.Close()
sawMessage := false
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
if expectMessages && !sawMessage {
t.Errorf("unexpected error: %v", err)
}
return
}
require.Equal(t, testMessage, p)
assert.Equal(t, testMessage, p)
sawMessage = true
if err := conn.WriteMessage(messageType, testResponse); err != nil {
return
}

View File

@ -2,8 +2,10 @@ package ingress
import (
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/pkg/errors"
@ -12,7 +14,8 @@ import (
)
var (
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
}
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req = req.Clone(req.Context())
req.URL.Host = o.url.Host
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
req.URL.Scheme = o.url.Scheme
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return newWSConnection(o.transport.TLSClientConfig, req)
return o.newWebsocketProxyConnection(req)
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "https"
return o.transport.RoundTrip(req)
}
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss"
return newWSConnection(o.transport.TLSClientConfig, req)
req.ContentLength = 0
req.Body = nil
resp, err := o.transport.RoundTrip(req)
if err != nil {
return nil, nil, err
}
toClose := resp.Body
defer func() {
if toClose != nil {
_ = toClose.Close()
}
}()
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, nil, errUnsupportedConnectionType
}
conn := wsProxyConnection{
rwc: rwc,
}
// clear to prevent defer from closing
toClose = nil
return &conn, resp, nil
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {

View File

@ -33,7 +33,7 @@ func assertEstablishConnectionResponse(t *testing.T,
}
func TestHTTPServiceEstablishConnection(t *testing.T) {
origin := echoWSOrigin(t)
origin := echoWSOrigin(t, false)
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
@ -71,11 +71,11 @@ func TestHelloWorldEstablishConnection(t *testing.T) {
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
expectHeader := http.Header{
"Connection": {"Upgrade"},
// Accept key when Sec-Websocket-Key is not specified
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)

View File

@ -11,7 +11,6 @@ import (
"sync"
"time"
gws "github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@ -19,7 +18,6 @@ import (
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/websocket"
)
// originService is something a tunnel can proxy traffic to.
@ -50,16 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
return nil
}
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
NetDial: o.transport.Dial,
NetDialContext: o.transport.DialContext,
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
return d.Dial(reqURL.String(), headers)
}
type httpService struct {
url *url.URL
hostHeader string
@ -171,8 +159,8 @@ func (o *socksProxyOverWSService) String() string {
// HelloWorld is an OriginService for the built-in Hello World server.
// Users only use this for testing and experimenting with cloudflared.
type helloWorld struct {
server net.Listener
transport *http.Transport
httpService
server net.Listener
}
func (o *helloWorld) String() string {
@ -187,11 +175,10 @@ func (o *helloWorld) start(
errC chan error,
cfg OriginRequestConfig,
) error {
transport, err := newHTTPTransport(o, cfg, log)
if err != nil {
if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil {
return err
}
o.transport = transport
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server")
@ -202,6 +189,12 @@ func (o *helloWorld) start(
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
}()
o.server = helloListener
o.httpService.url = &url.URL{
Scheme: "https",
Host: o.server.Addr().String(),
}
return nil
}

View File

@ -67,7 +67,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting,
}
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
return err
}
@ -96,7 +96,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return fmt.Errorf("Not a connection-oriented service")
}
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
p.logRequestError(err, cfRay, ruleNum)
return err
}
@ -152,7 +152,6 @@ func (p *proxy) proxyStreamRequest(
serveCtx context.Context,
w connection.ResponseWriter,
req *http.Request,
sourceConnectionType connection.Type,
connectionProxy ingress.StreamBasedOriginProxy,
fields logFields,
) error {

View File

@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
@ -79,6 +80,11 @@ func (w *mockWSRespWriter) respBody() io.ReadWriter {
return bytes.NewBuffer(data)
}
func (w *mockWSRespWriter) Close() error {
close(w.writeNotification)
return nil
}
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data)
}
@ -125,14 +131,14 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
t.Run("testProxySSE", testProxySSE(t, proxy))
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
cancel()
wg.Wait()
}
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
@ -145,23 +151,43 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T
}
}
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
// WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background())
const testTimeout = 5 * time.Second * 1000
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
readPipe, writePipe := io.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
responseWriter := newMockWSRespWriter(readPipe)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
responseWriter := newMockWSRespWriter(nil)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
finished := make(chan struct{})
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
}()
return nil
})
errGroup.Go(func() error {
select {
case <-finished:
case <-ctx.Done():
}
if ctx.Err() == context.DeadlineExceeded {
t.Errorf("Test timed out")
readPipe.Close()
writePipe.Close()
responseWriter.Close()
}
return nil
})
msg := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, msg)
@ -179,12 +205,16 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
cancel()
wg.Wait()
_ = readPipe.Close()
_ = writePipe.Close()
_ = responseWriter.Close()
close(finished)
errGroup.Wait()
}
}
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
var (
pushCount = 50

View File

@ -3,116 +3,47 @@ package websocket
import (
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var stripWebsocketHeaders = []string{
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Extensions",
}
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does
// not need to be closed by the application.
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = ChangeRequestScheme(req.URL)
wsHeaders := websocketHeaders(req)
if dialler == nil {
dialler = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
if err != nil {
return nil, response, err
}
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
return conn, response, nil
}
// NewResponseHeader returns headers needed to return to origin for completing handshake
func NewResponseHeader(req *http.Request) http.Header {
header := http.Header{}
header.Add("Connection", "Upgrade")
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key")))
header.Add("Upgrade", "websocket")
return header
}
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
func websocketHeaders(req *http.Request) http.Header {
wsHeaders := make(http.Header)
for key, val := range req.Header {
wsHeaders[key] = val
}
// Assume the header keys are in canonical format.
for _, header := range stripWebsocketHeaders {
wsHeaders.Del(header)
}
wsHeaders.Set("Host", req.Host) // See TUN-1097
return wsHeaders
}
// sha1Base64 sha1 and then base64 encodes str.
func sha1Base64(str string) string {
hasher := sha1.New()
_, _ = io.WriteString(hasher, str)
hash := hasher.Sum(nil)
return base64.StdEncoding.EncodeToString(hash)
}
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
func generateAcceptKey(req *http.Request) string {
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
}
// ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
func ChangeRequestScheme(reqURL *url.URL) string {
switch reqURL.Scheme {
case "https":
return "wss"
case "http":
return "ws"
case "":
return "ws"
default:
return reqURL.Scheme
}
}
// Stream copies copy data to & from provided io.ReadWriters.
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
proxyDone := make(chan struct{}, 2)
go func() {
_, err := io.Copy(conn, backendConn)
_, err := copyData(tunnelConn, originConn, "origin->tunnel")
if err != nil {
log.Debug().Msgf("conn to backendConn copy: %v", err)
log.Debug().Msgf("origin to tunnel copy: %v", err)
}
proxyDone <- struct{}{}
}()
go func() {
_, err := io.Copy(backendConn, conn)
_, err := copyData(originConn, tunnelConn, "tunnel->origin")
if err != nil {
log.Debug().Msgf("backendConn to conn copy: %v", err)
log.Debug().Msgf("tunnel to origin copy: %v", err)
}
proxyDone <- struct{}{}
}()
@ -120,3 +51,60 @@ func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
// If one side is done, we are done.
<-proxyDone
}
// when set to true, enables logging of content copied to/from origin and tunnel
const debugCopy = false
func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
if debugCopy {
// copyBuffer is based on stdio Copy implementation but shows copied data
copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
var buf []byte
size := 32 * 1024
buf = make([]byte, size)
for {
t := time.Now()
nr, er := src.Read(buf)
if nr > 0 {
fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
return copyBuffer(dst, src, dir)
} else {
return io.Copy(dst, src)
}
}
// from RFC-6455
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func generateAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

View File

@ -1,24 +1,9 @@
package websocket
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"math/rand"
"net/http"
"testing"
"time"
gws "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig"
)
const (
@ -28,126 +13,6 @@ const (
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
)
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
req, err := http.NewRequest("GET", url, stream)
if err != nil {
t.Fatalf("testRequestHeader error")
}
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", "WebSocket")
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
req.Header.Add("Sec-Websocket-Version", "13")
req.Header.Add("User-Agent", "curl/7.59.0")
return req
}
func websocketClientTLSConfig(t *testing.T) *tls.Config {
certPool := x509.NewCertPool()
helloCert, err := tlsconfig.GetHelloCertificateX509()
assert.NoError(t, err)
certPool.AddCert(helloCert)
assert.NotNil(t, certPool)
return &tls.Config{RootCAs: certPool}
}
func TestWebsocketHeaders(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
wsHeaders := websocketHeaders(req)
for _, header := range stripWebsocketHeaders {
assert.Empty(t, wsHeaders[header])
}
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
}
func TestGenerateAcceptKey(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
}
func TestServe(t *testing.T) {
log := zerolog.Nop()
shutdownC := make(chan struct{})
errC := make(chan error)
listener, err := hello.CreateTLSListener("localhost:1111")
assert.NoError(t, err)
defer listener.Close()
go func() {
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
}()
req := testRequest(t, "https://localhost:1111/ws", nil)
tlsConfig := websocketClientTLSConfig(t)
assert.NotNil(t, tlsConfig)
d := gws.Dialer{TLSClientConfig: tlsConfig}
conn, resp, err := ClientConnect(req, &d)
assert.NoError(t, err)
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
for i := 0; i < 1000; i++ {
messageSize := rand.Int()%2048 + 1
clientMessage := make([]byte, messageSize)
// rand.Read always returns len(clientMessage) and a nil error
rand.Read(clientMessage)
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
assert.NoError(t, err)
messageType, message, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, websocket.BinaryFrame, messageType)
assert.Equal(t, clientMessage, message)
}
_ = conn.Close()
close(shutdownC)
<-errC
}
func TestWebsocketWrapper(t *testing.T) {
listener, err := hello.CreateTLSListener("localhost:0")
require.NoError(t, err)
serverErrorChan := make(chan error)
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
defer func() { <-serverErrorChan }()
defer cancelHelloSvr()
go func() {
log := zerolog.Nop()
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
}()
tlsConfig := websocketClientTLSConfig(t)
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
req := testRequest(t, testAddr, nil)
conn, resp, err := ClientConnect(req, &d)
require.NoError(t, err)
require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
// Websocket now connected to test server so lets check our wrapper
wrapper := GorillaConn{Conn: conn}
buf := make([]byte, 100)
wrapper.Write([]byte("abc"))
n, err := wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 3)
require.Equal(t, "abc", string(buf[:n]))
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
wrapper.Write([]byte("abc"))
buf = buf[:1]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 1)
require.Equal(t, "a", string(buf[:n]))
buf = buf[:cap(buf)]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 2)
require.Equal(t, "bc", string(buf[:n]))
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(testSecWebsocketKey))
}