TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService

This commit is contained in:
cthuang 2021-02-05 13:01:53 +00:00 committed by Nuno Diegues
parent 5943808746
commit ab4dda5427
10 changed files with 563 additions and 212 deletions

View File

@ -87,6 +87,7 @@ func (t Type) String() string {
} }
type OriginProxy interface { type OriginProxy interface {
// If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter
Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error
} }

View File

@ -25,7 +25,6 @@ var (
) )
const ( const (
ServiceBridge = "bridge service"
ServiceBastion = "bastion" ServiceBastion = "bastion"
ServiceWarpRouting = "warp-routing" ServiceWarpRouting = "warp-routing"
) )
@ -98,8 +97,7 @@ type WarpRoutingService struct {
} }
func NewWarpRoutingService() *WarpRoutingService { func NewWarpRoutingService() *WarpRoutingService {
warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting) return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}}
return &WarpRoutingService{Proxy: warpRoutingService}
} }
// Get a single origin service from the CLI/config. // Get a single origin service from the CLI/config.
@ -108,7 +106,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
return new(helloWorld), nil return new(helloWorld), nil
} }
if c.IsSet(config.BastionFlag) { if c.IsSet(config.BastionFlag) {
return newBridgeService(nil, ServiceBastion), nil return newBastionService(), nil
} }
if c.IsSet("url") { if c.IsSet("url") {
originURL, err := config.ValidateUrl(c, allowURLFromArgs) originURL, err := config.ValidateUrl(c, allowURLFromArgs)
@ -120,7 +118,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
url: originURL, url: originURL,
}, nil }, nil
} }
return newSingleTCPService(originURL), nil return newTCPOverWSService(originURL), nil
} }
if c.IsSet("unix-socket") { if c.IsSet("unix-socket") {
path, err := config.ValidateUnixSocket(c) path, err := config.ValidateUnixSocket(c)
@ -182,7 +180,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
// overwrite the localService.URL field when `start` is called. So, // overwrite the localService.URL field when `start` is called. So,
// leave the URL field empty for now. // leave the URL field empty for now.
cfg.BastionMode = true cfg.BastionMode = true
service = newBridgeService(nil, ServiceBastion) service = newBastionService()
} else { } else {
// Validate URL services // Validate URL services
u, err := url.Parse(r.Service) u, err := url.Parse(r.Service)
@ -200,7 +198,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
if isHTTPService(u) { if isHTTPService(u) {
service = &httpService{url: u} service = &httpService{url: u}
} else { } else {
service = newSingleTCPService(u) service = newTCPOverWSService(u)
} }
} }

View File

