Support ingress rule matching for bastion mode
This commit is contained in:
parent
f27418044b
commit
df3ef06169
|
@ -16,13 +16,14 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/token"
|
"github.com/cloudflare/cloudflared/token"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
LogFieldOriginURL = "originURL"
|
LogFieldOriginURL = "originURL"
|
||||||
CFAccessTokenHeader = "Cf-Access-Token"
|
CFAccessTokenHeader = "Cf-Access-Token"
|
||||||
cfJumpDestinationHeader = "Cf-Access-Jump-Destination"
|
CFJumpDestinationHeader = "Cf-Access-Jump-Destination"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StartOptions struct {
|
type StartOptions struct {
|
||||||
|
@ -163,12 +164,16 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque
|
||||||
|
|
||||||
func SetBastionDest(header http.Header, destination string) {
|
func SetBastionDest(header http.Header, destination string) {
|
||||||
if destination != "" {
|
if destination != "" {
|
||||||
header.Set(cfJumpDestinationHeader, destination)
|
header.Set(CFJumpDestinationHeader, destination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResolveBastionDest(r *http.Request) (string, error) {
|
func ResolveBastionDest(req *http.Request, bastionMode bool, service string) (string, error) {
|
||||||
jumpDestination := r.Header.Get(cfJumpDestinationHeader)
|
jumpDestination := req.Header.Get(CFJumpDestinationHeader)
|
||||||
|
if bastionMode && service != config.BastionFlag {
|
||||||
|
jumpDestination = service
|
||||||
|
}
|
||||||
|
|
||||||
if jumpDestination == "" {
|
if jumpDestination == "" {
|
||||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||||
}
|
}
|
||||||
|
|
|
@ -158,82 +158,112 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBastionDestination(t *testing.T) {
|
func TestBastionDestination(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
header http.Header
|
header http.Header
|
||||||
expectedDest string
|
expectedDest string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
|
bastionMode bool
|
||||||
|
service string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "hostname destination",
|
name: "hostname destination",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"localhost"},
|
CFJumpDestinationHeader: []string{"localhost"},
|
||||||
},
|
},
|
||||||
expectedDest: "localhost",
|
expectedDest: "localhost",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "hostname destination with port",
|
name: "hostname destination with port",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"localhost:9000"},
|
CFJumpDestinationHeader: []string{"localhost:9000"},
|
||||||
},
|
},
|
||||||
expectedDest: "localhost:9000",
|
expectedDest: "localhost:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "hostname destination with scheme and port",
|
name: "hostname destination with scheme and port",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"ssh://localhost:9000"},
|
CFJumpDestinationHeader: []string{"ssh://localhost:9000"},
|
||||||
},
|
},
|
||||||
expectedDest: "localhost:9000",
|
expectedDest: "localhost:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "full hostname url",
|
name: "full hostname url",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"},
|
CFJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"},
|
||||||
},
|
},
|
||||||
expectedDest: "localhost:9000",
|
expectedDest: "localhost:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "hostname destination with port and path",
|
name: "hostname destination with port and path",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"localhost:9000/metrics"},
|
CFJumpDestinationHeader: []string{"localhost:9000/metrics"},
|
||||||
},
|
},
|
||||||
expectedDest: "localhost:9000",
|
expectedDest: "localhost:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ip destination",
|
name: "ip destination",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"127.0.0.1"},
|
CFJumpDestinationHeader: []string{"127.0.0.1"},
|
||||||
},
|
},
|
||||||
expectedDest: "127.0.0.1",
|
expectedDest: "127.0.0.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ip destination with port",
|
name: "ip destination with port",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"127.0.0.1:9000"},
|
CFJumpDestinationHeader: []string{"127.0.0.1:9000"},
|
||||||
},
|
},
|
||||||
expectedDest: "127.0.0.1:9000",
|
expectedDest: "127.0.0.1:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ip destination with port and path",
|
name: "ip destination with port and path",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
|
CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
|
||||||
},
|
},
|
||||||
expectedDest: "127.0.0.1:9000",
|
expectedDest: "127.0.0.1:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ip destination with schem and port",
|
name: "ip destination with schem and port",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"},
|
CFJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"},
|
||||||
},
|
},
|
||||||
expectedDest: "127.0.0.1:9000",
|
expectedDest: "127.0.0.1:9000",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "full ip url",
|
name: "full ip url",
|
||||||
header: http.Header{
|
header: http.Header{
|
||||||
cfJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
CFJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
||||||
},
|
},
|
||||||
expectedDest: "127.0.0.1:9000",
|
expectedDest: "127.0.0.1:9000",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "full ip url with bastion mode",
|
||||||
|
header: http.Header{
|
||||||
|
CFJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
bastionMode: true,
|
||||||
|
service: "ssh://127.0.0.1:9002/metrics",
|
||||||
|
expectedDest: "127.0.0.1:9002",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with port and path with bastion mode",
|
||||||
|
header: http.Header{
|
||||||
|
CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
bastionMode: true,
|
||||||
|
service: "127.0.0.1:9002/metrics",
|
||||||
|
expectedDest: "127.0.0.1:9002",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip destination with port and path without bastion mode",
|
||||||
|
header: http.Header{
|
||||||
|
CFJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"},
|
||||||
|
},
|
||||||
|
bastionMode: false,
|
||||||
|
service: "127.0.0.1:9002/metrics",
|
||||||
|
expectedDest: "127.0.0.1:9000",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no destination",
|
name: "no destination",
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -243,7 +273,7 @@ func TestBastionDestination(t *testing.T) {
|
||||||
r := &http.Request{
|
r := &http.Request{
|
||||||
Header: test.header,
|
Header: test.header,
|
||||||
}
|
}
|
||||||
dest, err := ResolveBastionDest(r)
|
dest, err := ResolveBastionDest(r, test.bastionMode, test.service)
|
||||||
if test.wantErr {
|
if test.wantErr {
|
||||||
assert.Error(t, err, "Test %s expects error", test.name)
|
assert.Error(t, err, "Test %s expects error", test.name)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -138,7 +138,7 @@ func testURLCommand(c *cli.Context) error {
|
||||||
return errors.Wrap(err, "Validation failed")
|
return errors.Wrap(err, "Validation failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path)
|
_, i := ing.FindMatchingRule(requestURL.Hostname(), requestURL.Path, "")
|
||||||
fmt.Printf("Matched rule #%d\n", i)
|
fmt.Printf("Matched rule #%d\n", i)
|
||||||
fmt.Println(ing.Rules[i].MultiLineString())
|
fmt.Println(ing.Rules[i].MultiLineString())
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -28,7 +28,6 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ServiceBastion = "bastion"
|
|
||||||
ServiceSocksProxy = "socks-proxy"
|
ServiceSocksProxy = "socks-proxy"
|
||||||
ServiceWarpRouting = "warp-routing"
|
ServiceWarpRouting = "warp-routing"
|
||||||
)
|
)
|
||||||
|
@ -38,12 +37,13 @@ const (
|
||||||
// which is the case if the rules were instantiated via the ingress#Validate method.
|
// which is the case if the rules were instantiated via the ingress#Validate method.
|
||||||
//
|
//
|
||||||
// Negative index rule signifies local cloudflared rules (not-user defined).
|
// Negative index rule signifies local cloudflared rules (not-user defined).
|
||||||
func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) {
|
func (ing Ingress) FindMatchingRule(hostname, path string, cfJumpDestinationHeader string) (*Rule, int) {
|
||||||
// The hostname might contain port. We only want to compare the host part with the rule
|
// The hostname might contain port. We only want to compare the host part with the rule
|
||||||
host, _, err := net.SplitHostPort(hostname)
|
host, _, err := net.SplitHostPort(hostname)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
hostname = host
|
hostname = host
|
||||||
}
|
}
|
||||||
|
derivedHostName := hostname
|
||||||
for i, rule := range ing.InternalRules {
|
for i, rule := range ing.InternalRules {
|
||||||
if rule.Matches(hostname, path) {
|
if rule.Matches(hostname, path) {
|
||||||
// Local rule matches return a negative rule index to distiguish local rules from user-defined rules in logs
|
// Local rule matches return a negative rule index to distiguish local rules from user-defined rules in logs
|
||||||
|
@ -52,7 +52,15 @@ func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i, rule := range ing.Rules {
|
for i, rule := range ing.Rules {
|
||||||
if rule.Matches(hostname, path) {
|
// If bastion mode is turned on and request is made as bastion, attempt
|
||||||
|
// to match a rule where jump destination header matches the hostname
|
||||||
|
if rule.Config.BastionMode && len(cfJumpDestinationHeader) > 0 {
|
||||||
|
jumpDestinationUri, err := url.Parse(cfJumpDestinationHeader)
|
||||||
|
if err == nil {
|
||||||
|
derivedHostName = jumpDestinationUri.Hostname()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rule.Matches(derivedHostName, path) {
|
||||||
return &rule, i
|
return &rule, i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -265,6 +273,7 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
|
||||||
}
|
}
|
||||||
srv := newStatusCode(statusCode)
|
srv := newStatusCode(statusCode)
|
||||||
service = &srv
|
service = &srv
|
||||||
|
|
||||||
} else if r.Service == HelloWorldFlag || r.Service == HelloWorldService {
|
} else if r.Service == HelloWorldFlag || r.Service == HelloWorldService {
|
||||||
service = new(helloWorld)
|
service = new(helloWorld)
|
||||||
} else if r.Service == ServiceSocksProxy {
|
} else if r.Service == ServiceSocksProxy {
|
||||||
|
@ -284,12 +293,21 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
|
||||||
}
|
}
|
||||||
|
|
||||||
service = newSocksProxyOverWSService(accessPolicy)
|
service = newSocksProxyOverWSService(accessPolicy)
|
||||||
} else if r.Service == ServiceBastion || cfg.BastionMode {
|
} else if r.Service == config.BastionFlag || cfg.BastionMode {
|
||||||
// Bastion mode will always start a Websocket proxy server, which will
|
// Bastion mode will always start a Websocket proxy server, which will
|
||||||
// overwrite the localService.URL field when `start` is called. So,
|
// overwrite the localService.URL field when `start` is called. So,
|
||||||
// leave the URL field empty for now.
|
// leave the URL field empty for now.
|
||||||
cfg.BastionMode = true
|
cfg.BastionMode = true
|
||||||
|
|
||||||
|
if cfg.BastionMode && r.Service != config.BastionFlag {
|
||||||
|
u, err := url.Parse(r.Service)
|
||||||
|
if err != nil {
|
||||||
|
return Ingress{}, err
|
||||||
|
}
|
||||||
|
service = newBastionServiceWithDest(u)
|
||||||
|
} else {
|
||||||
service = newBastionService()
|
service = newBastionService()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Validate URL services
|
// Validate URL services
|
||||||
u, err := url.Parse(r.Service)
|
u, err := url.Parse(r.Service)
|
||||||
|
|
|
@ -439,6 +439,50 @@ ingress:
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Bastion mode turned on with with custom service",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- hostname: bastiondest.foo.com
|
||||||
|
service: http://localhost:9000
|
||||||
|
originRequest:
|
||||||
|
bastionMode: true
|
||||||
|
- service: http_status:404
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Hostname: "bastiondest.foo.com",
|
||||||
|
Service: newBastionServiceWithDest(MustParseURL(t, "http://localhost:9000")),
|
||||||
|
Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Service: &fourOhFour,
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP service with Bastion mode turned off",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- hostname: tcp.foo.com
|
||||||
|
service: tcp://localhost:9000
|
||||||
|
originRequest:
|
||||||
|
bastionMode: false
|
||||||
|
- service: http_status:404
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Hostname: "tcp.foo.com",
|
||||||
|
Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:9000")),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Service: &fourOhFour,
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Hostname contains port",
|
name: "Hostname contains port",
|
||||||
args: args{rawYAML: `
|
args: args{rawYAML: `
|
||||||
|
@ -656,6 +700,7 @@ func TestSingleOriginServices_URL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindMatchingRule(t *testing.T) {
|
func TestFindMatchingRule(t *testing.T) {
|
||||||
|
|
||||||
ingress := Ingress{
|
ingress := Ingress{
|
||||||
Rules: []Rule{
|
Rules: []Rule{
|
||||||
{
|
{
|
||||||
|
@ -666,6 +711,13 @@ func TestFindMatchingRule(t *testing.T) {
|
||||||
Hostname: "tunnel-b.example.com",
|
Hostname: "tunnel-b.example.com",
|
||||||
Path: MustParsePath(t, "/health"),
|
Path: MustParsePath(t, "/health"),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Hostname: "tunnel-d.example.com",
|
||||||
|
Path: nil,
|
||||||
|
Config: OriginRequestConfig{
|
||||||
|
BastionMode: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
},
|
},
|
||||||
|
@ -675,43 +727,62 @@ func TestFindMatchingRule(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
host string
|
host string
|
||||||
path string
|
path string
|
||||||
|
cfJumpDestinationHeader string
|
||||||
req *http.Request
|
req *http.Request
|
||||||
wantRuleIndex int
|
wantRuleIndex int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
host: "tunnel-a.example.com",
|
host: "tunnel-a.example.com",
|
||||||
path: "/",
|
path: "/",
|
||||||
|
cfJumpDestinationHeader: "",
|
||||||
wantRuleIndex: 0,
|
wantRuleIndex: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
host: "tunnel-a.example.com",
|
host: "tunnel-a.example.com",
|
||||||
path: "/pages/about",
|
path: "/pages/about",
|
||||||
|
cfJumpDestinationHeader: "",
|
||||||
wantRuleIndex: 0,
|
wantRuleIndex: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
host: "tunnel-a.example.com:443",
|
host: "tunnel-a.example.com:443",
|
||||||
path: "/pages/about",
|
path: "/pages/about",
|
||||||
|
cfJumpDestinationHeader: "",
|
||||||
wantRuleIndex: 0,
|
wantRuleIndex: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
host: "tunnel-b.example.com",
|
host: "tunnel-b.example.com",
|
||||||
path: "/health",
|
path: "/health",
|
||||||
|
cfJumpDestinationHeader: "",
|
||||||
wantRuleIndex: 1,
|
wantRuleIndex: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
host: "tunnel-b.example.com",
|
host: "tunnel-b.example.com",
|
||||||
path: "/index.html",
|
path: "/index.html",
|
||||||
|
cfJumpDestinationHeader: "",
|
||||||
|
wantRuleIndex: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
host: "tunnel-d.example.com",
|
||||||
|
path: "/",
|
||||||
|
cfJumpDestinationHeader: "https://tunnel-d.example.com",
|
||||||
|
wantRuleIndex: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
host: "tunnel-d.example.com",
|
||||||
|
path: "/",
|
||||||
|
cfJumpDestinationHeader: "https://tunnel-d.example.com",
|
||||||
wantRuleIndex: 2,
|
wantRuleIndex: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
host: "tunnel-c.example.com",
|
host: "tunnel-c.example.com",
|
||||||
path: "/",
|
path: "/",
|
||||||
wantRuleIndex: 2,
|
cfJumpDestinationHeader: "",
|
||||||
|
wantRuleIndex: 3,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
_, ruleIndex := ingress.FindMatchingRule(test.host, test.path)
|
_, ruleIndex := ingress.FindMatchingRule(test.host, test.path, test.cfJumpDestinationHeader)
|
||||||
assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex))
|
assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -786,9 +857,9 @@ ingress:
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
ing.FindMatchingRule("tunnel1.example.com", "")
|
ing.FindMatchingRule("tunnel1.example.com", "", "")
|
||||||
ing.FindMatchingRule("tunnel2.example.com", "")
|
ing.FindMatchingRule("tunnel2.example.com", "", "")
|
||||||
ing.FindMatchingRule("tunnel3.example.com", "")
|
ing.FindMatchingRule("tunnel3.example.com", "", "")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,8 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
|
|
||||||
bastionReq := baseReq.Clone(context.Background())
|
bastionReq := baseReq.Clone(context.Background())
|
||||||
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||||
|
u, err := url.Parse("https://place-holder1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
testCase string
|
testCase string
|
||||||
|
@ -81,12 +83,23 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
req: baseReq,
|
req: baseReq,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
testCase: "bastion service",
|
||||||
|
service: newBastionServiceWithDest(u),
|
||||||
|
req: bastionReq,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testCase: "bastion service",
|
||||||
|
service: newBastionServiceWithDest(u),
|
||||||
|
req: bastionReq,
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.testCase, func(t *testing.T) {
|
t.Run(test.testCase, func(t *testing.T) {
|
||||||
if test.expectErr {
|
if test.expectErr {
|
||||||
bastionHost, _ := carrier.ResolveBastionDest(test.req)
|
bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion")
|
||||||
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
@ -98,7 +111,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
|
|
||||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||||
// Origin not listening for new connection, should return an error
|
// Origin not listening for new connection, should return an error
|
||||||
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
|
bastionHost, _ := carrier.ResolveBastionDest(bastionReq, false, "bastion")
|
||||||
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
"github.com/cloudflare/cloudflared/ipaccess"
|
"github.com/cloudflare/cloudflared/ipaccess"
|
||||||
"github.com/cloudflare/cloudflared/management"
|
"github.com/cloudflare/cloudflared/management"
|
||||||
|
@ -151,6 +152,14 @@ func newBastionService() *tcpOverWSService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newBastionServiceWithDest(url *url.URL) *tcpOverWSService {
|
||||||
|
return &tcpOverWSService{
|
||||||
|
isBastion: true,
|
||||||
|
scheme: url.Scheme,
|
||||||
|
dest: url.Host,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func newSocksProxyOverWSService(accessPolicy *ipaccess.Policy) *socksProxyOverWSService {
|
func newSocksProxyOverWSService(accessPolicy *ipaccess.Policy) *socksProxyOverWSService {
|
||||||
proxy := socksProxyOverWSService{
|
proxy := socksProxyOverWSService{
|
||||||
conn: &socksProxyOverWSConnection{
|
conn: &socksProxyOverWSConnection{
|
||||||
|
@ -170,8 +179,8 @@ func addPortIfMissing(uri *url.URL, port int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *tcpOverWSService) String() string {
|
func (o *tcpOverWSService) String() string {
|
||||||
if o.isBastion {
|
if o.isBastion && len(o.dest) == 0 {
|
||||||
return ServiceBastion
|
return config.BastionFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.scheme != "" {
|
if o.scheme != "" {
|
||||||
|
|
|
@ -59,6 +59,7 @@ func (r *Rule) Matches(hostname, path string) bool {
|
||||||
} else {
|
} else {
|
||||||
hostMatch = matchHost(r.Hostname, hostname)
|
hostMatch = matchHost(r.Hostname, hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
punycodeHostMatch := false
|
punycodeHostMatch := false
|
||||||
if r.punycodeHostname != "" {
|
if r.punycodeHostname != "" {
|
||||||
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)
|
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/cfio"
|
"github.com/cloudflare/cloudflared/cfio"
|
||||||
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/stream"
|
"github.com/cloudflare/cloudflared/stream"
|
||||||
|
@ -86,7 +87,7 @@ func (p *Proxy) ProxyHTTP(
|
||||||
|
|
||||||
_, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match",
|
_, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match",
|
||||||
trace.WithAttributes(attribute.String("req-host", req.Host)))
|
trace.WithAttributes(attribute.String("req-host", req.Host)))
|
||||||
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path, req.Header.Get(carrier.CFJumpDestinationHeader))
|
||||||
ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum))
|
ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum))
|
||||||
ruleSpan.End()
|
ruleSpan.End()
|
||||||
logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String())
|
logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String())
|
||||||
|
@ -98,6 +99,29 @@ func (p *Proxy) ProxyHTTP(
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// Handling for StreamBasedOriginProxy or BastionMode
|
||||||
|
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
|
||||||
|
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
|
||||||
|
return fmt.Errorf("Unrecognized service: %s", rule.Service)
|
||||||
|
}
|
||||||
|
|
||||||
|
dest, err := getDestFromRule(rule, req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("response writer is not a flusher")
|
||||||
|
}
|
||||||
|
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
|
||||||
|
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
|
||||||
|
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
|
||||||
|
logRequestError(&logger, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
switch originProxy := rule.Service.(type) {
|
switch originProxy := rule.Service.(type) {
|
||||||
case ingress.HTTPOriginProxy:
|
case ingress.HTTPOriginProxy:
|
||||||
|
@ -113,22 +137,6 @@ func (p *Proxy) ProxyHTTP(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
case ingress.StreamBasedOriginProxy:
|
|
||||||
dest, err := getDestFromRule(rule, req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("response writer is not a flusher")
|
|
||||||
}
|
|
||||||
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
|
|
||||||
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
|
|
||||||
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy, &logger); err != nil {
|
|
||||||
logRequestError(&logger, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case ingress.HTTPLocalProxy:
|
case ingress.HTTPLocalProxy:
|
||||||
p.proxyLocalRequest(originProxy, w, req, isWebsocket)
|
p.proxyLocalRequest(originProxy, w, req, isWebsocket)
|
||||||
return nil
|
return nil
|
||||||
|
@ -335,10 +343,9 @@ func copyTrailers(w connection.ResponseWriter, response *http.Response) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
||||||
switch rule.Service.String() {
|
if rule.Config.BastionMode || rule.Service.String() == config.BastionFlag {
|
||||||
case ingress.ServiceBastion:
|
return carrier.ResolveBastionDest(req, rule.Config.BastionMode, rule.Service.String())
|
||||||
return carrier.ResolveBastionDest(req)
|
} else {
|
||||||
default:
|
|
||||||
return rule.Service.String(), nil
|
return rule.Service.String(), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue