TUN-3286: Use either ID or name in Named Tunnel subcommands.

This commit is contained in:
Adam Chalmers 2020-08-18 16:54:05 -05:00
parent 60de05bfc1
commit 1a96889141
5 changed files with 160 additions and 31 deletions

View File

@ -138,17 +138,17 @@ func (sc *subcommandContext) tunnelCredentialsPath(tunnelID uuid.UUID) (string,
func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) { func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) {
client, err := sc.client() client, err := sc.client()
if err != nil { if err != nil {
return nil, err return nil, errors.Wrap(err, "couldn't create client to talk to Argo Tunnel backend")
} }
tunnelSecret, err := generateTunnelSecret() tunnelSecret, err := generateTunnelSecret()
if err != nil { if err != nil {
return nil, err return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel")
} }
tunnel, err := client.CreateTunnel(name, tunnelSecret) tunnel, err := client.CreateTunnel(name, tunnelSecret)
if err != nil { if err != nil {
return nil, err return nil, errors.Wrap(err, "Create Tunnel API call failed")
} }
credential, err := sc.credential() credential, err := sc.credential()
@ -280,3 +280,61 @@ func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, boo
// There should only be 1 active tunnel for a given name // There should only be 1 active tunnel for a given name
return tunnels[0], true, nil return tunnels[0], true, nil
} }
// findID parses the input. If it's a UUID, return the UUID.
// Otherwise, assume it's a name, and look up the ID of that tunnel.
func (sc *subcommandContext) findID(input string) (uuid.UUID, error) {
if u, err := uuid.Parse(input); err == nil {
return u, nil
}
if tunnel, found, err := sc.tunnelActive(input); err != nil {
return uuid.Nil, err
} else if found {
return tunnel.ID, nil
}
return uuid.Nil, fmt.Errorf("%s is neither the ID nor the name of any of your tunnels", input)
}
// findIDs is just like mapping `findID` over a slice, but it only uses
// one Tunnelstore API call.
func (sc *subcommandContext) findIDs(inputs []string) ([]uuid.UUID, error) {
// First, look up all tunnels the user has
filter := tunnelstore.NewFilter()
filter.NoDeleted()
tunnels, err := sc.list(filter)
if err != nil {
return nil, err
}
// Do the pure list-processing in its own function, so that it can be
// unit tested easily.
return findIDs(tunnels, inputs)
}
func findIDs(tunnels []*tunnelstore.Tunnel, inputs []string) ([]uuid.UUID, error) {
// Put them into a dictionary for faster lookups
nameToID := make(map[string]uuid.UUID, len(tunnels))
for _, tunnel := range tunnels {
nameToID[tunnel.Name] = tunnel.ID
}
// For each input, try to find the tunnel ID.
tunnelIDs := make([]uuid.UUID, len(inputs))
var badInputs []string
for i, input := range inputs {
if id, err := uuid.Parse(input); err == nil {
tunnelIDs[i] = id
} else if id, ok := nameToID[input]; ok {
tunnelIDs[i] = id
} else {
badInputs = append(badInputs, input)
}
}
if len(badInputs) > 0 {
msg := "Please specify either the ID or name of a tunnel. The following inputs were neither: %s"
return nil, fmt.Errorf(msg, strings.Join(badInputs, ", "))
}
return tunnelIDs, nil
}

View File

