TUN-3294: Perform basic validation on arguments of route command; remove default pool name which wasn't valid

This commit is contained in:
Igor Postelnik 2020-09-14 16:41:02 -05:00
parent bfae12008d
commit 5753aa9f18
2 changed files with 54 additions and 17 deletions

View File

@ -7,6 +7,7 @@ import (
"io/ioutil"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"text/tabwriter"
@ -355,50 +356,65 @@ func buildRouteCommand() *cli.Command {
Name: "route",
Action: cliutil.ErrorHandler(routeCommand),
Usage: "Define what hostname or load balancer can route to this tunnel",
Description: `The route defines what hostname or load balancer can route to this tunnel.
Description: `The route defines what hostname or load balancer will proxy requests to this tunnel.
To route a hostname: cloudflared tunnel route dns <tunnel ID> <hostname>
To use this tunnel as a load balancer origin: cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>`,
To route a hostname by creating a CNAME to tunnel's address:
cloudflared tunnel route dns <tunnel ID> <hostname>
To use this tunnel as a load balancer origin, creating pool and load balancer if necessary:
cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>`,
ArgsUsage: "dns|lb TUNNEL HOSTNAME [LB-POOL]",
}
}
func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
func dnsRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
const (
userHostnameIndex = 2
expectArgs = 3
expectedNArgs = 3
)
if c.NArg() != expectArgs {
return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg())
if c.NArg() != expectedNArgs {
return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg())
}
userHostname := c.Args().Get(userHostnameIndex)
if userHostname == "" {
return nil, cliutil.UsageError("The third argument should be the hostname")
} else if !validateName(userHostname) {
return nil, errors.Errorf("%s is not a valid hostname", userHostname)
}
return tunnelstore.NewDNSRoute(userHostname), nil
}
func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
const (
lbNameIndex = 2
lbPoolIndex = 3
expectMinArgs = 3
expectedNArgs = 4
)
if c.NArg() < expectMinArgs {
return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg())
if c.NArg() != expectedNArgs {
return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg())
}
lbName := c.Args().Get(lbNameIndex)
if lbName == "" {
return nil, cliutil.UsageError("The third argument should be the load balancer name")
} else if !validateName(lbName) {
return nil, errors.Errorf("%s is not a valid load balancer name", lbName)
}
lbPool := c.Args().Get(lbPoolIndex)
if lbPool == "" {
lbPool = defaultPoolName(tunnelID)
return nil, cliutil.UsageError("The fourth argument should be the pool name")
} else if !validateName(lbPool) {
return nil, errors.Errorf("%s is not a valid pool name", lbPool)
}
return tunnelstore.NewLBRoute(lbName, lbPool), nil
}
var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
func validateName(s string) bool {
return nameRegex.MatchString(s)
}
func routeCommand(c *cli.Context) error {
if c.NArg() < 2 {
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`)
@ -419,7 +435,7 @@ func routeCommand(c *cli.Context) error {
if err != nil {
return err
}
r, err = dnsRouteFromArg(c, tunnelID)
r, err = dnsRouteFromArg(c)
if err != nil {
return err
}
@ -428,7 +444,7 @@ func routeCommand(c *cli.Context) error {
if err != nil {
return err
}
r, err = lbRouteFromArg(c, tunnelID)
r, err = lbRouteFromArg(c)
if err != nil {
return err
}
@ -443,6 +459,3 @@ func routeCommand(c *cli.Context) error {
return nil
}
func defaultPoolName(tunnelID uuid.UUID) string {
return fmt.Sprintf("tunnel:%v", tunnelID)
}

View File

@ -98,3 +98,27 @@ func TestTunnelfilePath(t *testing.T) {
expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
assert.Equal(t, expected, actual)
}
func TestValidateName(t *testing.T) {
tests := []struct {
name string
want bool
}{
{name: "", want: false},
{name: "-", want: false},
{name: ".", want: false},
{name: "a b", want: false},
{name: "a+b", want: false},
{name: "-ab", want: false},
{name: "ab", want: true},
{name: "ab-c", want: true},
{name: "abc.def", want: true},
{name: "_ab_c.-d-ef", want: true},
}
for _, tt := range tests {
if got := validateName(tt.name); got != tt.want {
t.Errorf("validateName() = %v, want %v", got, tt.want)
}
}
}