TUN-3294: Perform basic validation on arguments of route command; remove default pool name which wasn't valid
This commit is contained in:
parent
bfae12008d
commit
5753aa9f18
|
@ -7,6 +7,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"text/tabwriter"
|
"text/tabwriter"
|
||||||
|
@ -355,50 +356,65 @@ func buildRouteCommand() *cli.Command {
|
||||||
Name: "route",
|
Name: "route",
|
||||||
Action: cliutil.ErrorHandler(routeCommand),
|
Action: cliutil.ErrorHandler(routeCommand),
|
||||||
Usage: "Define what hostname or load balancer can route to this tunnel",
|
Usage: "Define what hostname or load balancer can route to this tunnel",
|
||||||
Description: `The route defines what hostname or load balancer can route to this tunnel.
|
Description: `The route defines what hostname or load balancer will proxy requests to this tunnel.
|
||||||
|
|
||||||
To route a hostname: cloudflared tunnel route dns <tunnel ID> <hostname>
|
To route a hostname by creating a CNAME to tunnel's address:
|
||||||
To use this tunnel as a load balancer origin: cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>`,
|
cloudflared tunnel route dns <tunnel ID> <hostname>
|
||||||
|
To use this tunnel as a load balancer origin, creating pool and load balancer if necessary:
|
||||||
|
cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>`,
|
||||||
ArgsUsage: "dns|lb TUNNEL HOSTNAME [LB-POOL]",
|
ArgsUsage: "dns|lb TUNNEL HOSTNAME [LB-POOL]",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
|
func dnsRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
|
||||||
const (
|
const (
|
||||||
userHostnameIndex = 2
|
userHostnameIndex = 2
|
||||||
expectArgs = 3
|
expectedNArgs = 3
|
||||||
)
|
)
|
||||||
if c.NArg() != expectArgs {
|
if c.NArg() != expectedNArgs {
|
||||||
return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg())
|
return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg())
|
||||||
}
|
}
|
||||||
userHostname := c.Args().Get(userHostnameIndex)
|
userHostname := c.Args().Get(userHostnameIndex)
|
||||||
if userHostname == "" {
|
if userHostname == "" {
|
||||||
return nil, cliutil.UsageError("The third argument should be the hostname")
|
return nil, cliutil.UsageError("The third argument should be the hostname")
|
||||||
|
} else if !validateName(userHostname) {
|
||||||
|
return nil, errors.Errorf("%s is not a valid hostname", userHostname)
|
||||||
}
|
}
|
||||||
return tunnelstore.NewDNSRoute(userHostname), nil
|
return tunnelstore.NewDNSRoute(userHostname), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
|
func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
|
||||||
const (
|
const (
|
||||||
lbNameIndex = 2
|
lbNameIndex = 2
|
||||||
lbPoolIndex = 3
|
lbPoolIndex = 3
|
||||||
expectMinArgs = 3
|
expectedNArgs = 4
|
||||||
)
|
)
|
||||||
if c.NArg() < expectMinArgs {
|
if c.NArg() != expectedNArgs {
|
||||||
return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg())
|
return nil, cliutil.UsageError("Expected %d arguments, got %d", expectedNArgs, c.NArg())
|
||||||
}
|
}
|
||||||
lbName := c.Args().Get(lbNameIndex)
|
lbName := c.Args().Get(lbNameIndex)
|
||||||
if lbName == "" {
|
if lbName == "" {
|
||||||
return nil, cliutil.UsageError("The third argument should be the load balancer name")
|
return nil, cliutil.UsageError("The third argument should be the load balancer name")
|
||||||
|
} else if !validateName(lbName) {
|
||||||
|
return nil, errors.Errorf("%s is not a valid load balancer name", lbName)
|
||||||
}
|
}
|
||||||
|
|
||||||
lbPool := c.Args().Get(lbPoolIndex)
|
lbPool := c.Args().Get(lbPoolIndex)
|
||||||
if lbPool == "" {
|
if lbPool == "" {
|
||||||
lbPool = defaultPoolName(tunnelID)
|
return nil, cliutil.UsageError("The fourth argument should be the pool name")
|
||||||
|
} else if !validateName(lbPool) {
|
||||||
|
return nil, errors.Errorf("%s is not a valid pool name", lbPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
return tunnelstore.NewLBRoute(lbName, lbPool), nil
|
return tunnelstore.NewLBRoute(lbName, lbPool), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||||
|
|
||||||
|
func validateName(s string) bool {
|
||||||
|
return nameRegex.MatchString(s)
|
||||||
|
}
|
||||||
|
|
||||||
func routeCommand(c *cli.Context) error {
|
func routeCommand(c *cli.Context) error {
|
||||||
if c.NArg() < 2 {
|
if c.NArg() < 2 {
|
||||||
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`)
|
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`)
|
||||||
|
@ -419,7 +435,7 @@ func routeCommand(c *cli.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r, err = dnsRouteFromArg(c, tunnelID)
|
r, err = dnsRouteFromArg(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -428,7 +444,7 @@ func routeCommand(c *cli.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r, err = lbRouteFromArg(c, tunnelID)
|
r, err = lbRouteFromArg(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -443,6 +459,3 @@ func routeCommand(c *cli.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultPoolName(tunnelID uuid.UUID) string {
|
|
||||||
return fmt.Sprintf("tunnel:%v", tunnelID)
|
|
||||||
}
|
|
||||||
|
|
|
@ -98,3 +98,27 @@ func TestTunnelfilePath(t *testing.T) {
|
||||||
expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
|
expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{name: "", want: false},
|
||||||
|
{name: "-", want: false},
|
||||||
|
{name: ".", want: false},
|
||||||
|
{name: "a b", want: false},
|
||||||
|
{name: "a+b", want: false},
|
||||||
|
{name: "-ab", want: false},
|
||||||
|
|
||||||
|
{name: "ab", want: true},
|
||||||
|
{name: "ab-c", want: true},
|
||||||
|
{name: "abc.def", want: true},
|
||||||
|
{name: "_ab_c.-d-ef", want: true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := validateName(tt.name); got != tt.want {
|
||||||
|
t.Errorf("validateName() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue