From df3ef06169644e2afa75c8049af9958a99734982 Mon Sep 17 00:00:00 2001 From: Shayon Mukherjee Date: Thu, 9 May 2024 15:07:59 -0400 Subject: [PATCH] Support ingress rule matching for bastion mode --- Makefile | 2 +- carrier/carrier.go | 13 +- carrier/carrier_test.go | 52 ++++++-- cmd/cloudflared/tunnel/ingress_subcommands.go | 2 +- ingress/ingress.go | 28 +++- ingress/ingress_test.go | 123 ++++++++++++++---- ingress/origin_proxy_test.go | 17 ++- ingress/origin_service.go | 13 +- ingress/rule.go | 1 + proxy/proxy.go | 49 ++++--- 10 files changed, 227 insertions(+), 73 deletions(-) diff --git a/Makefile b/Makefile index 665a1bf5..0155bbf0 100644 --- a/Makefile +++ b/Makefile @@ -109,7 +109,7 @@ ifneq ($(TARGET_ARM), ) ARM_COMMAND := GOARM=$(TARGET_ARM) endif -ifeq ($(TARGET_ARM), 7) +ifeq ($(TARGET_ARM), 7) PACKAGE_ARCH := armhf else PACKAGE_ARCH := $(TARGET_ARCH) diff --git a/carrier/carrier.go b/carrier/carrier.go index b44e1324..640228bb 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -16,13 +16,14 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/token" ) const ( LogFieldOriginURL = "originURL" CFAccessTokenHeader = "Cf-Access-Token" - cfJumpDestinationHeader = "Cf-Access-Jump-Destination" + CFJumpDestinationHeader = "Cf-Access-Jump-Destination" ) type StartOptions struct { @@ -163,12 +164,16 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque func SetBastionDest(header http.Header, destination string) { if destination != "" { - header.Set(cfJumpDestinationHeader, destination) + header.Set(CFJumpDestinationHeader, destination) } } -func ResolveBastionDest(r *http.Request) (string, error) { - jumpDestination := r.Header.Get(cfJumpDestinationHeader) +func ResolveBastionDest(req *http.Request, bastionMode bool, service string) (string, error) { + jumpDestination := req.Header.Get(CFJumpDestinationHeader) + if bastionMode && service != config.BastionFlag { + jumpDestination = service + } + if jumpDestination == "" { return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") } diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index 84738e7f..bedaad07 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -158,82 +158,112 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { } func TestBastionDestination(t *testing.T) { + tests := []struct { name string header http.Header expectedDest string wantErr bool + bastionMode bool + service string }{ { name: "hostname destination", header: http.Header{ - cfJumpDestinationHeader: []string{"localhost"}, + CFJumpDestinationHeader: []string{"localhost"}, }, expectedDest: "localhost", }, { name: "hostname destination with port", header: http.Header{ - cfJumpDestinationHeader: []string{"localhost:9000"}, + CFJumpDestinationHeader: []string{"localhost:9000"}, }, expectedDest: "localhost:9000", }, { name: "hostname destination with scheme and port", header: http.Header{ - cfJumpDestinationHeader: []string{"ssh://localhost:9000"}, + CFJumpDestinationHeader: []string{"ssh://localhost:9000"}, }, expectedDest: "localhost:9000", }, { name: "full hostname url", header: http.Header{ - cfJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"}, + CFJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"}, }, expectedDest: "localhost:9000", }, { name: "hostname destination with port and path", header: http.Header{ - cfJumpDestinationHeader: []string{"localhost:9000/metrics"}, + CFJumpDestinationHeader: []string{"localhost:9000/metrics"}, }, expectedDest: "localhost:9000", }, { name: "ip destination", header: http.Header{ - cfJumpDestinationHeader: []string{"127.0.0.1"}, + CFJumpDestinationHeader: []string{"127.0.0.1"}, }, expectedDest: "127.0.0.1", }, { name: "ip destination with port", header: http.Header{ - cfJumpDestinationHeader: []string{"127.0.0.1:9000"}, + CFJumpDestinationHeader: []string{"127.0.0.1:9000"}, }, expectedDest: "127.0.0.1:9000", }, { name: "ip destination with port and path", header: http.Header{ - cfJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"}, + CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"}, }, expectedDest: "127.0.0.1:9000", }, { name: "ip destination with schem and port", header: http.Header{ - cfJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"}, + CFJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"}, }, expectedDest: "127.0.0.1:9000", }, { name: "full ip url", header: http.Header{ - cfJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"}, + CFJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"}, }, expectedDest: "127.0.0.1:9000", }, + { + name: "full ip url with bastion mode", + header: http.Header{ + CFJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"}, + }, + bastionMode: true, + service: "ssh://127.0.0.1:9002/metrics", + expectedDest: "127.0.0.1:9002", + }, + { + name: "ip destination with port and path with bastion mode", + header: http.Header{ + CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"}, + }, + bastionMode: true, + service: "127.0.0.1:9002/metrics", + expectedDest: "127.0.0.1:9002", + }, + { + name: "ip destination with port and path without bastion mode", + header: http.Header{ + CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"}, + }, + bastionMode: false, + service: "127.0.0.1:9002/metrics", + expectedDest: "127.0.0.1:9000", + }, { name: "no destination", wantErr: true, @@ -243,7 +273,7 @@ func TestBastionDestination(t *testing.T) { r := &http.Request{ Header: test.header, } - dest, err := ResolveBastionDest(r) + dest, err := ResolveBastionDest(r, test.bastionMode, test.service) if test.wantErr { assert.Error(t, err, "Test %s expects error", test.name) } else { diff --git a/cmd/cloudflared/tunnel/ingress_subcommands.go b/cmd/cloudflared/tunnel/ingress_subcommands.go index 82ef7561..8f41a176 100644 --- a/cmd/cloudflared/tunnel/ingress_subcommands.go +++ b/cmd/cloudflared/tunnel/ingress_subcommands.go @@ -138,7 +138,7 @@ func testURLCommand(c *cli.Context) error { return errors.Wrap(err, "Validation failed") } - _, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path) + _, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path, "") fmt.Printf("Matched rule #%d\n", i) fmt.Println(ing.Rules[i].MultiLineString()) return nil diff --git a/ingress/ingress.go b/ingress/ingress.go index 60ee87ac..cc034351 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -28,7 +28,6 @@ var ( ) const ( - ServiceBastion = "bastion" ServiceSocksProxy = "socks-proxy" ServiceWarpRouting = "warp-routing" ) @@ -38,12 +37,13 @@ const ( // which is the case if the rules were instantiated via the ingress#Validate method. // // Negative index rule signifies local cloudflared rules (not-user defined). -func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) { +func (ing Ingress) FindMatchingRule(hostname, path string, cfJumpDestinationHeader string) (*Rule, int) { // The hostname might contain port. We only want to compare the host part with the rule host, _, err := net.SplitHostPort(hostname) if err == nil { hostname = host } + derivedHostName := hostname for i, rule := range ing.InternalRules { if rule.Matches(hostname, path) { // Local rule matches return a negative rule index to distiguish local rules from user-defined rules in logs @@ -52,7 +52,15 @@ func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) { } } for i, rule := range ing.Rules { - if rule.Matches(hostname, path) { + // If bastion mode is turned on and request is made as bastion, attempt + // to match a rule where jump destination header matches the hostname + if rule.Config.BastionMode && len(cfJumpDestinationHeader) > 0 { + jumpDestinationUri, err := url.Parse(cfJumpDestinationHeader) + if err == nil { + derivedHostName = jumpDestinationUri.Hostname() + } + } + if rule.Matches(derivedHostName, path) { return &rule, i } } @@ -265,6 +273,7 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq } srv := newStatusCode(statusCode) service = &srv + } else if r.Service == HelloWorldFlag || r.Service == HelloWorldService { service = new(helloWorld) } else if r.Service == ServiceSocksProxy { @@ -284,12 +293,21 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq } service = newSocksProxyOverWSService(accessPolicy) - } else if r.Service == ServiceBastion || cfg.BastionMode { + } else if r.Service == config.BastionFlag || cfg.BastionMode { // Bastion mode will always start a Websocket proxy server, which will // overwrite the localService.URL field when `start` is called. So, // leave the URL field empty for now. cfg.BastionMode = true - service = newBastionService() + + if cfg.BastionMode && r.Service != config.BastionFlag { + u, err := url.Parse(r.Service) + if err != nil { + return Ingress{}, err + } + service = newBastionServiceWithDest(u) + } else { + service = newBastionService() + } } else { // Validate URL services u, err := url.Parse(r.Service) diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 109cb353..5bb5049a 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -439,6 +439,50 @@ ingress: }, }, }, + { + name: "Bastion mode turned on with with custom service", + args: args{rawYAML: ` +ingress: +- hostname: bastiondest.foo.com + service: http://localhost:9000 + originRequest: + bastionMode: true +- service: http_status:404 +`}, + want: []Rule{ + { + Hostname: "bastiondest.foo.com", + Service: newBastionServiceWithDest(MustParseURL(t, "http://localhost:9000")), + Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), + }, + { + Service: &fourOhFour, + Config: defaultConfig, + }, + }, + }, + { + name: "TCP service with Bastion mode turned off", + args: args{rawYAML: ` +ingress: +- hostname: tcp.foo.com + service: tcp://localhost:9000 + originRequest: + bastionMode: false +- service: http_status:404 +`}, + want: []Rule{ + { + Hostname: "tcp.foo.com", + Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:9000")), + Config: defaultConfig, + }, + { + Service: &fourOhFour, + Config: defaultConfig, + }, + }, + }, { name: "Hostname contains port", args: args{rawYAML: ` @@ -656,6 +700,7 @@ func TestSingleOriginServices_URL(t *testing.T) { } func TestFindMatchingRule(t *testing.T) { + ingress := Ingress{ Rules: []Rule{ { @@ -666,6 +711,13 @@ func TestFindMatchingRule(t *testing.T) { Hostname: "tunnel-b.example.com", Path: MustParsePath(t, "/health"), }, + { + Hostname: "tunnel-d.example.com", + Path: nil, + Config: OriginRequestConfig{ + BastionMode: true, + }, + }, { Hostname: "*", }, @@ -673,45 +725,64 @@ func TestFindMatchingRule(t *testing.T) { } tests := []struct { - host string - path string - req *http.Request - wantRuleIndex int + host string + path string + cfJumpDestinationHeader string + req *http.Request + wantRuleIndex int }{ { - host: "tunnel-a.example.com", - path: "/", - wantRuleIndex: 0, + host: "tunnel-a.example.com", + path: "/", + cfJumpDestinationHeader: "", + wantRuleIndex: 0, }, { - host: "tunnel-a.example.com", - path: "/pages/about", - wantRuleIndex: 0, + host: "tunnel-a.example.com", + path: "/pages/about", + cfJumpDestinationHeader: "", + wantRuleIndex: 0, }, { - host: "tunnel-a.example.com:443", - path: "/pages/about", - wantRuleIndex: 0, + host: "tunnel-a.example.com:443", + path: "/pages/about", + cfJumpDestinationHeader: "", + wantRuleIndex: 0, }, { - host: "tunnel-b.example.com", - path: "/health", - wantRuleIndex: 1, + host: "tunnel-b.example.com", + path: "/health", + cfJumpDestinationHeader: "", + wantRuleIndex: 1, }, { - host: "tunnel-b.example.com", - path: "/index.html", - wantRuleIndex: 2, + host: "tunnel-b.example.com", + path: "/index.html", + cfJumpDestinationHeader: "", + wantRuleIndex: 3, }, { - host: "tunnel-c.example.com", - path: "/", - wantRuleIndex: 2, + host: "tunnel-d.example.com", + path: "/", + cfJumpDestinationHeader: "https://tunnel-d.example.com", + wantRuleIndex: 2, + }, + { + host: "tunnel-d.example.com", + path: "/", + cfJumpDestinationHeader: "https://tunnel-d.example.com", + wantRuleIndex: 2, + }, + { + host: "tunnel-c.example.com", + path: "/", + cfJumpDestinationHeader: "", + wantRuleIndex: 3, }, } for _, test := range tests { - _, ruleIndex := ingress.FindMatchingRule(test.host, test.path) + _, ruleIndex := ingress.FindMatchingRule(test.host, test.path, test.cfJumpDestinationHeader) assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex)) } } @@ -786,9 +857,9 @@ ingress: } for n := 0; n < b.N; n++ { - ing.FindMatchingRule("tunnel1.example.com", "") - ing.FindMatchingRule("tunnel2.example.com", "") - ing.FindMatchingRule("tunnel3.example.com", "") + ing.FindMatchingRule("tunnel1.example.com", "", "") + ing.FindMatchingRule("tunnel2.example.com", "", "") + ing.FindMatchingRule("tunnel3.example.com", "", "") } } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 7a6170a2..9387cd2d 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -58,6 +58,8 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { bastionReq := baseReq.Clone(context.Background()) carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) + u, err := url.Parse("https://place-holder1") + require.NoError(t, err) tests := []struct { testCase string @@ -81,12 +83,23 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { req: baseReq, expectErr: true, }, + { + testCase: "bastion service", + service: newBastionServiceWithDest(u), + req: bastionReq, + }, + { + testCase: "bastion service", + service: newBastionServiceWithDest(u), + req: bastionReq, + expectErr: true, + }, } for _, test := range tests { t.Run(test.testCase, func(t *testing.T) { if test.expectErr { - bastionHost, _ := carrier.ResolveBastionDest(test.req) + bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion") _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } @@ -98,7 +111,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { // Origin not listening for new connection, should return an error - bastionHost, _ := carrier.ResolveBastionDest(bastionReq) + bastionHost, _ := carrier.ResolveBastionDest(bastionReq, false, "bastion") _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } diff --git a/ingress/origin_service.go b/ingress/origin_service.go index e13204c5..dccbf0b9 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -15,6 +15,7 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ipaccess" "github.com/cloudflare/cloudflared/management" @@ -151,6 +152,14 @@ func newBastionService() *tcpOverWSService { } } +func newBastionServiceWithDest(url *url.URL) *tcpOverWSService { + return &tcpOverWSService{ + isBastion: true, + scheme: url.Scheme, + dest: url.Host, + } +} + func newSocksProxyOverWSService(accessPolicy *ipaccess.Policy) *socksProxyOverWSService { proxy := socksProxyOverWSService{ conn: &socksProxyOverWSConnection{ @@ -170,8 +179,8 @@ func addPortIfMissing(uri *url.URL, port int) { } func (o *tcpOverWSService) String() string { - if o.isBastion { - return ServiceBastion + if o.isBastion && len(o.dest) == 0 { + return config.BastionFlag } if o.scheme != "" { diff --git a/ingress/rule.go b/ingress/rule.go index 43c7ad5e..79cdf1d5 100644 --- a/ingress/rule.go +++ b/ingress/rule.go @@ -59,6 +59,7 @@ func (r *Rule) Matches(hostname, path string) bool { } else { hostMatch = matchHost(r.Hostname, hostname) } + punycodeHostMatch := false if r.punycodeHostname != "" { punycodeHostMatch = matchHost(r.punycodeHostname, hostname) diff --git a/proxy/proxy.go b/proxy/proxy.go index dc02eeac..c0cf7052 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,6 +15,7 @@ import ( "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/cfio" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/stream" @@ -86,7 +87,7 @@ func (p *Proxy) ProxyHTTP( _, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match", trace.WithAttributes(attribute.String("req-host", req.Host))) - rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) + rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path, req.Header.Get(carrier.CFJumpDestinationHeader)) ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum)) ruleSpan.End() logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String()) @@ -98,6 +99,29 @@ func (p *Proxy) ProxyHTTP( } return err } + // Handling for StreamBasedOriginProxy or BastionMode + if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode { + if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode { + return fmt.Errorf("Unrecognized service: %s", rule.Service) + } + + dest, err := getDestFromRule(rule, req) + if err != nil { + return err + } + + flusher, ok := w.(http.Flusher) + if !ok { + return fmt.Errorf("response writer is not a flusher") + } + rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) + logger := logger.With().Str(logFieldDestAddr, dest).Logger() + if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil { + logRequestError(&logger, err) + return err + } + return nil + } switch originProxy := rule.Service.(type) { case ingress.HTTPOriginProxy: @@ -113,22 +137,6 @@ func (p *Proxy) ProxyHTTP( return err } return nil - case ingress.StreamBasedOriginProxy: - dest, err := getDestFromRule(rule, req) - if err != nil { - return err - } - flusher, ok := w.(http.Flusher) - if !ok { - return fmt.Errorf("response writer is not a flusher") - } - rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) - logger := logger.With().Str(logFieldDestAddr, dest).Logger() - if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy, &logger); err != nil { - logRequestError(&logger, err) - return err - } - return nil case ingress.HTTPLocalProxy: p.proxyLocalRequest(originProxy, w, req, isWebsocket) return nil @@ -335,10 +343,9 @@ func copyTrailers(w connection.ResponseWriter, response *http.Response) { } func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) { - switch rule.Service.String() { - case ingress.ServiceBastion: - return carrier.ResolveBastionDest(req) - default: + if rule.Config.BastionMode || rule.Service.String() == config.BastionFlag { + return carrier.ResolveBastionDest(req, rule.Config.BastionMode, rule.Service.String()) + } else { return rule.Service.String(), nil } }