TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService
This commit is contained in:
		
							parent
							
								
									5943808746
								
							
						
					
					
						commit
						ab4dda5427
					
				|  | @ -87,6 +87,7 @@ func (t Type) String() string { | |||
| } | ||||
| 
 | ||||
| type OriginProxy interface { | ||||
| 	// If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter
 | ||||
| 	Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -25,7 +25,6 @@ var ( | |||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	ServiceBridge      = "bridge service" | ||||
| 	ServiceBastion     = "bastion" | ||||
| 	ServiceWarpRouting = "warp-routing" | ||||
| ) | ||||
|  | @ -98,8 +97,7 @@ type WarpRoutingService struct { | |||
| } | ||||
| 
 | ||||
| func NewWarpRoutingService() *WarpRoutingService { | ||||
| 	warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting) | ||||
| 	return &WarpRoutingService{Proxy: warpRoutingService} | ||||
| 	return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}} | ||||
| } | ||||
| 
 | ||||
| // Get a single origin service from the CLI/config.
 | ||||
|  | @ -108,7 +106,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ | |||
| 		return new(helloWorld), nil | ||||
| 	} | ||||
| 	if c.IsSet(config.BastionFlag) { | ||||
| 		return newBridgeService(nil, ServiceBastion), nil | ||||
| 		return newBastionService(), nil | ||||
| 	} | ||||
| 	if c.IsSet("url") { | ||||
| 		originURL, err := config.ValidateUrl(c, allowURLFromArgs) | ||||
|  | @ -120,7 +118,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ | |||
| 				url: originURL, | ||||
| 			}, nil | ||||
| 		} | ||||
| 		return newSingleTCPService(originURL), nil | ||||
| 		return newTCPOverWSService(originURL), nil | ||||
| 	} | ||||
| 	if c.IsSet("unix-socket") { | ||||
| 		path, err := config.ValidateUnixSocket(c) | ||||
|  | @ -182,7 +180,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon | |||
| 			// overwrite the localService.URL field when `start` is called. So,
 | ||||
| 			// leave the URL field empty for now.
 | ||||
| 			cfg.BastionMode = true | ||||
| 			service = newBridgeService(nil, ServiceBastion) | ||||
| 			service = newBastionService() | ||||
| 		} else { | ||||
| 			// Validate URL services
 | ||||
| 			u, err := url.Parse(r.Service) | ||||
|  | @ -200,7 +198,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon | |||
| 			if isHTTPService(u) { | ||||
| 				service = &httpService{url: u} | ||||
| 			} else { | ||||
| 				service = newSingleTCPService(u) | ||||
| 				service = newTCPOverWSService(u) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
|  |  | |||
|  | @ -238,12 +238,12 @@ ingress: | |||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Hostname: "tcp.foo.com", | ||||
| 					Service:  newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")), | ||||
| 					Service:  newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")), | ||||
| 					Config:   defaultConfig, | ||||
| 				}, | ||||
| 				{ | ||||
| 					Hostname: "tcp2.foo.com", | ||||
| 					Service:  newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")), | ||||
| 					Service:  newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")), | ||||
| 					Config:   defaultConfig, | ||||
| 				}, | ||||
| 				{ | ||||
|  | @ -260,7 +260,7 @@ ingress: | |||
| `}, | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")), | ||||
| 					Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")), | ||||
| 					Config:  defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
|  | @ -273,7 +273,7 @@ ingress: | |||
| `}, | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")), | ||||
| 					Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")), | ||||
| 					Config:  defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
|  | @ -286,7 +286,7 @@ ingress: | |||
| `}, | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")), | ||||
| 					Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")), | ||||
| 					Config:  defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
|  | @ -299,7 +299,7 @@ ingress: | |||
| `}, | ||||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")), | ||||
| 					Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")), | ||||
| 					Config:  defaultConfig, | ||||
| 				}, | ||||
| 			}, | ||||
|  | @ -316,7 +316,7 @@ ingress: | |||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Hostname: "bastion.foo.com", | ||||
| 					Service:  newBridgeService(nil, ServiceBastion), | ||||
| 					Service:  newBastionService(), | ||||
| 					Config:   setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), | ||||
| 				}, | ||||
| 				{ | ||||
|  | @ -336,7 +336,7 @@ ingress: | |||
| 			want: []Rule{ | ||||
| 				{ | ||||
| 					Hostname: "bastion.foo.com", | ||||
| 					Service:  newBridgeService(nil, ServiceBastion), | ||||
| 					Service:  newBastionService(), | ||||
| 					Config:   setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), | ||||
| 				}, | ||||
| 				{ | ||||
|  |  | |||
|  | @ -1,11 +1,12 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/connection" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 	gws "github.com/gorilla/websocket" | ||||
| 	"github.com/rs/zerolog" | ||||
|  | @ -15,9 +16,8 @@ import ( | |||
| // Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
 | ||||
| type OriginConnection interface { | ||||
| 	// Stream should generally be implemented as a bidirectional io.Copy.
 | ||||
| 	Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) | ||||
| 	Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) | ||||
| 	Close() | ||||
| 	Type() connection.Type | ||||
| } | ||||
| 
 | ||||
| type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) | ||||
|  | @ -54,30 +54,38 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze | |||
| 
 | ||||
| // tcpConnection is an OriginConnection that directly streams to raw TCP.
 | ||||
