diff --git a/CHANGES.md b/CHANGES.md index 45c8117d..6e63848d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,24 @@ **Experimental**: This is a new format for release notes. The format and availability is subject to change. +## UNRELEASED + +### Backward Incompatible Changes + +- none + +### New Features + +- none + +### Improvements + +- none + +### Bug Fixes + +- Fixed proxying of websocket requests to avoid possibility of losing initial frames that were sent in the same TCP + packet as response headers [#345](https://github.com/cloudflare/cloudflared/issues/345). + ## 2021.3.6 ### Bug Fixes diff --git a/carrier/websocket.go b/carrier/websocket.go index 6ae4252b..0bf7cea9 100644 --- a/carrier/websocket.go +++ b/carrier/websocket.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "net/http/httputil" + "net/url" "github.com/gorilla/websocket" "github.com/rs/zerolog" @@ -60,7 +61,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso TLSClientConfig: options.TLSClientConfig, Proxy: http.ProxyFromEnvironment, } - wsConn, resp, err := cfwebsocket.ClientConnect(req, dialer) + wsConn, resp, err := clientConnect(req, dialer) defer closeRespBody(resp) if err != nil && IsAccessResponse(resp) { @@ -87,6 +88,63 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso return &cfwebsocket.GorillaConn{Conn: wsConn}, nil } +var stripWebsocketHeaders = []string{ + "Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Extensions", +} + +// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key, +// Sec-WebSocket-Version and Sec-Websocket-Extensions headers. +// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194. +func websocketHeaders(req *http.Request) http.Header { + wsHeaders := make(http.Header) + for key, val := range req.Header { + wsHeaders[key] = val + } + // Assume the header keys are in canonical format. + for _, header := range stripWebsocketHeaders { + wsHeaders.Del(header) + } + wsHeaders.Set("Host", req.Host) // See TUN-1097 + return wsHeaders +} + +// clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing +// the connection. The response body may not contain the entire response and does +// not need to be closed by the application. +func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) { + req.URL.Scheme = changeRequestScheme(req.URL) + wsHeaders := websocketHeaders(req) + if dialler == nil { + dialler = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + } + } + conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) + if err != nil { + return nil, response, err + } + return conn, response, nil +} + +// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme. +// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.) +func changeRequestScheme(reqURL *url.URL) string { + switch reqURL.Scheme { + case "https": + return "wss" + case "http": + return "ws" + case "": + return "ws" + default: + return reqURL.Scheme + } +} + // createAccessAuthenticatedStream will try load a token from storage and make // a connection with the token set on the request. If it still get redirect, // this probably means the token in storage is invalid (expired/revoked). If that @@ -126,7 +184,7 @@ func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*w dump, err := httputil.DumpRequest(req, false) log.Debug().Msgf("Access Websocket request: %s", string(dump)) - conn, resp, err := cfwebsocket.ClientConnect(req, nil) + conn, resp, err := clientConnect(req, nil) if resp != nil { r, err := httputil.DumpResponse(resp, true) diff --git a/carrier/websocket_test.go b/carrier/websocket_test.go new file mode 100644 index 00000000..a0a32bb8 --- /dev/null +++ b/carrier/websocket_test.go @@ -0,0 +1,123 @@ +package carrier + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "math/rand" + "testing" + "time" + + gws "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/websocket" + + "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/tlsconfig" + cfwebsocket "github.com/cloudflare/cloudflared/websocket" +) + +func websocketClientTLSConfig(t *testing.T) *tls.Config { + certPool := x509.NewCertPool() + helloCert, err := tlsconfig.GetHelloCertificateX509() + assert.NoError(t, err) + certPool.AddCert(helloCert) + assert.NotNil(t, certPool) + return &tls.Config{RootCAs: certPool} +} + +func TestWebsocketHeaders(t *testing.T) { + req := testRequest(t, "http://example.com", nil) + wsHeaders := websocketHeaders(req) + for _, header := range stripWebsocketHeaders { + assert.Empty(t, wsHeaders[header]) + } + assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent")) +} + +func TestServe(t *testing.T) { + log := zerolog.Nop() + shutdownC := make(chan struct{}) + errC := make(chan error) + listener, err := hello.CreateTLSListener("localhost:1111") + assert.NoError(t, err) + defer listener.Close() + + go func() { + errC <- hello.StartHelloWorldServer(&log, listener, shutdownC) + }() + + req := testRequest(t, "https://localhost:1111/ws", nil) + + tlsConfig := websocketClientTLSConfig(t) + assert.NotNil(t, tlsConfig) + d := gws.Dialer{TLSClientConfig: tlsConfig} + conn, resp, err := clientConnect(req, &d) + assert.NoError(t, err) + assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) + + for i := 0; i < 1000; i++ { + messageSize := rand.Int()%2048 + 1 + clientMessage := make([]byte, messageSize) + // rand.Read always returns len(clientMessage) and a nil error + rand.Read(clientMessage) + err = conn.WriteMessage(websocket.BinaryFrame, clientMessage) + assert.NoError(t, err) + + messageType, message, err := conn.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, websocket.BinaryFrame, messageType) + assert.Equal(t, clientMessage, message) + } + + _ = conn.Close() + close(shutdownC) + <-errC +} + +func TestWebsocketWrapper(t *testing.T) { + listener, err := hello.CreateTLSListener("localhost:0") + require.NoError(t, err) + + serverErrorChan := make(chan error) + helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background()) + defer func() { <-serverErrorChan }() + defer cancelHelloSvr() + go func() { + log := zerolog.Nop() + serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done()) + }() + + tlsConfig := websocketClientTLSConfig(t) + d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute} + testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String()) + req := testRequest(t, testAddr, nil) + conn, resp, err := clientConnect(req, &d) + require.NoError(t, err) + assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) + + // Websocket now connected to test server so lets check our wrapper + wrapper := cfwebsocket.GorillaConn{Conn: conn} + buf := make([]byte, 100) + wrapper.Write([]byte("abc")) + n, err := wrapper.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 3) + require.Equal(t, "abc", string(buf[:n])) + + // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read + wrapper.Write([]byte("abc")) + buf = buf[:1] + n, err = wrapper.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 1) + require.Equal(t, "a", string(buf[:n])) + buf = buf[:cap(buf)] + n, err = wrapper.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 2) + require.Equal(t, "bc", string(buf[:n])) +} diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 3ed17186..c97d42fd 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -2,12 +2,9 @@ package ingress import ( "context" - "crypto/tls" "io" "net" - "net/http" - gws "github.com/gorilla/websocket" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/ipaccess" @@ -58,35 +55,6 @@ func (wc *tcpOverWSConnection) Close() { wc.conn.Close() } -// wsConnection is an OriginConnection that streams WS between eyeball and origin. -type wsConnection struct { - wsConn *gws.Conn - resp *http.Response -} - -func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log) -} - -func (wsc *wsConnection) Close() { - wsc.resp.Body.Close() - wsc.wsConn.Close() -} - -func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) { - d := &gws.Dialer{ - TLSClientConfig: clientTLSConfig, - } - wsConn, resp, err := websocket.ClientConnect(r, d) - if err != nil { - return nil, nil, err - } - return &wsConnection{ - wsConn, - resp, - }, resp, nil -} - // socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS. // The connection to the origin happens inside the SOCKS code as the client specifies the origin // details in the packet. @@ -100,3 +68,16 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. func (sp *socksProxyOverWSConnection) Close() { } + +// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin +type wsProxyConnection struct { + rwc io.ReadWriteCloser +} + +func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { + websocket.Stream(tunnelConn, conn.rwc, log) +} + +func (conn *wsProxyConnection) Close() { + conn.rwc.Close() +} diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go index 78a2a151..d3294fca 100644 --- a/ingress/origin_connection_test.go +++ b/ingress/origin_connection_test.go @@ -3,13 +3,13 @@ package ingress import ( "bytes" "context" - "crypto/tls" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" + "sync" "testing" "time" @@ -193,18 +193,26 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) { func TestStreamWSConnection(t *testing.T) { eyeballConn, edgeConn := net.Pipe() - origin := echoWSOrigin(t) + origin := echoWSOrigin(t, true) defer origin.Close() + var svc httpService + err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{ + NoTLSVerify: true, + }) + require.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, origin.URL, nil) require.NoError(t, err) req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + conn, resp, err := svc.newWebsocketProxyConnection(req) - clientTLSConfig := &tls.Config{ - InsecureSkipVerify: true, - } - wsConn, resp, err := newWSConnection(clientTLSConfig, req) require.NoError(t, err) + defer conn.Close() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) require.Equal(t, "Upgrade", resp.Header.Get("Connection")) require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept")) @@ -213,13 +221,37 @@ func TestStreamWSConnection(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) defer cancel() + connClosed := make(chan struct{}) + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + select { + case <-connClosed: + case <-ctx.Done(): + } + if ctx.Err() == context.DeadlineExceeded { + eyeballConn.Close() + edgeConn.Close() + conn.Close() + } + + return ctx.Err() + }) + errGroup.Go(func() error { echoWSEyeball(t, eyeballConn) + fmt.Println("closing pipe") + edgeConn.Close() + return eyeballConn.Close() + }) + + errGroup.Go(func() error { + defer conn.Close() + conn.Stream(ctx, edgeConn, testLogger) + close(connClosed) return nil }) - wsConn.Stream(ctx, edgeConn, testLogger) require.NoError(t, errGroup.Wait()) } @@ -241,17 +273,23 @@ func (wse *wsEyeball) Write(p []byte) (int, error) { } func echoWSEyeball(t *testing.T, conn net.Conn) { - require.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) + defer func() { + assert.NoError(t, conn.Close()) + }() + + if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) { + return + } readMsg, err := wsutil.ReadServerBinary(conn) - require.NoError(t, err) + if !assert.NoError(t, err) { + return + } - require.Equal(t, testResponse, readMsg) - - require.NoError(t, conn.Close()) + assert.Equal(t, testResponse, readMsg) } -func echoWSOrigin(t *testing.T) *httptest.Server { +func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server { var upgrader = gorillaWS.Upgrader{ ReadBufferSize: 10, WriteBufferSize: 10, @@ -268,12 +306,17 @@ func echoWSOrigin(t *testing.T) *httptest.Server { require.NoError(t, err) defer conn.Close() + sawMessage := false for { messageType, p, err := conn.ReadMessage() if err != nil { + if expectMessages && !sawMessage { + t.Errorf("unexpected error: %v", err) + } return } - require.Equal(t, testMessage, p) + assert.Equal(t, testMessage, p) + sawMessage = true if err := conn.WriteMessage(messageType, testResponse); err != nil { return } diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 08ef4bd9..e5956c59 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -2,8 +2,10 @@ package ingress import ( "fmt" + "io" "net" "net/http" + "strings" "github.com/pkg/errors" @@ -12,7 +14,8 @@ import ( ) var ( - switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) + switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) + errUnsupportedConnectionType = errors.New("internal error: unsupported connection type") ) // HTTPOriginProxy can be implemented by origin services that want to proxy http requests. @@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { } func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { + req = req.Clone(req.Context()) + req.URL.Host = o.url.Host - req.URL.Scheme = websocket.ChangeRequestScheme(o.url) + req.URL.Scheme = o.url.Scheme + // allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail + switch req.URL.Scheme { + case "ws": + req.URL.Scheme = "http" + case "wss": + req.URL.Scheme = "https" + } + if o.hostHeader != "" { // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. req.Host = o.hostHeader } - return newWSConnection(o.transport.TLSClientConfig, req) + + return o.newWebsocketProxyConnection(req) } -func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { - // Rewrite the request URL so that it goes to the Hello World server. - req.URL.Host = o.server.Addr().String() - req.URL.Scheme = "https" - return o.transport.RoundTrip(req) -} +func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) { + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") -func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { - req.URL.Host = o.server.Addr().String() - req.URL.Scheme = "wss" - return newWSConnection(o.transport.TLSClientConfig, req) + req.ContentLength = 0 + req.Body = nil + + resp, err := o.transport.RoundTrip(req) + if err != nil { + return nil, nil, err + } + + toClose := resp.Body + defer func() { + if toClose != nil { + _ = toClose.Close() + } + }() + + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status) + } + if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" { + return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade")) + } + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + return nil, nil, errUnsupportedConnectionType + } + conn := wsProxyConnection{ + rwc: rwc, + } + // clear to prevent defer from closing + toClose = nil + + return &conn, resp, nil } func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 22266383..4939409f 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -33,7 +33,7 @@ func assertEstablishConnectionResponse(t *testing.T, } func TestHTTPServiceEstablishConnection(t *testing.T) { - origin := echoWSOrigin(t) + origin := echoWSOrigin(t, false) defer origin.Close() originURL, err := url.Parse(origin.URL) require.NoError(t, err) @@ -71,11 +71,11 @@ func TestHelloWorldEstablishConnection(t *testing.T) { // Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil) require.NoError(t, err) + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") expectHeader := http.Header{ - "Connection": {"Upgrade"}, - // Accept key when Sec-Websocket-Key is not specified - "Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, + "Connection": {"Upgrade"}, + "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, "Upgrade": {"websocket"}, } assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader) diff --git a/ingress/origin_service.go b/ingress/origin_service.go index eeff7ebc..a915ec72 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -11,7 +11,6 @@ import ( "sync" "time" - gws "github.com/gorilla/websocket" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -19,7 +18,6 @@ import ( "github.com/cloudflare/cloudflared/ipaccess" "github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/tlsconfig" - "github.com/cloudflare/cloudflared/websocket" ) // originService is something a tunnel can proxy traffic to. @@ -50,16 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown return nil } -func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { - d := &gws.Dialer{ - NetDial: o.transport.Dial, - NetDialContext: o.transport.DialContext, - TLSClientConfig: o.transport.TLSClientConfig, - } - reqURL.Scheme = websocket.ChangeRequestScheme(reqURL) - return d.Dial(reqURL.String(), headers) -} - type httpService struct { url *url.URL hostHeader string @@ -171,8 +159,8 @@ func (o *socksProxyOverWSService) String() string { // HelloWorld is an OriginService for the built-in Hello World server. // Users only use this for testing and experimenting with cloudflared. type helloWorld struct { - server net.Listener - transport *http.Transport + httpService + server net.Listener } func (o *helloWorld) String() string { @@ -187,11 +175,10 @@ func (o *helloWorld) start( errC chan error, cfg OriginRequestConfig, ) error { - transport, err := newHTTPTransport(o, cfg, log) - if err != nil { + if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil { return err } - o.transport = transport + helloListener, err := hello.CreateTLSListener("127.0.0.1:") if err != nil { return errors.Wrap(err, "Cannot start Hello World Server") @@ -202,6 +189,12 @@ func (o *helloWorld) start( _ = hello.StartHelloWorldServer(log, helloListener, shutdownC) }() o.server = helloListener + + o.httpService.url = &url.URL{ + Scheme: "https", + Host: o.server.Addr().String(), + } + return nil } diff --git a/origin/proxy.go b/origin/proxy.go index 0e114ab0..87f77b3a 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -67,7 +67,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn lbProbe: lbProbe, rule: ingress.ServiceWarpRouting, } - if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil { + if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil { p.logRequestError(err, cfRay, ingress.ServiceWarpRouting) return err } @@ -96,7 +96,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return fmt.Errorf("Not a connection-oriented service") } - if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil { + if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil { p.logRequestError(err, cfRay, ruleNum) return err } @@ -152,7 +152,6 @@ func (p *proxy) proxyStreamRequest( serveCtx context.Context, w connection.ResponseWriter, req *http.Request, - sourceConnectionType connection.Type, connectionProxy ingress.StreamBasedOriginProxy, fields logFields, ) error { diff --git a/origin/proxy_test.go b/origin/proxy_test.go index dc4bdbe4..1af9add4 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/urfave/cli/v2" + "golang.org/x/sync/errgroup" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" @@ -79,6 +80,11 @@ func (w *mockWSRespWriter) respBody() io.ReadWriter { return bytes.NewBuffer(data) } +func (w *mockWSRespWriter) Close() error { + close(w.writeNotification) + return nil +} + func (w *mockWSRespWriter) Read(data []byte) (int, error) { return w.reader.Read(data) } @@ -125,14 +131,14 @@ func TestProxySingleOrigin(t *testing.T) { require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log) - t.Run("testProxyHTTP", testProxyHTTP(t, proxy)) - t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy)) - t.Run("testProxySSE", testProxySSE(t, proxy)) + t.Run("testProxyHTTP", testProxyHTTP(proxy)) + t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) + t.Run("testProxySSE", testProxySSE(proxy)) cancel() wg.Wait() } -func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { +func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) @@ -145,23 +151,43 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T } } -func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { +func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { // WSRoute is a websocket echo handler - ctx, cancel := context.WithCancel(context.Background()) + const testTimeout = 5 * time.Second * 1000 + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() readPipe, writePipe := io.Pipe() req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe) - responseWriter := newMockWSRespWriter(readPipe) + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + responseWriter := newMockWSRespWriter(nil) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() + finished := make(chan struct{}) + + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket) require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code) - }() + return nil + }) + + errGroup.Go(func() error { + select { + case <-finished: + case <-ctx.Done(): + } + if ctx.Err() == context.DeadlineExceeded { + t.Errorf("Test timed out") + readPipe.Close() + writePipe.Close() + responseWriter.Close() + } + return nil + }) msg := []byte("test websocket") err = wsutil.WriteClientText(writePipe, msg) @@ -179,12 +205,16 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test require.NoError(t, err) require.Equal(t, msg, returnedMsg) - cancel() - wg.Wait() + _ = readPipe.Close() + _ = writePipe.Close() + _ = responseWriter.Close() + + close(finished) + errGroup.Wait() } } -func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { +func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { var ( pushCount = 50 diff --git a/websocket/websocket.go b/websocket/websocket.go index 6a619b70..b94b4f54 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -3,116 +3,47 @@ package websocket import ( "crypto/sha1" "encoding/base64" + "encoding/hex" + "errors" + "fmt" "io" "net/http" - "net/url" + "time" "github.com/gorilla/websocket" "github.com/rs/zerolog" ) -var stripWebsocketHeaders = []string{ - "Upgrade", - "Connection", - "Sec-Websocket-Key", - "Sec-Websocket-Version", - "Sec-Websocket-Extensions", -} - // IsWebSocketUpgrade checks to see if the request is a WebSocket connection. func IsWebSocketUpgrade(req *http.Request) bool { return websocket.IsWebSocketUpgrade(req) } -// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing -// the connection. The response body may not contain the entire response and does -// not need to be closed by the application. -func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) { - req.URL.Scheme = ChangeRequestScheme(req.URL) - wsHeaders := websocketHeaders(req) - if dialler == nil { - dialler = &websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - } - } - conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) - if err != nil { - return nil, response, err - } - response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req)) - return conn, response, nil -} - // NewResponseHeader returns headers needed to return to origin for completing handshake func NewResponseHeader(req *http.Request) http.Header { header := http.Header{} header.Add("Connection", "Upgrade") - header.Add("Sec-Websocket-Accept", generateAcceptKey(req)) + header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key"))) header.Add("Upgrade", "websocket") return header } -// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key, -// Sec-WebSocket-Version and Sec-Websocket-Extensions headers. -// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194. -func websocketHeaders(req *http.Request) http.Header { - wsHeaders := make(http.Header) - for key, val := range req.Header { - wsHeaders[key] = val - } - // Assume the header keys are in canonical format. - for _, header := range stripWebsocketHeaders { - wsHeaders.Del(header) - } - wsHeaders.Set("Host", req.Host) // See TUN-1097 - return wsHeaders -} - -// sha1Base64 sha1 and then base64 encodes str. -func sha1Base64(str string) string { - hasher := sha1.New() - _, _ = io.WriteString(hasher, str) - hash := hasher.Sum(nil) - return base64.StdEncoding.EncodeToString(hash) -} - -// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header. -// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail. -func generateAcceptKey(req *http.Request) string { - return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -} - -// ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme. -// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.) -func ChangeRequestScheme(reqURL *url.URL) string { - switch reqURL.Scheme { - case "https": - return "wss" - case "http": - return "ws" - case "": - return "ws" - default: - return reqURL.Scheme - } -} - // Stream copies copy data to & from provided io.ReadWriters. -func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) { +func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) { proxyDone := make(chan struct{}, 2) go func() { - _, err := io.Copy(conn, backendConn) + _, err := copyData(tunnelConn, originConn, "origin->tunnel") if err != nil { - log.Debug().Msgf("conn to backendConn copy: %v", err) + log.Debug().Msgf("origin to tunnel copy: %v", err) } proxyDone <- struct{}{} }() go func() { - _, err := io.Copy(backendConn, conn) + _, err := copyData(originConn, tunnelConn, "tunnel->origin") if err != nil { - log.Debug().Msgf("backendConn to conn copy: %v", err) + log.Debug().Msgf("tunnel to origin copy: %v", err) } proxyDone <- struct{}{} }() @@ -120,3 +51,60 @@ func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) { // If one side is done, we are done. <-proxyDone } + +// when set to true, enables logging of content copied to/from origin and tunnel +const debugCopy = false + +func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) { + if debugCopy { + // copyBuffer is based on stdio Copy implementation but shows copied data + copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) { + var buf []byte + size := 32 * 1024 + buf = make([]byte, size) + for { + t := time.Now() + nr, er := src.Read(buf) + if nr > 0 { + fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr])) + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write") + } + } + written += int64(nw) + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err + } + return copyBuffer(dst, src, dir) + } else { + return io.Copy(dst, src) + } +} + +// from RFC-6455 +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func generateAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index e738a106..5d661a88 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -1,24 +1,9 @@ package websocket import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "io" - "math/rand" - "net/http" "testing" - "time" - gws "github.com/gorilla/websocket" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/websocket" - - "github.com/cloudflare/cloudflared/hello" - "github.com/cloudflare/cloudflared/tlsconfig" ) const ( @@ -28,126 +13,6 @@ const ( testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" ) -func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { - req, err := http.NewRequest("GET", url, stream) - if err != nil { - t.Fatalf("testRequestHeader error") - } - - req.Header.Add("Connection", "Upgrade") - req.Header.Add("Upgrade", "WebSocket") - req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) - req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") - req.Header.Add("Sec-Websocket-Version", "13") - req.Header.Add("User-Agent", "curl/7.59.0") - - return req -} - -func websocketClientTLSConfig(t *testing.T) *tls.Config { - certPool := x509.NewCertPool() - helloCert, err := tlsconfig.GetHelloCertificateX509() - assert.NoError(t, err) - certPool.AddCert(helloCert) - assert.NotNil(t, certPool) - return &tls.Config{RootCAs: certPool} -} - -func TestWebsocketHeaders(t *testing.T) { - req := testRequest(t, "http://example.com", nil) - wsHeaders := websocketHeaders(req) - for _, header := range stripWebsocketHeaders { - assert.Empty(t, wsHeaders[header]) - } - assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent")) -} - func TestGenerateAcceptKey(t *testing.T) { - req := testRequest(t, "http://example.com", nil) - assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req)) -} - -func TestServe(t *testing.T) { - log := zerolog.Nop() - shutdownC := make(chan struct{}) - errC := make(chan error) - listener, err := hello.CreateTLSListener("localhost:1111") - assert.NoError(t, err) - defer listener.Close() - - go func() { - errC <- hello.StartHelloWorldServer(&log, listener, shutdownC) - }() - - req := testRequest(t, "https://localhost:1111/ws", nil) - - tlsConfig := websocketClientTLSConfig(t) - assert.NotNil(t, tlsConfig) - d := gws.Dialer{TLSClientConfig: tlsConfig} - conn, resp, err := ClientConnect(req, &d) - assert.NoError(t, err) - assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept")) - - for i := 0; i < 1000; i++ { - messageSize := rand.Int()%2048 + 1 - clientMessage := make([]byte, messageSize) - // rand.Read always returns len(clientMessage) and a nil error - rand.Read(clientMessage) - err = conn.WriteMessage(websocket.BinaryFrame, clientMessage) - assert.NoError(t, err) - - messageType, message, err := conn.ReadMessage() - assert.NoError(t, err) - assert.Equal(t, websocket.BinaryFrame, messageType) - assert.Equal(t, clientMessage, message) - } - - _ = conn.Close() - close(shutdownC) - <-errC -} - -func TestWebsocketWrapper(t *testing.T) { - - listener, err := hello.CreateTLSListener("localhost:0") - require.NoError(t, err) - - serverErrorChan := make(chan error) - helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background()) - defer func() { <-serverErrorChan }() - defer cancelHelloSvr() - go func() { - log := zerolog.Nop() - serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done()) - }() - - tlsConfig := websocketClientTLSConfig(t) - d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute} - testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String()) - req := testRequest(t, testAddr, nil) - conn, resp, err := ClientConnect(req, &d) - require.NoError(t, err) - require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept")) - - // Websocket now connected to test server so lets check our wrapper - wrapper := GorillaConn{Conn: conn} - buf := make([]byte, 100) - wrapper.Write([]byte("abc")) - n, err := wrapper.Read(buf) - require.NoError(t, err) - require.Equal(t, n, 3) - require.Equal(t, "abc", string(buf[:n])) - - // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read - wrapper.Write([]byte("abc")) - buf = buf[:1] - n, err = wrapper.Read(buf) - require.NoError(t, err) - require.Equal(t, n, 1) - require.Equal(t, "a", string(buf[:n])) - buf = buf[:cap(buf)] - n, err = wrapper.Read(buf) - require.NoError(t, err) - require.Equal(t, n, 2) - require.Equal(t, "bc", string(buf[:n])) + assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(testSecWebsocketKey)) }