TUN-4626: Proxy non-stream based origin websockets with http Roundtrip.

Reuses HTTPProxy's Roundtrip method to directly proxy websockets from
eyeball clients (determined by websocket type and ingress not being
connection oriented , i.e. Not ssh or smb for example) to proxy
websocket traffic.
This commit is contained in:
Sudarsan Reddy 2021-07-01 10:29:53 +01:00
parent 3eb9efd9f0
commit f1b57526b3
5 changed files with 76 additions and 218 deletions

View File

@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"sync"
"testing" "testing"
"time" "time"
@ -190,71 +189,6 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
} }
} }
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t, true)
defer origin.Close()
var svc httpService
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
NoTLSVerify: true,
})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
conn, resp, err := svc.newWebsocketProxyConnection(req)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
connClosed := make(chan struct{})
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
select {
case <-connClosed:
case <-ctx.Done():
}
if ctx.Err() == context.DeadlineExceeded {
eyeballConn.Close()
edgeConn.Close()
conn.Close()
}
return ctx.Err()
})
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
fmt.Println("closing pipe")
edgeConn.Close()
return eyeballConn.Close()
})
errGroup.Go(func() error {
defer conn.Close()
conn.Stream(ctx, edgeConn, testLogger)
close(connClosed)
return nil
})
require.NoError(t, errGroup.Wait())
}
type wsEyeball struct { type wsEyeball struct {
conn net.Conn conn net.Conn
} }

View File

@ -2,10 +2,8 @@ package ingress
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -36,7 +34,15 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service. // Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host req.URL.Host = o.url.Host
req.URL.Scheme = o.url.Scheme switch o.url.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
default:
req.URL.Scheme = o.url.Scheme
}
if o.hostHeader != "" { if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader req.Host = o.hostHeader
@ -44,67 +50,6 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req) return o.transport.RoundTrip(req)
} }
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req = req.Clone(req.Context())
req.URL.Host = o.url.Host
req.URL.Scheme = o.url.Scheme
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return o.newWebsocketProxyConnection(req)
}
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.ContentLength = 0
req.Body = nil
resp, err := o.transport.RoundTrip(req)
if err != nil {
return nil, nil, err
}
toClose := resp.Body
defer func() {
if toClose != nil {
_ = toClose.Close()
}
}()
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, nil, errUnsupportedConnectionType
}
conn := wsProxyConnection{
rwc: rwc,
}
// clear to prevent defer from closing
toClose = nil
return &conn, resp, nil
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil return o.resp, nil
} }

View File

@ -2,7 +2,6 @@ package ingress
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -32,57 +31,6 @@ func assertEstablishConnectionResponse(t *testing.T,
assert.Equal(t, expectHeader, resp.Header) assert.Equal(t, expectHeader, resp.Header)
} }
func TestHTTPServiceEstablishConnection(t *testing.T) {
origin := echoWSOrigin(t, false)
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
httpService := &httpService{
url: originURL,
hostHeader: origin.URL,
transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Test-Cloudflared-Echo", t.Name())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
"Test-Cloudflared-Echo": {t.Name()},
}
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
}
func TestHelloWorldEstablishConnection(t *testing.T) {
var wg sync.WaitGroup
shutdownC := make(chan struct{})
errC := make(chan error)
helloWorldSerivce := &helloWorld{}
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
close(shutdownC)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) { func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0") originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
@ -218,10 +166,6 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
req = req.Clone(context.Background())
_, resp, err = httpService.EstablishConnection(req)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
} }
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {

View File

@ -15,6 +15,7 @@ import (
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
) )
const ( const (
@ -85,27 +86,27 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
} }
p.logRequest(req, logFields) p.logRequest(req, logFields)
if sourceConnectionType == connection.TypeHTTP { switch originProxy := rule.Service.(type) {
if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil { case ingress.HTTPOriginProxy:
if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket,
rule.Config.DisableChunkedEncoding, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum) rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv) p.logRequestError(err, cfRay, rule, srv)
return err return err
} }
return nil return nil
}
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy) case ingress.StreamBasedOriginProxy:
if !ok { if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) rule, srv := ruleField(p.ingressRules, ruleNum)
return fmt.Errorf("Not a connection-oriented service") p.logRequestError(err, cfRay, rule, srv)
} return err
}
return nil
default:
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
} }
return nil
} }
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
@ -116,26 +117,35 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
return fmt.Sprintf("%d", ruleNum), srv return fmt.Sprintf("%d", ruleNum), srv
} }
func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error { func (p *proxy) proxyHTTPRequest(
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate w connection.ResponseWriter,
if rule.Config.DisableChunkedEncoding { req *http.Request,
req.TransferEncoding = []string{"gzip", "deflate"} httpService ingress.HTTPOriginProxy,
cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) isWebsocket bool,
if err == nil { disableChunkedEncoding bool,
req.ContentLength = int64(cLength) fields logFields) error {
roundTripReq := req
if isWebsocket {
roundTripReq = req.Clone(req.Context())
roundTripReq.Header.Set("Connection", "Upgrade")
roundTripReq.Header.Set("Upgrade", "websocket")
roundTripReq.Header.Set("Sec-Websocket-Version", "13")
roundTripReq.ContentLength = 0
roundTripReq.Body = nil
} else {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if disableChunkedEncoding {
roundTripReq.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
roundTripReq.ContentLength = int64(cLength)
}
} }
// Request origin to keep connection alive to improve performance
roundTripReq.Header.Set("Connection", "keep-alive")
} }
// Request origin to keep connection alive to improve performance resp, err := httpService.RoundTrip(roundTripReq)
req.Header.Set("Connection", "keep-alive")
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service)
return fmt.Errorf("Not a http service")
}
resp, err := httpService.RoundTrip(req)
if err != nil { if err != nil {
return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared") return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared")
} }
@ -145,6 +155,23 @@ func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request,
if err != nil { if err != nil {
return errors.Wrap(err, "Error writing response header") return errors.Wrap(err, "Error writing response header")
} }
if resp.StatusCode == http.StatusSwitchingProtocols {
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return errors.New("internal error: unsupported connection type")
}
defer rwc.Close()
eyeballStream := &bidirectionalStream{
writer: w,
reader: req.Body,
}
websocket.Stream(eyeballStream, rwc, p.log)
return nil
}
if connection.IsServerSentEvent(resp.Header) { if connection.IsServerSentEvent(resp.Header) {
p.log.Debug().Msg("Detected Server-Side Events from Origin") p.log.Debug().Msg("Detected Server-Side Events from Origin")
p.writeEventStream(w, resp.Body) p.writeEventStream(w, resp.Body)

View File

@ -571,8 +571,14 @@ func TestConnections(t *testing.T) {
}, },
}, },
want: want{ want: want{
message: []byte{}, message: []byte("Forbidden\n"),
err: true, err: false,
headers: map[string][]string{
"Content-Length": {"10"},
"Content-Type": {"text/plain; charset=utf-8"},
"Sec-Websocket-Version": {"13"},
"X-Content-Type-Options": {"nosniff"},
},
}, },
}, },
{ {
@ -806,6 +812,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
// respHeaders is a test function to read respHeaders // respHeaders is a test function to read respHeaders
func (w *wsRespWriter) headers() http.Header { func (w *wsRespWriter) headers() http.Header {
// Removing indeterminstic header because it cannot be asserted.
w.responseHeaders.Del("Date")
return w.responseHeaders return w.responseHeaders
} }