TUN-3286: Use either ID or name in Named Tunnel subcommands.

This commit is contained in:
Adam Chalmers 2020-08-18 16:54:05 -05:00
parent 60de05bfc1
commit 1a96889141
5 changed files with 160 additions and 31 deletions

View File

@ -138,17 +138,17 @@ func (sc *subcommandContext) tunnelCredentialsPath(tunnelID uuid.UUID) (string,
func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) {
client, err := sc.client()
if err != nil {
return nil, err
return nil, errors.Wrap(err, "couldn't create client to talk to Argo Tunnel backend")
}
tunnelSecret, err := generateTunnelSecret()
if err != nil {
return nil, err
return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel")
}
tunnel, err := client.CreateTunnel(name, tunnelSecret)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "Create Tunnel API call failed")
}
credential, err := sc.credential()
@ -280,3 +280,61 @@ func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, boo
// There should only be 1 active tunnel for a given name
return tunnels[0], true, nil
}
// findID parses the input. If it's a UUID, return the UUID.
// Otherwise, assume it's a name, and look up the ID of that tunnel.
func (sc *subcommandContext) findID(input string) (uuid.UUID, error) {
if u, err := uuid.Parse(input); err == nil {
return u, nil
}
if tunnel, found, err := sc.tunnelActive(input); err != nil {
return uuid.Nil, err
} else if found {
return tunnel.ID, nil
}
return uuid.Nil, fmt.Errorf("%s is neither the ID nor the name of any of your tunnels", input)
}
// findIDs is just like mapping `findID` over a slice, but it only uses
// one Tunnelstore API call.
func (sc *subcommandContext) findIDs(inputs []string) ([]uuid.UUID, error) {
// First, look up all tunnels the user has
filter := tunnelstore.NewFilter()
filter.NoDeleted()
tunnels, err := sc.list(filter)
if err != nil {
return nil, err
}
// Do the pure list-processing in its own function, so that it can be
// unit tested easily.
return findIDs(tunnels, inputs)
}
func findIDs(tunnels []*tunnelstore.Tunnel, inputs []string) ([]uuid.UUID, error) {
// Put them into a dictionary for faster lookups
nameToID := make(map[string]uuid.UUID, len(tunnels))
for _, tunnel := range tunnels {
nameToID[tunnel.Name] = tunnel.ID
}
// For each input, try to find the tunnel ID.
tunnelIDs := make([]uuid.UUID, len(inputs))
var badInputs []string
for i, input := range inputs {
if id, err := uuid.Parse(input); err == nil {
tunnelIDs[i] = id
} else if id, ok := nameToID[input]; ok {
tunnelIDs[i] = id
} else {
badInputs = append(badInputs, input)
}
}
if len(badInputs) > 0 {
msg := "Please specify either the ID or name of a tunnel. The following inputs were neither: %s"
return nil, fmt.Errorf(msg, strings.Join(badInputs, ", "))
}
return tunnelIDs, nil
}

View File

@ -0,0 +1,82 @@
package tunnel
import (
"reflect"
"testing"
"github.com/cloudflare/cloudflared/tunnelstore"
"github.com/google/uuid"
)
func Test_findIDs(t *testing.T) {
type args struct {
tunnels []*tunnelstore.Tunnel
inputs []string
}
tests := []struct {
name string
args args
want []uuid.UUID
wantErr bool
}{
{
name: "input not found",
args: args{
inputs: []string{"asdf"},
},
wantErr: true,
},
{
name: "only UUID",
args: args{
inputs: []string{"a8398a0b-876d-48ed-b609-3fcfd67a4950"},
},
want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")},
},
{
name: "only name",
args: args{
tunnels: []*tunnelstore.Tunnel{
{
ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
Name: "tunnel1",
},
},
inputs: []string{"tunnel1"},
},
want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")},
},
{
name: "both UUID and name",
args: args{
tunnels: []*tunnelstore.Tunnel{
{
ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
Name: "tunnel1",
},
{
ID: uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"),
Name: "tunnel2",
},
},
inputs: []string{"tunnel1", "bf028b68-744f-466e-97f8-c46161d80aa5"},
},
want: []uuid.UUID{
uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := findIDs(tt.args.tunnels, tt.args.inputs)
if (err != nil) != tt.wantErr {
t.Errorf("findIDs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("findIDs() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -272,10 +272,10 @@ func deleteCommand(c *cli.Context) error {
}
if c.NArg() < 1 {
return cliutil.UsageError(`"cloudflared tunnel delete" requires at least argument, the ID of the tunnel to delete.`)
return cliutil.UsageError(`"cloudflared tunnel delete" requires at least 1 argument, the ID or name of the tunnel to delete.`)
}
tunnelIDs, err := tunnelIDsFromArgs(c)
tunnelIDs, err := sc.findIDs(c.Args().Slice())
if err != nil {
return err
}
@ -283,19 +283,6 @@ func deleteCommand(c *cli.Context) error {
return sc.delete(tunnelIDs)
}
func tunnelIDsFromArgs(c *cli.Context) ([]uuid.UUID, error) {
tunnelIDs := make([]uuid.UUID, 0, c.NArg())
for i := 0; i < c.NArg(); i++ {
tunnelID, err := uuid.Parse(c.Args().Get(i))
if err != nil {
return nil, err
}
tunnelIDs = append(tunnelIDs, tunnelID)
}
return tunnelIDs, nil
}
func renderOutput(format string, v interface{}) error {
switch format {
case "json":
@ -327,7 +314,7 @@ func runCommand(c *cli.Context) error {
}
if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID or name of the tunnel to run.`)
}
tunnelID, err := uuid.Parse(c.Args().First())
if err != nil {
@ -357,7 +344,7 @@ func cleanupCommand(c *cli.Context) error {
return err
}
tunnelIDs, err := tunnelIDsFromArgs(c)
tunnelIDs, err := sc.findIDs(c.Args().Slice())
if err != nil {
return err
}
@ -417,15 +404,15 @@ func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, erro
func routeCommand(c *cli.Context) error {
if c.NArg() < 2 {
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`)
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`)
}
const tunnelIDIndex = 1
tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex))
sc, err := newSubcommandContext(c)
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
return err
}
sc, err := newSubcommandContext(c)
const tunnelIDIndex = 1
tunnelID, err := sc.findID(c.Args().Get(tunnelIDIndex))
if err != nil {
return err
}

View File

@ -372,18 +372,18 @@ func randomASCIIPrintableChar(rand *rand.Rand) int {
// between 1 and `maxLength`.
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
length := minLength + rand.Intn(maxLength)
result := ""
var result strings.Builder
for i := 0; i < length; i++ {
c := randomASCIIPrintableChar(rand)
// 1/4 chance of using percent encoding when not necessary
if c == '%' || rand.Intn(4) == 0 {
result += fmt.Sprintf("%%%02X", c)
result.WriteString(fmt.Sprintf("%%%02X", c))
} else {
result += string(c)
result.WriteByte(byte(c))
}
}
return result
return result.String()
}
// Calls `randomASCIIText` and ensures the result is a valid URL path,
@ -663,7 +663,7 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
h1resp := &http.Response{
StatusCode: 200,
Header: h1,
Header: h1,
}
b.ReportAllocs()
@ -672,4 +672,3 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
_ = H1ResponseToH2ResponseHeaders(h1resp)
}
}

View File

@ -174,6 +174,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
if name == "" {
return nil, errors.New("tunnel name required")
}
if _, err := uuid.Parse(name); err == nil {
return nil, errors.New("you cannot use UUIDs as tunnel names")
}
body := &newTunnel{
Name: name,
TunnelSecret: tunnelSecret,