| type tcpConnection struct { | ||||
| 	conn          net.Conn | ||||
| 	streamHandler streamHandlerFunc | ||||
| 	conn net.Conn | ||||
| } | ||||
| 
 | ||||
| func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) { | ||||
| 	tc.streamHandler(tunnelConn, tc.conn, log) | ||||
| func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { | ||||
| 	Stream(tunnelConn, tc.conn, log) | ||||
| } | ||||
| 
 | ||||
| func (tc *tcpConnection) Close() { | ||||
| 	tc.conn.Close() | ||||
| } | ||||
| 
 | ||||
| func (*tcpConnection) Type() connection.Type { | ||||
| 	return connection.TypeTCP | ||||
| // tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
 | ||||
| type tcpOverWSConnection struct { | ||||
| 	conn          net.Conn | ||||
| 	streamHandler streamHandlerFunc | ||||
| } | ||||
| 
 | ||||
| // wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
 | ||||
| // TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
 | ||||
| func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { | ||||
| 	wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log) | ||||
| } | ||||
| 
 | ||||
| func (wc *tcpOverWSConnection) Close() { | ||||
| 	wc.conn.Close() | ||||
| } | ||||
| 
 | ||||
| // wsConnection is an OriginConnection that streams WS between eyeball and origin.
 | ||||
| type wsConnection struct { | ||||
| 	wsConn *gws.Conn | ||||
| 	resp   *http.Response | ||||
| } | ||||
| 
 | ||||
| func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) { | ||||
| func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { | ||||
| 	Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log) | ||||
| } | ||||
| 
 | ||||
|  | @ -86,13 +94,9 @@ func (wsc *wsConnection) Close() { | |||
| 	wsc.wsConn.Close() | ||||
| } | ||||
| 
 | ||||
| func (wsc *wsConnection) Type() connection.Type { | ||||
| 	return connection.TypeWebsocket | ||||
| } | ||||
| 
 | ||||
| func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	d := &gws.Dialer{ | ||||
| 		TLSClientConfig: transport.TLSClientConfig, | ||||
| 		TLSClientConfig: clientTLSConfig, | ||||
| 	} | ||||
| 	wsConn, resp, err := websocket.ClientConnect(r, d) | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -0,0 +1,177 @@ | |||
| package ingress | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/logger" | ||||
| 	"github.com/gobwas/ws/wsutil" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	testStreamTimeout = time.Second * 3 | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	testLogger   = logger.Create(nil) | ||||
| 	testMessage  = []byte("TestStreamOriginConnection") | ||||
| 	testResponse = []byte(fmt.Sprintf("echo-%s", testMessage)) | ||||
| ) | ||||
| 
 | ||||
