TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input
This change extracts the need for EstablishConnection to know about a request's entire context. It also removes the concern of populating the http.Response from EstablishConnection's responsibilities.
This commit is contained in:
		
							parent
							
								
									f1b57526b3
								
							
						
					
					
						commit
						d678584d89
					
				|  | @ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Get a single origin service from the CLI/config.
 | // Get a single origin service from the CLI/config.
 | ||||||
| func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) { | func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) { | ||||||
| 	if c.IsSet("hello-world") { | 	if c.IsSet("hello-world") { | ||||||
| 		return new(helloWorld), nil | 		return new(helloWorld), nil | ||||||
| 	} | 	} | ||||||
|  | @ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon | ||||||
| 	rules := make([]Rule, len(ingress)) | 	rules := make([]Rule, len(ingress)) | ||||||
| 	for i, r := range ingress { | 	for i, r := range ingress { | ||||||
| 		cfg := setConfig(defaults, r.OriginRequest) | 		cfg := setConfig(defaults, r.OriginRequest) | ||||||
| 		var service originService | 		var service OriginService | ||||||
| 
 | 
 | ||||||
| 		if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) { | 		if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) { | ||||||
| 			// No validation necessary for unix socket filepath services
 | 			// No validation necessary for unix socket filepath services
 | ||||||
|  |  | ||||||
|  | @ -6,9 +6,6 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
| 
 |  | ||||||
| 	"github.com/cloudflare/cloudflared/carrier" |  | ||||||
| 	"github.com/cloudflare/cloudflared/websocket" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( | var ( | ||||||
|  | @ -24,7 +21,7 @@ type HTTPOriginProxy interface { | ||||||
| 
 | 
 | ||||||
| // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
 | // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
 | ||||||
| type StreamBasedOriginProxy interface { | type StreamBasedOriginProxy interface { | ||||||
| 	EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) | 	EstablishConnection(dest string) (OriginConnection, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { | func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
|  | @ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { | ||||||
| 	return o.resp, nil | 	return o.resp, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) { | ||||||
| 	dest, err := getRequestHost(r) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, nil, err |  | ||||||
| 	} |  | ||||||
| 	conn, err := net.Dial("tcp", dest) | 	conn, err := net.Dial("tcp", dest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	originConn := &tcpConnection{ | 	originConn := &tcpConnection{ | ||||||
| 		conn: conn, | 		conn: conn, | ||||||
| 	} | 	} | ||||||
| 	resp := &http.Response{ | 	return originConn, nil | ||||||
| 		Status:        switchingProtocolText, |  | ||||||
| 		StatusCode:    http.StatusSwitchingProtocols, |  | ||||||
| 		ContentLength: -1, |  | ||||||
| 	} |  | ||||||
| 	return originConn, resp, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getRequestHost returns the host of the http.Request.
 | func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) { | ||||||
| func getRequestHost(r *http.Request) (string, error) { |  | ||||||
| 	if r.Host != "" { |  | ||||||
| 		return r.Host, nil |  | ||||||
| 	} |  | ||||||
| 	if r.URL != nil { |  | ||||||
| 		return r.URL.Host, nil |  | ||||||
| 	} |  | ||||||
| 	return "", errors.New("host not found") |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { |  | ||||||
| 	var err error | 	var err error | ||||||
| 	dest := o.dest | 	if !o.isBastion { | ||||||
| 	if o.isBastion { | 		dest = o.dest | ||||||
| 		dest, err = carrier.ResolveBastionDest(r) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return nil, nil, err |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	conn, err := net.Dial("tcp", dest) | 	conn, err := net.Dial("tcp", dest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	originConn := &tcpOverWSConnection{ | 	originConn := &tcpOverWSConnection{ | ||||||
| 		conn:          conn, | 		conn:          conn, | ||||||
| 		streamHandler: o.streamHandler, | 		streamHandler: o.streamHandler, | ||||||
| 	} | 	} | ||||||
| 	resp := &http.Response{ | 	return originConn, nil | ||||||
| 		Status:        switchingProtocolText, |  | ||||||
| 		StatusCode:    http.StatusSwitchingProtocols, |  | ||||||
| 		Header:        websocket.NewResponseHeader(r), |  | ||||||
| 		ContentLength: -1, |  | ||||||
| 	} |  | ||||||
| 	return originConn, resp, nil |  | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) { | ||||||
| 	originConn := o.conn | 	return o.conn, nil | ||||||
| 	resp := &http.Response{ |  | ||||||
| 		Status:        switchingProtocolText, |  | ||||||
| 		StatusCode:    http.StatusSwitchingProtocols, |  | ||||||
| 		Header:        websocket.NewResponseHeader(r), |  | ||||||
| 		ContentLength: -1, |  | ||||||
| 	} |  | ||||||
| 	return originConn, resp, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -17,20 +17,6 @@ import ( | ||||||
| 	"github.com/cloudflare/cloudflared/websocket" | 	"github.com/cloudflare/cloudflared/websocket" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
 |  | ||||||
| // the expected response
 |  | ||||||
| func assertEstablishConnectionResponse(t *testing.T, |  | ||||||
| 	originProxy StreamBasedOriginProxy, |  | ||||||
| 	req *http.Request, |  | ||||||
| 	expectHeader http.Header, |  | ||||||
| ) { |  | ||||||
| 	_, resp, err := originProxy.EstablishConnection(req) |  | ||||||
| 	assert.NoError(t, err) |  | ||||||
| 	assert.Equal(t, switchingProtocolText, resp.Status) |  | ||||||
| 	assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) |  | ||||||
| 	assert.Equal(t, expectHeader, resp.Header) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestRawTCPServiceEstablishConnection(t *testing.T) { | func TestRawTCPServiceEstablishConnection(t *testing.T) { | ||||||
| 	originListener, err := net.Listen("tcp", "127.0.0.1:0") | 	originListener, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
|  | @ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { | ||||||
| 	req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) | 	req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	assertEstablishConnectionResponse(t, rawTCPService, req, nil) |  | ||||||
| 
 |  | ||||||
| 	originListener.Close() | 	originListener.Close() | ||||||
| 	<-listenerClosed | 	<-listenerClosed | ||||||
| 
 | 
 | ||||||
|  | @ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	// Origin not listening for new connection, should return an error
 | 	// Origin not listening for new connection, should return an error
 | ||||||
| 	_, resp, err := rawTCPService.EstablishConnection(req) | 	_, err = rawTCPService.EstablishConnection(req.URL.String()) | ||||||
| 	require.Error(t, err) | 	require.Error(t, err) | ||||||
| 	require.Nil(t, resp) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestTCPOverWSServiceEstablishConnection(t *testing.T) { | func TestTCPOverWSServiceEstablishConnection(t *testing.T) { | ||||||
|  | @ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { | ||||||
| 	bastionReq := baseReq.Clone(context.Background()) | 	bastionReq := baseReq.Clone(context.Background()) | ||||||
| 	carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) | 	carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) | ||||||
| 
 | 
 | ||||||
| 	expectHeader := http.Header{ |  | ||||||
| 		"Connection":           {"Upgrade"}, |  | ||||||
| 		"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, |  | ||||||
| 		"Upgrade":              {"websocket"}, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		testCase  string | 		testCase  string | ||||||
| 		service   *tcpOverWSService | 		service   *tcpOverWSService | ||||||
|  | @ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
| 		t.Run(test.testCase, func(t *testing.T) { | 		t.Run(test.testCase, func(t *testing.T) { | ||||||
| 			if test.expectErr { | 			if test.expectErr { | ||||||
| 				_, resp, err := test.service.EstablishConnection(test.req) | 				bastionHost, _ := carrier.ResolveBastionDest(test.req) | ||||||
|  | 				_, err := test.service.EstablishConnection(bastionHost) | ||||||
| 				assert.Error(t, err) | 				assert.Error(t, err) | ||||||
| 				assert.Nil(t, resp) |  | ||||||
| 			} else { |  | ||||||
| 				assertEstablishConnectionResponse(t, test.service, test.req, expectHeader) |  | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|  | @ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { | 	for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { | ||||||
| 		// Origin not listening for new connection, should return an error
 | 		// Origin not listening for new connection, should return an error
 | ||||||
| 		_, resp, err := service.EstablishConnection(bastionReq) | 		bastionHost, _ := carrier.ResolveBastionDest(bastionReq) | ||||||
|  | 		_, err := service.EstablishConnection(bastionHost) | ||||||
| 		assert.Error(t, err) | 		assert.Error(t, err) | ||||||
| 		assert.Nil(t, resp) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -20,8 +20,8 @@ import ( | ||||||
| 	"github.com/cloudflare/cloudflared/tlsconfig" | 	"github.com/cloudflare/cloudflared/tlsconfig" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // originService is something a tunnel can proxy traffic to.
 | // OriginService is something a tunnel can proxy traffic to.
 | ||||||
| type originService interface { | type OriginService interface { | ||||||
| 	String() string | 	String() string | ||||||
| 	// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
 | 	// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
 | ||||||
| 	// If it's not managed by cloudflared, this is a no-op because the user is responsible for
 | 	// If it's not managed by cloudflared, this is a no-op because the user is responsible for
 | ||||||
|  | @ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { | func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { | ||||||
| 	originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log) | 	originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, errors.Wrap(err, "Error loading cert pool") | 		return nil, errors.Wrap(err, "Error loading cert pool") | ||||||
|  |  | ||||||
|  | @ -17,7 +17,7 @@ type Rule struct { | ||||||
| 	// A (probably local) address. Requests for a hostname which matches this
 | 	// A (probably local) address. Requests for a hostname which matches this
 | ||||||
| 	// rule's hostname pattern will be proxied to the service running on this
 | 	// rule's hostname pattern will be proxied to the service running on this
 | ||||||
| 	// address.
 | 	// address.
 | ||||||
| 	Service originService | 	Service OriginService | ||||||
| 
 | 
 | ||||||
| 	// Configure the request cloudflared sends to this specific origin.
 | 	// Configure the request cloudflared sends to this specific origin.
 | ||||||
| 	Config OriginRequestConfig | 	Config OriginRequestConfig | ||||||
|  |  | ||||||
|  | @ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) { | ||||||
| 	type fields struct { | 	type fields struct { | ||||||
| 		Hostname string | 		Hostname string | ||||||
| 		Path     *regexp.Regexp | 		Path     *regexp.Regexp | ||||||
| 		Service  originService | 		Service  OriginService | ||||||
| 	} | 	} | ||||||
| 	type args struct { | 	type args struct { | ||||||
| 		requestURL *url.URL | 		requestURL *url.URL | ||||||
|  |  | ||||||
|  | @ -12,6 +12,7 @@ import ( | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
| 	"github.com/rs/zerolog" | 	"github.com/rs/zerolog" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/cloudflare/cloudflared/carrier" | ||||||
| 	"github.com/cloudflare/cloudflared/connection" | 	"github.com/cloudflare/cloudflared/connection" | ||||||
| 	"github.com/cloudflare/cloudflared/ingress" | 	"github.com/cloudflare/cloudflared/ingress" | ||||||
| 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||||
|  | @ -33,6 +34,8 @@ type proxy struct { | ||||||
| 	bufferPool   *bufferPool | 	bufferPool   *bufferPool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) | ||||||
|  | 
 | ||||||
| func NewOriginProxy( | func NewOriginProxy( | ||||||
| 	ingressRules ingress.Ingress, | 	ingressRules ingress.Ingress, | ||||||
| 	warpRouting *ingress.WarpRoutingService, | 	warpRouting *ingress.WarpRoutingService, | ||||||
|  | @ -71,7 +74,13 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn | ||||||
| 			lbProbe: lbProbe, | 			lbProbe: lbProbe, | ||||||
| 			rule:    ingress.ServiceWarpRouting, | 			rule:    ingress.ServiceWarpRouting, | ||||||
| 		} | 		} | ||||||
| 		if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil { | 
 | ||||||
|  | 		host, err := getRequestHost(req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err) | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil { | ||||||
| 			p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting) | 			p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting) | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  | @ -97,7 +106,11 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn | ||||||
| 		return nil | 		return nil | ||||||
| 
 | 
 | ||||||
| 	case ingress.StreamBasedOriginProxy: | 	case ingress.StreamBasedOriginProxy: | ||||||
| 		if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil { | 		dest, err := getDestFromRule(rule, req) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil { | ||||||
| 			rule, srv := ruleField(p.ingressRules, ruleNum) | 			rule, srv := ruleField(p.ingressRules, ruleNum) | ||||||
| 			p.logRequestError(err, cfRay, rule, srv) | 			p.logRequestError(err, cfRay, rule, srv) | ||||||
| 			return err | 			return err | ||||||
|  | @ -105,10 +118,29 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn | ||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
| 		return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy) | 		return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy) | ||||||
| 
 |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) { | ||||||
|  | 	switch rule.Service.String() { | ||||||
|  | 	case ingress.ServiceBastion: | ||||||
|  | 		return carrier.ResolveBastionDest(req) | ||||||
|  | 	default: | ||||||
|  | 		return rule.Service.String(), nil | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getRequestHost returns the host of the http.Request.
 | ||||||
|  | func getRequestHost(r *http.Request) (string, error) { | ||||||
|  | 	if r.Host != "" { | ||||||
|  | 		return r.Host, nil | ||||||
|  | 	} | ||||||
|  | 	if r.URL != nil { | ||||||
|  | 		return r.URL.Host, nil | ||||||
|  | 	} | ||||||
|  | 	return "", errors.New("host not set in incoming request") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { | func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { | ||||||
| 	srv = ing.Rules[ruleNum].Service.String() | 	srv = ing.Rules[ruleNum].Service.String() | ||||||
| 	if ing.IsSingleRule() { | 	if ing.IsSingleRule() { | ||||||
|  | @ -191,16 +223,24 @@ func (p *proxy) proxyHTTPRequest( | ||||||
| func (p *proxy) proxyStreamRequest( | func (p *proxy) proxyStreamRequest( | ||||||
| 	serveCtx context.Context, | 	serveCtx context.Context, | ||||||
| 	w connection.ResponseWriter, | 	w connection.ResponseWriter, | ||||||
|  | 	dest string, | ||||||
| 	req *http.Request, | 	req *http.Request, | ||||||
| 	connectionProxy ingress.StreamBasedOriginProxy, | 	connectionProxy ingress.StreamBasedOriginProxy, | ||||||
| 	fields logFields, | 	fields logFields, | ||||||
| ) error { | ) error { | ||||||
| 	originConn, resp, err := connectionProxy.EstablishConnection(req) | 	originConn, err := connectionProxy.EstablishConnection(dest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	if resp.Body != nil { | 
 | ||||||
| 		defer resp.Body.Close() | 	resp := &http.Response{ | ||||||
|  | 		Status:        switchingProtocolText, | ||||||
|  | 		StatusCode:    http.StatusSwitchingProtocols, | ||||||
|  | 		ContentLength: -1, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" { | ||||||
|  | 		resp.Header = websocket.NewResponseHeader(req) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { | 	if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue