From f1b57526b3e9fbc18e0f19df3696e4dc453e8514 Mon Sep 17 00:00:00 2001 From: Sudarsan Reddy Date: Thu, 1 Jul 2021 10:29:53 +0100 Subject: [PATCH] TUN-4626: Proxy non-stream based origin websockets with http Roundtrip. Reuses HTTPProxy's Roundtrip method to directly proxy websockets from eyeball clients (determined by websocket type and ingress not being connection oriented , i.e. Not ssh or smb for example) to proxy websocket traffic. --- ingress/origin_connection_test.go | 66 ----------------------- ingress/origin_proxy.go | 73 ++++---------------------- ingress/origin_proxy_test.go | 56 -------------------- origin/proxy.go | 87 ++++++++++++++++++++----------- origin/proxy_test.go | 12 ++++- 5 files changed, 76 insertions(+), 218 deletions(-) diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go index d3294fca..040662a0 100644 --- a/ingress/origin_connection_test.go +++ b/ingress/origin_connection_test.go @@ -9,7 +9,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "sync" "testing" "time" @@ -190,71 +189,6 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) { } } -func TestStreamWSConnection(t *testing.T) { - eyeballConn, edgeConn := net.Pipe() - - origin := echoWSOrigin(t, true) - defer origin.Close() - - var svc httpService - err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{ - NoTLSVerify: true, - }) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodGet, origin.URL, nil) - require.NoError(t, err) - req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - - conn, resp, err := svc.newWebsocketProxyConnection(req) - - require.NoError(t, err) - defer conn.Close() - - require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) - require.Equal(t, "Upgrade", resp.Header.Get("Connection")) - require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept")) - require.Equal(t, "websocket", resp.Header.Get("Upgrade")) - - ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) - defer cancel() - - connClosed := make(chan struct{}) - - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - select { - case <-connClosed: - case <-ctx.Done(): - } - if ctx.Err() == context.DeadlineExceeded { - eyeballConn.Close() - edgeConn.Close() - conn.Close() - } - - return ctx.Err() - }) - - errGroup.Go(func() error { - echoWSEyeball(t, eyeballConn) - fmt.Println("closing pipe") - edgeConn.Close() - return eyeballConn.Close() - }) - - errGroup.Go(func() error { - defer conn.Close() - conn.Stream(ctx, edgeConn, testLogger) - close(connClosed) - return nil - }) - - require.NoError(t, errGroup.Wait()) -} - type wsEyeball struct { conn net.Conn } diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index e5956c59..6af8e1cd 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -2,10 +2,8 @@ package ingress import ( "fmt" - "io" "net" "net/http" - "strings" "github.com/pkg/errors" @@ -36,7 +34,15 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { // Rewrite the request URL so that it goes to the origin service. req.URL.Host = o.url.Host - req.URL.Scheme = o.url.Scheme + switch o.url.Scheme { + case "ws": + req.URL.Scheme = "http" + case "wss": + req.URL.Scheme = "https" + default: + req.URL.Scheme = o.url.Scheme + } + if o.hostHeader != "" { // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. req.Host = o.hostHeader @@ -44,67 +50,6 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } -func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { - req = req.Clone(req.Context()) - - req.URL.Host = o.url.Host - req.URL.Scheme = o.url.Scheme - // allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail - switch req.URL.Scheme { - case "ws": - req.URL.Scheme = "http" - case "wss": - req.URL.Scheme = "https" - } - - if o.hostHeader != "" { - // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. - req.Host = o.hostHeader - } - - return o.newWebsocketProxyConnection(req) -} - -func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) { - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") - - req.ContentLength = 0 - req.Body = nil - - resp, err := o.transport.RoundTrip(req) - if err != nil { - return nil, nil, err - } - - toClose := resp.Body - defer func() { - if toClose != nil { - _ = toClose.Close() - } - }() - - if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status) - } - if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" { - return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade")) - } - - rwc, ok := resp.Body.(io.ReadWriteCloser) - if !ok { - return nil, nil, errUnsupportedConnectionType - } - conn := wsProxyConnection{ - rwc: rwc, - } - // clear to prevent defer from closing - toClose = nil - - return &conn, resp, nil -} - func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { return o.resp, nil } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 4939409f..1e54c2fe 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -2,7 +2,6 @@ package ingress import ( "context" - "crypto/tls" "fmt" "net" "net/http" @@ -32,57 +31,6 @@ func assertEstablishConnectionResponse(t *testing.T, assert.Equal(t, expectHeader, resp.Header) } -func TestHTTPServiceEstablishConnection(t *testing.T) { - origin := echoWSOrigin(t, false) - defer origin.Close() - originURL, err := url.Parse(origin.URL) - require.NoError(t, err) - - httpService := &httpService{ - url: originURL, - hostHeader: origin.URL, - transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - } - req, err := http.NewRequest(http.MethodGet, origin.URL, nil) - require.NoError(t, err) - req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") - req.Header.Set("Test-Cloudflared-Echo", t.Name()) - - expectHeader := http.Header{ - "Connection": {"Upgrade"}, - "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, - "Upgrade": {"websocket"}, - "Test-Cloudflared-Echo": {t.Name()}, - } - assertEstablishConnectionResponse(t, httpService, req, expectHeader) -} - -func TestHelloWorldEstablishConnection(t *testing.T) { - var wg sync.WaitGroup - shutdownC := make(chan struct{}) - errC := make(chan error) - helloWorldSerivce := &helloWorld{} - helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{}) - - // Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service - req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil) - require.NoError(t, err) - req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") - - expectHeader := http.Header{ - "Connection": {"Upgrade"}, - "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, - "Upgrade": {"websocket"}, - } - assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader) - - close(shutdownC) -} - func TestRawTCPServiceEstablishConnection(t *testing.T) { originListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -218,10 +166,6 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - req = req.Clone(context.Background()) - _, resp, err = httpService.EstablishConnection(req) - require.NoError(t, err) - require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) } func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { diff --git a/origin/proxy.go b/origin/proxy.go index 1a5cfa8b..a10452d3 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -15,6 +15,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" ) const ( @@ -85,27 +86,27 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn } p.logRequest(req, logFields) - if sourceConnectionType == connection.TypeHTTP { - if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil { + switch originProxy := rule.Service.(type) { + case ingress.HTTPOriginProxy: + if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket, + rule.Config.DisableChunkedEncoding, logFields); err != nil { rule, srv := ruleField(p.ingressRules, ruleNum) p.logRequestError(err, cfRay, rule, srv) return err } return nil - } - connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy) - if !ok { - p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) - return fmt.Errorf("Not a connection-oriented service") - } + case ingress.StreamBasedOriginProxy: + if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil { + rule, srv := ruleField(p.ingressRules, ruleNum) + p.logRequestError(err, cfRay, rule, srv) + return err + } + return nil + default: + return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy) - if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil { - rule, srv := ruleField(p.ingressRules, ruleNum) - p.logRequestError(err, cfRay, rule, srv) - return err } - return nil } func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { @@ -116,26 +117,35 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { return fmt.Sprintf("%d", ruleNum), srv } -func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error { - // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if rule.Config.DisableChunkedEncoding { - req.TransferEncoding = []string{"gzip", "deflate"} - cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) - if err == nil { - req.ContentLength = int64(cLength) +func (p *proxy) proxyHTTPRequest( + w connection.ResponseWriter, + req *http.Request, + httpService ingress.HTTPOriginProxy, + isWebsocket bool, + disableChunkedEncoding bool, + fields logFields) error { + roundTripReq := req + if isWebsocket { + roundTripReq = req.Clone(req.Context()) + roundTripReq.Header.Set("Connection", "Upgrade") + roundTripReq.Header.Set("Upgrade", "websocket") + roundTripReq.Header.Set("Sec-Websocket-Version", "13") + roundTripReq.ContentLength = 0 + roundTripReq.Body = nil + } else { + // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate + if disableChunkedEncoding { + roundTripReq.TransferEncoding = []string{"gzip", "deflate"} + cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) + if err == nil { + roundTripReq.ContentLength = int64(cLength) + } } + // Request origin to keep connection alive to improve performance + roundTripReq.Header.Set("Connection", "keep-alive") } - // Request origin to keep connection alive to improve performance - req.Header.Set("Connection", "keep-alive") - - httpService, ok := rule.Service.(ingress.HTTPOriginProxy) - if !ok { - p.log.Error().Msgf("%s is not a http service", rule.Service) - return fmt.Errorf("Not a http service") - } - - resp, err := httpService.RoundTrip(req) + resp, err := httpService.RoundTrip(roundTripReq) if err != nil { return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared") } @@ -145,6 +155,23 @@ func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, if err != nil { return errors.Wrap(err, "Error writing response header") } + + if resp.StatusCode == http.StatusSwitchingProtocols { + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + return errors.New("internal error: unsupported connection type") + } + defer rwc.Close() + + eyeballStream := &bidirectionalStream{ + writer: w, + reader: req.Body, + } + + websocket.Stream(eyeballStream, rwc, p.log) + return nil + } + if connection.IsServerSentEvent(resp.Header) { p.log.Debug().Msg("Detected Server-Side Events from Origin") p.writeEventStream(w, resp.Body) diff --git a/origin/proxy_test.go b/origin/proxy_test.go index b7fecbed..6b785647 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -571,8 +571,14 @@ func TestConnections(t *testing.T) { }, }, want: want{ - message: []byte{}, - err: true, + message: []byte("Forbidden\n"), + err: false, + headers: map[string][]string{ + "Content-Length": {"10"}, + "Content-Type": {"text/plain; charset=utf-8"}, + "Sec-Websocket-Version": {"13"}, + "X-Content-Type-Options": {"nosniff"}, + }, }, }, { @@ -806,6 +812,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { // respHeaders is a test function to read respHeaders func (w *wsRespWriter) headers() http.Header { + // Removing indeterminstic header because it cannot be asserted. + w.responseHeaders.Del("Date") return w.responseHeaders }