From 292a7f07a2432acb621d7a6330220ed8f87c5932 Mon Sep 17 00:00:00 2001 From: cthuang Date: Fri, 7 Aug 2020 13:29:53 +0100 Subject: [PATCH] TUN-3243: Refactor tunnel subcommands to allow commands to compose better --- cmd/cloudflared/tunnel/subcommand_context.go | 282 ++++++++++++++ cmd/cloudflared/tunnel/subcommands.go | 376 +++++-------------- tunnelstore/client.go | 6 +- tunnelstore/filter.go | 2 +- 4 files changed, 383 insertions(+), 283 deletions(-) create mode 100644 cmd/cloudflared/tunnel/subcommand_context.go diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go new file mode 100644 index 00000000..98d76e6c --- /dev/null +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -0,0 +1,282 @@ +package tunnel + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/cloudflare/cloudflared/certutil" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/tunnelstore" + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/urfave/cli/v2" +) + +// subcommandContext carries structs shared between subcommands, to reduce number of arguments needed to pass between subcommands, +// and make sure they are only initialized once +type subcommandContext struct { + c *cli.Context + logger logger.Service + + // These fields should be accessed using their respective Getter + tunnelstoreClient tunnelstore.Client + userCredential *userCredential +} + +func newSubcommandContext(c *cli.Context) (*subcommandContext, error) { + logger, err := createLogger(c, false) + if err != nil { + return nil, errors.Wrap(err, "error setting up logger") + } + return &subcommandContext{ + c: c, + logger: logger, + }, nil +} + +type userCredential struct { + cert *certutil.OriginCert + certPath string +} + +func (sc *subcommandContext) client() (tunnelstore.Client, error) { + if sc.tunnelstoreClient != nil { + return sc.tunnelstoreClient, nil + } + credential, err := sc.credential() + if err != nil { + return nil, err + } + client, err := tunnelstore.NewRESTClient(sc.c.String("api-url"), credential.cert.AccountID, credential.cert.ZoneID, credential.cert.ServiceKey, sc.logger) + if err != nil { + return nil, err + } + sc.tunnelstoreClient = client + return client, nil +} + +func (sc *subcommandContext) credential() (*userCredential, error) { + if sc.userCredential == nil { + originCertPath, err := findOriginCert(sc.c, sc.logger) + if err != nil { + return nil, errors.Wrap(err, "Error locating origin cert") + } + blocks, err := readOriginCert(originCertPath, sc.logger) + if err != nil { + return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) + } + + cert, err := certutil.DecodeOriginCert(blocks) + if err != nil { + return nil, errors.Wrap(err, "Error decoding origin cert") + } + + if cert.AccountID == "" { + return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) + } + + sc.userCredential = &userCredential{ + cert: cert, + certPath: originCertPath, + } + } + return sc.userCredential, nil +} + +func (sc *subcommandContext) readTunnelCredentials(tunnelID uuid.UUID) (*pogs.TunnelAuth, error) { + filePath, err := sc.tunnelCredentialsPath(tunnelID) + if err != nil { + return nil, err + } + body, err := ioutil.ReadFile(filePath) + if err != nil { + return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath) + } + + var auth pogs.TunnelAuth + if err = json.Unmarshal(body, &auth); err != nil { + return nil, err + } + return &auth, nil +} + +func (sc *subcommandContext) tunnelCredentialsPath(tunnelID uuid.UUID) (string, error) { + if filePath := sc.c.String("credentials-file"); filePath != "" { + if validFilePath(filePath) { + return filePath, nil + } + } + + // Fallback to look for tunnel credentials in the origin cert directory + if originCertPath, err := findOriginCert(sc.c, sc.logger); err == nil { + originCertDir := filepath.Dir(originCertPath) + if filePath, err := tunnelFilePath(tunnelID, originCertDir); err == nil { + if validFilePath(filePath) { + return filePath, nil + } + } + } + + // Last resort look under default config directories + for _, configDir := range config.DefaultConfigDirs { + if filePath, err := tunnelFilePath(tunnelID, configDir); err == nil { + if validFilePath(filePath) { + return filePath, nil + } + } + } + return "", fmt.Errorf("Tunnel credentials file not found") +} + +func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) { + client, err := sc.client() + if err != nil { + return nil, err + } + + tunnelSecret, err := generateTunnelSecret() + if err != nil { + return nil, err + } + + tunnel, err := client.CreateTunnel(name, tunnelSecret) + if err != nil { + return nil, err + } + + credential, err := sc.credential() + if err != nil { + return nil, err + } + if writeFileErr := writeTunnelCredentials(tunnel.ID, credential.cert.AccountID, credential.certPath, tunnelSecret, sc.logger); err != nil { + var errorLines []string + errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID)) + errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr)) + if deleteErr := client.DeleteTunnel(tunnel.ID); deleteErr != nil { + errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID)) + errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr)) + } else { + errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the tunnelfile")) + } + errorMsg := strings.Join(errorLines, "\n") + return nil, errors.New(errorMsg) + } + + if outputFormat := sc.c.String(outputFormatFlag.Name); outputFormat != "" { + return nil, renderOutput(outputFormat, &tunnel) + } + + sc.logger.Infof("Created tunnel %s with id %s", tunnel.Name, tunnel.ID) + return tunnel, nil +} + +func (sc *subcommandContext) list(filter *tunnelstore.Filter) ([]*tunnelstore.Tunnel, error) { + client, err := sc.client() + if err != nil { + return nil, err + } + return client.ListTunnels(filter) +} + +func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error { + forceFlagSet := sc.c.Bool("force") + + client, err := sc.client() + if err != nil { + return err + } + + for _, id := range tunnelIDs { + tunnel, err := client.GetTunnel(id) + if err != nil { + return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnel.ID) + } + + // Check if tunnel DeletedAt field has already been set + if !tunnel.DeletedAt.IsZero() { + return fmt.Errorf("Tunnel %s has already been deleted", tunnel.ID) + } + // Check if tunnel has existing connections and if force flag is set, cleanup connections + if len(tunnel.Connections) > 0 { + if !forceFlagSet { + return fmt.Errorf("You can not delete tunnel %s because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.", id) + } + + if err := client.CleanupConnections(tunnel.ID); err != nil { + return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnel.ID) + } + } + + if err := client.DeleteTunnel(tunnel.ID); err != nil { + return errors.Wrapf(err, "Error deleting tunnel %s", tunnel.ID) + } + + tunnelCredentialsPath, err := sc.tunnelCredentialsPath(tunnel.ID) + if err != nil { + sc.logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err) + return nil + } + + if err = os.Remove(tunnelCredentialsPath); err != nil { + sc.logger.Infof("Cannot delete tunnel credentials, error: %v. Please delete the file manually", err) + } + } + return nil +} + +func (sc *subcommandContext) run(tunnelID uuid.UUID) error { + credentials, err := sc.readTunnelCredentials(tunnelID) + if err != nil { + return err + } + return StartServer(sc.c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID}) +} + +func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error { + client, err := sc.client() + if err != nil { + return err + } + for _, tunnelID := range tunnelIDs { + sc.logger.Infof("Cleanup connection for tunnel %s", tunnelID) + if err := client.CleanupConnections(tunnelID); err != nil { + sc.logger.Errorf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err) + } + } + return nil +} + +func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) error { + client, err := sc.client() + if err != nil { + return err + } + + if err := client.RouteTunnel(tunnelID, r); err != nil { + return err + } + + return nil +} + +func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, bool, error) { + filter := tunnelstore.NewFilter() + filter.NoDeleted() + filter.ByName(name) + tunnels, err := sc.list(filter) + if err != nil { + return nil, false, err + } + if len(tunnels) == 0 { + return nil, false, nil + } + // There should only be 1 active tunnel for a given name + return tunnels[0], true, nil +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 25858d0c..f0b7d51b 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -18,11 +18,8 @@ import ( "github.com/urfave/cli/v2" "gopkg.in/yaml.v2" - "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" - "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelstore" ) @@ -88,7 +85,7 @@ const hideSubcommands = true func buildCreateCommand() *cli.Command { return &cli.Command{ Name: "create", - Action: cliutil.ErrorHandler(createTunnel), + Action: cliutil.ErrorHandler(createCommand), Usage: "Create a new tunnel with given name", ArgsUsage: "TUNNEL-NAME", Hidden: hideSubcommands, @@ -103,56 +100,19 @@ func generateTunnelSecret() ([]byte, error) { return randomBytes, err } -func createTunnel(c *cli.Context) error { +func createCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) + if err != nil { + return errors.Wrap(err, "error setting up logger") + } + if c.NArg() != 1 { return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`) } name := c.Args().First() - logger, err := createLogger(c, false) - if err != nil { - return errors.Wrap(err, "error setting up logger") - } - - tunnelSecret, err := generateTunnelSecret() - if err != nil { - return err - } - - cert, originCertPath, err := getOriginCertFromContext(c, logger) - if err != nil { - return err - } - client, err := newTunnelstoreClient(c, cert, logger) - if err != nil { - return err - } - - tunnel, err := client.CreateTunnel(name, tunnelSecret) - if err != nil { - return errors.Wrap(err, "Error creating a new tunnel") - } - - if writeFileErr := writeTunnelCredentials(tunnel.ID, cert.AccountID, originCertPath, tunnelSecret, logger); err != nil { - var errorLines []string - errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID)) - errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr)) - if deleteErr := client.DeleteTunnel(tunnel.ID); deleteErr != nil { - errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID)) - errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr)) - } else { - errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the tunnelfile")) - } - errorMsg := strings.Join(errorLines, "\n") - return errors.New(errorMsg) - } - - if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { - return renderOutput(outputFormat, &tunnel) - } - - logger.Infof("Created tunnel %s with id %s", tunnel.Name, tunnel.ID) - return nil + _, err = sc.create(name) + return errors.Wrap(err, "failed to create tunnel") } func tunnelFilePath(tunnelID uuid.UUID, directory string) (string, error) { @@ -179,51 +139,6 @@ func writeTunnelCredentials(tunnelID uuid.UUID, accountID, originCertPath string return ioutil.WriteFile(filePath, body, 400) } -func readTunnelCredentials(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (*pogs.TunnelAuth, error) { - filePath, err := tunnelCredentialsPath(c, tunnelID, logger) - if err != nil { - return nil, err - } - body, err := ioutil.ReadFile(filePath) - if err != nil { - return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath) - } - - var auth pogs.TunnelAuth - if err = json.Unmarshal(body, &auth); err != nil { - return nil, err - } - return &auth, nil -} - -func tunnelCredentialsPath(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (string, error) { - if filePath := c.String("credentials-file"); filePath != "" { - if validFilePath(filePath) { - return filePath, nil - } - } - - // Fallback to look for tunnel credentials in the origin cert directory - if originCertPath, err := findOriginCert(c, logger); err == nil { - originCertDir := filepath.Dir(originCertPath) - if filePath, err := tunnelFilePath(tunnelID, originCertDir); err == nil { - if validFilePath(filePath) { - return filePath, nil - } - } - } - - // Last resort look under default config directories - for _, configDir := range config.DefaultConfigDirs { - if filePath, err := tunnelFilePath(tunnelID, configDir); err == nil { - if validFilePath(filePath) { - return filePath, nil - } - } - } - return "", fmt.Errorf("Tunnel credentials file not found") -} - func validFilePath(path string) bool { fileStat, err := os.Stat(path) if err != nil { @@ -235,7 +150,7 @@ func validFilePath(path string) bool { func buildListCommand() *cli.Command { return &cli.Command{ Name: "list", - Action: cliutil.ErrorHandler(listTunnels), + Action: cliutil.ErrorHandler(listCommand), Usage: "List existing tunnels", ArgsUsage: " ", Hidden: hideSubcommands, @@ -243,24 +158,15 @@ func buildListCommand() *cli.Command { } } -func listTunnels(c *cli.Context) error { - logger, err := createLogger(c, false) - if err != nil { - return errors.Wrap(err, "error setting up logger") - } - - cert, _, err := getOriginCertFromContext(c, logger) - if err != nil { - return err - } - client, err := newTunnelstoreClient(c, cert, logger) +func listCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) if err != nil { return err } filter := tunnelstore.NewFilter() if !c.Bool("show-deleted") { - filter.ShowDeleted() + filter.NoDeleted() } if name := c.String("name"); name != "" { filter.ByName(name) @@ -276,9 +182,9 @@ func listTunnels(c *cli.Context) error { filter.ByTunnelID(tunnelID) } - tunnels, err := client.ListTunnels(filter) + tunnels, err := sc.list(filter) if err != nil { - return errors.Wrap(err, "Error listing tunnels") + return err } if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { @@ -290,11 +196,10 @@ func listTunnels(c *cli.Context) error { } else { fmt.Println("You have no tunnels, use 'cloudflared tunnel create' to define a new tunnel") } - return nil } -func fmtAndPrintTunnelList(tunnels []tunnelstore.Tunnel, showRecentlyDisconnected bool) { +func fmtAndPrintTunnelList(tunnels []*tunnelstore.Tunnel, showRecentlyDisconnected bool) { const ( minWidth = 0 tabWidth = 8 @@ -352,74 +257,43 @@ func fmtConnections(connections []tunnelstore.Connection, showRecentlyDisconnect func buildDeleteCommand() *cli.Command { return &cli.Command{ Name: "delete", - Action: cliutil.ErrorHandler(deleteTunnel), - Usage: "Delete existing tunnel with given ID", + Action: cliutil.ErrorHandler(deleteCommand), + Usage: "Delete existing tunnel with given IDs", ArgsUsage: "TUNNEL-ID", Hidden: hideSubcommands, Flags: []cli.Flag{credentialsFileFlag, forceDeleteFlag}, } } -func deleteTunnel(c *cli.Context) error { - if c.NArg() != 1 { - return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`) - } - tunnelID, err := uuid.Parse(c.Args().First()) - if err != nil { - return errors.Wrap(err, "error parsing tunnel ID") - } - - logger, err := createLogger(c, false) - if err != nil { - return errors.Wrap(err, "error setting up logger") - } - - cert, _, err := getOriginCertFromContext(c, logger) - if err != nil { - return err - } - client, err := newTunnelstoreClient(c, cert, logger) +func deleteCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) if err != nil { return err } - forceFlagSet := c.Bool("force") + if c.NArg() < 1 { + return cliutil.UsageError(`"cloudflared tunnel delete" requires at least argument, the ID of the tunnel to delete.`) + } - tunnel, err := client.GetTunnel(tunnelID) + tunnelIDs, err := tunnelIDsFromArgs(c) if err != nil { - return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnelID) + return err } - // Check if tunnel DeletedAt field has already been set - if !tunnel.DeletedAt.IsZero() { - return errors.New("This tunnel has already been deleted.") - } - // Check if tunnel has existing connections and if force flag is set, cleanup connections - if len(tunnel.Connections) > 0 { - if !forceFlagSet { - return errors.New("You can not delete this tunnel because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.") - } - - if err := client.CleanupConnections(tunnelID); err != nil { - return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnelID) + return sc.delete(tunnelIDs) +} + +func tunnelIDsFromArgs(c *cli.Context) ([]uuid.UUID, error) { + tunnelIDs := make([]uuid.UUID, 0, c.NArg()) + for i := 0; i < c.NArg(); i++ { + tunnelID, err := uuid.Parse(c.Args().Get(i)) + if err != nil { + return nil, err } + tunnelIDs = append(tunnelIDs, tunnelID) } + return tunnelIDs, nil - if err := client.DeleteTunnel(tunnelID); err != nil { - return errors.Wrapf(err, "Error deleting tunnel %s", tunnelID) - } - - tunnelCredentialsPath, err := tunnelCredentialsPath(c, tunnelID, logger) - if err != nil { - logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err) - return nil - } - - if err = os.Remove(tunnelCredentialsPath); err != nil { - logger.Infof("Cannot delete tunnel credentials, error: %v. Please delete the file manually", err) - } - - return nil } func renderOutput(format string, v interface{}) error { @@ -435,35 +309,10 @@ func renderOutput(format string, v interface{}) error { } } -func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) (tunnelstore.Client, error) { - return tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ZoneID, cert.ServiceKey, logger) -} - -func getOriginCertFromContext(c *cli.Context, logger logger.Service) (cert *certutil.OriginCert, originCertPath string, err error) { - originCertPath, err = findOriginCert(c, logger) - if err != nil { - return nil, "", errors.Wrap(err, "Error locating origin cert") - } - blocks, err := readOriginCert(originCertPath, logger) - if err != nil { - return nil, "", errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) - } - - cert, err = certutil.DecodeOriginCert(blocks) - if err != nil { - return nil, "", errors.Wrap(err, "Error decoding origin cert") - } - - if cert.AccountID == "" { - return nil, "", errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) - } - return cert, originCertPath, nil -} - func buildRunCommand() *cli.Command { return &cli.Command{ Name: "run", - Action: cliutil.ErrorHandler(runTunnel), + Action: cliutil.ErrorHandler(runCommand), Usage: "Proxy a local web server by running the given tunnel", ArgsUsage: "TUNNEL-ID", Hidden: hideSubcommands, @@ -471,77 +320,55 @@ func buildRunCommand() *cli.Command { } } -func runTunnel(c *cli.Context) error { +func runCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) + if err != nil { + return err + } + if c.NArg() != 1 { return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) } - tunnelID, err := uuid.Parse(c.Args().First()) if err != nil { return errors.Wrap(err, "error parsing tunnel ID") } - logger, err := createLogger(c, false) - if err != nil { - return errors.Wrap(err, "error setting up logger") - } - - credentials, err := readTunnelCredentials(c, tunnelID, logger) - if err != nil { - return err - } - logger.Debugf("Read credentials for %v", credentials.AccountTag) - return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID}) + return sc.run(tunnelID) } func buildCleanupCommand() *cli.Command { return &cli.Command{ Name: "cleanup", - Action: cliutil.ErrorHandler(cleanupConnections), + Action: cliutil.ErrorHandler(cleanupCommand), Usage: "Cleanup connections for the tunnel with given IDs", ArgsUsage: "TUNNEL-IDS", Hidden: hideSubcommands, } } -func cleanupConnections(c *cli.Context) error { +func cleanupCommand(c *cli.Context) error { if c.NArg() < 1 { return cliutil.UsageError(`"cloudflared tunnel cleanup" requires at least 1 argument, the IDs of the tunnels to cleanup connections.`) } - logger, err := createLogger(c, false) - if err != nil { - return errors.Wrap(err, "error setting up logger") - } - - cert, _, err := getOriginCertFromContext(c, logger) - if err != nil { - return err - } - client, err := newTunnelstoreClient(c, cert, logger) + sc, err := newSubcommandContext(c) if err != nil { return err } - for i := 0; i < c.NArg(); i++ { - tunnelID, err := uuid.Parse(c.Args().Get(i)) - if err != nil { - logger.Errorf("Failed to parse argument %d as tunnelID, error :%v", i, err) - continue - } - logger.Infof("Cleanup connection for tunnel %s", tunnelID) - if err := client.CleanupConnections(tunnelID); err != nil { - logger.Errorf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err) - } + tunnelIDs, err := tunnelIDsFromArgs(c) + if err != nil { + return err } - return nil + return sc.cleanupConnections(tunnelIDs) } func buildRouteCommand() *cli.Command { return &cli.Command{ Name: "route", - Action: cliutil.ErrorHandler(routeTunnel), + Action: cliutil.ErrorHandler(routeCommand), Usage: "Define what hostname or load balancer can route to this tunnel", Description: `The route defines what hostname or load balancer can route to this tunnel. To route a hostname: cloudflared tunnel route dns @@ -552,57 +379,6 @@ func buildRouteCommand() *cli.Command { } } -func routeTunnel(c *cli.Context) error { - if c.NArg() < 2 { - return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`) - } - const tunnelIDIndex = 1 - tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex)) - if err != nil { - return errors.Wrap(err, "error parsing tunnel ID") - } - - logger, err := createLogger(c, false) - if err != nil { - return errors.Wrap(err, "error setting up logger") - } - - routeType := c.Args().First() - var route tunnelstore.Route - switch routeType { - case "dns": - route, err = dnsRouteFromArg(c, tunnelID) - if err != nil { - return err - } - case "lb": - route, err = lbRouteFromArg(c, tunnelID, logger) - if err != nil { - return err - } - default: - return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType) - } - - cert, _, err := getOriginCertFromContext(c, logger) - if err != nil { - return err - } - - client, err := newTunnelstoreClient(c, cert, logger) - if err != nil { - return err - } - - if err := client.RouteTunnel(tunnelID, route); err != nil { - return errors.Wrap(err, "Failed to route tunnel") - } - - logger.Infof(route.SuccessSummary()) - - return nil -} - func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { const ( userHostnameIndex = 2 @@ -618,7 +394,7 @@ func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, err return tunnelstore.NewDNSRoute(userHostname), nil } -func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (tunnelstore.Route, error) { +func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { const ( lbNameIndex = 2 lbPoolIndex = 3 @@ -633,9 +409,51 @@ func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) ( } lbPool := c.Args().Get(lbPoolIndex) if lbPool == "" { - lbPool = fmt.Sprintf("tunnel:%v", tunnelID) - logger.Infof("Generate pool name %s", lbPool) + lbPool = defaultPoolName(tunnelID) } return tunnelstore.NewLBRoute(lbName, lbPool), nil } + +func routeCommand(c *cli.Context) error { + if c.NArg() < 2 { + return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`) + } + const tunnelIDIndex = 1 + tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex)) + if err != nil { + return errors.Wrap(err, "error parsing tunnel ID") + } + + sc, err := newSubcommandContext(c) + if err != nil { + return err + } + + routeType := c.Args().First() + var r tunnelstore.Route + switch routeType { + case "dns": + r, err = dnsRouteFromArg(c, tunnelID) + if err != nil { + return err + } + case "lb": + r, err = lbRouteFromArg(c, tunnelID) + if err != nil { + return err + } + default: + return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType) + } + + if err := sc.route(tunnelID, r); err != nil { + return err + } + sc.logger.Infof(r.SuccessSummary()) + return nil +} + +func defaultPoolName(tunnelID uuid.UUID) string { + return fmt.Sprintf("tunnel:%v", tunnelID) +} diff --git a/tunnelstore/client.go b/tunnelstore/client.go index 991e59ec..f211a02b 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -117,7 +117,7 @@ type Client interface { CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) DeleteTunnel(tunnelID uuid.UUID) error - ListTunnels(filter *Filter) ([]Tunnel, error) + ListTunnels(filter *Filter) ([]*Tunnel, error) CleanupConnections(tunnelID uuid.UUID) error RouteTunnel(tunnelID uuid.UUID, route Route) error } @@ -223,7 +223,7 @@ func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error { return r.statusCodeToError("delete tunnel", resp) } -func (r *RESTClient) ListTunnels(filter *Filter) ([]Tunnel, error) { +func (r *RESTClient) ListTunnels(filter *Filter) ([]*Tunnel, error) { endpoint := r.baseEndpoints.accountLevel endpoint.RawQuery = filter.encode() resp, err := r.sendRequest("GET", endpoint, nil) @@ -233,7 +233,7 @@ func (r *RESTClient) ListTunnels(filter *Filter) ([]Tunnel, error) { defer resp.Body.Close() if resp.StatusCode == http.StatusOK { - var tunnels []Tunnel + var tunnels []*Tunnel if err := json.NewDecoder(resp.Body).Decode(&tunnels); err != nil { return nil, errors.Wrap(err, "failed to decode response") } diff --git a/tunnelstore/filter.go b/tunnelstore/filter.go index 1c24e63d..141997f7 100644 --- a/tunnelstore/filter.go +++ b/tunnelstore/filter.go @@ -25,7 +25,7 @@ func (f *Filter) ByName(name string) { f.queryParams.Set("name", name) } -func (f *Filter) ShowDeleted() { +func (f *Filter) NoDeleted() { f.queryParams.Set("is_deleted", "false") }