TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).
This commit is contained in:
parent
b25d38dd72
commit
3ad99b241c
19
CHANGES.md
19
CHANGES.md
|
@ -1,5 +1,24 @@
|
||||||
**Experimental**: This is a new format for release notes. The format and availability is subject to change.
|
**Experimental**: This is a new format for release notes. The format and availability is subject to change.
|
||||||
|
|
||||||
|
## UNRELEASED
|
||||||
|
|
||||||
|
### Backward Incompatible Changes
|
||||||
|
|
||||||
|
- none
|
||||||
|
|
||||||
|
### New Features
|
||||||
|
|
||||||
|
- none
|
||||||
|
|
||||||
|
### Improvements
|
||||||
|
|
||||||
|
- none
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
- Fixed proxying of websocket requests to avoid possibility of losing initial frames that were sent in the same TCP
|
||||||
|
packet as response headers [#345](https://github.com/cloudflare/cloudflared/issues/345).
|
||||||
|
|
||||||
## 2021.3.6
|
## 2021.3.6
|
||||||
|
|
||||||
### Bug Fixes
|
### Bug Fixes
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -60,7 +61,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
|
||||||
TLSClientConfig: options.TLSClientConfig,
|
TLSClientConfig: options.TLSClientConfig,
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
}
|
}
|
||||||
wsConn, resp, err := cfwebsocket.ClientConnect(req, dialer)
|
wsConn, resp, err := clientConnect(req, dialer)
|
||||||
defer closeRespBody(resp)
|
defer closeRespBody(resp)
|
||||||
|
|
||||||
if err != nil && IsAccessResponse(resp) {
|
if err != nil && IsAccessResponse(resp) {
|
||||||
|
@ -87,6 +88,63 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso
|
||||||
return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
|
return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var stripWebsocketHeaders = []string{
|
||||||
|
"Upgrade",
|
||||||
|
"Connection",
|
||||||
|
"Sec-Websocket-Key",
|
||||||
|
"Sec-Websocket-Version",
|
||||||
|
"Sec-Websocket-Extensions",
|
||||||
|
}
|
||||||
|
|
||||||
|
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||||
|
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
|
||||||
|
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
|
||||||
|
func websocketHeaders(req *http.Request) http.Header {
|
||||||
|
wsHeaders := make(http.Header)
|
||||||
|
for key, val := range req.Header {
|
||||||
|
wsHeaders[key] = val
|
||||||
|
}
|
||||||
|
// Assume the header keys are in canonical format.
|
||||||
|
for _, header := range stripWebsocketHeaders {
|
||||||
|
wsHeaders.Del(header)
|
||||||
|
}
|
||||||
|
wsHeaders.Set("Host", req.Host) // See TUN-1097
|
||||||
|
return wsHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||||
|
// the connection. The response body may not contain the entire response and does
|
||||||
|
// not need to be closed by the application.
|
||||||
|
func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
|
||||||
|
req.URL.Scheme = changeRequestScheme(req.URL)
|
||||||
|
wsHeaders := websocketHeaders(req)
|
||||||
|
if dialler == nil {
|
||||||
|
dialler = &websocket.Dialer{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
||||||
|
if err != nil {
|
||||||
|
return nil, response, err
|
||||||
|
}
|
||||||
|
return conn, response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
|
||||||
|
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
|
||||||
|
func changeRequestScheme(reqURL *url.URL) string {
|
||||||
|
switch reqURL.Scheme {
|
||||||
|
case "https":
|
||||||
|
return "wss"
|
||||||
|
case "http":
|
||||||
|
return "ws"
|
||||||
|
case "":
|
||||||
|
return "ws"
|
||||||
|
default:
|
||||||
|
return reqURL.Scheme
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// createAccessAuthenticatedStream will try load a token from storage and make
|
// createAccessAuthenticatedStream will try load a token from storage and make
|
||||||
// a connection with the token set on the request. If it still get redirect,
|
// a connection with the token set on the request. If it still get redirect,
|
||||||
// this probably means the token in storage is invalid (expired/revoked). If that
|
// this probably means the token in storage is invalid (expired/revoked). If that
|
||||||
|
@ -126,7 +184,7 @@ func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*w
|
||||||
dump, err := httputil.DumpRequest(req, false)
|
dump, err := httputil.DumpRequest(req, false)
|
||||||
log.Debug().Msgf("Access Websocket request: %s", string(dump))
|
log.Debug().Msgf("Access Websocket request: %s", string(dump))
|
||||||
|
|
||||||
conn, resp, err := cfwebsocket.ClientConnect(req, nil)
|
conn, resp, err := clientConnect(req, nil)
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
r, err := httputil.DumpResponse(resp, true)
|
r, err := httputil.DumpResponse(resp, true)
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
package carrier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gws "github.com/gorilla/websocket"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/websocket"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
|
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
certPool.AddCert(helloCert)
|
||||||
|
assert.NotNil(t, certPool)
|
||||||
|
return &tls.Config{RootCAs: certPool}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketHeaders(t *testing.T) {
|
||||||
|
req := testRequest(t, "http://example.com", nil)
|
||||||
|
wsHeaders := websocketHeaders(req)
|
||||||
|
for _, header := range stripWebsocketHeaders {
|
||||||
|
assert.Empty(t, wsHeaders[header])
|
||||||
|
}
|
||||||
|
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServe(t *testing.T) {
|
||||||
|
log := zerolog.Nop()
|
||||||
|
shutdownC := make(chan struct{})
|
||||||
|
errC := make(chan error)
|
||||||
|
listener, err := hello.CreateTLSListener("localhost:1111")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
|
||||||
|
}()
|
||||||
|
|
||||||
|
req := testRequest(t, "https://localhost:1111/ws", nil)
|
||||||
|
|
||||||
|
tlsConfig := websocketClientTLSConfig(t)
|
||||||
|
assert.NotNil(t, tlsConfig)
|
||||||
|
d := gws.Dialer{TLSClientConfig: tlsConfig}
|
||||||
|
conn, resp, err := clientConnect(req, &d)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
messageSize := rand.Int()%2048 + 1
|
||||||
|
clientMessage := make([]byte, messageSize)
|
||||||
|
// rand.Read always returns len(clientMessage) and a nil error
|
||||||
|
rand.Read(clientMessage)
|
||||||
|
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
messageType, message, err := conn.ReadMessage()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, websocket.BinaryFrame, messageType)
|
||||||
|
assert.Equal(t, clientMessage, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = conn.Close()
|
||||||
|
close(shutdownC)
|
||||||
|
<-errC
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketWrapper(t *testing.T) {
|
||||||
|
listener, err := hello.CreateTLSListener("localhost:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverErrorChan := make(chan error)
|
||||||
|
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
|
||||||
|
defer func() { <-serverErrorChan }()
|
||||||
|
defer cancelHelloSvr()
|
||||||
|
go func() {
|
||||||
|
log := zerolog.Nop()
|
||||||
|
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
|
||||||
|
}()
|
||||||
|
|
||||||
|
tlsConfig := websocketClientTLSConfig(t)
|
||||||
|
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
|
||||||
|
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
|
||||||
|
req := testRequest(t, testAddr, nil)
|
||||||
|
conn, resp, err := clientConnect(req, &d)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||||
|
|
||||||
|
// Websocket now connected to test server so lets check our wrapper
|
||||||
|
wrapper := cfwebsocket.GorillaConn{Conn: conn}
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
wrapper.Write([]byte("abc"))
|
||||||
|
n, err := wrapper.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, n, 3)
|
||||||
|
require.Equal(t, "abc", string(buf[:n]))
|
||||||
|
|
||||||
|
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
|
||||||
|
wrapper.Write([]byte("abc"))
|
||||||
|
buf = buf[:1]
|
||||||
|
n, err = wrapper.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, n, 1)
|
||||||
|
require.Equal(t, "a", string(buf[:n]))
|
||||||
|
buf = buf[:cap(buf)]
|
||||||
|
n, err = wrapper.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, n, 2)
|
||||||
|
require.Equal(t, "bc", string(buf[:n]))
|
||||||
|
}
|
|
@ -2,12 +2,9 @@ package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
|
|
||||||
gws "github.com/gorilla/websocket"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/ipaccess"
|
"github.com/cloudflare/cloudflared/ipaccess"
|
||||||
|
@ -58,35 +55,6 @@ func (wc *tcpOverWSConnection) Close() {
|
||||||
wc.conn.Close()
|
wc.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
|
|
||||||
type wsConnection struct {
|
|
||||||
wsConn *gws.Conn
|
|
||||||
resp *http.Response
|
|
||||||
}
|
|
||||||
|
|
||||||
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
|
||||||
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (wsc *wsConnection) Close() {
|
|
||||||
wsc.resp.Body.Close()
|
|
||||||
wsc.wsConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
|
|
||||||
d := &gws.Dialer{
|
|
||||||
TLSClientConfig: clientTLSConfig,
|
|
||||||
}
|
|
||||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return &wsConnection{
|
|
||||||
wsConn,
|
|
||||||
resp,
|
|
||||||
}, resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
|
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
|
||||||
// The connection to the origin happens inside the SOCKS code as the client specifies the origin
|
// The connection to the origin happens inside the SOCKS code as the client specifies the origin
|
||||||
// details in the packet.
|
// details in the packet.
|
||||||
|
@ -100,3 +68,16 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
|
||||||
|
|
||||||
func (sp *socksProxyOverWSConnection) Close() {
|
func (sp *socksProxyOverWSConnection) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin
|
||||||
|
type wsProxyConnection struct {
|
||||||
|
rwc io.ReadWriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
|
websocket.Stream(tunnelConn, conn.rwc, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *wsProxyConnection) Close() {
|
||||||
|
conn.rwc.Close()
|
||||||
|
}
|
||||||
|
|
|
@ -3,13 +3,13 @@ package ingress
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -193,18 +193,26 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||||
func TestStreamWSConnection(t *testing.T) {
|
func TestStreamWSConnection(t *testing.T) {
|
||||||
eyeballConn, edgeConn := net.Pipe()
|
eyeballConn, edgeConn := net.Pipe()
|
||||||
|
|
||||||
origin := echoWSOrigin(t)
|
origin := echoWSOrigin(t, true)
|
||||||
defer origin.Close()
|
defer origin.Close()
|
||||||
|
|
||||||
|
var svc httpService
|
||||||
|
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
|
||||||
|
NoTLSVerify: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
|
||||||
|
conn, resp, err := svc.newWebsocketProxyConnection(req)
|
||||||
|
|
||||||
clientTLSConfig := &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
}
|
|
||||||
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||||
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
||||||
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
||||||
|
@ -213,13 +221,37 @@ func TestStreamWSConnection(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
connClosed := make(chan struct{})
|
||||||
|
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
select {
|
||||||
|
case <-connClosed:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
eyeballConn.Close()
|
||||||
|
edgeConn.Close()
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.Err()
|
||||||
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
echoWSEyeball(t, eyeballConn)
|
echoWSEyeball(t, eyeballConn)
|
||||||
|
fmt.Println("closing pipe")
|
||||||
|
edgeConn.Close()
|
||||||
|
return eyeballConn.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
defer conn.Close()
|
||||||
|
conn.Stream(ctx, edgeConn, testLogger)
|
||||||
|
close(connClosed)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
wsConn.Stream(ctx, edgeConn, testLogger)
|
|
||||||
require.NoError(t, errGroup.Wait())
|
require.NoError(t, errGroup.Wait())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,17 +273,23 @@ func (wse *wsEyeball) Write(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||||
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
defer func() {
|
||||||
|
assert.NoError(t, conn.Close())
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
readMsg, err := wsutil.ReadServerBinary(conn)
|
readMsg, err := wsutil.ReadServerBinary(conn)
|
||||||
require.NoError(t, err)
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
require.Equal(t, testResponse, readMsg)
|
assert.Equal(t, testResponse, readMsg)
|
||||||
|
|
||||||
require.NoError(t, conn.Close())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func echoWSOrigin(t *testing.T) *httptest.Server {
|
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
|
||||||
var upgrader = gorillaWS.Upgrader{
|
var upgrader = gorillaWS.Upgrader{
|
||||||
ReadBufferSize: 10,
|
ReadBufferSize: 10,
|
||||||
WriteBufferSize: 10,
|
WriteBufferSize: 10,
|
||||||
|
@ -268,12 +306,17 @@ func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
sawMessage := false
|
||||||
for {
|
for {
|
||||||
messageType, p, err := conn.ReadMessage()
|
messageType, p, err := conn.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if expectMessages && !sawMessage {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
require.Equal(t, testMessage, p)
|
assert.Equal(t, testMessage, p)
|
||||||
|
sawMessage = true
|
||||||
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,10 @@ package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
@ -12,7 +14,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||||
|
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||||
|
@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
|
|
||||||
req.URL.Host = o.url.Host
|
req.URL.Host = o.url.Host
|
||||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
req.URL.Scheme = o.url.Scheme
|
||||||
|
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
|
||||||
|
switch req.URL.Scheme {
|
||||||
|
case "ws":
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
case "wss":
|
||||||
|
req.URL.Scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
if o.hostHeader != "" {
|
if o.hostHeader != "" {
|
||||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||||
req.Host = o.hostHeader
|
req.Host = o.hostHeader
|
||||||
}
|
}
|
||||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
|
||||||
|
return o.newWebsocketProxyConnection(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
// Rewrite the request URL so that it goes to the Hello World server.
|
req.Header.Set("Connection", "Upgrade")
|
||||||
req.URL.Host = o.server.Addr().String()
|
req.Header.Set("Upgrade", "websocket")
|
||||||
req.URL.Scheme = "https"
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
return o.transport.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
req.ContentLength = 0
|
||||||
req.URL.Host = o.server.Addr().String()
|
req.Body = nil
|
||||||
req.URL.Scheme = "wss"
|
|
||||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
resp, err := o.transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
toClose := resp.Body
|
||||||
|
defer func() {
|
||||||
|
if toClose != nil {
|
||||||
|
_ = toClose.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
|
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
|
||||||
|
}
|
||||||
|
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
|
||||||
|
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
|
||||||
|
}
|
||||||
|
|
||||||
|
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errUnsupportedConnectionType
|
||||||
|
}
|
||||||
|
conn := wsProxyConnection{
|
||||||
|
rwc: rwc,
|
||||||
|
}
|
||||||
|
// clear to prevent defer from closing
|
||||||
|
toClose = nil
|
||||||
|
|
||||||
|
return &conn, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
|
|
|
@ -33,7 +33,7 @@ func assertEstablishConnectionResponse(t *testing.T,
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
||||||
origin := echoWSOrigin(t)
|
origin := echoWSOrigin(t, false)
|
||||||
defer origin.Close()
|
defer origin.Close()
|
||||||
originURL, err := url.Parse(origin.URL)
|
originURL, err := url.Parse(origin.URL)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -71,11 +71,11 @@ func TestHelloWorldEstablishConnection(t *testing.T) {
|
||||||
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
||||||
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
|
||||||
expectHeader := http.Header{
|
expectHeader := http.Header{
|
||||||
"Connection": {"Upgrade"},
|
"Connection": {"Upgrade"},
|
||||||
// Accept key when Sec-Websocket-Key is not specified
|
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||||
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
|
||||||
"Upgrade": {"websocket"},
|
"Upgrade": {"websocket"},
|
||||||
}
|
}
|
||||||
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
gws "github.com/gorilla/websocket"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
@ -19,7 +18,6 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/ipaccess"
|
"github.com/cloudflare/cloudflared/ipaccess"
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// originService is something a tunnel can proxy traffic to.
|
// originService is something a tunnel can proxy traffic to.
|
||||||
|
@ -50,16 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
|
||||||
d := &gws.Dialer{
|
|
||||||
NetDial: o.transport.Dial,
|
|
||||||
NetDialContext: o.transport.DialContext,
|
|
||||||
TLSClientConfig: o.transport.TLSClientConfig,
|
|
||||||
}
|
|
||||||
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
|
|
||||||
return d.Dial(reqURL.String(), headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
type httpService struct {
|
type httpService struct {
|
||||||
url *url.URL
|
url *url.URL
|
||||||
hostHeader string
|
hostHeader string
|
||||||
|
@ -171,8 +159,8 @@ func (o *socksProxyOverWSService) String() string {
|
||||||
// HelloWorld is an OriginService for the built-in Hello World server.
|
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||||
// Users only use this for testing and experimenting with cloudflared.
|
// Users only use this for testing and experimenting with cloudflared.
|
||||||
type helloWorld struct {
|
type helloWorld struct {
|
||||||
server net.Listener
|
httpService
|
||||||
transport *http.Transport
|
server net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *helloWorld) String() string {
|
func (o *helloWorld) String() string {
|
||||||
|
@ -187,11 +175,10 @@ func (o *helloWorld) start(
|
||||||
errC chan error,
|
errC chan error,
|
||||||
cfg OriginRequestConfig,
|
cfg OriginRequestConfig,
|
||||||
) error {
|
) error {
|
||||||
transport, err := newHTTPTransport(o, cfg, log)
|
if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
o.transport = transport
|
|
||||||
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Cannot start Hello World Server")
|
return errors.Wrap(err, "Cannot start Hello World Server")
|
||||||
|
@ -202,6 +189,12 @@ func (o *helloWorld) start(
|
||||||
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
|
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
|
||||||
}()
|
}()
|
||||||
o.server = helloListener
|
o.server = helloListener
|
||||||
|
|
||||||
|
o.httpService.url = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: o.server.Addr().String(),
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
lbProbe: lbProbe,
|
lbProbe: lbProbe,
|
||||||
rule: ingress.ServiceWarpRouting,
|
rule: ingress.ServiceWarpRouting,
|
||||||
}
|
}
|
||||||
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
|
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
|
||||||
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
|
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
return fmt.Errorf("Not a connection-oriented service")
|
return fmt.Errorf("Not a connection-oriented service")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
|
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
|
||||||
p.logRequestError(err, cfRay, ruleNum)
|
p.logRequestError(err, cfRay, ruleNum)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -152,7 +152,6 @@ func (p *proxy) proxyStreamRequest(
|
||||||
serveCtx context.Context,
|
serveCtx context.Context,
|
||||||
w connection.ResponseWriter,
|
w connection.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
sourceConnectionType connection.Type,
|
|
||||||
connectionProxy ingress.StreamBasedOriginProxy,
|
connectionProxy ingress.StreamBasedOriginProxy,
|
||||||
fields logFields,
|
fields logFields,
|
||||||
) error {
|
) error {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
@ -79,6 +80,11 @@ func (w *mockWSRespWriter) respBody() io.ReadWriter {
|
||||||
return bytes.NewBuffer(data)
|
return bytes.NewBuffer(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *mockWSRespWriter) Close() error {
|
||||||
|
close(w.writeNotification)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||||
return w.reader.Read(data)
|
return w.reader.Read(data)
|
||||||
}
|
}
|
||||||
|
@ -125,14 +131,14 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||||
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
|
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
|
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
|
||||||
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
|
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
|
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||||
t.Run("testProxySSE", testProxySSE(t, proxy))
|
t.Run("testProxySSE", testProxySSE(proxy))
|
||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||||
|
@ -145,23 +151,43 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
// WSRoute is a websocket echo handler
|
// WSRoute is a websocket echo handler
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
const testTimeout = 5 * time.Second * 1000
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||||
|
defer cancel()
|
||||||
readPipe, writePipe := io.Pipe()
|
readPipe, writePipe := io.Pipe()
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
|
||||||
responseWriter := newMockWSRespWriter(readPipe)
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
responseWriter := newMockWSRespWriter(nil)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
finished := make(chan struct{})
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
defer wg.Done()
|
errGroup.Go(func() error {
|
||||||
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
|
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
|
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
|
||||||
}()
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
select {
|
||||||
|
case <-finished:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
t.Errorf("Test timed out")
|
||||||
|
readPipe.Close()
|
||||||
|
writePipe.Close()
|
||||||
|
responseWriter.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
msg := []byte("test websocket")
|
msg := []byte("test websocket")
|
||||||
err = wsutil.WriteClientText(writePipe, msg)
|
err = wsutil.WriteClientText(writePipe, msg)
|
||||||
|
@ -179,12 +205,16 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, msg, returnedMsg)
|
require.Equal(t, msg, returnedMsg)
|
||||||
|
|
||||||
cancel()
|
_ = readPipe.Close()
|
||||||
wg.Wait()
|
_ = writePipe.Close()
|
||||||
|
_ = responseWriter.Close()
|
||||||
|
|
||||||
|
close(finished)
|
||||||
|
errGroup.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
pushCount = 50
|
pushCount = 50
|
||||||
|
|
|
@ -3,116 +3,47 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stripWebsocketHeaders = []string{
|
|
||||||
"Upgrade",
|
|
||||||
"Connection",
|
|
||||||
"Sec-Websocket-Key",
|
|
||||||
"Sec-Websocket-Version",
|
|
||||||
"Sec-Websocket-Extensions",
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||||
return websocket.IsWebSocketUpgrade(req)
|
return websocket.IsWebSocketUpgrade(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
|
||||||
// the connection. The response body may not contain the entire response and does
|
|
||||||
// not need to be closed by the application.
|
|
||||||
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
|
|
||||||
req.URL.Scheme = ChangeRequestScheme(req.URL)
|
|
||||||
wsHeaders := websocketHeaders(req)
|
|
||||||
if dialler == nil {
|
|
||||||
dialler = &websocket.Dialer{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
|
||||||
if err != nil {
|
|
||||||
return nil, response, err
|
|
||||||
}
|
|
||||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
|
||||||
return conn, response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
||||||
func NewResponseHeader(req *http.Request) http.Header {
|
func NewResponseHeader(req *http.Request) http.Header {
|
||||||
header := http.Header{}
|
header := http.Header{}
|
||||||
header.Add("Connection", "Upgrade")
|
header.Add("Connection", "Upgrade")
|
||||||
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
|
header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key")))
|
||||||
header.Add("Upgrade", "websocket")
|
header.Add("Upgrade", "websocket")
|
||||||
return header
|
return header
|
||||||
}
|
}
|
||||||
|
|
||||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
|
||||||
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
|
|
||||||
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
|
|
||||||
func websocketHeaders(req *http.Request) http.Header {
|
|
||||||
wsHeaders := make(http.Header)
|
|
||||||
for key, val := range req.Header {
|
|
||||||
wsHeaders[key] = val
|
|
||||||
}
|
|
||||||
// Assume the header keys are in canonical format.
|
|
||||||
for _, header := range stripWebsocketHeaders {
|
|
||||||
wsHeaders.Del(header)
|
|
||||||
}
|
|
||||||
wsHeaders.Set("Host", req.Host) // See TUN-1097
|
|
||||||
return wsHeaders
|
|
||||||
}
|
|
||||||
|
|
||||||
// sha1Base64 sha1 and then base64 encodes str.
|
|
||||||
func sha1Base64(str string) string {
|
|
||||||
hasher := sha1.New()
|
|
||||||
_, _ = io.WriteString(hasher, str)
|
|
||||||
hash := hasher.Sum(nil)
|
|
||||||
return base64.StdEncoding.EncodeToString(hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
|
|
||||||
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
|
|
||||||
func generateAcceptKey(req *http.Request) string {
|
|
||||||
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
|
|
||||||
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
|
|
||||||
func ChangeRequestScheme(reqURL *url.URL) string {
|
|
||||||
switch reqURL.Scheme {
|
|
||||||
case "https":
|
|
||||||
return "wss"
|
|
||||||
case "http":
|
|
||||||
return "ws"
|
|
||||||
case "":
|
|
||||||
return "ws"
|
|
||||||
default:
|
|
||||||
return reqURL.Scheme
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream copies copy data to & from provided io.ReadWriters.
|
// Stream copies copy data to & from provided io.ReadWriters.
|
||||||
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
proxyDone := make(chan struct{}, 2)
|
proxyDone := make(chan struct{}, 2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(conn, backendConn)
|
_, err := copyData(tunnelConn, originConn, "origin->tunnel")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Msgf("conn to backendConn copy: %v", err)
|
log.Debug().Msgf("origin to tunnel copy: %v", err)
|
||||||
}
|
}
|
||||||
proxyDone <- struct{}{}
|
proxyDone <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(backendConn, conn)
|
_, err := copyData(originConn, tunnelConn, "tunnel->origin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Msgf("backendConn to conn copy: %v", err)
|
log.Debug().Msgf("tunnel to origin copy: %v", err)
|
||||||
}
|
}
|
||||||
proxyDone <- struct{}{}
|
proxyDone <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
@ -120,3 +51,60 @@ func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
// If one side is done, we are done.
|
// If one side is done, we are done.
|
||||||
<-proxyDone
|
<-proxyDone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when set to true, enables logging of content copied to/from origin and tunnel
|
||||||
|
const debugCopy = false
|
||||||
|
|
||||||
|
func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
|
||||||
|
if debugCopy {
|
||||||
|
// copyBuffer is based on stdio Copy implementation but shows copied data
|
||||||
|
copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
|
||||||
|
var buf []byte
|
||||||
|
size := 32 * 1024
|
||||||
|
buf = make([]byte, size)
|
||||||
|
for {
|
||||||
|
t := time.Now()
|
||||||
|
nr, er := src.Read(buf)
|
||||||
|
if nr > 0 {
|
||||||
|
fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
|
||||||
|
nw, ew := dst.Write(buf[0:nr])
|
||||||
|
if nw < 0 || nr < nw {
|
||||||
|
nw = 0
|
||||||
|
if ew == nil {
|
||||||
|
ew = errors.New("invalid write")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
written += int64(nw)
|
||||||
|
if ew != nil {
|
||||||
|
err = ew
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if nr != nw {
|
||||||
|
err = io.ErrShortWrite
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if er != nil {
|
||||||
|
if er != io.EOF {
|
||||||
|
err = er
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
return copyBuffer(dst, src, dir)
|
||||||
|
} else {
|
||||||
|
return io.Copy(dst, src)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// from RFC-6455
|
||||||
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
|
||||||
|
func generateAcceptKey(challengeKey string) string {
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write([]byte(challengeKey))
|
||||||
|
h.Write(keyGUID)
|
||||||
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
|
@ -1,24 +1,9 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
gws "github.com/gorilla/websocket"
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/net/websocket"
|
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -28,126 +13,6 @@ const (
|
||||||
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||||
)
|
)
|
||||||
|
|
||||||
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
|
||||||
req, err := http.NewRequest("GET", url, stream)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("testRequestHeader error")
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Add("Connection", "Upgrade")
|
|
||||||
req.Header.Add("Upgrade", "WebSocket")
|
|
||||||
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
|
|
||||||
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
|
|
||||||
req.Header.Add("Sec-Websocket-Version", "13")
|
|
||||||
req.Header.Add("User-Agent", "curl/7.59.0")
|
|
||||||
|
|
||||||
return req
|
|
||||||
}
|
|
||||||
|
|
||||||
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
|
||||||
certPool := x509.NewCertPool()
|
|
||||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
certPool.AddCert(helloCert)
|
|
||||||
assert.NotNil(t, certPool)
|
|
||||||
return &tls.Config{RootCAs: certPool}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWebsocketHeaders(t *testing.T) {
|
|
||||||
req := testRequest(t, "http://example.com", nil)
|
|
||||||
wsHeaders := websocketHeaders(req)
|
|
||||||
for _, header := range stripWebsocketHeaders {
|
|
||||||
assert.Empty(t, wsHeaders[header])
|
|
||||||
}
|
|
||||||
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateAcceptKey(t *testing.T) {
|
func TestGenerateAcceptKey(t *testing.T) {
|
||||||
req := testRequest(t, "http://example.com", nil)
|
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(testSecWebsocketKey))
|
||||||
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServe(t *testing.T) {
|
|
||||||
log := zerolog.Nop()
|
|
||||||
shutdownC := make(chan struct{})
|
|
||||||
errC := make(chan error)
|
|
||||||
listener, err := hello.CreateTLSListener("localhost:1111")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
|
|
||||||
}()
|
|
||||||
|
|
||||||
req := testRequest(t, "https://localhost:1111/ws", nil)
|
|
||||||
|
|
||||||
tlsConfig := websocketClientTLSConfig(t)
|
|
||||||
assert.NotNil(t, tlsConfig)
|
|
||||||
d := gws.Dialer{TLSClientConfig: tlsConfig}
|
|
||||||
conn, resp, err := ClientConnect(req, &d)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
messageSize := rand.Int()%2048 + 1
|
|
||||||
clientMessage := make([]byte, messageSize)
|
|
||||||
// rand.Read always returns len(clientMessage) and a nil error
|
|
||||||
rand.Read(clientMessage)
|
|
||||||
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
messageType, message, err := conn.ReadMessage()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, websocket.BinaryFrame, messageType)
|
|
||||||
assert.Equal(t, clientMessage, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = conn.Close()
|
|
||||||
close(shutdownC)
|
|
||||||
<-errC
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWebsocketWrapper(t *testing.T) {
|
|
||||||
|
|
||||||
listener, err := hello.CreateTLSListener("localhost:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverErrorChan := make(chan error)
|
|
||||||
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
|
|
||||||
defer func() { <-serverErrorChan }()
|
|
||||||
defer cancelHelloSvr()
|
|
||||||
go func() {
|
|
||||||
log := zerolog.Nop()
|
|
||||||
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
|
|
||||||
}()
|
|
||||||
|
|
||||||
tlsConfig := websocketClientTLSConfig(t)
|
|
||||||
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
|
|
||||||
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
|
|
||||||
req := testRequest(t, testAddr, nil)
|
|
||||||
conn, resp, err := ClientConnect(req, &d)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
|
||||||
|
|
||||||
// Websocket now connected to test server so lets check our wrapper
|
|
||||||
wrapper := GorillaConn{Conn: conn}
|
|
||||||
buf := make([]byte, 100)
|
|
||||||
wrapper.Write([]byte("abc"))
|
|
||||||
n, err := wrapper.Read(buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, n, 3)
|
|
||||||
require.Equal(t, "abc", string(buf[:n]))
|
|
||||||
|
|
||||||
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
|
|
||||||
wrapper.Write([]byte("abc"))
|
|
||||||
buf = buf[:1]
|
|
||||||
n, err = wrapper.Read(buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, n, 1)
|
|
||||||
require.Equal(t, "a", string(buf[:n]))
|
|
||||||
buf = buf[:cap(buf)]
|
|
||||||
n, err = wrapper.Read(buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, n, 2)
|
|
||||||
require.Equal(t, "bc", string(buf[:n]))
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue