diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 6e4067f1..c83b567d 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -138,17 +138,17 @@ func (sc *subcommandContext) tunnelCredentialsPath(tunnelID uuid.UUID) (string, func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) { client, err := sc.client() if err != nil { - return nil, err + return nil, errors.Wrap(err, "couldn't create client to talk to Argo Tunnel backend") } tunnelSecret, err := generateTunnelSecret() if err != nil { - return nil, err + return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel") } tunnel, err := client.CreateTunnel(name, tunnelSecret) if err != nil { - return nil, err + return nil, errors.Wrap(err, "Create Tunnel API call failed") } credential, err := sc.credential() @@ -280,3 +280,61 @@ func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, boo // There should only be 1 active tunnel for a given name return tunnels[0], true, nil } + +// findID parses the input. If it's a UUID, return the UUID. +// Otherwise, assume it's a name, and look up the ID of that tunnel. +func (sc *subcommandContext) findID(input string) (uuid.UUID, error) { + if u, err := uuid.Parse(input); err == nil { + return u, nil + } + + if tunnel, found, err := sc.tunnelActive(input); err != nil { + return uuid.Nil, err + } else if found { + return tunnel.ID, nil + } + + return uuid.Nil, fmt.Errorf("%s is neither the ID nor the name of any of your tunnels", input) +} + +// findIDs is just like mapping `findID` over a slice, but it only uses +// one Tunnelstore API call. +func (sc *subcommandContext) findIDs(inputs []string) ([]uuid.UUID, error) { + + // First, look up all tunnels the user has + filter := tunnelstore.NewFilter() + filter.NoDeleted() + tunnels, err := sc.list(filter) + if err != nil { + return nil, err + } + // Do the pure list-processing in its own function, so that it can be + // unit tested easily. + return findIDs(tunnels, inputs) +} + +func findIDs(tunnels []*tunnelstore.Tunnel, inputs []string) ([]uuid.UUID, error) { + // Put them into a dictionary for faster lookups + nameToID := make(map[string]uuid.UUID, len(tunnels)) + for _, tunnel := range tunnels { + nameToID[tunnel.Name] = tunnel.ID + } + + // For each input, try to find the tunnel ID. + tunnelIDs := make([]uuid.UUID, len(inputs)) + var badInputs []string + for i, input := range inputs { + if id, err := uuid.Parse(input); err == nil { + tunnelIDs[i] = id + } else if id, ok := nameToID[input]; ok { + tunnelIDs[i] = id + } else { + badInputs = append(badInputs, input) + } + } + if len(badInputs) > 0 { + msg := "Please specify either the ID or name of a tunnel. The following inputs were neither: %s" + return nil, fmt.Errorf(msg, strings.Join(badInputs, ", ")) + } + return tunnelIDs, nil +} diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go new file mode 100644 index 00000000..6a736838 --- /dev/null +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -0,0 +1,82 @@ +package tunnel + +import ( + "reflect" + "testing" + + "github.com/cloudflare/cloudflared/tunnelstore" + "github.com/google/uuid" +) + +func Test_findIDs(t *testing.T) { + type args struct { + tunnels []*tunnelstore.Tunnel + inputs []string + } + tests := []struct { + name string + args args + want []uuid.UUID + wantErr bool + }{ + { + name: "input not found", + args: args{ + inputs: []string{"asdf"}, + }, + wantErr: true, + }, + { + name: "only UUID", + args: args{ + inputs: []string{"a8398a0b-876d-48ed-b609-3fcfd67a4950"}, + }, + want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")}, + }, + { + name: "only name", + args: args{ + tunnels: []*tunnelstore.Tunnel{ + { + ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), + Name: "tunnel1", + }, + }, + inputs: []string{"tunnel1"}, + }, + want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")}, + }, + { + name: "both UUID and name", + args: args{ + tunnels: []*tunnelstore.Tunnel{ + { + ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), + Name: "tunnel1", + }, + { + ID: uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"), + Name: "tunnel2", + }, + }, + inputs: []string{"tunnel1", "bf028b68-744f-466e-97f8-c46161d80aa5"}, + }, + want: []uuid.UUID{ + uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), + uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := findIDs(tt.args.tunnels, tt.args.inputs) + if (err != nil) != tt.wantErr { + t.Errorf("findIDs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("findIDs() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index f0b7d51b..a9fb1cae 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -272,10 +272,10 @@ func deleteCommand(c *cli.Context) error { } if c.NArg() < 1 { - return cliutil.UsageError(`"cloudflared tunnel delete" requires at least argument, the ID of the tunnel to delete.`) + return cliutil.UsageError(`"cloudflared tunnel delete" requires at least 1 argument, the ID or name of the tunnel to delete.`) } - tunnelIDs, err := tunnelIDsFromArgs(c) + tunnelIDs, err := sc.findIDs(c.Args().Slice()) if err != nil { return err } @@ -283,19 +283,6 @@ func deleteCommand(c *cli.Context) error { return sc.delete(tunnelIDs) } -func tunnelIDsFromArgs(c *cli.Context) ([]uuid.UUID, error) { - tunnelIDs := make([]uuid.UUID, 0, c.NArg()) - for i := 0; i < c.NArg(); i++ { - tunnelID, err := uuid.Parse(c.Args().Get(i)) - if err != nil { - return nil, err - } - tunnelIDs = append(tunnelIDs, tunnelID) - } - return tunnelIDs, nil - -} - func renderOutput(format string, v interface{}) error { switch format { case "json": @@ -327,7 +314,7 @@ func runCommand(c *cli.Context) error { } if c.NArg() != 1 { - return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) + return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID or name of the tunnel to run.`) } tunnelID, err := uuid.Parse(c.Args().First()) if err != nil { @@ -357,7 +344,7 @@ func cleanupCommand(c *cli.Context) error { return err } - tunnelIDs, err := tunnelIDsFromArgs(c) + tunnelIDs, err := sc.findIDs(c.Args().Slice()) if err != nil { return err } @@ -417,15 +404,15 @@ func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, erro func routeCommand(c *cli.Context) error { if c.NArg() < 2 { - return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`) + return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID or name of the tunnel`) } - const tunnelIDIndex = 1 - tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex)) + sc, err := newSubcommandContext(c) if err != nil { - return errors.Wrap(err, "error parsing tunnel ID") + return err } - sc, err := newSubcommandContext(c) + const tunnelIDIndex = 1 + tunnelID, err := sc.findID(c.Args().Get(tunnelIDIndex)) if err != nil { return err } diff --git a/h2mux/header_test.go b/h2mux/header_test.go index 6b2411b2..0d3f5b06 100644 --- a/h2mux/header_test.go +++ b/h2mux/header_test.go @@ -372,18 +372,18 @@ func randomASCIIPrintableChar(rand *rand.Rand) int { // between 1 and `maxLength`. func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string { length := minLength + rand.Intn(maxLength) - result := "" + var result strings.Builder for i := 0; i < length; i++ { c := randomASCIIPrintableChar(rand) // 1/4 chance of using percent encoding when not necessary if c == '%' || rand.Intn(4) == 0 { - result += fmt.Sprintf("%%%02X", c) + result.WriteString(fmt.Sprintf("%%%02X", c)) } else { - result += string(c) + result.WriteByte(byte(c)) } } - return result + return result.String() } // Calls `randomASCIIText` and ensures the result is a valid URL path, @@ -663,7 +663,7 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) { h1resp := &http.Response{ StatusCode: 200, - Header: h1, + Header: h1, } b.ReportAllocs() @@ -672,4 +672,3 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) { _ = H1ResponseToH2ResponseHeaders(h1resp) } } - diff --git a/tunnelstore/client.go b/tunnelstore/client.go index f211a02b..da151cd0 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -174,6 +174,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er if name == "" { return nil, errors.New("tunnel name required") } + if _, err := uuid.Parse(name); err == nil { + return nil, errors.New("you cannot use UUIDs as tunnel names") + } body := &newTunnel{ Name: name, TunnelSecret: tunnelSecret,