TUN-4626: Proxy non-stream based origin websockets with http Roundtrip.
Reuses HTTPProxy's Roundtrip method to directly proxy websockets from eyeball clients (determined by websocket type and ingress not being connection oriented , i.e. Not ssh or smb for example) to proxy websocket traffic.
This commit is contained in:
parent
3eb9efd9f0
commit
f1b57526b3
|
@ -9,7 +9,6 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -190,71 +189,6 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStreamWSConnection(t *testing.T) {
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
origin := echoWSOrigin(t, true)
|
||||
defer origin.Close()
|
||||
|
||||
var svc httpService
|
||||
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
|
||||
NoTLSVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
|
||||
conn, resp, err := svc.newWebsocketProxyConnection(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
||||
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
||||
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
connClosed := make(chan struct{})
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case <-connClosed:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
eyeballConn.Close()
|
||||
edgeConn.Close()
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
echoWSEyeball(t, eyeballConn)
|
||||
fmt.Println("closing pipe")
|
||||
edgeConn.Close()
|
||||
return eyeballConn.Close()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
defer conn.Close()
|
||||
conn.Stream(ctx, edgeConn, testLogger)
|
||||
close(connClosed)
|
||||
return nil
|
||||
})
|
||||
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
type wsEyeball struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
|
|
@ -2,10 +2,8 @@ package ingress
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
|
@ -36,7 +34,15 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
switch o.url.Scheme {
|
||||
case "ws":
|
||||
req.URL.Scheme = "http"
|
||||
case "wss":
|
||||
req.URL.Scheme = "https"
|
||||
default:
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
}
|
||||
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
|
@ -44,67 +50,6 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
|
||||
switch req.URL.Scheme {
|
||||
case "ws":
|
||||
req.URL.Scheme = "http"
|
||||
case "wss":
|
||||
req.URL.Scheme = "https"
|
||||
}
|
||||
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
|
||||
return o.newWebsocketProxyConnection(req)
|
||||
}
|
||||
|
||||
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
|
||||
req.ContentLength = 0
|
||||
req.Body = nil
|
||||
|
||||
resp, err := o.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
toClose := resp.Body
|
||||
defer func() {
|
||||
if toClose != nil {
|
||||
_ = toClose.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
|
||||
}
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
|
||||
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
|
||||
}
|
||||
|
||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
return nil, nil, errUnsupportedConnectionType
|
||||
}
|
||||
conn := wsProxyConnection{
|
||||
rwc: rwc,
|
||||
}
|
||||
// clear to prevent defer from closing
|
||||
toClose = nil
|
||||
|
||||
return &conn, resp, nil
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package ingress
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -32,57 +31,6 @@ func assertEstablishConnectionResponse(t *testing.T,
|
|||
assert.Equal(t, expectHeader, resp.Header)
|
||||
}
|
||||
|
||||
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
||||
origin := echoWSOrigin(t, false)
|
||||
defer origin.Close()
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpService := &httpService{
|
||||
url: originURL,
|
||||
hostHeader: origin.URL,
|
||||
transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Test-Cloudflared-Echo", t.Name())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
"Test-Cloudflared-Echo": {t.Name()},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
|
||||
}
|
||||
|
||||
func TestHelloWorldEstablishConnection(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
helloWorldSerivce := &helloWorld{}
|
||||
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
|
||||
|
||||
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
||||
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
||||
|
||||
close(shutdownC)
|
||||
}
|
||||
|
||||
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
@ -218,10 +166,6 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
req = req.Clone(context.Background())
|
||||
_, resp, err = httpService.EstablishConnection(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
|
||||
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -85,27 +86,27 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
|||
}
|
||||
p.logRequest(req, logFields)
|
||||
|
||||
if sourceConnectionType == connection.TypeHTTP {
|
||||
if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil {
|
||||
switch originProxy := rule.Service.(type) {
|
||||
case ingress.HTTPOriginProxy:
|
||||
if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket,
|
||||
rule.Config.DisableChunkedEncoding, logFields); err != nil {
|
||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||
p.logRequestError(err, cfRay, rule, srv)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
|
||||
if !ok {
|
||||
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
||||
return fmt.Errorf("Not a connection-oriented service")
|
||||
}
|
||||
case ingress.StreamBasedOriginProxy:
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
|
||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||
p.logRequestError(err, cfRay, rule, srv)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
|
||||
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
|
||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||
p.logRequestError(err, cfRay, rule, srv)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
||||
|
@ -116,26 +117,35 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
|||
return fmt.Sprintf("%d", ruleNum), srv
|
||||
}
|
||||
|
||||
func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if rule.Config.DisableChunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
func (p *proxy) proxyHTTPRequest(
|
||||
w connection.ResponseWriter,
|
||||
req *http.Request,
|
||||
httpService ingress.HTTPOriginProxy,
|
||||
isWebsocket bool,
|
||||
disableChunkedEncoding bool,
|
||||
fields logFields) error {
|
||||
roundTripReq := req
|
||||
if isWebsocket {
|
||||
roundTripReq = req.Clone(req.Context())
|
||||
roundTripReq.Header.Set("Connection", "Upgrade")
|
||||
roundTripReq.Header.Set("Upgrade", "websocket")
|
||||
roundTripReq.Header.Set("Sec-Websocket-Version", "13")
|
||||
roundTripReq.ContentLength = 0
|
||||
roundTripReq.Body = nil
|
||||
} else {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if disableChunkedEncoding {
|
||||
roundTripReq.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
roundTripReq.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
// Request origin to keep connection alive to improve performance
|
||||
roundTripReq.Header.Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
|
||||
if !ok {
|
||||
p.log.Error().Msgf("%s is not a http service", rule.Service)
|
||||
return fmt.Errorf("Not a http service")
|
||||
}
|
||||
|
||||
resp, err := httpService.RoundTrip(req)
|
||||
resp, err := httpService.RoundTrip(roundTripReq)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared")
|
||||
}
|
||||
|
@ -145,6 +155,23 @@ func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request,
|
|||
if err != nil {
|
||||
return errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusSwitchingProtocols {
|
||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
return errors.New("internal error: unsupported connection type")
|
||||
}
|
||||
defer rwc.Close()
|
||||
|
||||
eyeballStream := &bidirectionalStream{
|
||||
writer: w,
|
||||
reader: req.Body,
|
||||
}
|
||||
|
||||
websocket.Stream(eyeballStream, rwc, p.log)
|
||||
return nil
|
||||
}
|
||||
|
||||
if connection.IsServerSentEvent(resp.Header) {
|
||||
p.log.Debug().Msg("Detected Server-Side Events from Origin")
|
||||
p.writeEventStream(w, resp.Body)
|
||||
|
|
|
@ -571,8 +571,14 @@ func TestConnections(t *testing.T) {
|
|||
},
|
||||
},
|
||||
want: want{
|
||||
message: []byte{},
|
||||
err: true,
|
||||
message: []byte("Forbidden\n"),
|
||||
err: false,
|
||||
headers: map[string][]string{
|
||||
"Content-Length": {"10"},
|
||||
"Content-Type": {"text/plain; charset=utf-8"},
|
||||
"Sec-Websocket-Version": {"13"},
|
||||
"X-Content-Type-Options": {"nosniff"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -806,6 +812,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
|||
|
||||
// respHeaders is a test function to read respHeaders
|
||||
func (w *wsRespWriter) headers() http.Header {
|
||||
// Removing indeterminstic header because it cannot be asserted.
|
||||
w.responseHeaders.Del("Date")
|
||||
return w.responseHeaders
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue