diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 306ce13b..07740789 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -1,9 +1,11 @@ package tunnel import ( + "crypto/rand" "encoding/json" "fmt" "os" + "path/filepath" "sort" "strings" "time" @@ -15,6 +17,7 @@ import ( "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelstore" ) @@ -39,6 +42,13 @@ func buildCreateCommand() *cli.Command { } } +// generateTunnelSecret as an array of 32 bytes using secure random number generator +func generateTunnelSecret() ([]byte, error) { + randomBytes := make([]byte, 32) + _, err := rand.Read(randomBytes) + return randomBytes, err +} + func createTunnel(c *cli.Context) error { if c.NArg() != 1 { return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`) @@ -50,16 +60,40 @@ func createTunnel(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - client, err := newTunnelstoreClient(c, logger) + tunnelSecret, err := generateTunnelSecret() if err != nil { return err } - tunnel, err := client.CreateTunnel(name) + originCertPath, err := findOriginCert(c, logger) + if err != nil { + return errors.Wrap(err, "Error locating origin cert") + } + cert, err := getOriginCertFromContext(originCertPath, logger) + if err != nil { + return err + } + client := newTunnelstoreClient(c, cert, logger) + + tunnel, err := client.CreateTunnel(name, tunnelSecret) if err != nil { return errors.Wrap(err, "Error creating a new tunnel") } + if writeFileErr := writeTunnelCredentials(tunnel.ID, cert.AccountID, originCertPath, tunnelSecret, logger); err != nil { + var errorLines []string + errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID)) + errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr)) + if deleteErr := client.DeleteTunnel(tunnel.ID); deleteErr != nil { + errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID)) + errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr)) + } else { + errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the tunnelfile")) + } + errorMsg := strings.Join(errorLines, "\n") + return errors.New(errorMsg) + } + if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { return renderOutput(outputFormat, &tunnel) } @@ -68,6 +102,34 @@ func createTunnel(c *cli.Context) error { return nil } +func tunnelFilePath(tunnelID, originCertPath string) (string, error) { + fileName := fmt.Sprintf("%v.json", tunnelID) + return filepath.Clean(fmt.Sprintf("%v/../%v", originCertPath, fileName)), nil +} + +func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error { + filePath, err := tunnelFilePath(tunnelID, originCertPath) + if err != nil { + return err + } + logger.Infof("Writing tunnel credentials to %v. cloudflared chose this file based on where your origin certificate was found.", filePath) + logger.Infof("Keep this file secret. To revoke these credentials, delete the tunnel.") + file, err := os.Create(filePath) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Unable to write to %s", filePath)) + } + defer file.Close() + body, err := json.Marshal(pogs.TunnelAuth{ + AccountTag: accountID, + TunnelSecret: tunnelSecret, + }) + if err != nil { + return errors.Wrap(err, "Unable to marshal tunnel credentials to JSON") + } + fmt.Fprintf(file, "%d", body) + return nil +} + func buildListCommand() *cli.Command { return &cli.Command{ Name: "list", @@ -85,10 +147,15 @@ func listTunnels(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - client, err := newTunnelstoreClient(c, logger) + originCertPath, err := findOriginCert(c, logger) + if err != nil { + return errors.Wrap(err, "Error locating origin cert") + } + cert, err := getOriginCertFromContext(originCertPath, logger) if err != nil { return err } + client := newTunnelstoreClient(c, cert, logger) tunnels, err := client.ListTunnels() if err != nil { @@ -155,10 +222,15 @@ func deleteTunnel(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - client, err := newTunnelstoreClient(c, logger) + originCertPath, err := findOriginCert(c, logger) + if err != nil { + return errors.Wrap(err, "Error locating origin cert") + } + cert, err := getOriginCertFromContext(originCertPath, logger) if err != nil { return err } + client := newTunnelstoreClient(c, cert, logger) if err := client.DeleteTunnel(id); err != nil { return errors.Wrapf(err, "Error deleting tunnel %s", id) @@ -180,11 +252,12 @@ func renderOutput(format string, v interface{}) error { } } -func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Client, error) { - originCertPath, err := findOriginCert(c, logger) - if err != nil { - return nil, errors.Wrap(err, "Error locating origin cert") - } +func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client { + client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger) + return client +} + +func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) { blocks, err := readOriginCert(originCertPath, logger) if err != nil { @@ -199,8 +272,5 @@ func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Cl if cert.AccountID == "" { return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) } - - client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger) - - return client, nil + return cert, nil } diff --git a/cmd/cloudflared/tunnel/subcommands_test.go b/cmd/cloudflared/tunnel/subcommands_test.go index 202aeeb4..36b86daf 100644 --- a/cmd/cloudflared/tunnel/subcommands_test.go +++ b/cmd/cloudflared/tunnel/subcommands_test.go @@ -6,6 +6,7 @@ import ( "github.com/cloudflare/cloudflared/tunnelstore" "github.com/google/uuid" + "github.com/stretchr/testify/assert" ) func Test_fmtConnections(t *testing.T) { @@ -69,3 +70,10 @@ func Test_fmtConnections(t *testing.T) { }) } } + +func TestTunnelfilePath(t *testing.T) { + actual, err := tunnelFilePath("tunnel", "~/.cloudflared/cert.pem") + assert.NoError(t, err) + expected := "~/.cloudflared/tunnel.json" + assert.Equal(t, expected, actual) +} diff --git a/tunnelstore/client.go b/tunnelstore/client.go index 58d332f8..2730cdb1 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -39,7 +39,7 @@ type Connection struct { } type Client interface { - CreateTunnel(name string) (*Tunnel, error) + CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) GetTunnel(id string) (*Tunnel, error) DeleteTunnel(id string) error ListTunnels() ([]Tunnel, error) @@ -74,15 +74,17 @@ func NewRESTClient(baseURL string, accountTag string, authToken string, logger l } type newTunnel struct { - Name string `json:"name"` + Name string `json:"name"` + TunnelSecret []byte `json:"tunnel_secret"` } -func (r *RESTClient) CreateTunnel(name string) (*Tunnel, error) { +func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) { if name == "" { return nil, errors.New("tunnel name required") } body, err := json.Marshal(&newTunnel{ - Name: name, + Name: name, + TunnelSecret: tunnelSecret, }) if err != nil { return nil, errors.Wrap(err, "Failed to serialize new tunnel request")