TUN-4626: Proxy non-stream based origin websockets with http Roundtrip.

Reuses HTTPProxy's Roundtrip method to directly proxy websockets from
eyeball clients (determined by websocket type and ingress not being
connection oriented , i.e. Not ssh or smb for example) to proxy
websocket traffic.
This commit is contained in:
Sudarsan Reddy 2021-07-01 10:29:53 +01:00
parent 3eb9efd9f0
commit f1b57526b3
5 changed files with 76 additions and 218 deletions

View File

@ -9,7 +9,6 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
@ -190,71 +189,6 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
}
}
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t, true)
defer origin.Close()
var svc httpService
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
NoTLSVerify: true,
})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
conn, resp, err := svc.newWebsocketProxyConnection(req)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
connClosed := make(chan struct{})
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
select {
case <-connClosed:
case <-ctx.Done():
}
if ctx.Err() == context.DeadlineExceeded {
eyeballConn.Close()
edgeConn.Close()
conn.Close()
}
return ctx.Err()
})
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
fmt.Println("closing pipe")
edgeConn.Close()
return eyeballConn.Close()
})
errGroup.Go(func() error {
defer conn.Close()
conn.Stream(ctx, edgeConn, testLogger)
close(connClosed)
return nil
})
require.NoError(t, errGroup.Wait())
}
type wsEyeball struct {
conn net.Conn
}

View File

@ -2,10 +2,8 @@ package ingress
import (
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/pkg/errors"
@ -36,7 +34,15 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host
req.URL.Scheme = o.url.Scheme
switch o.url.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
default:
req.URL.Scheme = o.url.Scheme
}
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
@ -44,67 +50,6 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req = req.Clone(req.Context())
req.URL.Host = o.url.Host
req.URL.Scheme = o.url.Scheme
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return o.newWebsocketProxyConnection(req)
}
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.ContentLength = 0
req.Body = nil
resp, err := o.transport.RoundTrip(req)
if err != nil {
return nil, nil, err
}
toClose := resp.Body
defer func() {
if toClose != nil {
_ = toClose.Close()
}
}()
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, nil, errUnsupportedConnectionType
}
conn := wsProxyConnection{
rwc: rwc,
}
// clear to prevent defer from closing
toClose = nil
return &conn, resp, nil
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}

View File

@ -2,7 +2,6 @@ package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
@ -32,57 +31,6 @@ func assertEstablishConnectionResponse(t *testing.T,
assert.Equal(t, expectHeader, resp.Header)
}
func TestHTTPServiceEstablishConnection(t *testing.T) {
origin := echoWSOrigin(t, false)
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
httpService := &httpService{
url: originURL,
hostHeader: origin.URL,
transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Test-Cloudflared-Echo", t.Name())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
"Test-Cloudflared-Echo": {t.Name()},
}
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
}
func TestHelloWorldEstablishConnection(t *testing.T) {
var wg sync.WaitGroup
shutdownC := make(chan struct{})
errC := make(chan error)
helloWorldSerivce := &helloWorld{}
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
close(shutdownC)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -218,10 +166,6 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
req = req.Clone(context.Background())
_, resp, err = httpService.EstablishConnection(req)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
}
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {

View File

@ -15,6 +15,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
)
const (
@ -85,27 +86,27 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
}
p.logRequest(req, logFields)
if sourceConnectionType == connection.TypeHTTP {
if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil {
switch originProxy := rule.Service.(type) {
case ingress.HTTPOriginProxy:
if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket,
rule.Config.DisableChunkedEncoding, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
}
return nil
}
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
return fmt.Errorf("Not a connection-oriented service")
}
case ingress.StreamBasedOriginProxy:
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
}
return nil
default:
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
}
return nil
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
@ -116,26 +117,35 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
return fmt.Sprintf("%d", ruleNum), srv
}
func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
func (p *proxy) proxyHTTPRequest(
w connection.ResponseWriter,
req *http.Request,
httpService ingress.HTTPOriginProxy,
isWebsocket bool,
disableChunkedEncoding bool,
fields logFields) error {
roundTripReq := req
if isWebsocket {
roundTripReq = req.Clone(req.Context())
roundTripReq.Header.Set("Connection", "Upgrade")
roundTripReq.Header.Set("Upgrade", "websocket")
roundTripReq.Header.Set("Sec-Websocket-Version", "13")
roundTripReq.ContentLength = 0
roundTripReq.Body = nil
} else {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if disableChunkedEncoding {
roundTripReq.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
roundTripReq.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
roundTripReq.Header.Set("Connection", "keep-alive")
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service)
return fmt.Errorf("Not a http service")
}
resp, err := httpService.RoundTrip(req)
resp, err := httpService.RoundTrip(roundTripReq)
if err != nil {
return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared")
}
@ -145,6 +155,23 @@ func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request,
if err != nil {
return errors.Wrap(err, "Error writing response header")
}
if resp.StatusCode == http.StatusSwitchingProtocols {
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return errors.New("internal error: unsupported connection type")
}
defer rwc.Close()
eyeballStream := &bidirectionalStream{
writer: w,
reader: req.Body,
}
websocket.Stream(eyeballStream, rwc, p.log)
return nil
}
if connection.IsServerSentEvent(resp.Header) {
p.log.Debug().Msg("Detected Server-Side Events from Origin")
p.writeEventStream(w, resp.Body)

View File

@ -571,8 +571,14 @@ func TestConnections(t *testing.T) {
},
},
want: want{
message: []byte{},
err: true,
message: []byte("Forbidden\n"),
err: false,
headers: map[string][]string{
"Content-Length": {"10"},
"Content-Type": {"text/plain; charset=utf-8"},
"Sec-Websocket-Version": {"13"},
"X-Content-Type-Options": {"nosniff"},
},
},
},
{
@ -806,6 +812,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
// respHeaders is a test function to read respHeaders
func (w *wsRespWriter) headers() http.Header {
// Removing indeterminstic header because it cannot be asserted.
w.responseHeaders.Del("Date")
return w.responseHeaders
}