diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 7cb352bf..5d3f27e5 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -646,12 +646,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Value: "https://api.trycloudflare.com", Hidden: true, }), - &cli.UintFlag{ + altsrc.NewIntFlag(&cli.IntFlag{ Name: "max-fetch-size", Usage: `The maximum number of results that cloudflared can fetch from Cloudflare API for any listing operations needed`, EnvVars: []string{"TUNNEL_MAX_FETCH_SIZE"}, Hidden: true, - }, + }), selectProtocolFlag, overwriteDNSFlag, }...) diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 5ba40255..0335bd27 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -336,10 +336,6 @@ func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, boo filter := tunnelstore.NewFilter() filter.NoDeleted() filter.ByName(name) - if maxFetch := sc.c.Uint("max-fetch-size"); maxFetch > 0 { - filter.MaxFetchSize(maxFetch) - } - tunnels, err := sc.list(filter) if err != nil { return nil, false, err @@ -377,56 +373,42 @@ func (sc *subcommandContext) findID(input string) (uuid.UUID, error) { } // findIDs is just like mapping `findID` over a slice, but it only uses -// one Tunnelstore API call. +// one Tunnelstore API call per non-UUID input provided. func (sc *subcommandContext) findIDs(inputs []string) ([]uuid.UUID, error) { + uuids, names := splitUuids(inputs) - // Shortcut without Tunnelstore call if we find that all inputs are already UUIDs. - uuids, err := convertNamesToUuids(inputs, make(map[string]uuid.UUID)) - if err == nil { - return uuids, nil + for _, name := range names { + filter := tunnelstore.NewFilter() + filter.NoDeleted() + filter.ByName(name) + + tunnels, err := sc.list(filter) + if err != nil { + return nil, err + } + + if len(tunnels) != 1 { + return nil, fmt.Errorf("there should only be 1 non-deleted Tunnel named %s", name) + } + + uuids = append(uuids, tunnels[0].ID) } - // First, look up all tunnels the user has - filter := tunnelstore.NewFilter() - filter.NoDeleted() - if maxFetch := sc.c.Uint("max-fetch-size"); maxFetch > 0 { - filter.MaxFetchSize(maxFetch) - } - - tunnels, err := sc.list(filter) - if err != nil { - return nil, err - } - // Do the pure list-processing in its own function, so that it can be - // unit tested easily. - return findIDs(tunnels, inputs) + return uuids, nil } -func findIDs(tunnels []*tunnelstore.Tunnel, inputs []string) ([]uuid.UUID, error) { - // Put them into a dictionary for faster lookups - nameToID := make(map[string]uuid.UUID, len(tunnels)) - for _, tunnel := range tunnels { - nameToID[tunnel.Name] = tunnel.ID - } +func splitUuids(inputs []string) ([]uuid.UUID, []string) { + uuids := make([]uuid.UUID, 0) + names := make([]string, 0) - return convertNamesToUuids(inputs, nameToID) -} - -func convertNamesToUuids(inputs []string, nameToID map[string]uuid.UUID) ([]uuid.UUID, error) { - tunnelIDs := make([]uuid.UUID, len(inputs)) - var badInputs []string - for i, input := range inputs { - if id, err := uuid.Parse(input); err == nil { - tunnelIDs[i] = id - } else if id, ok := nameToID[input]; ok { - tunnelIDs[i] = id + for _, input := range inputs { + id, err := uuid.Parse(input) + if err != nil { + names = append(names, input) } else { - badInputs = append(badInputs, input) + uuids = append(uuids, id) } } - if len(badInputs) > 0 { - msg := "Please specify either the ID or name of a tunnel. The following inputs were neither: %s" - return nil, fmt.Errorf(msg, strings.Join(badInputs, ", ")) - } - return tunnelIDs, nil + + return uuids, names } diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go index 2155b3dc..cb9a9512 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_test.go +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -17,79 +17,6 @@ import ( "github.com/cloudflare/cloudflared/tunnelstore" ) -func Test_findIDs(t *testing.T) { - type args struct { - tunnels []*tunnelstore.Tunnel - inputs []string - } - tests := []struct { - name string - args args - want []uuid.UUID - wantErr bool - }{ - { - name: "input not found", - args: args{ - inputs: []string{"asdf"}, - }, - wantErr: true, - }, - { - name: "only UUID", - args: args{ - inputs: []string{"a8398a0b-876d-48ed-b609-3fcfd67a4950"}, - }, - want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")}, - }, - { - name: "only name", - args: args{ - tunnels: []*tunnelstore.Tunnel{ - { - ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), - Name: "tunnel1", - }, - }, - inputs: []string{"tunnel1"}, - }, - want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")}, - }, - { - name: "both UUID and name", - args: args{ - tunnels: []*tunnelstore.Tunnel{ - { - ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), - Name: "tunnel1", - }, - { - ID: uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"), - Name: "tunnel2", - }, - }, - inputs: []string{"tunnel1", "bf028b68-744f-466e-97f8-c46161d80aa5"}, - }, - want: []uuid.UUID{ - uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), - uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := findIDs(tt.args.tunnels, tt.args.inputs) - if (err != nil) != tt.wantErr { - t.Errorf("findIDs() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("findIDs() = %v, want %v", got, tt.want) - } - }) - } -} - type mockFileSystem struct { rf func(string) ([]byte, error) vfp func(string) bool diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 19ba1886..af44d182 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -277,8 +277,8 @@ func listCommand(c *cli.Context) error { } filter.ByTunnelID(tunnelID) } - if maxFetch := c.Uint("max-fetch-size"); maxFetch > 0 { - filter.MaxFetchSize(maxFetch) + if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 { + filter.MaxFetchSize(uint(maxFetch)) } tunnels, err := sc.list(filter)