| func TestStreamTCPConnection(t *testing.T) { | ||||
| 	cfdConn, originConn := net.Pipe() | ||||
| 	tcpConn := tcpConnection{ | ||||
| 		conn: cfdConn, | ||||
| 	} | ||||
| 
 | ||||
| 	eyeballConn, edgeConn := net.Pipe() | ||||
| 
 | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	errGroup, ctx := errgroup.WithContext(ctx) | ||||
| 	errGroup.Go(func() error { | ||||
| 		_, err := eyeballConn.Write(testMessage) | ||||
| 
 | ||||
| 		readBuffer := make([]byte, len(testResponse)) | ||||
| 		_, err = eyeballConn.Read(readBuffer) | ||||
| 		require.NoError(t, err) | ||||
| 
 | ||||
| 		require.Equal(t, testResponse, readBuffer) | ||||
| 
 | ||||
| 		return nil | ||||
| 	}) | ||||
| 	errGroup.Go(func() error { | ||||
| 		echoTCPOrigin(t, originConn) | ||||
| 		originConn.Close() | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	tcpConn.Stream(ctx, edgeConn, testLogger) | ||||
| 	require.NoError(t, errGroup.Wait()) | ||||
| } | ||||
| 
 | ||||
| func TestStreamWSOverTCPConnection(t *testing.T) { | ||||
| 	cfdConn, originConn := net.Pipe() | ||||
| 	tcpOverWSConn := tcpOverWSConnection{ | ||||
| 		conn:          cfdConn, | ||||
| 		streamHandler: DefaultStreamHandler, | ||||
| 	} | ||||
| 
 | ||||
| 	eyeballConn, edgeConn := net.Pipe() | ||||
| 
 | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	errGroup, ctx := errgroup.WithContext(ctx) | ||||
| 	errGroup.Go(func() error { | ||||
| 		echoWSEyeball(t, eyeballConn) | ||||
| 		return nil | ||||
| 	}) | ||||
| 	errGroup.Go(func() error { | ||||
| 		echoTCPOrigin(t, originConn) | ||||
| 		originConn.Close() | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	tcpOverWSConn.Stream(ctx, edgeConn, testLogger) | ||||
| 	require.NoError(t, errGroup.Wait()) | ||||
| } | ||||
| 
 | ||||
| func TestStreamWSConnection(t *testing.T) { | ||||
| 	eyeballConn, edgeConn := net.Pipe() | ||||
| 
 | ||||
| 	origin := echoWSOrigin(t) | ||||
| 	defer origin.Close() | ||||
| 
 | ||||
| 	req, err := http.NewRequest(http.MethodGet, origin.URL, nil) | ||||
| 	require.NoError(t, err) | ||||
| 	req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") | ||||
| 
 | ||||
| 	clientTLSConfig := &tls.Config{ | ||||
| 		InsecureSkipVerify: true, | ||||
| 	} | ||||
| 	wsConn, resp, err := newWSConnection(clientTLSConfig, req) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) | ||||
| 	require.Equal(t, "Upgrade", resp.Header.Get("Connection")) | ||||
| 	require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept")) | ||||
| 	require.Equal(t, "websocket", resp.Header.Get("Upgrade")) | ||||
| 
 | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout) | ||||
| 	defer cancel() | ||||
| 
 | ||||
| 	errGroup, ctx := errgroup.WithContext(ctx) | ||||
| 	errGroup.Go(func() error { | ||||
| 		echoWSEyeball(t, eyeballConn) | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	wsConn.Stream(ctx, edgeConn, testLogger) | ||||
| 	require.NoError(t, errGroup.Wait()) | ||||
| } | ||||
| 
 | ||||
| func echoWSEyeball(t *testing.T, conn net.Conn) { | ||||
| 	require.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) | ||||
| 
 | ||||
| 	readMsg, err := wsutil.ReadServerBinary(conn) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	require.Equal(t, testResponse, readMsg) | ||||
| 
 | ||||
| 	require.NoError(t, conn.Close()) | ||||
| } | ||||
| 
 | ||||
| func echoWSOrigin(t *testing.T) *httptest.Server { | ||||
| 	var upgrader = websocket.Upgrader{ | ||||
| 		ReadBufferSize:  10, | ||||
| 		WriteBufferSize: 10, | ||||
| 	} | ||||
| 
 | ||||
| 	ws := func(w http.ResponseWriter, r *http.Request) { | ||||
| 		header := make(http.Header) | ||||
| 		for k, vs := range r.Header { | ||||
| 			if k == "Test-Cloudflared-Echo" { | ||||
| 				header[k] = vs | ||||
| 			} | ||||
| 		} | ||||
| 		conn, err := upgrader.Upgrade(w, r, header) | ||||
| 		require.NoError(t, err) | ||||
| 		defer conn.Close() | ||||
| 
 | ||||
| 		for { | ||||
| 			messageType, p, err := conn.ReadMessage() | ||||
| 			if err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 			require.Equal(t, testMessage, p) | ||||
| 			if err := conn.WriteMessage(messageType, testResponse); err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// NewTLSServer starts the server in another thread
 | ||||
| 	return httptest.NewTLSServer(http.HandlerFunc(ws)) | ||||
| } | ||||
| 
 | ||||
| func echoTCPOrigin(t *testing.T, conn net.Conn) { | ||||
| 	readBuffer := make([]byte, len(testMessage)) | ||||
| 	_, err := conn.Read(readBuffer) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	assert.Equal(t, testMessage, readBuffer) | ||||
| 
 | ||||
| 	_, err = conn.Write(testResponse) | ||||
| 	assert.NoError(t, err) | ||||
| } | ||||
|  | @ -7,19 +7,22 @@ import ( | |||
| 	"net/url" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/connection" | ||||
| 	"github.com/cloudflare/cloudflared/h2mux" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
| 
 | ||||
| var ( | ||||
| 	switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) | ||||
| ) | ||||
| 
 | ||||
| // HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
 | ||||
| type HTTPOriginProxy interface { | ||||
| 	// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
 | ||||
| 	http.RoundTripper | ||||
| } | ||||
| 
 | ||||
| // StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
 | ||||
| // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
 | ||||
| type StreamBasedOriginProxy interface { | ||||
| 	EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) | ||||
| } | ||||
|  | @ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { | |||
| 	return o.transport.RoundTrip(req) | ||||
| } | ||||
| 
 | ||||
| // TODO: TUN-3636: establish connection to origins over UDS
 | ||||
| func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") | ||||
| } | ||||
| 
 | ||||
| func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| 	// Rewrite the request URL so that it goes to the origin service.
 | ||||
| 	req.URL.Host = o.url.Host | ||||
|  | @ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, | |||
| 		// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
 | ||||
| 		req.Host = o.hostHeader | ||||
| 	} | ||||
| 	return newWSConnection(o.transport, req) | ||||
| 	return newWSConnection(o.transport.TLSClientConfig, req) | ||||
| } | ||||
| 
 | ||||
| func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
|  | @ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { | |||
| func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	req.URL.Host = o.server.Addr().String() | ||||
| 	req.URL.Scheme = "wss" | ||||
| 	return newWSConnection(o.transport, req) | ||||
| 	return newWSConnection(o.transport.TLSClientConfig, req) | ||||
| } | ||||
| 
 | ||||
| func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { | ||||
| 	return o.resp, nil | ||||
| } | ||||
| 
 | ||||
| func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	dest, err := o.destination(r) | ||||
| func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	dest, err := getRequestHost(r) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	conn, err := o.client.connect(r, dest) | ||||
| 	return conn, nil, err | ||||
| 	conn, err := net.Dial("tcp", dest) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	originConn := &tcpConnection{ | ||||
| 		conn: conn, | ||||
| 	} | ||||
| 	resp := &http.Response{ | ||||
| 		Status:        switchingProtocolText, | ||||
| 		StatusCode:    http.StatusSwitchingProtocols, | ||||
| 		ContentLength: -1, | ||||
| 	} | ||||
| 	return originConn, resp, nil | ||||
| } | ||||
| 
 | ||||
| // getRequestHost returns the host of the http.Request.
 | ||||
|  | @ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) { | |||
| 	return "", errors.New("host not found") | ||||
| } | ||||
| 
 | ||||
| func (o *bridgeService) destination(r *http.Request) (string, error) { | ||||
| 	if connection.IsTCPStream(r) { | ||||
| 		return getRequestHost(r) | ||||
| func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	var err error | ||||
| 	dest := o.dest | ||||
| 	if o.isBastion { | ||||
| 		dest, err = o.bastionDest(r) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	conn, err := net.Dial("tcp", dest) | ||||
| 	if err != nil { | ||||
| 		return nil, nil, err | ||||
| 	} | ||||
| 	originConn := &tcpOverWSConnection{ | ||||
| 		conn:          conn, | ||||
| 		streamHandler: o.streamHandler, | ||||
| 	} | ||||
| 	resp := &http.Response{ | ||||
| 		Status:        switchingProtocolText, | ||||
| 		StatusCode:    http.StatusSwitchingProtocols, | ||||
| 		Header:        websocket.NewResponseHeader(r), | ||||
| 		ContentLength: -1, | ||||
| 	} | ||||
| 	return originConn, resp, nil | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) { | ||||
| 	jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader) | ||||
| 	if jumpDestination == "" { | ||||
| 		return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") | ||||
|  | @ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) { | |||
| func removePath(dest string) string { | ||||
| 	return strings.SplitN(dest, "/", 2)[0] | ||||
| } | ||||
| 
 | ||||
| func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||
| 	conn, err := o.client.connect(r, o.dest) | ||||
| 	return conn, nil, err | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| type tcpClient struct { | ||||
| 	streamHandler streamHandlerFunc | ||||
| } | ||||
| 
 | ||||
| func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) { | ||||
| 	conn, err := net.Dial("tcp", addr) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &tcpConnection{ | ||||
| 		conn:          conn, | ||||
| 		streamHandler: c.streamHandler, | ||||
| 	}, nil | ||||
| } | ||||
|  |  | |||
|  | @ -2,6 +2,9 @@ package ingress | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
|  | @ -10,12 +13,168 @@ import ( | |||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/h2mux" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 	"github.com/rs/zerolog" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestBridgeServiceDestination(t *testing.T) { | ||||
| // TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
 | ||||
| // the expected response
 | ||||
| func assertEstablishConnectionResponse(t *testing.T, | ||||
| 	originProxy StreamBasedOriginProxy, | ||||
| 	req *http.Request, | ||||
| 	expectHeader http.Header, | ||||
| ) { | ||||
| 	_, resp, err := originProxy.EstablishConnection(req) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, switchingProtocolText, resp.Status) | ||||
| 	assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) | ||||
| 	assert.Equal(t, expectHeader, resp.Header) | ||||
| } | ||||
| 
 | ||||
| func TestHTTPServiceEstablishConnection(t *testing.T) { | ||||
| 	origin := echoWSOrigin(t) | ||||
| 	defer origin.Close() | ||||
| 	originURL, err := url.Parse(origin.URL) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	httpService := &httpService{ | ||||
| 		url:        originURL, | ||||
| 		hostHeader: origin.URL, | ||||
| 		transport: &http.Transport{ | ||||
| 			TLSClientConfig: &tls.Config{ | ||||
| 				InsecureSkipVerify: true, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	req, err := http.NewRequest(http.MethodGet, origin.URL, nil) | ||||
| 	require.NoError(t, err) | ||||
| 	req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") | ||||
| 	req.Header.Set("Test-Cloudflared-Echo", t.Name()) | ||||
| 
 | ||||
| 	expectHeader := http.Header{ | ||||
| 		"Connection":            {"Upgrade"}, | ||||
| 		"Sec-Websocket-Accept":  {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | ||||
| 		"Upgrade":               {"websocket"}, | ||||
| 		"Test-Cloudflared-Echo": {t.Name()}, | ||||
| 	} | ||||
| 	assertEstablishConnectionResponse(t, httpService, req, expectHeader) | ||||
| } | ||||
| 
 | ||||
| func TestHelloWorldEstablishConnection(t *testing.T) { | ||||
| 	var wg sync.WaitGroup | ||||
| 	shutdownC := make(chan struct{}) | ||||
| 	errC := make(chan error) | ||||
| 	helloWorldSerivce := &helloWorld{} | ||||
| 	helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{}) | ||||
| 
 | ||||
| 	// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
 | ||||
| 	req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	expectHeader := http.Header{ | ||||
| 		"Connection": {"Upgrade"}, | ||||
| 		// Accept key when Sec-Websocket-Key is not specified
 | ||||
| 		"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, | ||||
| 		"Upgrade":              {"websocket"}, | ||||
| 	} | ||||
| 	assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader) | ||||
| 
 | ||||
| 	close(shutdownC) | ||||
| } | ||||
| 
 | ||||
| func TestRawTCPServiceEstablishConnection(t *testing.T) { | ||||
| 	originListener, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	listenerClosed := make(chan struct{}) | ||||
| 	tcpListenRoutine(originListener, listenerClosed) | ||||
| 
 | ||||
| 	rawTCPService := &rawTCPService{name: ServiceWarpRouting} | ||||
| 
 | ||||
| 	req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	assertEstablishConnectionResponse(t, rawTCPService, req, nil) | ||||
| 
 | ||||
| 	originListener.Close() | ||||
| 	<-listenerClosed | ||||
| 
 | ||||
| 	req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	// Origin not listening for new connection, should return an error
 | ||||
| 	_, resp, err := rawTCPService.EstablishConnection(req) | ||||
| 	require.Error(t, err) | ||||
| 	require.Nil(t, resp) | ||||
| } | ||||
| 
 | ||||
| func TestTCPOverWSServiceEstablishConnection(t *testing.T) { | ||||
| 	originListener, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 	require.NoError(t, err) | ||||
| 
 | ||||
| 	listenerClosed := make(chan struct{}) | ||||
| 	tcpListenRoutine(originListener, listenerClosed) | ||||
| 
 | ||||
| 	originURL := &url.URL{ | ||||
| 		Scheme: "tcp", | ||||
| 		Host:   originListener.Addr().String(), | ||||
| 	} | ||||
| 
 | ||||
| 	baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil) | ||||
| 	require.NoError(t, err) | ||||
| 	baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") | ||||
| 
 | ||||
| 	bastionReq := baseReq.Clone(context.Background()) | ||||
| 	bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String()) | ||||
| 
 | ||||
| 	expectHeader := http.Header{ | ||||
| 		"Connection":           {"Upgrade"}, | ||||
| 		"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | ||||
| 		"Upgrade":              {"websocket"}, | ||||
| 	} | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		service   *tcpOverWSService | ||||
| 		req       *http.Request | ||||
| 		expectErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			service: newTCPOverWSService(originURL), | ||||
| 			req:     baseReq, | ||||
| 		}, | ||||
| 		{ | ||||
| 			service: newBastionService(), | ||||
| 			req:     bastionReq, | ||||
| 		}, | ||||
| 		{ | ||||
| 			service:   newBastionService(), | ||||
| 			req:       baseReq, | ||||
| 			expectErr: true, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		if test.expectErr { | ||||
| 			_, resp, err := test.service.EstablishConnection(test.req) | ||||
| 			assert.Error(t, err) | ||||
| 			assert.Nil(t, resp) | ||||
| 		} else { | ||||
| 			assertEstablishConnectionResponse(t, test.service, test.req, expectHeader) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	originListener.Close() | ||||
| 	<-listenerClosed | ||||
| 
 | ||||
| 	for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { | ||||
| 		// Origin not listening for new connection, should return an error
 | ||||
| 		_, resp, err := service.EstablishConnection(bastionReq) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Nil(t, resp) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestBastionDestination(t *testing.T) { | ||||
| 	canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader) | ||||
| 	tests := []struct { | ||||
| 		name         string | ||||
|  | @ -98,12 +257,12 @@ func TestBridgeServiceDestination(t *testing.T) { | |||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
| 	s := newBridgeService(nil, ServiceBastion) | ||||
| 	s := newBastionService() | ||||
| 	for _, test := range tests { | ||||
| 		r := &http.Request{ | ||||
| 			Header: test.header, | ||||
| 		} | ||||
| 		dest, err := s.destination(r) | ||||
| 		dest, err := s.bastionDest(r) | ||||
| 		if test.wantErr { | ||||
| 			assert.Error(t, err, "Test %s expects error", test.name) | ||||
| 		} else { | ||||
|  | @ -139,10 +298,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { | |||
| 		url: originURL, | ||||
| 	} | ||||
| 	var wg sync.WaitGroup | ||||
| 	log := zerolog.Nop() | ||||
| 	shutdownC := make(chan struct{}) | ||||
| 	errC := make(chan error) | ||||
| 	require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg)) | ||||
| 	require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg)) | ||||
| 
 | ||||
| 	req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) | ||||
| 	require.NoError(t, err) | ||||
|  | @ -156,3 +314,17 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { | |||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) | ||||
| } | ||||
| 
 | ||||
| func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			conn, err := listener.Accept() | ||||
| 			if err != nil { | ||||
| 				close(closeChan) | ||||
| 				return | ||||
| 			} | ||||
| 			// Close immediately, this test is not about testing read/write on connection
 | ||||
| 			conn.Close() | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|  |  | |||
|  | @ -78,46 +78,29 @@ func (o *httpService) String() string { | |||
| 	return o.url.String() | ||||
| } | ||||
| 
 | ||||
| // bridgeService is like a jump host, the destination is specified by the client
 | ||||
| type bridgeService struct { | ||||
| 	client      *tcpClient | ||||
| 	serviceName string | ||||
| // rawTCPService dials TCP to the destination specified by the client
 | ||||
| // It's used by warp routing
 | ||||
| type rawTCPService struct { | ||||
| 	name string | ||||
| } | ||||
| 
 | ||||
| // if streamHandler is nil, a default one is set.
 | ||||
| func newBridgeService(streamHandler streamHandlerFunc, serviceName string) *bridgeService { | ||||
| 	return &bridgeService{ | ||||
| 		client: &tcpClient{ | ||||
| 			streamHandler: streamHandler, | ||||
| 		}, | ||||
| 		serviceName: serviceName, | ||||
| 	} | ||||
| func (o *rawTCPService) String() string { | ||||
| 	return o.name | ||||
| } | ||||
| 
 | ||||
| func (o *bridgeService) String() string { | ||||
| 	return ServiceBridge + ":" + o.serviceName | ||||
| } | ||||
| 
 | ||||
| func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| 	// streamHandler is already set by the constructor.
 | ||||
| 	if o.client.streamHandler != nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	if cfg.ProxyType == socksProxy { | ||||
| 		o.client.streamHandler = socks.StreamHandler | ||||
| 	} else { | ||||
| 		o.client.streamHandler = DefaultStreamHandler | ||||
| 	} | ||||
| func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| type singleTCPService struct { | ||||
| 	dest   string | ||||
| 	client *tcpClient | ||||
| // tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as
 | ||||
| // cloudflared access commands.
 | ||||
| type tcpOverWSService struct { | ||||
| 	dest          string | ||||
| 	isBastion     bool | ||||
| 	streamHandler streamHandlerFunc | ||||
| } | ||||
| 
 | ||||
| func newSingleTCPService(url *url.URL) *singleTCPService { | ||||
| func newTCPOverWSService(url *url.URL) *tcpOverWSService { | ||||
| 	switch url.Scheme { | ||||
| 	case "ssh": | ||||
| 		addPortIfMissing(url, 22) | ||||
|  | @ -128,9 +111,14 @@ func newSingleTCPService(url *url.URL) *singleTCPService { | |||
| 	case "tcp": | ||||
| 		addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
 | ||||
| 	} | ||||
| 	return &singleTCPService{ | ||||
| 		dest:   url.Host, | ||||
| 		client: &tcpClient{}, | ||||
| 	return &tcpOverWSService{ | ||||
| 		dest: url.Host, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func newBastionService() *tcpOverWSService { | ||||
| 	return &tcpOverWSService{ | ||||
| 		isBastion: true, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -140,15 +128,18 @@ func addPortIfMissing(uri *url.URL, port int) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (o *singleTCPService) String() string { | ||||
| func (o *tcpOverWSService) String() string { | ||||
| 	if o.isBastion { | ||||
| 		return ServiceBastion | ||||
| 	} | ||||
| 	return o.dest | ||||
| } | ||||
| 
 | ||||
| func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { | ||||
| 	if cfg.ProxyType == socksProxy { | ||||
| 		o.client.streamHandler = socks.StreamHandler | ||||
| 		o.streamHandler = socks.StreamHandler | ||||
| 	} else { | ||||
| 		o.client.streamHandler = DefaultStreamHandler | ||||
| 		o.streamHandler = DefaultStreamHandler | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  |  | |||
							
								
								
									
										138
									
								
								origin/proxy.go
								
								
								
								
							
							
						
						
									
										138
									
								
								origin/proxy.go
								
								
								
								
							|  | @ -13,7 +13,6 @@ import ( | |||
| 	"github.com/cloudflare/cloudflared/connection" | ||||
| 	"github.com/cloudflare/cloudflared/ingress" | ||||
| 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/rs/zerolog" | ||||
| ) | ||||
|  | @ -45,6 +44,7 @@ func NewOriginProxy( | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // Caller is responsible for writing any error to ResponseWriter
 | ||||
| func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error { | ||||
| 	incrementRequests() | ||||
| 	defer decrementConcurrentRequests() | ||||
|  | @ -62,27 +62,31 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn | |||
| 			p.log.Error().Msg(err.Error()) | ||||
| 			return err | ||||
| 		} | ||||
| 		resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy) | ||||
| 		if err != nil { | ||||
| 		logFields := logFields{ | ||||
| 			cfRay:   cfRay, | ||||
| 			lbProbe: lbProbe, | ||||
| 			rule:    ingress.ServiceWarpRouting, | ||||
| 		} | ||||
| 		if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil { | ||||
| 			p.logRequestError(err, cfRay, ingress.ServiceWarpRouting) | ||||
| 			w.WriteErrorResponse() | ||||
| 			return err | ||||
| 		} | ||||
| 		p.logOriginResponse(resp, cfRay, lbProbe, ingress.ServiceWarpRouting) | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) | ||||
| 	p.logRequest(req, cfRay, lbProbe, ruleNum) | ||||
| 	logFields := logFields{ | ||||
| 		cfRay:   cfRay, | ||||
| 		lbProbe: lbProbe, | ||||
| 		rule:    ruleNum, | ||||
| 	} | ||||
| 	p.logRequest(req, logFields) | ||||
| 
 | ||||
| 	if sourceConnectionType == connection.TypeHTTP { | ||||
| 		resp, err := p.proxyHTTP(w, req, rule) | ||||
| 		if err != nil { | ||||
| 			p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) | ||||
| 		if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil { | ||||
| 			p.logRequestError(err, cfRay, ruleNum) | ||||
| 			return err | ||||
| 		} | ||||
| 
 | ||||
| 		p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
|  | @ -92,22 +96,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn | |||
| 		return fmt.Errorf("Not a connection-oriented service") | ||||
| 	} | ||||
| 
 | ||||
| 	resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, connectionProxy) | ||||
| 	if err != nil { | ||||
| 		p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) | ||||
| 	if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil { | ||||
| 		p.logRequestError(err, cfRay, ruleNum) | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) { | ||||
| 	p.logRequestError(err, cfRay, ruleNum) | ||||
| 	w.WriteErrorResponse() | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { | ||||
| func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error { | ||||
| 	// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
 | ||||
| 	if rule.Config.DisableChunkedEncoding { | ||||
| 		req.TransferEncoding = []string{"gzip", "deflate"} | ||||
|  | @ -123,18 +119,18 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * | |||
| 	httpService, ok := rule.Service.(ingress.HTTPOriginProxy) | ||||
| 	if !ok { | ||||
| 		p.log.Error().Msgf("%s is not a http service", rule.Service) | ||||
| 		return nil, fmt.Errorf("Not a http service") | ||||
| 		return fmt.Errorf("Not a http service") | ||||
| 	} | ||||
| 
 | ||||
| 	resp, err := httpService.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error proxying request to origin") | ||||
| 		return errors.Wrap(err, "Error proxying request to origin") | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 
 | ||||
| 	err = w.WriteRespHeaders(resp.StatusCode, resp.Header) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header") | ||||
| 		return errors.Wrap(err, "Error writing response header") | ||||
| 	} | ||||
| 	if connection.IsServerSentEvent(resp.Header) { | ||||
| 		p.log.Debug().Msg("Detected Server-Side Events from Origin") | ||||
|  | @ -146,43 +142,30 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * | |||
| 		defer p.bufferPool.Put(buf) | ||||
| 		_, _ = io.CopyBuffer(w, resp.Body, buf) | ||||
| 	} | ||||
| 	return resp, nil | ||||
| 	p.logOriginResponse(resp, fields) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) proxyConnection( | ||||
| // proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between
 | ||||
| // eyeball and origin.
 | ||||
| func (p *proxy) proxyStreamRequest( | ||||
| 	serveCtx context.Context, | ||||
| 	w connection.ResponseWriter, | ||||
| 	req *http.Request, | ||||
| 	sourceConnectionType connection.Type, | ||||
| 	connectionProxy ingress.StreamBasedOriginProxy, | ||||
| ) (*http.Response, error) { | ||||
| 	originConn, connectionResp, err := connectionProxy.EstablishConnection(req) | ||||
| 	fields logFields, | ||||
| ) error { | ||||
| 	originConn, resp, err := connectionProxy.EstablishConnection(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 		return err | ||||
| 	} | ||||
| 	if resp.Body != nil { | ||||
| 		defer resp.Body.Close() | ||||
| 	} | ||||
| 
 | ||||
| 	var eyeballConn io.ReadWriter = w | ||||
| 	respHeader := http.Header{} | ||||
| 	if connectionResp != nil { | ||||
| 		respHeader = connectionResp.Header | ||||
| 	} | ||||
| 	if sourceConnectionType == connection.TypeWebsocket { | ||||
| 		wsReadWriter := websocket.NewConn(serveCtx, w, p.log) | ||||
| 		// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
 | ||||
| 		if originConn.Type() != sourceConnectionType { | ||||
| 			eyeballConn = wsReadWriter | ||||
| 		} | ||||
| 	} | ||||
| 	status := http.StatusSwitchingProtocols | ||||
| 	resp := &http.Response{ | ||||
| 		Status:        http.StatusText(status), | ||||
| 		StatusCode:    status, | ||||
| 		Header:        respHeader, | ||||
| 		ContentLength: -1, | ||||
| 	} | ||||
| 	w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header") | ||||
| 	if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	streamCtx, cancel := context.WithCancel(serveCtx) | ||||
|  | @ -194,8 +177,9 @@ func (p *proxy) proxyConnection( | |||
| 		originConn.Close() | ||||
| 	}() | ||||
| 
 | ||||
| 	originConn.Stream(eyeballConn, p.log) | ||||
| 	return resp, nil | ||||
| 	originConn.Stream(serveCtx, w, p.log) | ||||
| 	p.logOriginResponse(resp, fields) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { | ||||
|  | @ -215,39 +199,45 @@ func (p *proxy) appendTagHeaders(r *http.Request) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, rule interface{}) { | ||||
| 	if cfRay != "" { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) | ||||
| 	} else if lbProbe { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) | ||||
| type logFields struct { | ||||
| 	cfRay   string | ||||
| 	lbProbe bool | ||||
| 	rule    interface{} | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) logRequest(r *http.Request, fields logFields) { | ||||
| 	if fields.cfRay != "" { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) | ||||
| 	} else if fields.lbProbe { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) | ||||
| 	} else { | ||||
| 		p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) | ||||
| 	} | ||||
| 	p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) | ||||
| 	p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", cfRay, rule) | ||||
| 	p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", fields.cfRay, r.Header) | ||||
| 	p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", fields.cfRay, fields.rule) | ||||
| 
 | ||||
| 	if contentLen := r.ContentLength; contentLen == -1 { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay) | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", fields.cfRay) | ||||
| 	} else { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen) | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Request content length %d", fields.cfRay, contentLen) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, rule interface{}) { | ||||
| 	responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc() | ||||
| 	if cfRay != "" { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, rule) | ||||
| 	} else if lbProbe { | ||||
| 		p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status) | ||||
| func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) { | ||||
| 	responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc() | ||||
| 	if fields.cfRay != "" { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule) | ||||
| 	} else if fields.lbProbe { | ||||
| 		p.log.Debug().Msgf("Response to Load Balancer health check %s", resp.Status) | ||||
| 	} else { | ||||
| 		p.log.Debug().Msgf("Status: %s served by ingress %v", r.Status, rule) | ||||
| 		p.log.Debug().Msgf("Status: %s served by ingress %v", resp.Status, fields.rule) | ||||
| 	} | ||||
| 	p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) | ||||
| 	p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", fields.cfRay, resp.Header) | ||||
| 
 | ||||
| 	if contentLen := r.ContentLength; contentLen == -1 { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay) | ||||
| 	if contentLen := resp.ContentLength; contentLen == -1 { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", fields.cfRay) | ||||
| 	} else { | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen) | ||||
| 		p.log.Debug().Msgf("CF-RAY: %s Response content length %d", fields.cfRay, contentLen) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -347,10 +347,7 @@ func TestProxyError(t *testing.T) { | |||
| 	req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	err = proxy.Proxy(respWriter, req, connection.TypeHTTP) | ||||
| 	assert.Error(t, err) | ||||
| 	assert.Equal(t, http.StatusBadGateway, respWriter.Code) | ||||
| 	assert.Equal(t, "http response error", respWriter.Body.String()) | ||||
| 	assert.Error(t, proxy.Proxy(respWriter, req, connection.TypeHTTP)) | ||||
| } | ||||
| 
 | ||||
| type replayer struct { | ||||
|  | @ -421,15 +418,17 @@ func TestConnections(t *testing.T) { | |||
| 			originService:        runEchoWSService, | ||||
| 			eyeballService:       newWSRespWriter([]byte("test1"), replayer), | ||||
| 			connectionType:       connection.TypeWebsocket, | ||||
| 			requestHeaders: map[string][]string{ | ||||
| 				"Test-Cloudflared-Echo": []string{"Echo"}, | ||||
| 			requestHeaders: http.Header{ | ||||
| 				// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | ||||
| 				"Sec-Websocket-Key":     {"dGhlIHNhbXBsZSBub25jZQ=="}, | ||||
| 				"Test-Cloudflared-Echo": {"Echo"}, | ||||
| 			}, | ||||
| 			wantMessage: []byte("echo-test1"), | ||||
| 			wantHeaders: map[string][]string{ | ||||
| 				"Connection":            []string{"Upgrade"}, | ||||
| 				"Sec-Websocket-Accept":  []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, | ||||
| 				"Upgrade":               []string{"websocket"}, | ||||
| 				"Test-Cloudflared-Echo": []string{"Echo"}, | ||||
| 			wantHeaders: http.Header{ | ||||
| 				"Connection":            {"Upgrade"}, | ||||
| 				"Sec-Websocket-Accept":  {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | ||||
| 				"Upgrade":               {"websocket"}, | ||||
| 				"Test-Cloudflared-Echo": {"Echo"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
|  | @ -441,25 +440,23 @@ func TestConnections(t *testing.T) { | |||
| 				replayer, | ||||
| 			), | ||||
| 			connectionType: connection.TypeTCP, | ||||
| 			requestHeaders: map[string][]string{ | ||||
| 				"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, | ||||
| 			requestHeaders: http.Header{ | ||||
| 				"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | ||||
| 			}, | ||||
| 			wantMessage: []byte("echo-test2"), | ||||
| 			wantHeaders: http.Header{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                 "tcp-ws proxy", | ||||
| 			ingressServicePrefix: "ws://", | ||||
| 			originService:        runEchoWSService, | ||||
| 			eyeballService:       newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), | ||||
| 			requestHeaders: map[string][]string{ | ||||
| 				"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, | ||||
| 			requestHeaders: http.Header{ | ||||
| 				"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | ||||
| 			}, | ||||
| 			connectionType: connection.TypeTCP, | ||||
| 			wantMessage:    []byte("echo-test3"), | ||||
| 			// We expect no headers here because they are sent back via
 | ||||
| 			// the stream.
 | ||||
| 			wantHeaders: http.Header{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:                 "ws-tcp proxy", | ||||
|  | @ -467,8 +464,16 @@ func TestConnections(t *testing.T) { | |||
| 			originService:        runEchoTCPService, | ||||
| 			eyeballService:       newWSRespWriter([]byte("test4"), replayer), | ||||
| 			connectionType:       connection.TypeWebsocket, | ||||
| 			wantMessage:          []byte("echo-test4"), | ||||
| 			wantHeaders:          http.Header{}, | ||||
| 			requestHeaders: http.Header{ | ||||
| 				// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | ||||
| 				"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="}, | ||||
| 			}, | ||||
| 			wantMessage: []byte("echo-test4"), | ||||
| 			wantHeaders: http.Header{ | ||||
| 				"Connection":           {"Upgrade"}, | ||||
| 				"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | ||||
| 				"Upgrade":              {"websocket"}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
|  | @ -477,19 +482,18 @@ func TestConnections(t *testing.T) { | |||
| 			ctx, cancel := context.WithCancel(context.Background()) | ||||
| 			ln, err := net.Listen("tcp", "127.0.0.1:0") | ||||
| 			require.NoError(t, err) | ||||
| 			// Starts origin service
 | ||||
| 			test.originService(t, ln) | ||||
| 
 | ||||
| 			ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String()) | ||||
| 			var wg sync.WaitGroup | ||||
| 			errC := make(chan error) | ||||
| 			ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) | ||||
| 			proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) | ||||
| 
 | ||||
| 			req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) | ||||
| 			require.NoError(t, err) | ||||
| 			reqHeaders := make(http.Header) | ||||
| 			for k, vs := range test.requestHeaders { | ||||
| 				reqHeaders[k] = vs | ||||
| 			} | ||||
| 			req.Header = reqHeaders | ||||
| 			req.Header = test.requestHeaders | ||||
| 
 | ||||
| 			if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { | ||||
| 				go func() { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue