From 87e06100df6d41ee9ebe1b69575c149aa68f832a Mon Sep 17 00:00:00 2001 From: cthuang Date: Thu, 2 Jul 2020 17:31:12 +0800 Subject: [PATCH] TUN-3131: Allow user to specify tunnel credentials path, and remove it in tunnel delete command --- cmd/cloudflared/tunnel/subcommands.go | 78 +++++++++++++++++++--- cmd/cloudflared/tunnel/subcommands_test.go | 10 ++- 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 837f5039..1a85f177 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -12,18 +12,24 @@ import ( "time" "github.com/google/uuid" + "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "gopkg.in/urfave/cli.v2" "gopkg.in/yaml.v2" "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelstore" ) +const ( + credFileFlagAlias = "cred-file" +) + var ( showDeletedFlag = &cli.BoolFlag{ Name: "show-deleted", @@ -43,6 +49,11 @@ var ( "overwrite the previous tunnel. If you want to use a single hostname with multiple " + "tunnels, you can do so with Cloudflare's Load Balancer product.", } + credentialsFileFlag = &cli.StringFlag{ + Name: "credentials-file", + Aliases: []string{credFileFlagAlias}, + Usage: "File path of tunnel credentials", + } ) const hideSubcommands = true @@ -118,13 +129,15 @@ func createTunnel(c *cli.Context) error { return nil } -func tunnelFilePath(tunnelID, originCertPath string) (string, error) { +func tunnelFilePath(tunnelID, directory string) (string, error) { fileName := fmt.Sprintf("%v.json", tunnelID) - return filepath.Clean(fmt.Sprintf("%v/../%v", originCertPath, fileName)), nil + filePath := filepath.Clean(fmt.Sprintf("%s/%s", directory, fileName)) + return homedir.Expand(filePath) } func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error { - filePath, err := tunnelFilePath(tunnelID, originCertPath) + originCertDir := filepath.Dir(originCertPath) + filePath, err := tunnelFilePath(tunnelID, originCertDir) if err != nil { return err } @@ -140,8 +153,8 @@ func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSe return ioutil.WriteFile(filePath, body, 400) } -func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, error) { - filePath, err := tunnelFilePath(tunnelID, originCertPath) +func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Service) (*pogs.TunnelAuth, error) { + filePath, err := tunnelCredentialsPath(c, tunnelID, logger) if err != nil { return nil, err } @@ -157,6 +170,42 @@ func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, e return &auth, nil } +func tunnelCredentialsPath(c *cli.Context, tunnelID string, logger logger.Service) (string, error) { + if filePath := c.String("credentials-file"); filePath != "" { + if validFilePath(filePath) { + return filePath, nil + } + } + + // Fallback to look for tunnel credentials in the origin cert directory + if originCertPath, err := findOriginCert(c, logger); err == nil { + originCertDir := filepath.Dir(originCertPath) + if filePath, err := tunnelFilePath(tunnelID, originCertDir); err == nil { + if validFilePath(filePath) { + return filePath, nil + } + } + } + + // Last resort look under default config directories + for _, configDir := range config.DefaultConfigDirs { + if filePath, err := tunnelFilePath(tunnelID, configDir); err == nil { + if validFilePath(filePath) { + return filePath, nil + } + } + } + return "", fmt.Errorf("Tunnel credentials file not found") +} + +func validFilePath(path string) bool { + fileStat, err := os.Stat(path) + if err != nil { + return false + } + return !fileStat.IsDir() +} + func buildListCommand() *cli.Command { return &cli.Command{ Name: "list", @@ -246,6 +295,7 @@ func buildDeleteCommand() *cli.Command { Usage: "Delete existing tunnel with given ID", ArgsUsage: "TUNNEL-ID", Hidden: hideSubcommands, + Flags: []cli.Flag{credentialsFileFlag}, } } @@ -274,6 +324,16 @@ func deleteTunnel(c *cli.Context) error { return errors.Wrapf(err, "Error deleting tunnel %s", id) } + tunnelCredentialsPath, err := tunnelCredentialsPath(c, id, logger) + if err != nil { + logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err) + return nil + } + + if err = os.Remove(tunnelCredentialsPath); err != nil { + logger.Infof("Cannot delete tunnel credentials, error: %v. Please delete the file manually", err) + } + return nil } @@ -320,7 +380,7 @@ func buildRunCommand() *cli.Command { Usage: "Proxy a local web server by running the given tunnel", ArgsUsage: "TUNNEL-ID", Hidden: hideSubcommands, - Flags: []cli.Flag{forceFlag}, + Flags: []cli.Flag{forceFlag, credentialsFileFlag}, } } @@ -339,11 +399,7 @@ func runTunnel(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - originCertPath, err := findOriginCert(c, logger) - if err != nil { - return errors.Wrap(err, "Error locating origin cert") - } - credentials, err := readTunnelCredentials(id, originCertPath) + credentials, err := readTunnelCredentials(c, id, logger) if err != nil { return err } diff --git a/cmd/cloudflared/tunnel/subcommands_test.go b/cmd/cloudflared/tunnel/subcommands_test.go index 36b86daf..a8b4ebe3 100644 --- a/cmd/cloudflared/tunnel/subcommands_test.go +++ b/cmd/cloudflared/tunnel/subcommands_test.go @@ -1,9 +1,12 @@ package tunnel import ( + "fmt" + "path/filepath" "testing" "github.com/cloudflare/cloudflared/tunnelstore" + "github.com/mitchellh/go-homedir" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -72,8 +75,11 @@ func Test_fmtConnections(t *testing.T) { } func TestTunnelfilePath(t *testing.T) { - actual, err := tunnelFilePath("tunnel", "~/.cloudflared/cert.pem") + originCertDir := filepath.Dir("~/.cloudflared/cert.pem") + actual, err := tunnelFilePath("tunnel", originCertDir) assert.NoError(t, err) - expected := "~/.cloudflared/tunnel.json" + homeDir, err := homedir.Dir() + assert.NoError(t, err) + expected := fmt.Sprintf("%s/.cloudflared/tunnel.json", homeDir) assert.Equal(t, expected, actual) }