diff --git a/carrier/websocket.go b/carrier/websocket.go index e05dcb8c..7011ea50 100644 --- a/carrier/websocket.go +++ b/carrier/websocket.go @@ -8,6 +8,7 @@ import ( "net/http/httputil" "github.com/cloudflare/cloudflared/cmd/cloudflared/token" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/socks" cfwebsocket "github.com/cloudflare/cloudflared/websocket" @@ -61,7 +62,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro _ = socksServer.Serve(conn) } else { - cfwebsocket.Stream(wsConn, conn) + ingress.Stream(wsConn, conn) } return nil } @@ -69,7 +70,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro // StartServer creates a Websocket server to listen for connections. // This is used on the origin (tunnel) side to take data from the muxer and send it to the origin func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC <-chan struct{}) error { - return cfwebsocket.StartProxyServer(ws.log, listener, remote, shutdownC, cfwebsocket.DefaultStreamHandler) + return cfwebsocket.StartProxyServer(ws.log, listener, remote, shutdownC, ingress.DefaultStreamHandler) } // createWebsocketStream will create a WebSocket connection to stream data over diff --git a/connection/connection.go b/connection/connection.go index 98e86baa..bc59b1e2 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -50,8 +50,17 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool { return c.Hostname == "" } +// Type indicates the connection type of the connection. +type Type int + +const ( + TypeWebsocket Type = iota + TypeTCP + TypeHTTP +) + type OriginProxy interface { - Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error + Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error } type ResponseWriter interface { diff --git a/connection/connection_test.go b/connection/connection_test.go index 7fe02d17..0fa3b29f 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -41,8 +41,8 @@ type testRequest struct { type mockOriginProxy struct { } -func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error { - if isWebsocket { +func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, sourceConnectionType Type) error { + if sourceConnectionType == TypeWebsocket { return wsEndpoint(w, r) } switch r.URL.Path { diff --git a/connection/h2mux.go b/connection/h2mux.go index 31921847..5fdf59d5 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -216,7 +216,12 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { return reqErr } - err := h.config.OriginProxy.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) + var sourceConnectionType = TypeHTTP + if websocket.IsWebSocketUpgrade(req) { + sourceConnectionType = TypeWebsocket + } + + err := h.config.OriginProxy.Proxy(respWriter, req, sourceConnectionType) if err != nil { respWriter.WriteErrorResponse() return err diff --git a/connection/http2.go b/connection/http2.go index d938dd53..c46a8ead 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -19,6 +19,7 @@ import ( const ( internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade" + tcpStreamHeader = "Cf-Cloudflared-Proxy-Src" websocketUpgrade = "websocket" controlStreamUpgrade = "control-stream" ) @@ -107,21 +108,33 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } respWriter.flusher = flusher - var err error - if isControlStreamUpgrade(r) { + + switch { + case isControlStreamUpgrade(r): respWriter.shouldFlush = true - err = c.serveControlStream(r.Context(), respWriter) - c.controlStreamErr = err - } else if isWebsocketUpgrade(r) { + if err := c.serveControlStream(r.Context(), respWriter); err != nil { + respWriter.WriteErrorResponse() + } + return + + case isWebsocketUpgrade(r): respWriter.shouldFlush = true stripWebsocketUpgradeHeader(r) - err = c.config.OriginProxy.Proxy(respWriter, r, true) - } else { - err = c.config.OriginProxy.Proxy(respWriter, r, false) - } + if err := c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket); err != nil { + respWriter.WriteErrorResponse() + } + return - if err != nil { - respWriter.WriteErrorResponse() + case IsTCPStream(r): + if err := c.config.OriginProxy.Proxy(respWriter, r, TypeTCP); err != nil { + respWriter.WriteErrorResponse() + } + return + + default: + if err := c.config.OriginProxy.Proxy(respWriter, r, TypeHTTP); err != nil { + respWriter.WriteErrorResponse() + } } } @@ -231,11 +244,16 @@ func (rp *http2RespWriter) Close() error { } func isControlStreamUpgrade(r *http.Request) bool { - return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade + return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade } func isWebsocketUpgrade(r *http.Request) bool { - return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade + return r.Header.Get(internalUpgradeHeader) == websocketUpgrade +} + +// IsTCPStream discerns if the connection request needs a tcp stream proxy. +func IsTCPStream(r *http.Request) bool { + return r.Header.Get(tcpStreamHeader) != "" } func stripWebsocketUpgradeHeader(r *http.Request) { diff --git a/ingress/ingress.go b/ingress/ingress.go index e3edd6a2..3e2f8f4c 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -24,6 +24,11 @@ var ( ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules") ) +const ( + ServiceBastion = "bastion" + ServiceTeamnet = "teamnet-proxy" +) + // FindMatchingRule returns the index of the Ingress Rule which matches the given // hostname and path. This function assumes the last rule matches everything, // which is the case if the rules were instantiated via the ingress#Validate method @@ -90,7 +95,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ return new(helloWorld), nil } if c.IsSet(config.BastionFlag) { - return newBridgeService(), nil + return newBridgeService(nil), nil } if c.IsSet("url") { originURL, err := config.ValidateUrl(c, allowURLFromArgs) @@ -159,12 +164,14 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon service = &srv } else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" { service = new(helloWorld) - } else if r.Service == "bastion" || cfg.BastionMode { + } else if r.Service == ServiceBastion || cfg.BastionMode { // Bastion mode will always start a Websocket proxy server, which will // overwrite the localService.URL field when `start` is called. So, // leave the URL field empty for now. cfg.BastionMode = true - service = newBridgeService() + service = newBridgeService(nil) + } else if r.Service == ServiceTeamnet { + service = newBridgeService(DefaultStreamHandler) } else { // Validate URL services u, err := url.Parse(r.Service) diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 64ad3315..6a143b1b 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -315,7 +315,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: newBridgeService(), + Service: newBridgeService(nil), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { @@ -335,7 +335,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: newBridgeService(), + Service: newBridgeService(nil), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 9e2b9671..d49ca945 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -17,10 +17,36 @@ type OriginConnection interface { Close() } +type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn) + +// Stream copies copy data to & from provided io.ReadWriters. +func Stream(conn, backendConn io.ReadWriter) { + proxyDone := make(chan struct{}, 2) + + go func() { + io.Copy(conn, backendConn) + proxyDone <- struct{}{} + }() + + go func() { + io.Copy(backendConn, conn) + proxyDone <- struct{}{} + }() + + // If one side is done, we are done. + <-proxyDone +} + +// DefaultStreamHandler is an implementation of streamHandlerFunc that +// performs a two way io.Copy between originConn and remoteConn. +func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) { + Stream(originConn, remoteConn) +} + // tcpConnection is an OriginConnection that directly streams to raw TCP. type tcpConnection struct { conn net.Conn - streamHandler func(tunnelConn io.ReadWriter, originConn net.Conn) + streamHandler streamHandlerFunc } func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) { @@ -39,7 +65,7 @@ type wsConnection struct { } func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) { - websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn()) + Stream(tunnelConn, wsc.wsConn.UnderlyingConn()) } func (wsc *wsConnection) Close() { diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index ed02ce72..0c617d6a 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -2,13 +2,14 @@ package ingress import ( "fmt" - "io" "net" "net/http" "net/url" "strings" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" + "github.com/pkg/errors" ) // HTTPOriginProxy can be implemented by origin services that want to proxy http requests. @@ -63,7 +64,21 @@ func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, return o.client.connect(r, dest) } +// getRequestHost returns the host of the http.Request. +func getRequestHost(r *http.Request) (string, error) { + if r.Host != "" { + return r.Host, nil + } + if r.URL != nil { + return r.URL.Host, nil + } + return "", errors.New("host not found") +} + func (o *bridgeService) destination(r *http.Request) (string, error) { + if connection.IsTCPStream(r) { + return getRequestHost(r) + } jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader) if jumpDestination == "" { return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") @@ -85,7 +100,7 @@ func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnectio } type tcpClient struct { - streamHandler func(originConn io.ReadWriter, remoteConn net.Conn) + streamHandler streamHandlerFunc } func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) { diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 13b8ff63..16534eac 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -91,7 +91,7 @@ func TestBridgeServiceDestination(t *testing.T) { wantErr: true, }, } - s := newBridgeService() + s := newBridgeService(nil) for _, test := range tests { r := &http.Request{ Header: test.header, diff --git a/ingress/origin_service.go b/ingress/origin_service.go index d8f0a015..ba8892df 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -81,9 +81,12 @@ type bridgeService struct { client *tcpClient } -func newBridgeService() *bridgeService { +// if streamHandler is nil, a default one is set. +func newBridgeService(streamHandler streamHandlerFunc) *bridgeService { return &bridgeService{ - client: &tcpClient{}, + client: &tcpClient{ + streamHandler: streamHandler, + }, } } @@ -92,10 +95,15 @@ func (o *bridgeService) String() string { } func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + // streamHandler is already set by the constructor. + if o.client.streamHandler != nil { + return nil + } + if cfg.ProxyType == socksProxy { o.client.streamHandler = socks.StreamHandler } else { - o.client.streamHandler = websocket.DefaultStreamHandler + o.client.streamHandler = DefaultStreamHandler } return nil } @@ -136,7 +144,7 @@ func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdo if cfg.ProxyType == socksProxy { o.client.streamHandler = socks.StreamHandler } else { - o.client.streamHandler = websocket.DefaultStreamHandler + o.client.streamHandler = DefaultStreamHandler } return nil } diff --git a/origin/proxy.go b/origin/proxy.go index 852355ee..aaafa9e6 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -38,7 +38,7 @@ func NewOriginProxy(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *ze } } -func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error { +func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error { incrementRequests() defer decrementConcurrentRequests() @@ -49,43 +49,50 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocke rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) p.logRequest(req, cfRay, lbProbe, ruleNum) - var ( - resp *http.Response - err error - ) - - if isWebsocket { - go websocket.NewConn(w, p.log).Pinger(req.Context()) - - connClosedChan := make(chan struct{}) - err = p.proxyConnection(connClosedChan, w, req, rule) - if err == nil { - respHeader := websocket.NewResponseHeader(req) - status := http.StatusSwitchingProtocols - resp = &http.Response{ - Status: http.StatusText(status), - StatusCode: status, - Header: respHeader, - ContentLength: -1, - } - - w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader) - <-connClosedChan + if sourceConnectionType == connection.TypeHTTP { + resp, err := p.proxyHTTP(w, req, rule) + if err != nil { + p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) + return err } - } else { - resp, err = p.proxyHTTP(w, req, rule) + + p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) + return nil } + + respHeader := http.Header{} + if sourceConnectionType == connection.TypeWebsocket { + go websocket.NewConn(w, p.log).Pinger(req.Context()) + respHeader = websocket.NewResponseHeader(req) + } + + connClosedChan := make(chan struct{}) + err := p.proxyConnection(connClosedChan, w, req, rule) if err != nil { - p.logRequestError(err, cfRay, ruleNum) - w.WriteErrorResponse() + p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) return err } - p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) + status := http.StatusSwitchingProtocols + resp := &http.Response{ + Status: http.StatusText(status), + StatusCode: status, + Header: respHeader, + ContentLength: -1, + } + w.WriteRespHeaders(http.StatusSwitchingProtocols, nil) + <-connClosedChan + + p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) return nil } +func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) { + p.logRequestError(err, cfRay, ruleNum) + w.WriteErrorResponse() +} + func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate if rule.Config.DisableChunkedEncoding { diff --git a/origin/proxy_test.go b/origin/proxy_test.go index f1590255..fa31d15b 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -143,7 +143,7 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) require.NoError(t, err) - err = proxy.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, connection.TypeHTTP) require.NoError(t, err) assert.Equal(t, http.StatusOK, respWriter.Code) @@ -163,7 +163,7 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test wg.Add(1) go func() { defer wg.Done() - err = proxy.Proxy(respWriter, req, true) + err = proxy.Proxy(respWriter, req, connection.TypeWebsocket) require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) @@ -205,7 +205,7 @@ func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) wg.Add(1) go func() { defer wg.Done() - err = proxy.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, connection.TypeHTTP) require.NoError(t, err) require.Equal(t, http.StatusOK, respWriter.Code) @@ -298,7 +298,7 @@ func TestProxyMultipleOrigins(t *testing.T) { req, err := http.NewRequest(http.MethodGet, test.url, nil) require.NoError(t, err) - err = proxy.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, connection.TypeHTTP) require.NoError(t, err) assert.Equal(t, test.expectedStatus, respWriter.Code) @@ -346,7 +346,7 @@ func TestProxyError(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) assert.NoError(t, err) - err = proxy.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, connection.TypeHTTP) assert.Error(t, err) assert.Equal(t, http.StatusBadGateway, respWriter.Code) assert.Equal(t, "http response error", respWriter.Body.String()) @@ -376,12 +376,10 @@ func TestProxyBastionMode(t *testing.T) { t.Run("testBastionWebsocket", testBastionWebsocket(proxy)) cancel() - } func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { - // WSRoute is a websocket echo handler ctx, cancel := context.WithCancel(context.Background()) readPipe, _ := io.Pipe() respWriter := newMockWSRespWriter(readPipe) @@ -389,14 +387,15 @@ func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) { var wg sync.WaitGroup msgFromConn := []byte("data from websocket proxy") ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) wg.Add(1) go func() { defer wg.Done() defer ln.Close() - server, err := ln.Accept() + conn, err := ln.Accept() require.NoError(t, err) - conn := websocket.NewConn(server, nil) - conn.Write(msgFromConn) + wsConn := websocket.NewConn(conn, nil) + wsConn.Write(msgFromConn) }() req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil) @@ -405,7 +404,7 @@ func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err = proxy.Proxy(respWriter, req, true) + err = proxy.Proxy(respWriter, req, connection.TypeWebsocket) require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) @@ -422,3 +421,92 @@ func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) { wg.Wait() } } + +func TestTCPStream(t *testing.T) { + logger := logger.Create(nil) + + ctx, cancel := context.WithCancel(context.Background()) + + ingressConfig := &config.Configuration{ + Ingress: []config.UnvalidatedIngressRule{ + config.UnvalidatedIngressRule{ + Hostname: "*", + Service: ingress.ServiceTeamnet, + }, + }, + } + ingressRule, err := ingress.ParseIngress(ingressConfig) + require.NoError(t, err) + + var wg sync.WaitGroup + errC := make(chan error) + ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) + + proxy := NewOriginProxy(ingressRule, testTags, logger) + + t.Run("testTCPStream", testTCPStreamProxy(proxy)) + cancel() + wg.Wait() +} + +type mockTCPRespWriter struct { + w io.Writer + code int +} + +func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) { + return len(p), nil +} + +func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { + return m.w.Write(p) +} + +func (m *mockTCPRespWriter) WriteErrorResponse() { +} + +func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { + m.code = status + return nil +} + +func testTCPStreamProxy(proxy connection.OriginProxy) func(t *testing.T) { + return func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + readPipe, writePipe := io.Pipe() + respWriter := &mockTCPRespWriter{ + w: writePipe, + } + msgFromConn := []byte("data from tcp proxy") + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { + defer ln.Close() + conn, err := ln.Accept() + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write(msgFromConn) + require.NoError(t, err) + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil) + require.NoError(t, err) + + req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value") + req.Host = ln.Addr().String() + err = proxy.Proxy(respWriter, req, connection.TypeTCP) + require.NoError(t, err) + + require.Equal(t, http.StatusSwitchingProtocols, respWriter.code) + + returnedMsg := make([]byte, len(msgFromConn)) + + _, err = readPipe.Read(returnedMsg) + + require.NoError(t, err) + require.Equal(t, msgFromConn, returnedMsg) + + cancel() + } +} diff --git a/websocket/websocket.go b/websocket/websocket.go index 9807b6fe..39657117 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -47,30 +47,6 @@ func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Con return conn, response, nil } -// Stream copies copy data to & from provided io.ReadWriters. -func Stream(conn, backendConn io.ReadWriter) { - proxyDone := make(chan struct{}, 2) - - go func() { - _, _ = io.Copy(conn, backendConn) - proxyDone <- struct{}{} - }() - - go func() { - _, _ = io.Copy(backendConn, conn) - proxyDone <- struct{}{} - }() - - // If one side is done, we are done. - <-proxyDone -} - -// DefaultStreamHandler is provided to the the standard websocket to origin stream -// This exist to allow SOCKS to deframe data before it gets to the origin -func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) { - Stream(originConn, remoteConn) -} - // StartProxyServer will start a websocket server that will decode // the websocket data and write the resulting data to the provided func StartProxyServer(