Support ingress rule matching for bastion mode

This commit is contained in:
Shayon Mukherjee 2024-05-09 15:07:59 -04:00
parent f27418044b
commit df3ef06169
10 changed files with 227 additions and 73 deletions

View File

@ -16,13 +16,14 @@ import (
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/token"
)
const (
LogFieldOriginURL = "originURL"
CFAccessTokenHeader = "Cf-Access-Token"
cfJumpDestinationHeader = "Cf-Access-Jump-Destination"
CFJumpDestinationHeader = "Cf-Access-Jump-Destination"
)
type StartOptions struct {
@ -163,12 +164,16 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque
func SetBastionDest(header http.Header, destination string) {
if destination != "" {
header.Set(cfJumpDestinationHeader, destination)
header.Set(CFJumpDestinationHeader, destination)
}
}
func ResolveBastionDest(r *http.Request) (string, error) {
jumpDestination := r.Header.Get(cfJumpDestinationHeader)
func ResolveBastionDest(req *http.Request, bastionMode bool, service string) (string, error) {
jumpDestination := req.Header.Get(CFJumpDestinationHeader)
if bastionMode && service != config.BastionFlag {
jumpDestination = service
}
if jumpDestination == "" {
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
}

View File

@ -158,82 +158,112 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
}
func TestBastionDestination(t *testing.T) {
tests := []struct {
name string
header http.Header
expectedDest string
wantErr bool
bastionMode bool
service string
}{
{
name: "hostname destination",
header: http.Header{
cfJumpDestinationHeader: []string{"localhost"},
CFJumpDestinationHeader: []string{"localhost"},
},
expectedDest: "localhost",
},
{
name: "hostname destination with port",
header: http.Header{
cfJumpDestinationHeader: []string{"localhost:9000"},
CFJumpDestinationHeader: []string{"localhost:9000"},
},
expectedDest: "localhost:9000",
},
{
name: "hostname destination with scheme and port",
header: http.Header{
cfJumpDestinationHeader: []string{"ssh://localhost:9000"},
CFJumpDestinationHeader: []string{"ssh://localhost:9000"},
},
expectedDest: "localhost:9000",
},
{
name: "full hostname url",
header: http.Header{
cfJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"},
CFJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"},
},
expectedDest: "localhost:9000",
},
{
name: "hostname destination with port and path",
header: http.Header{
cfJumpDestinationHeader: []string{"localhost:9000/metrics"},
CFJumpDestinationHeader: []string{"localhost:9000/metrics"},
},
expectedDest: "localhost:9000",
},
{
name: "ip destination",
header: http.Header{
cfJumpDestinationHeader: []string{"127.0.0.1"},
CFJumpDestinationHeader: []string{"127.0.0.1"},
},
expectedDest: "127.0.0.1",
},
{
name: "ip destination with port",
header: http.Header{
cfJumpDestinationHeader: []string{"127.0.0.1:9000"},
CFJumpDestinationHeader: []string{"127.0.0.1:9000"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "ip destination with port and path",
header: http.Header{
cfJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "ip destination with schem and port",
header: http.Header{
cfJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"},
CFJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "full ip url",
header: http.Header{
cfJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
CFJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
},
expectedDest: "127.0.0.1:9000",
},
{
name: "full ip url with bastion mode",
header: http.Header{
CFJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
},
bastionMode: true,
service: "ssh://127.0.0.1:9002/metrics",
expectedDest: "127.0.0.1:9002",
},
{
name: "ip destination with port and path with bastion mode",
header: http.Header{
CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
},
bastionMode: true,
service: "127.0.0.1:9002/metrics",
expectedDest: "127.0.0.1:9002",
},
{
name: "ip destination with port and path without bastion mode",
header: http.Header{
CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
},
bastionMode: false,
service: "127.0.0.1:9002/metrics",
expectedDest: "127.0.0.1:9000",
},
{
name: "no destination",
wantErr: true,
@ -243,7 +273,7 @@ func TestBastionDestination(t *testing.T) {
r := &http.Request{
Header: test.header,
}
dest, err := ResolveBastionDest(r)
dest, err := ResolveBastionDest(r, test.bastionMode, test.service)
if test.wantErr {
assert.Error(t, err, "Test %s expects error", test.name)
} else {

View File

@ -138,7 +138,7 @@ func testURLCommand(c *cli.Context) error {
return errors.Wrap(err, "Validation failed")
}
_, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path)
_, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path, "")
fmt.Printf("Matched rule #%d\n", i)
fmt.Println(ing.Rules[i].MultiLineString())
return nil

View File

@ -28,7 +28,6 @@ var (
)
const (
ServiceBastion = "bastion"
ServiceSocksProxy = "socks-proxy"
ServiceWarpRouting = "warp-routing"
)
@ -38,12 +37,13 @@ const (
// which is the case if the rules were instantiated via the ingress#Validate method.
//
// Negative index rule signifies local cloudflared rules (not-user defined).
func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) {
func (ing Ingress) FindMatchingRule(hostname, path string, cfJumpDestinationHeader string) (*Rule, int) {
// The hostname might contain port. We only want to compare the host part with the rule
host, _, err := net.SplitHostPort(hostname)
if err == nil {
hostname = host
}
derivedHostName := hostname
for i, rule := range ing.InternalRules {
if rule.Matches(hostname, path) {
// Local rule matches return a negative rule index to distiguish local rules from user-defined rules in logs
@ -52,7 +52,15 @@ func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) {
}
}
for i, rule := range ing.Rules {
if rule.Matches(hostname, path) {
// If bastion mode is turned on and request is made as bastion, attempt
// to match a rule where jump destination header matches the hostname
if rule.Config.BastionMode && len(cfJumpDestinationHeader) > 0 {
jumpDestinationUri, err := url.Parse(cfJumpDestinationHeader)
if err == nil {
derivedHostName = jumpDestinationUri.Hostname()
}
}
if rule.Matches(derivedHostName, path) {
return &rule, i
}
}
@ -265,6 +273,7 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
}
srv := newStatusCode(statusCode)
service = &srv
} else if r.Service == HelloWorldFlag || r.Service == HelloWorldService {
service = new(helloWorld)
} else if r.Service == ServiceSocksProxy {
@ -284,12 +293,21 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
}
service = newSocksProxyOverWSService(accessPolicy)
} else if r.Service == ServiceBastion || cfg.BastionMode {
} else if r.Service == config.BastionFlag || cfg.BastionMode {
// Bastion mode will always start a Websocket proxy server, which will
// overwrite the localService.URL field when `start` is called. So,
// leave the URL field empty for now.
cfg.BastionMode = true
if cfg.BastionMode && r.Service != config.BastionFlag {
u, err := url.Parse(r.Service)
if err != nil {
return Ingress{}, err
}
service = newBastionServiceWithDest(u)
} else {
service = newBastionService()
}
} else {
// Validate URL services
u, err := url.Parse(r.Service)

View File

@ -439,6 +439,50 @@ ingress:
},
},
},
{
name: "Bastion mode turned on with with custom service",
args: args{rawYAML: `
ingress:
- hostname: bastiondest.foo.com
service: http://localhost:9000
originRequest:
bastionMode: true
- service: http_status:404
`},
want: []Rule{
{
Hostname: "bastiondest.foo.com",
Service: newBastionServiceWithDest(MustParseURL(t, "http://localhost:9000")),
Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "TCP service with Bastion mode turned off",
args: args{rawYAML: `
ingress:
- hostname: tcp.foo.com
service: tcp://localhost:9000
originRequest:
bastionMode: false
- service: http_status:404
`},
want: []Rule{
{
Hostname: "tcp.foo.com",
Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:9000")),
Config: defaultConfig,
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "Hostname contains port",
args: args{rawYAML: `
@ -656,6 +700,7 @@ func TestSingleOriginServices_URL(t *testing.T) {
}
func TestFindMatchingRule(t *testing.T) {
ingress := Ingress{
Rules: []Rule{
{
@ -666,6 +711,13 @@ func TestFindMatchingRule(t *testing.T) {
Hostname: "tunnel-b.example.com",
Path: MustParsePath(t, "/health"),
},
{
Hostname: "tunnel-d.example.com",
Path: nil,
Config: OriginRequestConfig{
BastionMode: true,
},
},
{
Hostname: "*",
},
@ -675,43 +727,62 @@ func TestFindMatchingRule(t *testing.T) {
tests := []struct {
host string
path string
cfJumpDestinationHeader string
req *http.Request
wantRuleIndex int
}{
{
host: "tunnel-a.example.com",
path: "/",
cfJumpDestinationHeader: "",
wantRuleIndex: 0,
},
{
host: "tunnel-a.example.com",
path: "/pages/about",
cfJumpDestinationHeader: "",
wantRuleIndex: 0,
},
{
host: "tunnel-a.example.com:443",
path: "/pages/about",
cfJumpDestinationHeader: "",
wantRuleIndex: 0,
},
{
host: "tunnel-b.example.com",
path: "/health",
cfJumpDestinationHeader: "",
wantRuleIndex: 1,
},
{
host: "tunnel-b.example.com",
path: "/index.html",
cfJumpDestinationHeader: "",
wantRuleIndex: 3,
},
{
host: "tunnel-d.example.com",
path: "/",
cfJumpDestinationHeader: "https://tunnel-d.example.com",
wantRuleIndex: 2,
},
{
host: "tunnel-d.example.com",
path: "/",
cfJumpDestinationHeader: "https://tunnel-d.example.com",
wantRuleIndex: 2,
},
{
host: "tunnel-c.example.com",
path: "/",
wantRuleIndex: 2,
cfJumpDestinationHeader: "",
wantRuleIndex: 3,
},
}
for _, test := range tests {
_, ruleIndex := ingress.FindMatchingRule(test.host, test.path)
_, ruleIndex := ingress.FindMatchingRule(test.host, test.path, test.cfJumpDestinationHeader)
assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex))
}
}
@ -786,9 +857,9 @@ ingress:
}
for n := 0; n < b.N; n++ {
ing.FindMatchingRule("tunnel1.example.com", "")
ing.FindMatchingRule("tunnel2.example.com", "")
ing.FindMatchingRule("tunnel3.example.com", "")
ing.FindMatchingRule("tunnel1.example.com", "", "")
ing.FindMatchingRule("tunnel2.example.com", "", "")
ing.FindMatchingRule("tunnel3.example.com", "", "")
}
}

View File

@ -58,6 +58,8 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
u, err := url.Parse("https://place-holder1")
require.NoError(t, err)
tests := []struct {
testCase string
@ -81,12 +83,23 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
req: baseReq,
expectErr: true,
},
{
testCase: "bastion service",
service: newBastionServiceWithDest(u),
req: bastionReq,
},
{
testCase: "bastion service",
service: newBastionServiceWithDest(u),
req: bastionReq,
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req)
bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion")
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
@ -98,7 +111,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
bastionHost, _ := carrier.ResolveBastionDest(bastionReq, false, "bastion")
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}

View File

@ -15,6 +15,7 @@ import (
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/cloudflare/cloudflared/management"
@ -151,6 +152,14 @@ func newBastionService() *tcpOverWSService {
}
}
func newBastionServiceWithDest(url *url.URL) *tcpOverWSService {
return &tcpOverWSService{
isBastion: true,
scheme: url.Scheme,
dest: url.Host,
}
}
func newSocksProxyOverWSService(accessPolicy *ipaccess.Policy) *socksProxyOverWSService {
proxy := socksProxyOverWSService{
conn: &socksProxyOverWSConnection{
@ -170,8 +179,8 @@ func addPortIfMissing(uri *url.URL, port int) {
}
func (o *tcpOverWSService) String() string {
if o.isBastion {
return ServiceBastion
if o.isBastion && len(o.dest) == 0 {
return config.BastionFlag
}
if o.scheme != "" {

View File

@ -59,6 +59,7 @@ func (r *Rule) Matches(hostname, path string) bool {
} else {
hostMatch = matchHost(r.Hostname, hostname)
}
punycodeHostMatch := false
if r.punycodeHostname != "" {
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)

View File

@ -15,6 +15,7 @@ import (
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/stream"
@ -86,7 +87,7 @@ func (p *Proxy) ProxyHTTP(
_, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match",
trace.WithAttributes(attribute.String("req-host", req.Host)))
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path, req.Header.Get(carrier.CFJumpDestinationHeader))
ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum))
ruleSpan.End()
logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String())
@ -98,6 +99,29 @@ func (p *Proxy) ProxyHTTP(
}
return err
}
// Handling for StreamBasedOriginProxy or BastionMode
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
return fmt.Errorf("Unrecognized service: %s", rule.Service)
}
dest, err := getDestFromRule(rule, req)
if err != nil {
return err
}
flusher, ok := w.(http.Flusher)
if !ok {
return fmt.Errorf("response writer is not a flusher")
}
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
logRequestError(&logger, err)
return err
}
return nil
}
switch originProxy := rule.Service.(type) {
case ingress.HTTPOriginProxy:
@ -113,22 +137,6 @@ func (p *Proxy) ProxyHTTP(
return err
}
return nil
case ingress.StreamBasedOriginProxy:
dest, err := getDestFromRule(rule, req)
if err != nil {
return err
}
flusher, ok := w.(http.Flusher)
if !ok {
return fmt.Errorf("response writer is not a flusher")
}
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy, &logger); err != nil {
logRequestError(&logger, err)
return err
}
return nil
case ingress.HTTPLocalProxy:
p.proxyLocalRequest(originProxy, w, req, isWebsocket)
return nil
@ -335,10 +343,9 @@ func copyTrailers(w connection.ResponseWriter, response *http.Response) {
}
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() {
case ingress.ServiceBastion:
return carrier.ResolveBastionDest(req)
default:
if rule.Config.BastionMode || rule.Service.String() == config.BastionFlag {
return carrier.ResolveBastionDest(req, rule.Config.BastionMode, rule.Service.String())
} else {
return rule.Service.String(), nil
}
}