@ -238,12 +238,12 @@ ingress:
want: []Rule{ want: []Rule{
{ {
Hostname: "tcp.foo.com", Hostname: "tcp.foo.com",
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")), Service: newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")),
Config: defaultConfig, Config: defaultConfig,
}, },
{ {
Hostname: "tcp2.foo.com", Hostname: "tcp2.foo.com",
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")), Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")),
Config: defaultConfig, Config: defaultConfig,
}, },
{ {
@ -260,7 +260,7 @@ ingress:
`}, `},
want: []Rule{ want: []Rule{
{ {
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")), Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")),
Config: defaultConfig, Config: defaultConfig,
}, },
}, },
@ -273,7 +273,7 @@ ingress:
`}, `},
want: []Rule{ want: []Rule{
{ {
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")), Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")),
Config: defaultConfig, Config: defaultConfig,
}, },
}, },
@ -286,7 +286,7 @@ ingress:
`}, `},
want: []Rule{ want: []Rule{
{ {
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")), Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")),
Config: defaultConfig, Config: defaultConfig,
}, },
}, },
@ -299,7 +299,7 @@ ingress:
`}, `},
want: []Rule{ want: []Rule{
{ {
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")), Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")),
Config: defaultConfig, Config: defaultConfig,
}, },
}, },
@ -316,7 +316,7 @@ ingress:
want: []Rule{ want: []Rule{
{ {
Hostname: "bastion.foo.com", Hostname: "bastion.foo.com",
Service: newBridgeService(nil, ServiceBastion), Service: newBastionService(),
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
}, },
{ {
@ -336,7 +336,7 @@ ingress:
want: []Rule{ want: []Rule{
{ {
Hostname: "bastion.foo.com", Hostname: "bastion.foo.com",
Service: newBridgeService(nil, ServiceBastion), Service: newBastionService(),
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
}, },
{ {

View File

@ -1,11 +1,12 @@
package ingress package ingress
import ( import (
"context"
"crypto/tls"
"io" "io"
"net" "net"
"net/http" "net/http"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
gws "github.com/gorilla/websocket" gws "github.com/gorilla/websocket"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -15,9 +16,8 @@ import (
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters. // Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
type OriginConnection interface { type OriginConnection interface {
// Stream should generally be implemented as a bidirectional io.Copy. // Stream should generally be implemented as a bidirectional io.Copy.
Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger)
Close() Close()
Type() connection.Type
} }
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
@ -54,30 +54,38 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
// tcpConnection is an OriginConnection that directly streams to raw TCP. // tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct { type tcpConnection struct {
conn net.Conn conn net.Conn
streamHandler streamHandlerFunc
} }
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) { func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
tc.streamHandler(tunnelConn, tc.conn, log) Stream(tunnelConn, tc.conn, log)
} }
func (tc *tcpConnection) Close() { func (tc *tcpConnection) Close() {
tc.conn.Close() tc.conn.Close()
} }
func (*tcpConnection) Type() connection.Type { // tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
return connection.TypeTCP type tcpOverWSConnection struct {
conn net.Conn
streamHandler streamHandlerFunc
} }
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets. func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does. wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
}
func (wc *tcpOverWSConnection) Close() {
wc.conn.Close()
}
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
type wsConnection struct { type wsConnection struct {
wsConn *gws.Conn wsConn *gws.Conn
resp *http.Response resp *http.Response
} }
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) { func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log) Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
} }
@ -86,13 +94,9 @@ func (wsc *wsConnection) Close() {
wsc.wsConn.Close() wsc.wsConn.Close()
} }
func (wsc *wsConnection) Type() connection.Type { func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
return connection.TypeWebsocket
}
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
d := &gws.Dialer{ d := &gws.Dialer{
TLSClientConfig: transport.TLSClientConfig, TLSClientConfig: clientTLSConfig,
} }
wsConn, resp, err := websocket.ClientConnect(r, d) wsConn, resp, err := websocket.ClientConnect(r, d)
if err != nil { if err != nil {

View File

@ -0,0 +1,177 @@
package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/cloudflare/cloudflared/logger"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
const (
testStreamTimeout = time.Second * 3
)
var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
)
func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{
conn: cfdConn,
}
eyeballConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
_, err := eyeballConn.Write(testMessage)
readBuffer := make([]byte, len(testResponse))
_, err = eyeballConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, testResponse, readBuffer)
return nil
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
return nil
})
tcpConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func TestStreamWSOverTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpOverWSConn := tcpOverWSConnection{
conn: cfdConn,
streamHandler: DefaultStreamHandler,
}
eyeballConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
return nil
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
return nil
})
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t)
defer origin.Close()
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
clientTLSConfig := &tls.Config{
InsecureSkipVerify: true,
}
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
return nil
})
wsConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func echoWSEyeball(t *testing.T, conn net.Conn) {
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
readMsg, err := wsutil.ReadServerBinary(conn)
require.NoError(t, err)
require.Equal(t, testResponse, readMsg)
require.NoError(t, conn.Close())
}
func echoWSOrigin(t *testing.T) *httptest.Server {
var upgrader = websocket.Upgrader{
ReadBufferSize: 10,
WriteBufferSize: 10,
}
ws := func(w http.ResponseWriter, r *http.Request) {
header := make(http.Header)
for k, vs := range r.Header {
if k == "Test-Cloudflared-Echo" {
header[k] = vs
}
}
conn, err := upgrader.Upgrade(w, r, header)
require.NoError(t, err)
defer conn.Close()
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
return
}
require.Equal(t, testMessage, p)
if err := conn.WriteMessage(messageType, testResponse); err != nil {
return
}
}
}
// NewTLSServer starts the server in another thread
return httptest.NewTLSServer(http.HandlerFunc(ws))
}
func echoTCPOrigin(t *testing.T, conn net.Conn) {
readBuffer := make([]byte, len(testMessage))
_, err := conn.Read(readBuffer)
assert.NoError(t, err)
assert.Equal(t, testMessage, readBuffer)
_, err = conn.Write(testResponse)
assert.NoError(t, err)
}

View File

@ -7,19 +7,22 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
var (
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests. // HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
type HTTPOriginProxy interface { type HTTPOriginProxy interface {
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services // RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper http.RoundTripper
} }
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level. // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface { type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
} }
@ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req) return o.transport.RoundTrip(req)
} }
// TODO: TUN-3636: establish connection to origins over UDS
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
}
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service. // Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host req.URL.Host = o.url.Host
@ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection,
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader req.Host = o.hostHeader
} }
return newWSConnection(o.transport, req) return newWSConnection(o.transport.TLSClientConfig, req)
} }
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
@ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.server.Addr().String() req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss" req.URL.Scheme = "wss"
return newWSConnection(o.transport, req) return newWSConnection(o.transport.TLSClientConfig, req)
} }
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil return o.resp, nil
} }
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := o.destination(r) dest, err := getRequestHost(r)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
conn, err := o.client.connect(r, dest) conn, err := net.Dial("tcp", dest)
return conn, nil, err if err != nil {
return nil, nil, err
}
originConn := &tcpConnection{
conn: conn,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
} }
// getRequestHost returns the host of the http.Request. // getRequestHost returns the host of the http.Request.
@ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) {
return "", errors.New("host not found") return "", errors.New("host not found")
} }
func (o *bridgeService) destination(r *http.Request) (string, error) { func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
if connection.IsTCPStream(r) { var err error
return getRequestHost(r) dest := o.dest
if o.isBastion {
dest, err = o.bastionDest(r)
if err != nil {
return nil, nil, err
}
} }
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
}
originConn := &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
}
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader) jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
if jumpDestination == "" { if jumpDestination == "" {
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
@ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) {
func removePath(dest string) string { func removePath(dest string) string {
return strings.SplitN(dest, "/", 2)[0] return strings.SplitN(dest, "/", 2)[0]
} }
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
conn, err := o.client.connect(r, o.dest)
return conn, nil, err
}
type tcpClient struct {
streamHandler streamHandlerFunc
}
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &tcpConnection{
conn: conn,
streamHandler: c.streamHandler,
}, nil
}

View File

@ -2,6 +2,9 @@ package ingress
import ( import (
"context" "context"
"crypto/tls"
"fmt"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -10,12 +13,168 @@ import (
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestBridgeServiceDestination(t *testing.T) { // TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
// the expected response
func assertEstablishConnectionResponse(t *testing.T,
originProxy StreamBasedOriginProxy,
req *http.Request,
expectHeader http.Header,
) {
_, resp, err := originProxy.EstablishConnection(req)
assert.NoError(t, err)
assert.Equal(t, switchingProtocolText, resp.Status)
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
assert.Equal(t, expectHeader, resp.Header)
}
func TestHTTPServiceEstablishConnection(t *testing.T) {
origin := echoWSOrigin(t)
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
httpService := &httpService{
url: originURL,
hostHeader: origin.URL,
transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Test-Cloudflared-Echo", t.Name())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
"Test-Cloudflared-Echo": {t.Name()},
}
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
}
func TestHelloWorldEstablishConnection(t *testing.T) {
var wg sync.WaitGroup
shutdownC := make(chan struct{})
errC := make(chan error)
helloWorldSerivce := &helloWorld{}
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
require.NoError(t, err)
expectHeader := http.Header{
"Connection": {"Upgrade"},
// Accept key when Sec-Websocket-Key is not specified
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
"Upgrade": {"websocket"},
}
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
close(shutdownC)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
originListener.Close()
<-listenerClosed
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
// Origin not listening for new connection, should return an error
_, resp, err := rawTCPService.EstablishConnection(req)
require.Error(t, err)
require.Nil(t, resp)
}
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
originURL := &url.URL{
Scheme: "tcp",
Host: originListener.Addr().String(),
}
baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
require.NoError(t, err)
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
bastionReq := baseReq.Clone(context.Background())
bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
tests := []struct {
service *tcpOverWSService
req *http.Request
expectErr bool
}{
{
service: newTCPOverWSService(originURL),
req: baseReq,
},
{
service: newBastionService(),
req: bastionReq,
},
{
service: newBastionService(),
req: baseReq,
expectErr: true,
},
}
for _, test := range tests {
if test.expectErr {
_, resp, err := test.service.EstablishConnection(test.req)
assert.Error(t, err)
assert.Nil(t, resp)
} else {
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
}
}
originListener.Close()
<-listenerClosed
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
_, resp, err := service.EstablishConnection(bastionReq)
assert.Error(t, err)
assert.Nil(t, resp)
}
}
func TestBastionDestination(t *testing.T) {
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader) canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
tests := []struct { tests := []struct {
name string name string
@ -98,12 +257,12 @@ func TestBridgeServiceDestination(t *testing.T) {
wantErr: true, wantErr: true,
}, },
} }
s := newBridgeService(nil, ServiceBastion) s := newBastionService()
for _, test := range tests { for _, test := range tests {
r := &http.Request{ r := &http.Request{
Header: test.header, Header: test.header,
} }
dest, err := s.destination(r) dest, err := s.bastionDest(r)
if test.wantErr { if test.wantErr {
assert.Error(t, err, "Test %s expects error", test.name) assert.Error(t, err, "Test %s expects error", test.name)
} else { } else {
@ -139,10 +298,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
url: originURL, url: originURL,
} }
var wg sync.WaitGroup var wg sync.WaitGroup
log := zerolog.Nop()
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
errC := make(chan error) errC := make(chan error)
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg)) require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err) require.NoError(t, err)
@ -156,3 +314,17 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
} }
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
close(closeChan)
return
}
// Close immediately, this test is not about testing read/write on connection
conn.Close()
}
}()
}

