cloudflared-mirror/ingress/ingress_test.go

843 lines
20 KiB
Go

package ingress
import (
"flag"
"fmt"
"net/http"
"net/url"
"regexp"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
yaml "gopkg.in/yaml.v3"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/cloudflare/cloudflared/tlsconfig"
)
func TestParseUnixSocket(t *testing.T) {
rawYAML := `
ingress:
- service: unix:/tmp/echo.sock
`
ing, err := ParseIngress(MustReadIngress(rawYAML))
require.NoError(t, err)
s, ok := ing.Rules[0].Service.(*unixSocketPath)
require.True(t, ok)
require.Equal(t, "http", s.scheme)
}
func TestParseUnixSocketTLS(t *testing.T) {
rawYAML := `
ingress:
- service: unix+tls:/tmp/echo.sock
`
ing, err := ParseIngress(MustReadIngress(rawYAML))
require.NoError(t, err)
s, ok := ing.Rules[0].Service.(*unixSocketPath)
require.True(t, ok)
require.Equal(t, "https", s.scheme)
}
func TestParseIngressNilConfig(t *testing.T) {
_, err := ParseIngress(nil)
require.Error(t, err)
}
func TestParseIngress(t *testing.T) {
localhost8000 := MustParseURL(t, "https://localhost:8000")
localhost8001 := MustParseURL(t, "https://localhost:8001")
fourOhFour := newStatusCode(404)
defaultConfig := setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{})
require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections)
tr := true
type args struct {
rawYAML string
}
tests := []struct {
name string
args args
want []Rule
wantErr bool
}{
{
name: "Empty file",
args: args{rawYAML: ""},
wantErr: true,
},
{
name: "Multiple rules",
args: args{rawYAML: `
ingress:
- hostname: tunnel1.example.com
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
`},
want: []Rule{
{
Hostname: "tunnel1.example.com",
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
{
Hostname: "*",
Service: &httpService{url: localhost8001},
Config: defaultConfig,
},
},
},
{
name: "Extra keys",
args: args{rawYAML: `
ingress:
- hostname: "*"
service: https://localhost:8000
extraKey: extraValue
`},
want: []Rule{
{
Hostname: "*",
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
},
},
{
name: "ws service",
args: args{rawYAML: `
ingress:
- hostname: "*"
service: wss://localhost:8000
`},
want: []Rule{
{
Hostname: "*",
Service: &httpService{url: MustParseURL(t, "wss://localhost:8000")},
Config: defaultConfig,
},
},
},
{
name: "Hostname can be omitted",
args: args{rawYAML: `
ingress:
- service: https://localhost:8000
`},
want: []Rule{
{
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
},
},
{
name: "Unicode domain",
args: args{rawYAML: `
ingress:
- hostname: môô.cloudflare.com
service: https://localhost:8000
- service: https://localhost:8001
`},
want: []Rule{
{
Hostname: "môô.cloudflare.com",
punycodeHostname: "xn--m-xgaa.cloudflare.com",
Service: &httpService{url: localhost8000},
Config: defaultConfig,
},
{
Service: &httpService{url: localhost8001},
Config: defaultConfig,
},
},
},
{
name: "Invalid unicode domain",
args: args{rawYAML: fmt.Sprintf(`
ingress:
- hostname: %s
service: https://localhost:8000
`, string(rune(0xd8f3))+".cloudflare.com")},
wantErr: true,
},
{
name: "Invalid service",
args: args{rawYAML: `
ingress:
- hostname: "*"
service: https://local host:8000
`},
wantErr: true,
},
{
name: "Last rule isn't catchall",
args: args{rawYAML: `
ingress:
- hostname: example.com
service: https://localhost:8000
`},
wantErr: true,
},
{
name: "First rule is catchall",
args: args{rawYAML: `
ingress:
- service: https://localhost:8000
- hostname: example.com
service: https://localhost:8000
`},
wantErr: true,
},
{
name: "Catch-all rule can't have a path",
args: args{rawYAML: `
ingress:
- service: https://localhost:8001
path: /subpath1/(.*)/subpath2
`},
wantErr: true,
},
{
name: "Invalid regex",
args: args{rawYAML: `
ingress:
- hostname: example.com
service: https://localhost:8000
path: "*/subpath2"
- service: https://localhost:8001
`},
wantErr: true,
},
{
name: "Service must have a scheme",
args: args{rawYAML: `
ingress:
- service: localhost:8000
`},
wantErr: true,
},
{
name: "Wildcard not at start",
args: args{rawYAML: `
ingress:
- hostname: "test.*.example.com"
service: https://localhost:8000
`},
wantErr: true,
},
{
name: "Service can't have a path",
args: args{rawYAML: `
ingress:
- service: https://localhost:8000/static/
`},
wantErr: true,
},
{
name: "Invalid HTTP status",
args: args{rawYAML: `
ingress:
- service: http_status:asdf
`},
wantErr: true,
},
{
name: "Invalid HTTP status code",
args: args{rawYAML: `
ingress:
- service: http_status:8080
`},
wantErr: true,
},
{
name: "Valid HTTP status",
args: args{rawYAML: `
ingress:
- service: http_status:404
`},
want: []Rule{
{
Hostname: "",
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "Valid hello world service",
args: args{rawYAML: `
ingress:
- service: hello_world
`},
want: []Rule{
{
Hostname: "",
Service: new(helloWorld),
Config: defaultConfig,
},
},
},
{
name: "TCP services",
args: args{rawYAML: `
ingress:
- hostname: tcp.foo.com
service: tcp://127.0.0.1
- hostname: tcp2.foo.com
service: tcp://localhost:8000
- service: http_status:404
`},
want: []Rule{
{
Hostname: "tcp.foo.com",
Service: newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")),
Config: defaultConfig,
},
{
Hostname: "tcp2.foo.com",
Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")),
Config: defaultConfig,
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "SSH services",
args: args{rawYAML: `
ingress:
- service: ssh://127.0.0.1
`},
want: []Rule{
{
Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")),
Config: defaultConfig,
},
},
},
{
name: "RDP services",
args: args{rawYAML: `
ingress:
- service: rdp://127.0.0.1
`},
want: []Rule{
{
Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")),
Config: defaultConfig,
},
},
},
{
name: "SMB services",
args: args{rawYAML: `
ingress:
- service: smb://127.0.0.1
`},
want: []Rule{
{
Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")),
Config: defaultConfig,
},
},
},
{
name: "Other TCP services",
args: args{rawYAML: `
ingress:
- service: ftp://127.0.0.1
`},
want: []Rule{
{
Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")),
Config: defaultConfig,
},
},
},
{
name: "SOCKS services",
args: args{rawYAML: `
ingress:
- hostname: socks.foo.com
service: socks-proxy
originRequest:
ipRules:
- prefix: 1.1.1.0/24
ports: [80, 443]
allow: true
- prefix: 0.0.0.0/0
allow: false
- service: http_status:404
`},
want: []Rule{
{
Hostname: "socks.foo.com",
Service: newSocksProxyOverWSService(accessPolicy()),
Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{IPRules: []config.IngressIPRule{
{
Prefix: ipRulePrefix("1.1.1.0/24"),
Ports: []int{80, 443},
Allow: true,
},
{
Prefix: ipRulePrefix("0.0.0.0/0"),
Allow: false,
},
}}),
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "URL isn't necessary if using bastion",
args: args{rawYAML: `
ingress:
- hostname: bastion.foo.com
originRequest:
bastionMode: true
- service: http_status:404
`},
want: []Rule{
{
Hostname: "bastion.foo.com",
Service: newBastionService(),
Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "Bastion service",
args: args{rawYAML: `
ingress:
- hostname: bastion.foo.com
service: bastion
- service: http_status:404
`},
want: []Rule{
{
Hostname: "bastion.foo.com",
Service: newBastionService(),
Config: setConfig(originRequestFromConfig(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
},
{
Service: &fourOhFour,
Config: defaultConfig,
},
},
},
{
name: "Hostname contains port",
args: args{rawYAML: `
ingress:
- hostname: "test.example.com:443"
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
`},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseIngress(MustReadIngress(tt.args.rawYAML))
if (err != nil) != tt.wantErr {
t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr)
return
}
require.Equal(t, tt.want, got.Rules)
})
}
}
func ipRulePrefix(s string) *string {
return &s
}
func TestSingleOriginSetsConfig(t *testing.T) {
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
flagSet.Bool("hello-world", true, "")
flagSet.Duration(ProxyConnectTimeoutFlag, time.Second, "")
flagSet.Duration(ProxyTLSTimeoutFlag, time.Second, "")
flagSet.Duration(ProxyTCPKeepAliveFlag, time.Second, "")
flagSet.Bool(ProxyNoHappyEyeballsFlag, true, "")
flagSet.Int(ProxyKeepAliveConnectionsFlag, 10, "")
flagSet.Duration(ProxyKeepAliveTimeoutFlag, time.Second, "")
flagSet.String(HTTPHostHeaderFlag, "example.com:8080", "")
flagSet.String(OriginServerNameFlag, "example.com", "")
flagSet.String(tlsconfig.OriginCAPoolFlag, "/etc/certs/ca.pem", "")
flagSet.Bool(NoTLSVerifyFlag, true, "")
flagSet.Bool(NoChunkedEncodingFlag, true, "")
flagSet.Bool(config.BastionFlag, true, "")
flagSet.String(ProxyAddressFlag, "localhost:8080", "")
flagSet.Uint(ProxyPortFlag, 8080, "")
flagSet.Bool(Socks5Flag, true, "")
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
err := cliCtx.Set("hello-world", "true")
require.NoError(t, err)
err = cliCtx.Set(ProxyConnectTimeoutFlag, "1s")
require.NoError(t, err)
err = cliCtx.Set(ProxyTLSTimeoutFlag, "1s")
require.NoError(t, err)
err = cliCtx.Set(ProxyTCPKeepAliveFlag, "1s")
require.NoError(t, err)
err = cliCtx.Set(ProxyNoHappyEyeballsFlag, "true")
require.NoError(t, err)
err = cliCtx.Set(ProxyKeepAliveConnectionsFlag, "10")
require.NoError(t, err)
err = cliCtx.Set(ProxyKeepAliveTimeoutFlag, "1s")
require.NoError(t, err)
err = cliCtx.Set(HTTPHostHeaderFlag, "example.com:8080")
require.NoError(t, err)
err = cliCtx.Set(OriginServerNameFlag, "example.com")
require.NoError(t, err)
err = cliCtx.Set(tlsconfig.OriginCAPoolFlag, "/etc/certs/ca.pem")
require.NoError(t, err)
err = cliCtx.Set(NoTLSVerifyFlag, "true")
require.NoError(t, err)
err = cliCtx.Set(NoChunkedEncodingFlag, "true")
require.NoError(t, err)
err = cliCtx.Set(config.BastionFlag, "true")
require.NoError(t, err)
err = cliCtx.Set(ProxyAddressFlag, "localhost:8080")
require.NoError(t, err)
err = cliCtx.Set(ProxyPortFlag, "8080")
require.NoError(t, err)
err = cliCtx.Set(Socks5Flag, "true")
require.NoError(t, err)
allowURLFromArgs := false
require.NoError(t, err)
ingress, err := parseCLIIngress(cliCtx, allowURLFromArgs)
require.NoError(t, err)
assert.Equal(t, config.CustomDuration{Duration: time.Second}, ingress.Rules[0].Config.ConnectTimeout)
assert.Equal(t, config.CustomDuration{Duration: time.Second}, ingress.Rules[0].Config.TLSTimeout)
assert.Equal(t, config.CustomDuration{Duration: time.Second}, ingress.Rules[0].Config.TCPKeepAlive)
assert.True(t, ingress.Rules[0].Config.NoHappyEyeballs)
assert.Equal(t, 10, ingress.Rules[0].Config.KeepAliveConnections)
assert.Equal(t, config.CustomDuration{Duration: time.Second}, ingress.Rules[0].Config.KeepAliveTimeout)
assert.Equal(t, "example.com:8080", ingress.Rules[0].Config.HTTPHostHeader)
assert.Equal(t, "example.com", ingress.Rules[0].Config.OriginServerName)
assert.Equal(t, "/etc/certs/ca.pem", ingress.Rules[0].Config.CAPool)
assert.True(t, ingress.Rules[0].Config.NoTLSVerify)
assert.True(t, ingress.Rules[0].Config.DisableChunkedEncoding)
assert.True(t, ingress.Rules[0].Config.BastionMode)
assert.Equal(t, "localhost:8080", ingress.Rules[0].Config.ProxyAddress)
assert.Equal(t, uint(8080), ingress.Rules[0].Config.ProxyPort)
assert.Equal(t, socksProxy, ingress.Rules[0].Config.ProxyType)
}
func TestSingleOriginServices(t *testing.T) {
host := "://localhost:8080"
httpURL := urlMustParse("http" + host)
tcpURL := urlMustParse("tcp" + host)
unix := "unix://service"
newCli := func(params ...string) *cli.Context {
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
flagSet.Bool("hello-world", false, "")
flagSet.Bool("bastion", false, "")
flagSet.String("url", "", "")
flagSet.String("unix-socket", "", "")
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
for i := 0; i < len(params); i += 2 {
cliCtx.Set(params[i], params[i+1])
}
return cliCtx
}
tests := []struct {
name string
cli *cli.Context
expectedService OriginService
err error
}{
{
name: "Valid hello-world",
cli: newCli("hello-world", "true"),
expectedService: &helloWorld{},
},
{
name: "Valid bastion",
cli: newCli("bastion", "true"),
expectedService: newBastionService(),
},
{
name: "Valid http url",
cli: newCli("url", httpURL.String()),
expectedService: &httpService{url: httpURL},
},
{
name: "Valid tcp url",
cli: newCli("url", tcpURL.String()),
expectedService: newTCPOverWSService(tcpURL),
},
{
name: "Valid unix-socket",
cli: newCli("unix-socket", unix),
expectedService: &unixSocketPath{path: unix, scheme: "http"},
},
{
name: "No origins defined",
cli: newCli(),
err: ErrNoIngressRulesCLI,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ingress, err := parseCLIIngress(test.cli, false)
require.Equal(t, err, test.err)
if test.err != nil {
return
}
require.Equal(t, 1, len(ingress.Rules))
rule := ingress.Rules[0]
require.Equal(t, test.expectedService, rule.Service)
})
}
}
func urlMustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}
func TestSingleOriginServices_URL(t *testing.T) {
host := "://localhost:8080"
newCli := func(param string, value string) *cli.Context {
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
flagSet.String("url", "", "")
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
cliCtx.Set(param, value)
return cliCtx
}
httpTests := []string{"http", "https"}
for _, test := range httpTests {
t.Run(test, func(t *testing.T) {
url := urlMustParse(test + host)
ingress, err := parseCLIIngress(newCli("url", url.String()), false)
require.NoError(t, err)
require.Equal(t, 1, len(ingress.Rules))
rule := ingress.Rules[0]
require.Equal(t, &httpService{url: url}, rule.Service)
})
}
tcpTests := []string{"ssh", "rdp", "smb", "tcp"}
for _, test := range tcpTests {
t.Run(test, func(t *testing.T) {
url := urlMustParse(test + host)
ingress, err := parseCLIIngress(newCli("url", url.String()), false)
require.NoError(t, err)
require.Equal(t, 1, len(ingress.Rules))
rule := ingress.Rules[0]
require.Equal(t, newTCPOverWSService(url), rule.Service)
})
}
}
func TestFindMatchingRule(t *testing.T) {
ingress := Ingress{
Rules: []Rule{
{
Hostname: "tunnel-a.example.com",
Path: nil,
},
{
Hostname: "tunnel-b.example.com",
Path: MustParsePath(t, "/health"),
},
{
Hostname: "*",
},
},
}
tests := []struct {
host string
path string
req *http.Request
wantRuleIndex int
}{
{
host: "tunnel-a.example.com",
path: "/",
wantRuleIndex: 0,
},
{
host: "tunnel-a.example.com",
path: "/pages/about",
wantRuleIndex: 0,
},
{
host: "tunnel-a.example.com:443",
path: "/pages/about",
wantRuleIndex: 0,
},
{
host: "tunnel-b.example.com",
path: "/health",
wantRuleIndex: 1,
},
{
host: "tunnel-b.example.com",
path: "/index.html",
wantRuleIndex: 2,
},
{
host: "tunnel-c.example.com",
path: "/",
wantRuleIndex: 2,
},
}
for _, test := range tests {
_, ruleIndex := ingress.FindMatchingRule(test.host, test.path)
assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex))
}
}
func TestIsHTTPService(t *testing.T) {
tests := []struct {
url *url.URL
isHTTP bool
}{
{
url: MustParseURL(t, "http://localhost"),
isHTTP: true,
},
{
url: MustParseURL(t, "https://127.0.0.1:8000"),
isHTTP: true,
},
{
url: MustParseURL(t, "ws://localhost"),
isHTTP: true,
},
{
url: MustParseURL(t, "wss://localhost:8000"),
isHTTP: true,
},
{
url: MustParseURL(t, "tcp://localhost:9000"),
isHTTP: false,
},
}
for _, test := range tests {
assert.Equal(t, test.isHTTP, isHTTPService(test.url))
}
}
func MustParsePath(t *testing.T, path string) *Regexp {
regexp, err := regexp.Compile(path)
assert.NoError(t, err)
return &Regexp{Regexp: regexp}
}
func MustParseURL(t *testing.T, rawURL string) *url.URL {
u, err := url.Parse(rawURL)
require.NoError(t, err)
return u
}
func accessPolicy() *ipaccess.Policy {
cidr1 := "1.1.1.0/24"
cidr2 := "0.0.0.0/0"
rule1, _ := ipaccess.NewRuleByCIDR(&cidr1, []int{80, 443}, true)
rule2, _ := ipaccess.NewRuleByCIDR(&cidr2, nil, false)
rules := []ipaccess.Rule{rule1, rule2}
accessPolicy, _ := ipaccess.NewPolicy(false, rules)
return accessPolicy
}
func BenchmarkFindMatch(b *testing.B) {
rulesYAML := `
ingress:
- hostname: tunnel1.example.com
service: https://localhost:8000
- hostname: tunnel2.example.com
service: https://localhost:8001
- hostname: "*"
service: https://localhost:8002
`
ing, err := ParseIngress(MustReadIngress(rulesYAML))
if err != nil {
b.Error(err)
}
for n := 0; n < b.N; n++ {
ing.FindMatchingRule("tunnel1.example.com", "")
ing.FindMatchingRule("tunnel2.example.com", "")
ing.FindMatchingRule("tunnel3.example.com", "")
}
}
func TestParseAccessConfig(t *testing.T) {
tests := []struct {
name string
cfg config.AccessConfig
expectError bool
}{
{
name: "Config required with teamName only",
cfg: config.AccessConfig{Required: true, TeamName: "team"},
expectError: false,
},
{
name: "required false",
cfg: config.AccessConfig{Required: false},
expectError: false,
},
{
name: "required true but empty config",
cfg: config.AccessConfig{Required: true},
expectError: false,
},
{
name: "complete config",
cfg: config.AccessConfig{Required: true, TeamName: "team", AudTag: []string{"a"}},
expectError: false,
},
{
name: "required true with audTags but no teamName",
cfg: config.AccessConfig{Required: true, AudTag: []string{"a"}},
expectError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := validateAccessConfiguration(&test.cfg)
require.Equal(t, err != nil, test.expectError)
})
}
}
func MustReadIngress(s string) *config.Configuration {
var conf config.Configuration
err := yaml.Unmarshal([]byte(s), &conf)
if err != nil {
panic(err)
}
return &conf
}