TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input

This change extracts the need for EstablishConnection to know about a
request's entire context. It also removes the concern of populating the
http.Response from EstablishConnection's responsibilities.
This commit is contained in:
Sudarsan Reddy 2021-07-01 19:30:26 +01:00
parent f1b57526b3
commit d678584d89
7 changed files with 69 additions and 94 deletions

View File

@ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService {
} }
// Get a single origin service from the CLI/config. // Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) { func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
if c.IsSet("hello-world") { if c.IsSet("hello-world") {
return new(helloWorld), nil return new(helloWorld), nil
} }
@ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
rules := make([]Rule, len(ingress)) rules := make([]Rule, len(ingress))
for i, r := range ingress { for i, r := range ingress {
cfg := setConfig(defaults, r.OriginRequest) cfg := setConfig(defaults, r.OriginRequest)
var service originService var service OriginService
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) { if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
// No validation necessary for unix socket filepath services // No validation necessary for unix socket filepath services

View File

@ -6,9 +6,6 @@ import (
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/websocket"
) )
var ( var (
@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface { type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) EstablishConnection(dest string) (OriginConnection, error)
} }
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil return o.resp, nil
} }
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
dest, err := getRequestHost(r)
if err != nil {
return nil, nil, err
}
conn, err := net.Dial("tcp", dest) conn, err := net.Dial("tcp", dest)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
originConn := &tcpConnection{ originConn := &tcpConnection{
conn: conn, conn: conn,
} }
resp := &http.Response{ return originConn, nil
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
} }
// getRequestHost returns the host of the http.Request. func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
}
if r.URL != nil {
return r.URL.Host, nil
}
return "", errors.New("host not found")
}
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
var err error var err error
dest := o.dest if !o.isBastion {
if o.isBastion { dest = o.dest
dest, err = carrier.ResolveBastionDest(r)
if err != nil {
return nil, nil, err
}
} }
conn, err := net.Dial("tcp", dest) conn, err := net.Dial("tcp", dest)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
originConn := &tcpOverWSConnection{ originConn := &tcpOverWSConnection{
conn: conn, conn: conn,
streamHandler: o.streamHandler, streamHandler: o.streamHandler,
} }
resp := &http.Response{ return originConn, nil
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
} }
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
originConn := o.conn return o.conn, nil
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
} }

View File

@ -17,20 +17,6 @@ import (
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
) )
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
// the expected response
func assertEstablishConnectionResponse(t *testing.T,
originProxy StreamBasedOriginProxy,
req *http.Request,
expectHeader http.Header,
) {
_, resp, err := originProxy.EstablishConnection(req)
assert.NoError(t, err)
assert.Equal(t, switchingProtocolText, resp.Status)
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
assert.Equal(t, expectHeader, resp.Header)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) { func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0") originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
@ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err) require.NoError(t, err)
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
originListener.Close() originListener.Close()
<-listenerClosed <-listenerClosed
@ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Origin not listening for new connection, should return an error // Origin not listening for new connection, should return an error
_, resp, err := rawTCPService.EstablishConnection(req) _, err = rawTCPService.EstablishConnection(req.URL.String())
require.Error(t, err) require.Error(t, err)
require.Nil(t, resp)
} }
func TestTCPOverWSServiceEstablishConnection(t *testing.T) { func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
@ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
bastionReq := baseReq.Clone(context.Background()) bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
tests := []struct { tests := []struct {
testCase string testCase string
service *tcpOverWSService service *tcpOverWSService
@ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) { t.Run(test.testCase, func(t *testing.T) {
if test.expectErr { if test.expectErr {
_, resp, err := test.service.EstablishConnection(test.req) bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(bastionHost)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, resp)
} else {
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
} }
}) })
} }
@ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error // Origin not listening for new connection, should return an error
_, resp, err := service.EstablishConnection(bastionReq) bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(bastionHost)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, resp)
} }
} }

View File

@ -20,8 +20,8 @@ import (
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
) )
// originService is something a tunnel can proxy traffic to. // OriginService is something a tunnel can proxy traffic to.
type originService interface { type OriginService interface {
String() string String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for // If it's not managed by cloudflared, this is a no-op because the user is responsible for
@ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error {
return nil return nil
} }
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log) originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool") return nil, errors.Wrap(err, "Error loading cert pool")

View File

@ -17,7 +17,7 @@ type Rule struct {
// A (probably local) address. Requests for a hostname which matches this // A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this // rule's hostname pattern will be proxied to the service running on this
// address. // address.
Service originService Service OriginService
// Configure the request cloudflared sends to this specific origin. // Configure the request cloudflared sends to this specific origin.
Config OriginRequestConfig Config OriginRequestConfig

View File

@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
type fields struct { type fields struct {
Hostname string Hostname string
Path *regexp.Regexp Path *regexp.Regexp
Service originService Service OriginService
} }
type args struct { type args struct {
requestURL *url.URL requestURL *url.URL

View File

@ -12,6 +12,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -33,6 +34,8 @@ type proxy struct {
bufferPool *bufferPool bufferPool *bufferPool
} }
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
func NewOriginProxy( func NewOriginProxy(
ingressRules ingress.Ingress, ingressRules ingress.Ingress,
warpRouting *ingress.WarpRoutingService, warpRouting *ingress.WarpRoutingService,
@ -71,7 +74,13 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
lbProbe: lbProbe, lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting, rule: ingress.ServiceWarpRouting,
} }
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
host, err := getRequestHost(req)
if err != nil {
err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err)
return err
}
if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting) p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
return err return err
} }
@ -97,7 +106,11 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return nil return nil
case ingress.StreamBasedOriginProxy: case ingress.StreamBasedOriginProxy:
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil { dest, err := getDestFromRule(rule, req)
if err != nil {
return err
}
if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum) rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv) p.logRequestError(err, cfRay, rule, srv)
return err return err
@ -105,10 +118,29 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return nil return nil
default: default:
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy) return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
} }
} }
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() {
case ingress.ServiceBastion:
return carrier.ResolveBastionDest(req)
default:
return rule.Service.String(), nil
}
}
// getRequestHost returns the host of the http.Request.
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
}
if r.URL != nil {
return r.URL.Host, nil
}
return "", errors.New("host not set in incoming request")
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String() srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() { if ing.IsSingleRule() {
@ -191,16 +223,24 @@ func (p *proxy) proxyHTTPRequest(
func (p *proxy) proxyStreamRequest( func (p *proxy) proxyStreamRequest(
serveCtx context.Context, serveCtx context.Context,
w connection.ResponseWriter, w connection.ResponseWriter,
dest string,
req *http.Request, req *http.Request,
connectionProxy ingress.StreamBasedOriginProxy, connectionProxy ingress.StreamBasedOriginProxy,
fields logFields, fields logFields,
) error { ) error {
originConn, resp, err := connectionProxy.EstablishConnection(req) originConn, err := connectionProxy.EstablishConnection(dest)
if err != nil { if err != nil {
return err return err
} }
if resp.Body != nil {
defer resp.Body.Close() resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
resp.Header = websocket.NewResponseHeader(req)
} }
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {