diff --git a/connection/connection.go b/connection/connection.go index 145beae2..54ceaa02 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -87,6 +87,7 @@ func (t Type) String() string { } type OriginProxy interface { + // If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error } diff --git a/ingress/ingress.go b/ingress/ingress.go index 396d36fb..c9463a89 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -25,7 +25,6 @@ var ( ) const ( - ServiceBridge = "bridge service" ServiceBastion = "bastion" ServiceWarpRouting = "warp-routing" ) @@ -98,8 +97,7 @@ type WarpRoutingService struct { } func NewWarpRoutingService() *WarpRoutingService { - warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting) - return &WarpRoutingService{Proxy: warpRoutingService} + return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}} } // Get a single origin service from the CLI/config. @@ -108,7 +106,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ return new(helloWorld), nil } if c.IsSet(config.BastionFlag) { - return newBridgeService(nil, ServiceBastion), nil + return newBastionService(), nil } if c.IsSet("url") { originURL, err := config.ValidateUrl(c, allowURLFromArgs) @@ -120,7 +118,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ url: originURL, }, nil } - return newSingleTCPService(originURL), nil + return newTCPOverWSService(originURL), nil } if c.IsSet("unix-socket") { path, err := config.ValidateUnixSocket(c) @@ -182,7 +180,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon // overwrite the localService.URL field when `start` is called. So, // leave the URL field empty for now. cfg.BastionMode = true - service = newBridgeService(nil, ServiceBastion) + service = newBastionService() } else { // Validate URL services u, err := url.Parse(r.Service) @@ -200,7 +198,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon if isHTTPService(u) { service = &httpService{url: u} } else { - service = newSingleTCPService(u) + service = newTCPOverWSService(u) } } diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 911e24bf..bd23acb1 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -238,12 +238,12 @@ ingress: want: []Rule{ { Hostname: "tcp.foo.com", - Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")), + Service: newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")), Config: defaultConfig, }, { Hostname: "tcp2.foo.com", - Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")), + Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")), Config: defaultConfig, }, { @@ -260,7 +260,7 @@ ingress: `}, want: []Rule{ { - Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")), + Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")), Config: defaultConfig, }, }, @@ -273,7 +273,7 @@ ingress: `}, want: []Rule{ { - Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")), + Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")), Config: defaultConfig, }, }, @@ -286,7 +286,7 @@ ingress: `}, want: []Rule{ { - Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")), + Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")), Config: defaultConfig, }, }, @@ -299,7 +299,7 @@ ingress: `}, want: []Rule{ { - Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")), + Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")), Config: defaultConfig, }, }, @@ -316,7 +316,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: newBridgeService(nil, ServiceBastion), + Service: newBastionService(), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { @@ -336,7 +336,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: newBridgeService(nil, ServiceBastion), + Service: newBastionService(), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 4212ae79..44cfdc4d 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -1,11 +1,12 @@ package ingress import ( + "context" + "crypto/tls" "io" "net" "net/http" - "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/websocket" gws "github.com/gorilla/websocket" "github.com/rs/zerolog" @@ -15,9 +16,8 @@ import ( // Different concrete implementations will stream different protocols as long as they are io.ReadWriters. type OriginConnection interface { // Stream should generally be implemented as a bidirectional io.Copy. - Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) + Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) Close() - Type() connection.Type } type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) @@ -54,30 +54,38 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze // tcpConnection is an OriginConnection that directly streams to raw TCP. type tcpConnection struct { - conn net.Conn - streamHandler streamHandlerFunc + conn net.Conn } -func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) { - tc.streamHandler(tunnelConn, tc.conn, log) +func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { + Stream(tunnelConn, tc.conn, log) } func (tc *tcpConnection) Close() { tc.conn.Close() } -func (*tcpConnection) Type() connection.Type { - return connection.TypeTCP +// tcpOverWSConnection is an OriginConnection that streams to TCP over WS. +type tcpOverWSConnection struct { + conn net.Conn + streamHandler streamHandlerFunc } -// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets. -// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does. +func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { + wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log) +} + +func (wc *tcpOverWSConnection) Close() { + wc.conn.Close() +} + +// wsConnection is an OriginConnection that streams WS between eyeball and origin. type wsConnection struct { wsConn *gws.Conn resp *http.Response } -func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) { +func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log) } @@ -86,13 +94,9 @@ func (wsc *wsConnection) Close() { wsc.wsConn.Close() } -func (wsc *wsConnection) Type() connection.Type { - return connection.TypeWebsocket -} - -func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) { +func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) { d := &gws.Dialer{ - TLSClientConfig: transport.TLSClientConfig, + TLSClientConfig: clientTLSConfig, } wsConn, resp, err := websocket.ClientConnect(r, d) if err != nil { diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go new file mode 100644 index 00000000..80bcf562 --- /dev/null +++ b/ingress/origin_connection_test.go @@ -0,0 +1,177 @@ +package ingress + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/cloudflare/cloudflared/logger" + "github.com/gobwas/ws/wsutil" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +const ( + testStreamTimeout = time.Second * 3 +) + +var ( + testLogger = logger.Create(nil) + testMessage = []byte("TestStreamOriginConnection") + testResponse = []byte(fmt.Sprintf("echo-%s", testMessage)) +) + +func TestStreamTCPConnection(t *testing.T) { + cfdConn, originConn := net.Pipe() + tcpConn := tcpConnection{ + conn: cfdConn, + } + + eyeballConn, edgeConn := net.Pipe() + + ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) + defer cancel() + + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + _, err := eyeballConn.Write(testMessage) + + readBuffer := make([]byte, len(testResponse)) + _, err = eyeballConn.Read(readBuffer) + require.NoError(t, err) + + require.Equal(t, testResponse, readBuffer) + + return nil + }) + errGroup.Go(func() error { + echoTCPOrigin(t, originConn) + originConn.Close() + return nil + }) + + tcpConn.Stream(ctx, edgeConn, testLogger) + require.NoError(t, errGroup.Wait()) +} + +func TestStreamWSOverTCPConnection(t *testing.T) { + cfdConn, originConn := net.Pipe() + tcpOverWSConn := tcpOverWSConnection{ + conn: cfdConn, + streamHandler: DefaultStreamHandler, + } + + eyeballConn, edgeConn := net.Pipe() + + ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) + defer cancel() + + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + echoWSEyeball(t, eyeballConn) + return nil + }) + errGroup.Go(func() error { + echoTCPOrigin(t, originConn) + originConn.Close() + return nil + }) + + tcpOverWSConn.Stream(ctx, edgeConn, testLogger) + require.NoError(t, errGroup.Wait()) +} + +func TestStreamWSConnection(t *testing.T) { + eyeballConn, edgeConn := net.Pipe() + + origin := echoWSOrigin(t) + defer origin.Close() + + req, err := http.NewRequest(http.MethodGet, origin.URL, nil) + require.NoError(t, err) + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + + clientTLSConfig := &tls.Config{ + InsecureSkipVerify: true, + } + wsConn, resp, err := newWSConnection(clientTLSConfig, req) + require.NoError(t, err) + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.Equal(t, "Upgrade", resp.Header.Get("Connection")) + require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept")) + require.Equal(t, "websocket", resp.Header.Get("Upgrade")) + + ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) + defer cancel() + + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + echoWSEyeball(t, eyeballConn) + return nil + }) + + wsConn.Stream(ctx, edgeConn, testLogger) + require.NoError(t, errGroup.Wait()) +} + +func echoWSEyeball(t *testing.T, conn net.Conn) { + require.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) + + readMsg, err := wsutil.ReadServerBinary(conn) + require.NoError(t, err) + + require.Equal(t, testResponse, readMsg) + + require.NoError(t, conn.Close()) +} + +func echoWSOrigin(t *testing.T) *httptest.Server { + var upgrader = websocket.Upgrader{ + ReadBufferSize: 10, + WriteBufferSize: 10, + } + + ws := func(w http.ResponseWriter, r *http.Request) { + header := make(http.Header) + for k, vs := range r.Header { + if k == "Test-Cloudflared-Echo" { + header[k] = vs + } + } + conn, err := upgrader.Upgrade(w, r, header) + require.NoError(t, err) + defer conn.Close() + + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + return + } + require.Equal(t, testMessage, p) + if err := conn.WriteMessage(messageType, testResponse); err != nil { + return + } + } + } + + // NewTLSServer starts the server in another thread + return httptest.NewTLSServer(http.HandlerFunc(ws)) +} + +func echoTCPOrigin(t *testing.T, conn net.Conn) { + readBuffer := make([]byte, len(testMessage)) + _, err := conn.Read(readBuffer) + assert.NoError(t, err) + + assert.Equal(t, testMessage, readBuffer) + + _, err = conn.Write(testResponse) + assert.NoError(t, err) +} diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 23d89cf5..98f144e4 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -7,19 +7,22 @@ import ( "net/url" "strings" - "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/websocket" "github.com/pkg/errors" ) +var ( + switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) +) + // HTTPOriginProxy can be implemented by origin services that want to proxy http requests. type HTTPOriginProxy interface { // RoundTrip is how cloudflared proxies eyeball requests to the actual origin services http.RoundTripper } -// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level. +// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. type StreamBasedOriginProxy interface { EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) } @@ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } -// TODO: TUN-3636: establish connection to origins over UDS -func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { - return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") -} - func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { // Rewrite the request URL so that it goes to the origin service. req.URL.Host = o.url.Host @@ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. req.Host = o.hostHeader } - return newWSConnection(o.transport, req) + return newWSConnection(o.transport.TLSClientConfig, req) } func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { @@ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { req.URL.Host = o.server.Addr().String() req.URL.Scheme = "wss" - return newWSConnection(o.transport, req) + return newWSConnection(o.transport.TLSClientConfig, req) } func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { return o.resp, nil } -func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { - dest, err := o.destination(r) +func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { + dest, err := getRequestHost(r) if err != nil { return nil, nil, err } - conn, err := o.client.connect(r, dest) - return conn, nil, err + conn, err := net.Dial("tcp", dest) + if err != nil { + return nil, nil, err + } + + originConn := &tcpConnection{ + conn: conn, + } + resp := &http.Response{ + Status: switchingProtocolText, + StatusCode: http.StatusSwitchingProtocols, + ContentLength: -1, + } + return originConn, resp, nil } // getRequestHost returns the host of the http.Request. @@ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) { return "", errors.New("host not found") } -func (o *bridgeService) destination(r *http.Request) (string, error) { - if connection.IsTCPStream(r) { - return getRequestHost(r) +func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { + var err error + dest := o.dest + if o.isBastion { + dest, err = o.bastionDest(r) + if err != nil { + return nil, nil, err + } } + + conn, err := net.Dial("tcp", dest) + if err != nil { + return nil, nil, err + } + originConn := &tcpOverWSConnection{ + conn: conn, + streamHandler: o.streamHandler, + } + resp := &http.Response{ + Status: switchingProtocolText, + StatusCode: http.StatusSwitchingProtocols, + Header: websocket.NewResponseHeader(r), + ContentLength: -1, + } + return originConn, resp, nil + +} + +func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) { jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader) if jumpDestination == "" { return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") @@ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) { func removePath(dest string) string { return strings.SplitN(dest, "/", 2)[0] } - -func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { - conn, err := o.client.connect(r, o.dest) - return conn, nil, err - -} - -type tcpClient struct { - streamHandler streamHandlerFunc -} - -func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) { - conn, err := net.Dial("tcp", addr) - if err != nil { - return nil, err - } - return &tcpConnection{ - conn: conn, - streamHandler: c.streamHandler, - }, nil -} diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index bf664838..b0202b5f 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -2,6 +2,9 @@ package ingress import ( "context" + "crypto/tls" + "fmt" + "net" "net/http" "net/http/httptest" "net/url" @@ -10,12 +13,168 @@ import ( "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/websocket" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestBridgeServiceDestination(t *testing.T) { +// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns +// the expected response +func assertEstablishConnectionResponse(t *testing.T, + originProxy StreamBasedOriginProxy, + req *http.Request, + expectHeader http.Header, +) { + _, resp, err := originProxy.EstablishConnection(req) + assert.NoError(t, err) + assert.Equal(t, switchingProtocolText, resp.Status) + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + assert.Equal(t, expectHeader, resp.Header) +} + +func TestHTTPServiceEstablishConnection(t *testing.T) { + origin := echoWSOrigin(t) + defer origin.Close() + originURL, err := url.Parse(origin.URL) + require.NoError(t, err) + + httpService := &httpService{ + url: originURL, + hostHeader: origin.URL, + transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + req, err := http.NewRequest(http.MethodGet, origin.URL, nil) + require.NoError(t, err) + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Test-Cloudflared-Echo", t.Name()) + + expectHeader := http.Header{ + "Connection": {"Upgrade"}, + "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, + "Upgrade": {"websocket"}, + "Test-Cloudflared-Echo": {t.Name()}, + } + assertEstablishConnectionResponse(t, httpService, req, expectHeader) +} + +func TestHelloWorldEstablishConnection(t *testing.T) { + var wg sync.WaitGroup + shutdownC := make(chan struct{}) + errC := make(chan error) + helloWorldSerivce := &helloWorld{} + helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{}) + + // Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service + req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil) + require.NoError(t, err) + + expectHeader := http.Header{ + "Connection": {"Upgrade"}, + // Accept key when Sec-Websocket-Key is not specified + "Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, + "Upgrade": {"websocket"}, + } + assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader) + + close(shutdownC) +} + +func TestRawTCPServiceEstablishConnection(t *testing.T) { + originListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + listenerClosed := make(chan struct{}) + tcpListenRoutine(originListener, listenerClosed) + + rawTCPService := &rawTCPService{name: ServiceWarpRouting} + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) + require.NoError(t, err) + + assertEstablishConnectionResponse(t, rawTCPService, req, nil) + + originListener.Close() + <-listenerClosed + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) + require.NoError(t, err) + + // Origin not listening for new connection, should return an error + _, resp, err := rawTCPService.EstablishConnection(req) + require.Error(t, err) + require.Nil(t, resp) +} + +func TestTCPOverWSServiceEstablishConnection(t *testing.T) { + originListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + listenerClosed := make(chan struct{}) + tcpListenRoutine(originListener, listenerClosed) + + originURL := &url.URL{ + Scheme: "tcp", + Host: originListener.Addr().String(), + } + + baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil) + require.NoError(t, err) + baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + + bastionReq := baseReq.Clone(context.Background()) + bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String()) + + expectHeader := http.Header{ + "Connection": {"Upgrade"}, + "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, + "Upgrade": {"websocket"}, + } + + tests := []struct { + service *tcpOverWSService + req *http.Request + expectErr bool + }{ + { + service: newTCPOverWSService(originURL), + req: baseReq, + }, + { + service: newBastionService(), + req: bastionReq, + }, + { + service: newBastionService(), + req: baseReq, + expectErr: true, + }, + } + + for _, test := range tests { + if test.expectErr { + _, resp, err := test.service.EstablishConnection(test.req) + assert.Error(t, err) + assert.Nil(t, resp) + } else { + assertEstablishConnectionResponse(t, test.service, test.req, expectHeader) + } + } + + originListener.Close() + <-listenerClosed + + for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { + // Origin not listening for new connection, should return an error + _, resp, err := service.EstablishConnection(bastionReq) + assert.Error(t, err) + assert.Nil(t, resp) + } +} + +func TestBastionDestination(t *testing.T) { canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader) tests := []struct { name string @@ -98,12 +257,12 @@ func TestBridgeServiceDestination(t *testing.T) { wantErr: true, }, } - s := newBridgeService(nil, ServiceBastion) + s := newBastionService() for _, test := range tests { r := &http.Request{ Header: test.header, } - dest, err := s.destination(r) + dest, err := s.bastionDest(r) if test.wantErr { assert.Error(t, err, "Test %s expects error", test.name) } else { @@ -139,10 +298,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { url: originURL, } var wg sync.WaitGroup - log := zerolog.Nop() shutdownC := make(chan struct{}) errC := make(chan error) - require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg)) + require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg)) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) require.NoError(t, err) @@ -156,3 +314,17 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) } + +func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { + go func() { + for { + conn, err := listener.Accept() + if err != nil { + close(closeChan) + return + } + // Close immediately, this test is not about testing read/write on connection + conn.Close() + } + }() +} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index cc89abfd..8d55e7eb 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -78,46 +78,29 @@ func (o *httpService) String() string { return o.url.String() } -// bridgeService is like a jump host, the destination is specified by the client -type bridgeService struct { - client *tcpClient - serviceName string +// rawTCPService dials TCP to the destination specified by the client +// It's used by warp routing +type rawTCPService struct { + name string } -// if streamHandler is nil, a default one is set. -func newBridgeService(streamHandler streamHandlerFunc, serviceName string) *bridgeService { - return &bridgeService{ - client: &tcpClient{ - streamHandler: streamHandler, - }, - serviceName: serviceName, - } +func (o *rawTCPService) String() string { + return o.name } -func (o *bridgeService) String() string { - return ServiceBridge + ":" + o.serviceName -} - -func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { - // streamHandler is already set by the constructor. - if o.client.streamHandler != nil { - return nil - } - - if cfg.ProxyType == socksProxy { - o.client.streamHandler = socks.StreamHandler - } else { - o.client.streamHandler = DefaultStreamHandler - } +func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { return nil } -type singleTCPService struct { - dest string - client *tcpClient +// tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as +// cloudflared access commands. +type tcpOverWSService struct { + dest string + isBastion bool + streamHandler streamHandlerFunc } -func newSingleTCPService(url *url.URL) *singleTCPService { +func newTCPOverWSService(url *url.URL) *tcpOverWSService { switch url.Scheme { case "ssh": addPortIfMissing(url, 22) @@ -128,9 +111,14 @@ func newSingleTCPService(url *url.URL) *singleTCPService { case "tcp": addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case } - return &singleTCPService{ - dest: url.Host, - client: &tcpClient{}, + return &tcpOverWSService{ + dest: url.Host, + } +} + +func newBastionService() *tcpOverWSService { + return &tcpOverWSService{ + isBastion: true, } } @@ -140,15 +128,18 @@ func addPortIfMissing(uri *url.URL, port int) { } } -func (o *singleTCPService) String() string { +func (o *tcpOverWSService) String() string { + if o.isBastion { + return ServiceBastion + } return o.dest } -func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { if cfg.ProxyType == socksProxy { - o.client.streamHandler = socks.StreamHandler + o.streamHandler = socks.StreamHandler } else { - o.client.streamHandler = DefaultStreamHandler + o.streamHandler = DefaultStreamHandler } return nil } diff --git a/origin/proxy.go b/origin/proxy.go index 1acbff4d..9e0cbf46 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -13,7 +13,6 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/websocket" "github.com/pkg/errors" "github.com/rs/zerolog" ) @@ -45,6 +44,7 @@ func NewOriginProxy( } } +// Caller is responsible for writing any error to ResponseWriter func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error { incrementRequests() defer decrementConcurrentRequests() @@ -62,27 +62,31 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn p.log.Error().Msg(err.Error()) return err } - resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy) - if err != nil { + logFields := logFields{ + cfRay: cfRay, + lbProbe: lbProbe, + rule: ingress.ServiceWarpRouting, + } + if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil { p.logRequestError(err, cfRay, ingress.ServiceWarpRouting) - w.WriteErrorResponse() return err } - p.logOriginResponse(resp, cfRay, lbProbe, ingress.ServiceWarpRouting) return nil } rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) - p.logRequest(req, cfRay, lbProbe, ruleNum) + logFields := logFields{ + cfRay: cfRay, + lbProbe: lbProbe, + rule: ruleNum, + } + p.logRequest(req, logFields) if sourceConnectionType == connection.TypeHTTP { - resp, err := p.proxyHTTP(w, req, rule) - if err != nil { - p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) + if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil { + p.logRequestError(err, cfRay, ruleNum) return err } - - p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) return nil } @@ -92,22 +96,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return fmt.Errorf("Not a connection-oriented service") } - resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, connectionProxy) - if err != nil { - p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) + if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil { + p.logRequestError(err, cfRay, ruleNum) return err } - - p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) return nil } -func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) { - p.logRequestError(err, cfRay, ruleNum) - w.WriteErrorResponse() -} - -func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { +func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error { // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate if rule.Config.DisableChunkedEncoding { req.TransferEncoding = []string{"gzip", "deflate"} @@ -123,18 +119,18 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * httpService, ok := rule.Service.(ingress.HTTPOriginProxy) if !ok { p.log.Error().Msgf("%s is not a http service", rule.Service) - return nil, fmt.Errorf("Not a http service") + return fmt.Errorf("Not a http service") } resp, err := httpService.RoundTrip(req) if err != nil { - return nil, errors.Wrap(err, "Error proxying request to origin") + return errors.Wrap(err, "Error proxying request to origin") } defer resp.Body.Close() err = w.WriteRespHeaders(resp.StatusCode, resp.Header) if err != nil { - return nil, errors.Wrap(err, "Error writing response header") + return errors.Wrap(err, "Error writing response header") } if connection.IsServerSentEvent(resp.Header) { p.log.Debug().Msg("Detected Server-Side Events from Origin") @@ -146,43 +142,30 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * defer p.bufferPool.Put(buf) _, _ = io.CopyBuffer(w, resp.Body, buf) } - return resp, nil + p.logOriginResponse(resp, fields) + return nil } -func (p *proxy) proxyConnection( +// proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between +// eyeball and origin. +func (p *proxy) proxyStreamRequest( serveCtx context.Context, w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type, connectionProxy ingress.StreamBasedOriginProxy, -) (*http.Response, error) { - originConn, connectionResp, err := connectionProxy.EstablishConnection(req) + fields logFields, +) error { + originConn, resp, err := connectionProxy.EstablishConnection(req) if err != nil { - return nil, err + return err + } + if resp.Body != nil { + defer resp.Body.Close() } - var eyeballConn io.ReadWriter = w - respHeader := http.Header{} - if connectionResp != nil { - respHeader = connectionResp.Header - } - if sourceConnectionType == connection.TypeWebsocket { - wsReadWriter := websocket.NewConn(serveCtx, w, p.log) - // If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames - if originConn.Type() != sourceConnectionType { - eyeballConn = wsReadWriter - } - } - status := http.StatusSwitchingProtocols - resp := &http.Response{ - Status: http.StatusText(status), - StatusCode: status, - Header: respHeader, - ContentLength: -1, - } - w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader) - if err != nil { - return nil, errors.Wrap(err, "Error writing response header") + if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { + return err } streamCtx, cancel := context.WithCancel(serveCtx) @@ -194,8 +177,9 @@ func (p *proxy) proxyConnection( originConn.Close() }() - originConn.Stream(eyeballConn, p.log) - return resp, nil + originConn.Stream(serveCtx, w, p.log) + p.logOriginResponse(resp, fields) + return nil } func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { @@ -215,39 +199,45 @@ func (p *proxy) appendTagHeaders(r *http.Request) { } } -func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, rule interface{}) { - if cfRay != "" { - p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) - } else if lbProbe { - p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) +type logFields struct { + cfRay string + lbProbe bool + rule interface{} +} + +func (p *proxy) logRequest(r *http.Request, fields logFields) { + if fields.cfRay != "" { + p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) + } else if fields.lbProbe { + p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) } else { p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) } - p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) - p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", cfRay, rule) + p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", fields.cfRay, r.Header) + p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", fields.cfRay, fields.rule) if contentLen := r.ContentLength; contentLen == -1 { - p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay) + p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", fields.cfRay) } else { - p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen) + p.log.Debug().Msgf("CF-RAY: %s Request content length %d", fields.cfRay, contentLen) } } -func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, rule interface{}) { - responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc() - if cfRay != "" { - p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, rule) - } else if lbProbe { - p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status) +func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) { + responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc() + if fields.cfRay != "" { + p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule) + } else if fields.lbProbe { + p.log.Debug().Msgf("Response to Load Balancer health check %s", resp.Status) } else { - p.log.Debug().Msgf("Status: %s served by ingress %v", r.Status, rule) + p.log.Debug().Msgf("Status: %s served by ingress %v", resp.Status, fields.rule) } - p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) + p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", fields.cfRay, resp.Header) - if contentLen := r.ContentLength; contentLen == -1 { - p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay) + if contentLen := resp.ContentLength; contentLen == -1 { + p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", fields.cfRay) } else { - p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen) + p.log.Debug().Msgf("CF-RAY: %s Response content length %d", fields.cfRay, contentLen) } } diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 04ca922e..d12744d9 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -347,10 +347,7 @@ func TestProxyError(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) assert.NoError(t, err) - err = proxy.Proxy(respWriter, req, connection.TypeHTTP) - assert.Error(t, err) - assert.Equal(t, http.StatusBadGateway, respWriter.Code) - assert.Equal(t, "http response error", respWriter.Body.String()) + assert.Error(t, proxy.Proxy(respWriter, req, connection.TypeHTTP)) } type replayer struct { @@ -421,15 +418,17 @@ func TestConnections(t *testing.T) { originService: runEchoWSService, eyeballService: newWSRespWriter([]byte("test1"), replayer), connectionType: connection.TypeWebsocket, - requestHeaders: map[string][]string{ - "Test-Cloudflared-Echo": []string{"Echo"}, + requestHeaders: http.Header{ + // Example key from https://tools.ietf.org/html/rfc6455#section-1.2 + "Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="}, + "Test-Cloudflared-Echo": {"Echo"}, }, wantMessage: []byte("echo-test1"), - wantHeaders: map[string][]string{ - "Connection": []string{"Upgrade"}, - "Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, - "Upgrade": []string{"websocket"}, - "Test-Cloudflared-Echo": []string{"Echo"}, + wantHeaders: http.Header{ + "Connection": {"Upgrade"}, + "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, + "Upgrade": {"websocket"}, + "Test-Cloudflared-Echo": {"Echo"}, }, }, { @@ -441,25 +440,23 @@ func TestConnections(t *testing.T) { replayer, ), connectionType: connection.TypeTCP, - requestHeaders: map[string][]string{ - "Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, + requestHeaders: http.Header{ + "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, }, wantMessage: []byte("echo-test2"), - wantHeaders: http.Header{}, }, { name: "tcp-ws proxy", ingressServicePrefix: "ws://", originService: runEchoWSService, eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), - requestHeaders: map[string][]string{ - "Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, + requestHeaders: http.Header{ + "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, }, connectionType: connection.TypeTCP, wantMessage: []byte("echo-test3"), // We expect no headers here because they are sent back via // the stream. - wantHeaders: http.Header{}, }, { name: "ws-tcp proxy", @@ -467,8 +464,16 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballService: newWSRespWriter([]byte("test4"), replayer), connectionType: connection.TypeWebsocket, - wantMessage: []byte("echo-test4"), - wantHeaders: http.Header{}, + requestHeaders: http.Header{ + // Example key from https://tools.ietf.org/html/rfc6455#section-1.2 + "Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="}, + }, + wantMessage: []byte("echo-test4"), + wantHeaders: http.Header{ + "Connection": {"Upgrade"}, + "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, + "Upgrade": {"websocket"}, + }, }, } @@ -477,19 +482,18 @@ func TestConnections(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + // Starts origin service test.originService(t, ln) + ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String()) var wg sync.WaitGroup errC := make(chan error) ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) + req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) require.NoError(t, err) - reqHeaders := make(http.Header) - for k, vs := range test.requestHeaders { - reqHeaders[k] = vs - } - req.Header = reqHeaders + req.Header = test.requestHeaders if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { go func() {