@ -0,0 +1,82 @@
package tunnel
import (
"reflect"
"testing"
"github.com/cloudflare/cloudflared/tunnelstore"
"github.com/google/uuid"
)
func Test_findIDs(t *testing.T) {
type args struct {
tunnels []*tunnelstore.Tunnel
inputs []string
}
tests := []struct {
name string
args args
want []uuid.UUID
wantErr bool
}{
{
name: "input not found",
args: args{
inputs: []string{"asdf"},
},
wantErr: true,
},
{
name: "only UUID",
args: args{
inputs: []string{"a8398a0b-876d-48ed-b609-3fcfd67a4950"},
},
want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")},
},
{
name: "only name",
args: args{
tunnels: []*tunnelstore.Tunnel{
{
ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
Name: "tunnel1",
},
},
inputs: []string{"tunnel1"},
},
want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")},
},
{
name: "both UUID and name",
args: args{
tunnels: []*tunnelstore.Tunnel{
{
ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
Name: "tunnel1",
},
{
ID: uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"),
Name: "tunnel2",
},
},
inputs: []string{"tunnel1", "bf028b68-744f-466e-97f8-c46161d80aa5"},
},
want: []uuid.UUID{
uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := findIDs(tt.args.tunnels, tt.args.inputs)
if (err != nil) != tt.wantErr {
t.Errorf("findIDs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("findIDs() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -272,10 +272,10 @@ func deleteCommand(c *cli.Context) error {
} }
if c.NArg() < 1 { if c.NArg() < 1 {
return cliutil.UsageError(`"cloudflared tunnel delete" requires at least argument, the ID of the tunnel to delete.`) return cliutil.UsageError(`"cloudflared tunnel delete" requires at least 1 argument, the ID or name of the tunnel to delete.`)
} }
tunnelIDs, err := tunnelIDsFromArgs(c) tunnelIDs, err := sc.findIDs(c.Args().Slice())
if err != nil { if err != nil {
return err return err
} }
@ -283,19 +283,6 @@ func deleteCommand(c *cli.Context) error {
return sc.delete(tunnelIDs) return sc.delete(tunnelIDs)
} }
func tunnelIDsFromArgs(c *cli.Context) ([]uuid.UUID, error) {
tunnelIDs := make([]uuid.UUID, 0, c.NArg())
for i := 0; i < c.NArg(); i++ {
tunnelID, err := uuid.Parse(c.Args().Get(i))
if err != nil {
return nil, err
}
tunnelIDs = append(tunnelIDs, tunnelID)
}
return tunnelIDs, nil
}
func renderOutput(format string, v interface{}) error { func renderOutput(format string, v interface{}) error {
switch format { switch format {
case "json": case "json":
@ -327,7 +314,7 @@ func runCommand(c *cli.Context) error {
} }
if c.NArg() != 1 { if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID or name of the tunnel to run.`)
} }
tunnelID, err := uuid.Parse(c.Args().First()) tunnelID, err := uuid.Parse(c.Args().First())
if err != nil { if err != nil {
@ -357,7 +344,7 @@ func cleanupCommand(c *cli.Context) error {
return err return err
} }
tunnelIDs, err := tunnelIDsFromArgs(c) tunnelIDs, err := sc.findIDs(c.Args().Slice())
if err != nil { if err != nil {
return err return err
} }
@ -417,15 +404,15 @@ func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, erro
func routeCommand(c *cli.Context) error { func routeCommand(c *cli.Context) error {
if c.NArg() < 2 { if c.NArg() < 2 {
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`) return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`)
} }
const tunnelIDIndex = 1 sc, err := newSubcommandContext(c)
tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex))
if err != nil { if err != nil {
return errors.Wrap(err, "error parsing tunnel ID") return err
} }
sc, err := newSubcommandContext(c) const tunnelIDIndex = 1
tunnelID, err := sc.findID(c.Args().Get(tunnelIDIndex))
if err != nil { if err != nil {
return err return err
} }

View File

@ -372,18 +372,18 @@ func randomASCIIPrintableChar(rand *rand.Rand) int {
// between 1 and `maxLength`. // between 1 and `maxLength`.
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string { func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
length := minLength + rand.Intn(maxLength) length := minLength + rand.Intn(maxLength)
result := "" var result strings.Builder
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
c := randomASCIIPrintableChar(rand) c := randomASCIIPrintableChar(rand)
// 1/4 chance of using percent encoding when not necessary // 1/4 chance of using percent encoding when not necessary
if c == '%' || rand.Intn(4) == 0 { if c == '%' || rand.Intn(4) == 0 {
result += fmt.Sprintf("%%%02X", c) result.WriteString(fmt.Sprintf("%%%02X", c))
} else { } else {
result += string(c) result.WriteByte(byte(c))
} }
} }
return result return result.String()
} }
// Calls `randomASCIIText` and ensures the result is a valid URL path, // Calls `randomASCIIText` and ensures the result is a valid URL path,
@ -672,4 +672,3 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
_ = H1ResponseToH2ResponseHeaders(h1resp) _ = H1ResponseToH2ResponseHeaders(h1resp)
} }
} }

View File

@ -174,6 +174,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
if name == "" { if name == "" {
return nil, errors.New("tunnel name required") return nil, errors.New("tunnel name required")
} }
if _, err := uuid.Parse(name); err == nil {
return nil, errors.New("you cannot use UUIDs as tunnel names")
}
body := &newTunnel{ body := &newTunnel{
Name: name, Name: name,
TunnelSecret: tunnelSecret, TunnelSecret: tunnelSecret,