TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService

This commit is contained in:
cthuang 2021-02-05 13:01:53 +00:00 committed by Nuno Diegues
parent 5943808746
commit ab4dda5427
10 changed files with 563 additions and 212 deletions

View File

@ -87,6 +87,7 @@ func (t Type) String() string {
}
type OriginProxy interface {
// If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter
Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error
}

View File

@ -25,7 +25,6 @@ var (
)
const (
ServiceBridge = "bridge service"
ServiceBastion = "bastion"
ServiceWarpRouting = "warp-routing"
)
@ -98,8 +97,7 @@ type WarpRoutingService struct {
}
func NewWarpRoutingService() *WarpRoutingService {
warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting)
return &WarpRoutingService{Proxy: warpRoutingService}
return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}}
}
// Get a single origin service from the CLI/config.
@ -108,7 +106,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
return new(helloWorld), nil
}
if c.IsSet(config.BastionFlag) {
return newBridgeService(nil, ServiceBastion), nil
return newBastionService(), nil
}
if c.IsSet("url") {
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
@ -120,7 +118,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
url: originURL,
}, nil
}
return newSingleTCPService(originURL), nil
return newTCPOverWSService(originURL), nil
}
if c.IsSet("unix-socket") {
path, err := config.ValidateUnixSocket(c)
@ -182,7 +180,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
// overwrite the localService.URL field when `start` is called. So,
// leave the URL field empty for now.
cfg.BastionMode = true
service = newBridgeService(nil, ServiceBastion)
service = newBastionService()
} else {
// Validate URL services
u, err := url.Parse(r.Service)
@ -200,7 +198,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
if isHTTPService(u) {
service = &httpService{url: u}
} else {
service = newSingleTCPService(u)
service = newTCPOverWSService(u)
}
}

View File

