TUN-3294: Perform basic validation on arguments of route command; remove default pool name which wasn't valid

This commit is contained in:
Igor Postelnik 2020-09-14 16:41:02 -05:00
parent bfae12008d
commit 5753aa9f18
2 changed files with 54 additions and 17 deletions

View File

@ -7,6 +7,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strings" "strings"
"text/tabwriter" "text/tabwriter"
@ -355,50 +356,65 @@ func buildRouteCommand() *cli.Command {
Name: "route", Name: "route",
Action: cliutil.ErrorHandler(routeCommand), Action: cliutil.ErrorHandler(routeCommand),
Usage: "Define what hostname or load balancer can route to this tunnel", Usage: "Define what hostname or load balancer can route to this tunnel",
Description: `The route defines what hostname or load balancer can route to this tunnel. Description: `The route defines what hostname or load balancer will proxy requests to this tunnel.
To route a hostname: cloudflared tunnel route dns <tunnel ID> <hostname> To route a hostname by creating a CNAME to tunnel's address:
To use this tunnel as a load balancer origin: cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>`, cloudflared tunnel route dns <tunnel ID> <hostname>
To use this tunnel as a load balancer origin, creating pool and load balancer if necessary:
cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>`,
ArgsUsage: "dns|lb TUNNEL HOSTNAME [LB-POOL]", ArgsUsage: "dns|lb TUNNEL HOSTNAME [LB-POOL]",
} }
} }
func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { func dnsRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
const ( const (
userHostnameIndex = 2 userHostnameIndex = 2
expectArgs = 3 expectedNArgs = 3
) )
if c.NArg() != expectArgs { if c.NArg() != expectedNArgs {
return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg()) return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg())
} }
userHostname := c.Args().Get(userHostnameIndex) userHostname := c.Args().Get(userHostnameIndex)
if userHostname == "" { if userHostname == "" {
return nil, cliutil.UsageError("The third argument should be the hostname") return nil, cliutil.UsageError("The third argument should be the hostname")
} else if !validateName(userHostname) {
return nil, errors.Errorf("%s is not a valid hostname", userHostname)
} }
return tunnelstore.NewDNSRoute(userHostname), nil return tunnelstore.NewDNSRoute(userHostname), nil
} }
func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
const ( const (
lbNameIndex = 2 lbNameIndex = 2
lbPoolIndex = 3 lbPoolIndex = 3
expectMinArgs = 3 expectedNArgs = 4
) )
if c.NArg() < expectMinArgs { if c.NArg() != expectedNArgs {
return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg()) return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg())
} }
lbName := c.Args().Get(lbNameIndex) lbName := c.Args().Get(lbNameIndex)
if lbName == "" { if lbName == "" {
return nil, cliutil.UsageError("The third argument should be the load balancer name") return nil, cliutil.UsageError("The third argument should be the load balancer name")
} else if !validateName(lbName) {
return nil, errors.Errorf("%s is not a valid load balancer name", lbName)
} }
lbPool := c.Args().Get(lbPoolIndex) lbPool := c.Args().Get(lbPoolIndex)
if lbPool == "" { if lbPool == "" {
lbPool = defaultPoolName(tunnelID) return nil, cliutil.UsageError("The fourth argument should be the pool name")
} else if !validateName(lbPool) {
return nil, errors.Errorf("%s is not a valid pool name", lbPool)
} }
return tunnelstore.NewLBRoute(lbName, lbPool), nil return tunnelstore.NewLBRoute(lbName, lbPool), nil
} }
var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
func validateName(s string) bool {
return nameRegex.MatchString(s)
}
func routeCommand(c *cli.Context) error { func routeCommand(c *cli.Context) error {
if c.NArg() < 2 { if c.NArg() < 2 {
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`) return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`)
@ -419,7 +435,7 @@ func routeCommand(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
r, err = dnsRouteFromArg(c, tunnelID) r, err = dnsRouteFromArg(c)
if err != nil { if err != nil {
return err return err
} }
@ -428,7 +444,7 @@ func routeCommand(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
r, err = lbRouteFromArg(c, tunnelID) r, err = lbRouteFromArg(c)
if err != nil { if err != nil {
return err return err
} }
@ -443,6 +459,3 @@ func routeCommand(c *cli.Context) error {
return nil return nil
} }
func defaultPoolName(tunnelID uuid.UUID) string {
return fmt.Sprintf("tunnel:%v", tunnelID)
}

View File

@ -98,3 +98,27 @@ func TestTunnelfilePath(t *testing.T) {
expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID) expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
} }
func TestValidateName(t *testing.T) {
tests := []struct {
name string
want bool
}{
{name: "", want: false},
{name: "-", want: false},
{name: ".", want: false},
{name: "a b", want: false},
{name: "a+b", want: false},
{name: "-ab", want: false},
{name: "ab", want: true},
{name: "ab-c", want: true},
{name: "abc.def", want: true},
{name: "_ab_c.-d-ef", want: true},
}
for _, tt := range tests {
if got := validateName(tt.name); got != tt.want {
t.Errorf("validateName() = %v, want %v", got, tt.want)
}
}
}