View File

@ -78,46 +78,29 @@ func (o *httpService) String() string {
return o.url.String() return o.url.String()
} }
// bridgeService is like a jump host, the destination is specified by the client // rawTCPService dials TCP to the destination specified by the client
type bridgeService struct { // It's used by warp routing
client *tcpClient type rawTCPService struct {
serviceName string name string
} }
// if streamHandler is nil, a default one is set. func (o *rawTCPService) String() string {
func newBridgeService(streamHandler streamHandlerFunc, serviceName string) *bridgeService { return o.name
return &bridgeService{
client: &tcpClient{
streamHandler: streamHandler,
},
serviceName: serviceName,
}
} }
func (o *bridgeService) String() string { func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
return ServiceBridge + ":" + o.serviceName
}
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
// streamHandler is already set by the constructor.
if o.client.streamHandler != nil {
return nil
}
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = DefaultStreamHandler
}
return nil return nil
} }
type singleTCPService struct { // tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as
dest string // cloudflared access commands.
client *tcpClient type tcpOverWSService struct {
dest string
isBastion bool
streamHandler streamHandlerFunc
} }
func newSingleTCPService(url *url.URL) *singleTCPService { func newTCPOverWSService(url *url.URL) *tcpOverWSService {
switch url.Scheme { switch url.Scheme {
case "ssh": case "ssh":
addPortIfMissing(url, 22) addPortIfMissing(url, 22)
@ -128,9 +111,14 @@ func newSingleTCPService(url *url.URL) *singleTCPService {
case "tcp": case "tcp":
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
} }
return &singleTCPService{ return &tcpOverWSService{
dest: url.Host, dest: url.Host,
client: &tcpClient{}, }
}
func newBastionService() *tcpOverWSService {
return &tcpOverWSService{
isBastion: true,
} }
} }
@ -140,15 +128,18 @@ func addPortIfMissing(uri *url.URL, port int) {
} }
} }
func (o *singleTCPService) String() string { func (o *tcpOverWSService) String() string {
if o.isBastion {
return ServiceBastion
}
return o.dest return o.dest
} }
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy { if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler o.streamHandler = socks.StreamHandler
} else { } else {
o.client.streamHandler = DefaultStreamHandler o.streamHandler = DefaultStreamHandler
} }
return nil return nil
} }

View File

@ -13,7 +13,6 @@ import (
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
@ -45,6 +44,7 @@ func NewOriginProxy(
} }
} }
// Caller is responsible for writing any error to ResponseWriter
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error { func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error {
incrementRequests() incrementRequests()
defer decrementConcurrentRequests() defer decrementConcurrentRequests()
@ -62,27 +62,31 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
p.log.Error().Msg(err.Error()) p.log.Error().Msg(err.Error())
return err return err
} }
resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy) logFields := logFields{
if err != nil { cfRay: cfRay,
lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting,
}
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting) p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
w.WriteErrorResponse()
return err return err
} }
p.logOriginResponse(resp, cfRay, lbProbe, ingress.ServiceWarpRouting)
return nil return nil
} }
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
p.logRequest(req, cfRay, lbProbe, ruleNum) logFields := logFields{
cfRay: cfRay,
lbProbe: lbProbe,
rule: ruleNum,
}
p.logRequest(req, logFields)
if sourceConnectionType == connection.TypeHTTP { if sourceConnectionType == connection.TypeHTTP {
resp, err := p.proxyHTTP(w, req, rule) if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil {
if err != nil { p.logRequestError(err, cfRay, ruleNum)
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
return err return err
} }
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
return nil return nil
} }
@ -92,22 +96,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return fmt.Errorf("Not a connection-oriented service") return fmt.Errorf("Not a connection-oriented service")
} }
resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, connectionProxy) if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
if err != nil { p.logRequestError(err, cfRay, ruleNum)
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
return err return err
} }
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
return nil return nil
} }
func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) { func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error {
p.logRequestError(err, cfRay, ruleNum)
w.WriteErrorResponse()
}
func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding { if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"} req.TransferEncoding = []string{"gzip", "deflate"}
@ -123,18 +119,18 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
httpService, ok := rule.Service.(ingress.HTTPOriginProxy) httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok { if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service) p.log.Error().Msgf("%s is not a http service", rule.Service)
return nil, fmt.Errorf("Not a http service") return fmt.Errorf("Not a http service")
} }
resp, err := httpService.RoundTrip(req) resp, err := httpService.RoundTrip(req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin") return errors.Wrap(err, "Error proxying request to origin")
} }
defer resp.Body.Close() defer resp.Body.Close()
err = w.WriteRespHeaders(resp.StatusCode, resp.Header) err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error writing response header") return errors.Wrap(err, "Error writing response header")
} }
if connection.IsServerSentEvent(resp.Header) { if connection.IsServerSentEvent(resp.Header) {
p.log.Debug().Msg("Detected Server-Side Events from Origin") p.log.Debug().Msg("Detected Server-Side Events from Origin")
@ -146,43 +142,30 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
defer p.bufferPool.Put(buf) defer p.bufferPool.Put(buf)
_, _ = io.CopyBuffer(w, resp.Body, buf) _, _ = io.CopyBuffer(w, resp.Body, buf)
} }
return resp, nil p.logOriginResponse(resp, fields)
return nil
} }
func (p *proxy) proxyConnection( // proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between
// eyeball and origin.
func (p *proxy) proxyStreamRequest(
serveCtx context.Context, serveCtx context.Context,
w connection.ResponseWriter, w connection.ResponseWriter,
req *http.Request, req *http.Request,
sourceConnectionType connection.Type, sourceConnectionType connection.Type,
connectionProxy ingress.StreamBasedOriginProxy, connectionProxy ingress.StreamBasedOriginProxy,
) (*http.Response, error) { fields logFields,
originConn, connectionResp, err := connectionProxy.EstablishConnection(req) ) error {
originConn, resp, err := connectionProxy.EstablishConnection(req)
if err != nil { if err != nil {
return nil, err return err
}
if resp.Body != nil {
defer resp.Body.Close()
} }
var eyeballConn io.ReadWriter = w if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
respHeader := http.Header{} return err
if connectionResp != nil {
respHeader = connectionResp.Header
}
if sourceConnectionType == connection.TypeWebsocket {
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
if originConn.Type() != sourceConnectionType {
eyeballConn = wsReadWriter
}
}
status := http.StatusSwitchingProtocols
resp := &http.Response{
Status: http.StatusText(status),
StatusCode: status,
Header: respHeader,
ContentLength: -1,
}
w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
} }
streamCtx, cancel := context.WithCancel(serveCtx) streamCtx, cancel := context.WithCancel(serveCtx)
@ -194,8 +177,9 @@ func (p *proxy) proxyConnection(
originConn.Close() originConn.Close()
}() }()
originConn.Stream(eyeballConn, p.log) originConn.Stream(serveCtx, w, p.log)
return resp, nil p.logOriginResponse(resp, fields)
return nil
} }
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
@ -215,39 +199,45 @@ func (p *proxy) appendTagHeaders(r *http.Request) {
} }
} }
func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, rule interface{}) { type logFields struct {
if cfRay != "" { cfRay string
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) lbProbe bool
} else if lbProbe { rule interface{}
p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) }
func (p *proxy) logRequest(r *http.Request, fields logFields) {
if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
} else if fields.lbProbe {
p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
} else { } else {
p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
} }
p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", fields.cfRay, r.Header)
p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", cfRay, rule) p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", fields.cfRay, fields.rule)
if contentLen := r.ContentLength; contentLen == -1 { if contentLen := r.ContentLength; contentLen == -1 {
p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay) p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", fields.cfRay)
} else { } else {
p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen) p.log.Debug().Msgf("CF-RAY: %s Request content length %d", fields.cfRay, contentLen)
} }
} }
func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, rule interface{}) { func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) {
responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc() responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
if cfRay != "" { if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, rule) p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule)
} else if lbProbe { } else if fields.lbProbe {
p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status) p.log.Debug().Msgf("Response to Load Balancer health check %s", resp.Status)
} else { } else {
p.log.Debug().Msgf("Status: %s served by ingress %v", r.Status, rule) p.log.Debug().Msgf("Status: %s served by ingress %v", resp.Status, fields.rule)
} }
p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", fields.cfRay, resp.Header)
if contentLen := r.ContentLength; contentLen == -1 { if contentLen := resp.ContentLength; contentLen == -1 {
p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay) p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", fields.cfRay)
} else { } else {
p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen) p.log.Debug().Msgf("CF-RAY: %s Response content length %d", fields.cfRay, contentLen)
} }
} }