@ -238,12 +238,12 @@ ingress:
want: []Rule{
{
Hostname: "tcp.foo.com",
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")),
Service: newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")),
Config: defaultConfig,
},
{
Hostname: "tcp2.foo.com",
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")),
Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")),
Config: defaultConfig,
},
{
@ -260,7 +260,7 @@ ingress:
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")),
Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")),
Config: defaultConfig,
},
},
@ -273,7 +273,7 @@ ingress:
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")),
Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")),
Config: defaultConfig,
},
},
@ -286,7 +286,7 @@ ingress:
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")),
Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")),
Config: defaultConfig,
},
},
@ -299,7 +299,7 @@ ingress:
`},
want: []Rule{
{
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")),
Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")),
Config: defaultConfig,
},
},
@ -316,7 +316,7 @@ ingress:
want: []Rule{
{
Hostname: "bastion.foo.com",
Service: newBridgeService(nil, ServiceBastion),
Service: newBastionService(),
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{
@ -336,7 +336,7 @@ ingress:
want: []Rule{
{
Hostname: "bastion.foo.com",
Service: newBridgeService(nil, ServiceBastion),
Service: newBastionService(),
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{

View File

@ -1,11 +1,12 @@
package ingress
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/websocket"
gws "github.com/gorilla/websocket"
"github.com/rs/zerolog"
@ -15,9 +16,8 @@ import (
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
type OriginConnection interface {
// Stream should generally be implemented as a bidirectional io.Copy.
Stream(tunnelConn io.ReadWriter, log *zerolog.Logger)
Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger)
Close()
Type() connection.Type
}
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
@ -54,30 +54,38 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
// tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct {
conn net.Conn
streamHandler streamHandlerFunc
conn net.Conn
}
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) {
tc.streamHandler(tunnelConn, tc.conn, log)
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
Stream(tunnelConn, tc.conn, log)
}
func (tc *tcpConnection) Close() {
tc.conn.Close()
}
func (*tcpConnection) Type() connection.Type {
return connection.TypeTCP
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
type tcpOverWSConnection struct {
conn net.Conn
streamHandler streamHandlerFunc
}
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
}
func (wc *tcpOverWSConnection) Close() {
wc.conn.Close()
}
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
type wsConnection struct {
wsConn *gws.Conn
resp *http.Response
}
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) {
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
}
@ -86,13 +94,9 @@ func (wsc *wsConnection) Close() {
wsc.wsConn.Close()
}
func (wsc *wsConnection) Type() connection.Type {
return connection.TypeWebsocket
}
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: transport.TLSClientConfig,
TLSClientConfig: clientTLSConfig,
}
wsConn, resp, err := websocket.ClientConnect(r, d)
if err != nil {

View File

@ -0,0 +1,177 @@
package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/cloudflare/cloudflared/logger"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
const (
testStreamTimeout = time.Second * 3
)
var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
)
func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{
conn: cfdConn,
}
eyeballConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
_, err := eyeballConn.Write(testMessage)
readBuffer := make([]byte, len(testResponse))
_, err = eyeballConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, testResponse, readBuffer)
return nil
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
return nil
})
tcpConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func TestStreamWSOverTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpOverWSConn := tcpOverWSConnection{
conn: cfdConn,
streamHandler: DefaultStreamHandler,
}
eyeballConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
return nil
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
return nil
})
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t)
defer origin.Close()
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
clientTLSConfig := &tls.Config{
InsecureSkipVerify: true,
}
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
return nil
})
wsConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func echoWSEyeball(t *testing.T, conn net.Conn) {
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
readMsg, err := wsutil.ReadServerBinary(conn)
require.NoError(t, err)
require.Equal(t, testResponse, readMsg)
require.NoError(t, conn.Close())
}
func echoWSOrigin(t *testing.T) *httptest.Server {
var upgrader = websocket.Upgrader{
ReadBufferSize: 10,
WriteBufferSize: 10,
}
ws := func(w http.ResponseWriter, r *http.Request) {
header := make(http.Header)
for k, vs := range r.Header {
if k == "Test-Cloudflared-Echo" {
header[k] = vs
}
}
conn, err := upgrader.Upgrade(w, r, header)
require.NoError(t, err)
defer conn.Close()
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
return
}
require.Equal(t, testMessage, p)
if err := conn.WriteMessage(messageType, testResponse); err != nil {
return
}
}
}
// NewTLSServer starts the server in another thread
return httptest.NewTLSServer(http.HandlerFunc(ws))
}
func echoTCPOrigin(t *testing.T, conn net.Conn) {
readBuffer := make([]byte, len(testMessage))
_, err := conn.Read(readBuffer)
assert.NoError(t, err)
assert.Equal(t, testMessage, readBuffer)
_, err = conn.Write(testResponse)
assert.NoError(t, err)
}

View File

@ -7,19 +7,22 @@ import (
"net/url"
"strings"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
var (
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
type HTTPOriginProxy interface {
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper
}
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
}
@ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
// TODO: TUN-3636: establish connection to origins over UDS
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
}
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.url.Host
@ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection,
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return newWSConnection(o.transport, req)
return newWSConnection(o.transport.TLSClientConfig, req)
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
@ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss"
return newWSConnection(o.transport, req)
return newWSConnection(o.transport.TLSClientConfig, req)
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := o.destination(r)
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := getRequestHost(r)
if err != nil {
return nil, nil, err
}
conn, err := o.client.connect(r, dest)
return conn, nil, err
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
}
originConn := &tcpConnection{
conn: conn,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
}
// getRequestHost returns the host of the http.Request.
@ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) {
return "", errors.New("host not found")
}
func (o *bridgeService) destination(r *http.Request) (string, error) {
if connection.IsTCPStream(r) {
return getRequestHost(r)
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
var err error
dest := o.dest
if o.isBastion {
dest, err = o.bastionDest(r)
if err != nil {
return nil, nil, err
}
}
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
}
originConn := &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
}
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
if jumpDestination == "" {
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
@ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) {
func removePath(dest string) string {
return strings.SplitN(dest, "/", 2)[0]
}
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
conn, err := o.client.connect(r, o.dest)
return conn, nil, err
}
type tcpClient struct {
streamHandler streamHandlerFunc
}
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &tcpConnection{
conn: conn,
streamHandler: c.streamHandler,
}, nil
}

View File

@ -2,6 +2,9 @@ package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -10,12 +13,168 @@ import (
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBridgeServiceDestination(t *testing.T) {
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
// the expected response
func assertEstablishConnectionResponse(t *testing.T,
originProxy StreamBasedOriginProxy,
req *http.Request,
expectHeader http.Header,
) {
_, resp, err := originProxy.EstablishConnection(req)
assert.NoError(t, err)
assert.Equal(t, switchingProtocolText, resp.Status)
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
assert.Equal(t, expectHeader, resp.Header)
}
func TestHTTPServiceEstablishConnection(t *testing.T) {
origin := echoWSOrigin(t)
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
httpService := &httpService{
url: originURL,
hostHeader: origin.URL,
transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Test-Cloudflared-Echo", t.Name())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
"Test-Cloudflared-Echo": {t.Name()},
}
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
}
func TestHelloWorldEstablishConnection(t *testing.T) {
var wg sync.WaitGroup
shutdownC := make(chan struct{})
errC := make(chan error)
helloWorldSerivce := &helloWorld{}
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
require.NoError(t, err)
expectHeader := http.Header{
"Connection": {"Upgrade"},
// Accept key when Sec-Websocket-Key is not specified
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
"Upgrade": {"websocket"},
}
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
close(shutdownC)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
originListener.Close()
<-listenerClosed
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
// Origin not listening for new connection, should return an error
_, resp, err := rawTCPService.EstablishConnection(req)
require.Error(t, err)
require.Nil(t, resp)
}
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)
originURL := &url.URL{
Scheme: "tcp",
Host: originListener.Addr().String(),
}
baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
require.NoError(t, err)
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
bastionReq := baseReq.Clone(context.Background())
bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
tests := []struct {
service *tcpOverWSService
req *http.Request
expectErr bool
}{
{
service: newTCPOverWSService(originURL),
req: baseReq,
},
{
service: newBastionService(),
req: bastionReq,
},
{
service: newBastionService(),
req: baseReq,
expectErr: true,
},
}
for _, test := range tests {
if test.expectErr {
_, resp, err := test.service.EstablishConnection(test.req)
assert.Error(t, err)
assert.Nil(t, resp)
} else {
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
}
}
originListener.Close()
<-listenerClosed
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
_, resp, err := service.EstablishConnection(bastionReq)
assert.Error(t, err)
assert.Nil(t, resp)
}
}
func TestBastionDestination(t *testing.T) {
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
tests := []struct {
name string
@ -98,12 +257,12 @@ func TestBridgeServiceDestination(t *testing.T) {
wantErr: true,
},
}
s := newBridgeService(nil, ServiceBastion)
s := newBastionService()
for _, test := range tests {
r := &http.Request{
Header: test.header,
}
dest, err := s.destination(r)
dest, err := s.bastionDest(r)
if test.wantErr {
assert.Error(t, err, "Test %s expects error", test.name)
} else {
@ -139,10 +298,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
url: originURL,
}
var wg sync.WaitGroup
log := zerolog.Nop()
shutdownC := make(chan struct{})
errC := make(chan error)
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg))
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
@ -156,3 +314,17 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
}
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
go func() {
for {
conn, err := listener.Accept()
if err != nil {
close(closeChan)
return
}
// Close immediately, this test is not about testing read/write on connection
conn.Close()
}
}()
}

View File

@ -78,46 +78,29 @@ func (o *httpService) String() string {
return o.url.String()
}
// bridgeService is like a jump host, the destination is specified by the client
type bridgeService struct {
client *tcpClient
serviceName string
// rawTCPService dials TCP to the destination specified by the client
// It's used by warp routing
type rawTCPService struct {
name string
}
// if streamHandler is nil, a default one is set.
func newBridgeService(streamHandler streamHandlerFunc, serviceName string) *bridgeService {
return &bridgeService{
client: &tcpClient{
streamHandler: streamHandler,
},
serviceName: serviceName,
}
func (o *rawTCPService) String() string {
return o.name
}
func (o *bridgeService) String() string {
return ServiceBridge + ":" + o.serviceName
}
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
// streamHandler is already set by the constructor.
if o.client.streamHandler != nil {
return nil
}
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = DefaultStreamHandler
}
func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
return nil
}
type singleTCPService struct {
dest string
client *tcpClient
// tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as
// cloudflared access commands.
type tcpOverWSService struct {
dest string
isBastion bool
streamHandler streamHandlerFunc
}
func newSingleTCPService(url *url.URL) *singleTCPService {
func newTCPOverWSService(url *url.URL) *tcpOverWSService {
switch url.Scheme {
case "ssh":
addPortIfMissing(url, 22)
@ -128,9 +111,14 @@ func newSingleTCPService(url *url.URL) *singleTCPService {
case "tcp":
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
}
return &singleTCPService{
dest: url.Host,
client: &tcpClient{},
return &tcpOverWSService{
dest: url.Host,
}
}
func newBastionService() *tcpOverWSService {
return &tcpOverWSService{
isBastion: true,
}
}
@ -140,15 +128,18 @@ func addPortIfMissing(uri *url.URL, port int) {
}
}
func (o *singleTCPService) String() string {
func (o *tcpOverWSService) String() string {
if o.isBastion {
return ServiceBastion
}
return o.dest
}
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
o.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = DefaultStreamHandler
o.streamHandler = DefaultStreamHandler
}
return nil
}

View File

@ -13,7 +13,6 @@ import (
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
@ -45,6 +44,7 @@ func NewOriginProxy(
}
}
// Caller is responsible for writing any error to ResponseWriter
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error {
incrementRequests()
defer decrementConcurrentRequests()
@ -62,27 +62,31 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
p.log.Error().Msg(err.Error())
return err
}
resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy)
if err != nil {
logFields := logFields{
cfRay: cfRay,
lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting,
}
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
w.WriteErrorResponse()
return err
}
p.logOriginResponse(resp, cfRay, lbProbe, ingress.ServiceWarpRouting)
return nil
}
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
p.logRequest(req, cfRay, lbProbe, ruleNum)
logFields := logFields{
cfRay: cfRay,
lbProbe: lbProbe,
rule: ruleNum,
}
p.logRequest(req, logFields)
if sourceConnectionType == connection.TypeHTTP {
resp, err := p.proxyHTTP(w, req, rule)
if err != nil {
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil {
p.logRequestError(err, cfRay, ruleNum)
return err
}
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
return nil
}
@ -92,22 +96,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return fmt.Errorf("Not a connection-oriented service")
}
resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, connectionProxy)
if err != nil {
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
p.logRequestError(err, cfRay, ruleNum)
return err
}
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
return nil
}
func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) {
p.logRequestError(err, cfRay, ruleNum)
w.WriteErrorResponse()
}
func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
@ -123,18 +119,18 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service)
return nil, fmt.Errorf("Not a http service")
return fmt.Errorf("Not a http service")
}
resp, err := httpService.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to origin")
return errors.Wrap(err, "Error proxying request to origin")
}
defer resp.Body.Close()
err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
return errors.Wrap(err, "Error writing response header")
}
if connection.IsServerSentEvent(resp.Header) {
p.log.Debug().Msg("Detected Server-Side Events from Origin")
@ -146,43 +142,30 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
defer p.bufferPool.Put(buf)
_, _ = io.CopyBuffer(w, resp.Body, buf)
}
return resp, nil
p.logOriginResponse(resp, fields)
return nil
}
func (p *proxy) proxyConnection(
// proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between
// eyeball and origin.
func (p *proxy) proxyStreamRequest(
serveCtx context.Context,
w connection.ResponseWriter,
req *http.Request,
sourceConnectionType connection.Type,
connectionProxy ingress.StreamBasedOriginProxy,
) (*http.Response, error) {
originConn, connectionResp, err := connectionProxy.EstablishConnection(req)
fields logFields,
) error {
originConn, resp, err := connectionProxy.EstablishConnection(req)
if err != nil {
return nil, err
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
var eyeballConn io.ReadWriter = w
respHeader := http.Header{}
if connectionResp != nil {
respHeader = connectionResp.Header
}
if sourceConnectionType == connection.TypeWebsocket {
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
if originConn.Type() != sourceConnectionType {
eyeballConn = wsReadWriter
}
}
status := http.StatusSwitchingProtocols
resp := &http.Response{
Status: http.StatusText(status),
StatusCode: status,
Header: respHeader,
ContentLength: -1,
}
w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader)
if err != nil {
return nil, errors.Wrap(err, "Error writing response header")
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
return err
}
streamCtx, cancel := context.WithCancel(serveCtx)
@ -194,8 +177,9 @@ func (p *proxy) proxyConnection(
originConn.Close()
}()
originConn.Stream(eyeballConn, p.log)
return resp, nil
originConn.Stream(serveCtx, w, p.log)
p.logOriginResponse(resp, fields)
return nil
}
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
@ -215,39 +199,45 @@ func (p *proxy) appendTagHeaders(r *http.Request) {
}
}
func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, rule interface{}) {
if cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
} else if lbProbe {
p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
type logFields struct {
cfRay string
lbProbe bool
rule interface{}
}
func (p *proxy) logRequest(r *http.Request, fields logFields) {
if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
} else if fields.lbProbe {
p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
} else {
p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
}
p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", cfRay, rule)
p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", fields.cfRay, r.Header)
p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", fields.cfRay, fields.rule)
if contentLen := r.ContentLength; contentLen == -1 {
p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay)
p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", fields.cfRay)
} else {
p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen)
p.log.Debug().Msgf("CF-RAY: %s Request content length %d", fields.cfRay, contentLen)
}
}
func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, rule interface{}) {
responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc()
if cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, rule)
} else if lbProbe {
p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status)
func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) {
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule)
} else if fields.lbProbe {
p.log.Debug().Msgf("Response to Load Balancer health check %s", resp.Status)
} else {
p.log.Debug().Msgf("Status: %s served by ingress %v", r.Status, rule)
p.log.Debug().Msgf("Status: %s served by ingress %v", resp.Status, fields.rule)
}
p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", fields.cfRay, resp.Header)
if contentLen := r.ContentLength; contentLen == -1 {
p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay)
if contentLen := resp.ContentLength; contentLen == -1 {
p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", fields.cfRay)
} else {
p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen)
p.log.Debug().Msgf("CF-RAY: %s Response content length %d", fields.cfRay, contentLen)
}
}

View File

@ -347,10 +347,7 @@ func TestProxyError(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
assert.Error(t, err)
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
assert.Equal(t, "http response error", respWriter.Body.String())
assert.Error(t, proxy.Proxy(respWriter, req, connection.TypeHTTP))
}
type replayer struct {
@ -421,15 +418,17 @@ func TestConnections(t *testing.T) {
originService: runEchoWSService,
eyeballService: newWSRespWriter([]byte("test1"), replayer),
connectionType: connection.TypeWebsocket,
requestHeaders: map[string][]string{
"Test-Cloudflared-Echo": []string{"Echo"},
requestHeaders: http.Header{
// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
"Test-Cloudflared-Echo": {"Echo"},
},
wantMessage: []byte("echo-test1"),
wantHeaders: map[string][]string{
"Connection": []string{"Upgrade"},
"Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
"Upgrade": []string{"websocket"},
"Test-Cloudflared-Echo": []string{"Echo"},
wantHeaders: http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
"Test-Cloudflared-Echo": {"Echo"},
},
},
{
@ -441,25 +440,23 @@ func TestConnections(t *testing.T) {
replayer,
),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
requestHeaders: http.Header{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
wantMessage: []byte("echo-test2"),
wantHeaders: http.Header{},
},
{
name: "tcp-ws proxy",
ingressServicePrefix: "ws://",
originService: runEchoWSService,
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
requestHeaders: http.Header{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
connectionType: connection.TypeTCP,
wantMessage: []byte("echo-test3"),
// We expect no headers here because they are sent back via
// the stream.
wantHeaders: http.Header{},
},
{
name: "ws-tcp proxy",
@ -467,8 +464,16 @@ func TestConnections(t *testing.T) {
originService: runEchoTCPService,
eyeballService: newWSRespWriter([]byte("test4"), replayer),
connectionType: connection.TypeWebsocket,
wantMessage: []byte("echo-test4"),
wantHeaders: http.Header{},
requestHeaders: http.Header{
// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
},
wantMessage: []byte("echo-test4"),
wantHeaders: http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
},
},
}
@ -477,19 +482,18 @@ func TestConnections(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
// Starts origin service
test.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String())
var wg sync.WaitGroup
errC := make(chan error)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
require.NoError(t, err)
reqHeaders := make(http.Header)
for k, vs := range test.requestHeaders {
reqHeaders[k] = vs
}
req.Header = reqHeaders
req.Header = test.requestHeaders
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
go func() {