diff --git a/ingress/ingress.go b/ingress/ingress.go index 9c910c7d..86549a24 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -2,6 +2,7 @@ package ingress import ( "fmt" + "net" "net/url" "regexp" "strconv" @@ -19,6 +20,7 @@ var ( ErrNoIngressRules = errors.New("No ingress rules were specified in the config file") errLastRuleNotCatchAll = errors.New("The last ingress rule must match all hostnames (i.e. it must be missing, or must be \"*\")") errBadWildcard = errors.New("Hostname patterns can have at most one wildcard character (\"*\") and it can only be used for subdomains, e.g. \"*.example.com\"") + errHostnameContainsPort = errors.New("Hostname cannot contain port") ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules") ) @@ -26,6 +28,11 @@ var ( // hostname and path. This function assumes the last rule matches everything, // which is the case if the rules were instantiated via the ingress#Validate method func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) { + // The hostname might contain port. We only want to compare the host part with the rule + host, _, err := net.SplitHostPort(hostname) + if err == nil { + hostname = host + } for i, rule := range ing.Rules { if rule.Matches(hostname, path) { return &rule, i @@ -157,21 +164,8 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon service = &serviceURL } - // Ensure that there are no wildcards anywhere except the first character - // of the hostname. - if strings.LastIndex(r.Hostname, "*") > 0 { - return Ingress{}, errBadWildcard - } - - // The last rule should catch all hostnames. - isCatchAllRule := (r.Hostname == "" || r.Hostname == "*") && r.Path == "" - isLastRule := i == len(ingress)-1 - if isLastRule && !isCatchAllRule { - return Ingress{}, errLastRuleNotCatchAll - } - // ONLY the last rule should catch all hostnames. - if !isLastRule && isCatchAllRule { - return Ingress{}, errRuleShouldNotBeCatchAll{i: i, hostname: r.Hostname} + if err := validateHostname(r, i, len(ingress)); err != nil { + return Ingress{}, err } var pathRegex *regexp.Regexp @@ -193,15 +187,40 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon return Ingress{Rules: rules, defaults: defaults}, nil } +func validateHostname(r config.UnvalidatedIngressRule, ruleIndex, totalRules int) error { + // Ensure that the hostname doesn't contain port + _, _, err := net.SplitHostPort(r.Hostname) + if err == nil { + return errHostnameContainsPort + } + // Ensure that there are no wildcards anywhere except the first character + // of the hostname. + if strings.LastIndex(r.Hostname, "*") > 0 { + return errBadWildcard + } + + // The last rule should catch all hostnames. + isCatchAllRule := (r.Hostname == "" || r.Hostname == "*") && r.Path == "" + isLastRule := ruleIndex == totalRules-1 + if isLastRule && !isCatchAllRule { + return errLastRuleNotCatchAll + } + // ONLY the last rule should catch all hostnames. + if !isLastRule && isCatchAllRule { + return errRuleShouldNotBeCatchAll{index: ruleIndex, hostname: r.Hostname} + } + return nil +} + type errRuleShouldNotBeCatchAll struct { - i int + index int hostname string } func (e errRuleShouldNotBeCatchAll) Error() string { return fmt.Sprintf("Rule #%d is matching the hostname '%s', but "+ "this will match every hostname, meaning the rules which follow it "+ - "will never be triggered.", e.i+1, e.hostname) + "will never be triggered.", e.index+1, e.hostname) } // ParseIngress parses ingress rules, but does not send HTTP requests to the origins. diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 88b79d0e..7c5a945d 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -1,9 +1,12 @@ package ingress import ( + "fmt" "net/url" + "regexp" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" @@ -201,6 +204,17 @@ ingress: }, }, }, + { + name: "Hostname contains port", + args: args{rawYAML: ` +ingress: + - hostname: "test.example.com:443" + service: https://localhost:8000 + - hostname: "*" + service: https://localhost:8001 +`}, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -214,6 +228,72 @@ ingress: } } +func TestFindMatchingRule(t *testing.T) { + ingress := Ingress{ + Rules: []Rule{ + { + Hostname: "tunnel-a.example.com", + Path: nil, + }, + { + Hostname: "tunnel-b.example.com", + Path: mustParsePath(t, "/health"), + }, + { + Hostname: "*", + }, + }, + } + + tests := []struct { + host string + path string + wantRuleIndex int + }{ + { + host: "tunnel-a.example.com", + path: "/", + wantRuleIndex: 0, + }, + { + host: "tunnel-a.example.com", + path: "/pages/about", + wantRuleIndex: 0, + }, + { + host: "tunnel-a.example.com:443", + path: "/pages/about", + wantRuleIndex: 0, + }, + { + host: "tunnel-b.example.com", + path: "/health", + wantRuleIndex: 1, + }, + { + host: "tunnel-b.example.com", + path: "/index.html", + wantRuleIndex: 2, + }, + { + host: "tunnel-c.example.com", + path: "/", + wantRuleIndex: 2, + }, + } + + for i, test := range tests { + _, ruleIndex := ingress.FindMatchingRule(test.host, test.path) + assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, i)) + } +} + +func mustParsePath(t *testing.T, path string) *regexp.Regexp { + regexp, err := regexp.Compile(path) + assert.NoError(t, err) + return regexp +} + func MustParseURL(t *testing.T, rawURL string) *url.URL { u, err := url.Parse(rawURL) require.NoError(t, err) diff --git a/origin/tunnel.go b/origin/tunnel.go index 4558f3c4..564e17e0 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -697,6 +697,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, return nil, nil, errors.Wrap(err, "invalid request received") } h.AppendTagHeaders(req) + // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) return req, rule, nil }