View File

@ -347,10 +347,7 @@ func TestProxyError(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err) assert.NoError(t, err)
err = proxy.Proxy(respWriter, req, connection.TypeHTTP) assert.Error(t, proxy.Proxy(respWriter, req, connection.TypeHTTP))
assert.Error(t, err)
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
assert.Equal(t, "http response error", respWriter.Body.String())
} }
type replayer struct { type replayer struct {
@ -421,15 +418,17 @@ func TestConnections(t *testing.T) {
originService: runEchoWSService, originService: runEchoWSService,
eyeballService: newWSRespWriter([]byte("test1"), replayer), eyeballService: newWSRespWriter([]byte("test1"), replayer),
connectionType: connection.TypeWebsocket, connectionType: connection.TypeWebsocket,
requestHeaders: map[string][]string{ requestHeaders: http.Header{
"Test-Cloudflared-Echo": []string{"Echo"}, // Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
"Test-Cloudflared-Echo": {"Echo"},
}, },
wantMessage: []byte("echo-test1"), wantMessage: []byte("echo-test1"),
wantHeaders: map[string][]string{ wantHeaders: http.Header{
"Connection": []string{"Upgrade"}, "Connection": {"Upgrade"},
"Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": []string{"websocket"}, "Upgrade": {"websocket"},
"Test-Cloudflared-Echo": []string{"Echo"}, "Test-Cloudflared-Echo": {"Echo"},
}, },
}, },
{ {
@ -441,25 +440,23 @@ func TestConnections(t *testing.T) {
replayer, replayer,
), ),
connectionType: connection.TypeTCP, connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{ requestHeaders: http.Header{
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
}, },
wantMessage: []byte("echo-test2"), wantMessage: []byte("echo-test2"),
wantHeaders: http.Header{},
}, },
{ {
name: "tcp-ws proxy", name: "tcp-ws proxy",
ingressServicePrefix: "ws://", ingressServicePrefix: "ws://",
originService: runEchoWSService, originService: runEchoWSService,
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
requestHeaders: map[string][]string{ requestHeaders: http.Header{
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
}, },
connectionType: connection.TypeTCP, connectionType: connection.TypeTCP,
wantMessage: []byte("echo-test3"), wantMessage: []byte("echo-test3"),
// We expect no headers here because they are sent back via // We expect no headers here because they are sent back via
// the stream. // the stream.
wantHeaders: http.Header{},
}, },
{ {
name: "ws-tcp proxy", name: "ws-tcp proxy",
@ -467,8 +464,16 @@ func TestConnections(t *testing.T) {
originService: runEchoTCPService, originService: runEchoTCPService,
eyeballService: newWSRespWriter([]byte("test4"), replayer), eyeballService: newWSRespWriter([]byte("test4"), replayer),
connectionType: connection.TypeWebsocket, connectionType: connection.TypeWebsocket,
wantMessage: []byte("echo-test4"), requestHeaders: http.Header{
wantHeaders: http.Header{}, // Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
},
wantMessage: []byte("echo-test4"),
wantHeaders: http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
},
}, },
} }
@ -477,19 +482,18 @@ func TestConnections(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
// Starts origin service
test.originService(t, ln) test.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String()) ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String())
var wg sync.WaitGroup var wg sync.WaitGroup
errC := make(chan error) errC := make(chan error)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
require.NoError(t, err) require.NoError(t, err)
reqHeaders := make(http.Header) req.Header = test.requestHeaders
for k, vs := range test.requestHeaders {
reqHeaders[k] = vs
}
req.Header = reqHeaders
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
go func() { go func() {