TUN-3084: Generate and store tunnel_secret value during tunnel creation

This commit is contained in:
Adam Chalmers 2020-06-15 13:33:41 -05:00
parent 8f75feac94
commit 3ec500bdbb
3 changed files with 97 additions and 17 deletions

View File

@ -1,9 +1,11 @@
package tunnel package tunnel
import ( import (
"crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"path/filepath"
"sort" "sort"
"strings" "strings"
"time" "time"
@ -15,6 +17,7 @@ import (
"github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstore" "github.com/cloudflare/cloudflared/tunnelstore"
) )
@ -39,6 +42,13 @@ func buildCreateCommand() *cli.Command {
} }
} }
// generateTunnelSecret as an array of 32 bytes using secure random number generator
func generateTunnelSecret() ([]byte, error) {
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
return randomBytes, err
}
func createTunnel(c *cli.Context) error { func createTunnel(c *cli.Context) error {
if c.NArg() != 1 { if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`) return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`)
@ -50,16 +60,40 @@ func createTunnel(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger") return errors.Wrap(err, "error setting up logger")
} }
client, err := newTunnelstoreClient(c, logger) tunnelSecret, err := generateTunnelSecret()
if err != nil { if err != nil {
return err return err
} }
tunnel, err := client.CreateTunnel(name) originCertPath, err := findOriginCert(c, logger)
if err != nil {
return errors.Wrap(err, "Error locating origin cert")
}
cert, err := getOriginCertFromContext(originCertPath, logger)
if err != nil {
return err
}
client := newTunnelstoreClient(c, cert, logger)
tunnel, err := client.CreateTunnel(name, tunnelSecret)
if err != nil { if err != nil {
return errors.Wrap(err, "Error creating a new tunnel") return errors.Wrap(err, "Error creating a new tunnel")
} }
if writeFileErr := writeTunnelCredentials(tunnel.ID, cert.AccountID, originCertPath, tunnelSecret, logger); err != nil {
var errorLines []string
errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID))
errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr))
if deleteErr := client.DeleteTunnel(tunnel.ID); deleteErr != nil {
errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID))
errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr))
} else {
errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the tunnelfile"))
}
errorMsg := strings.Join(errorLines, "\n")
return errors.New(errorMsg)
}
if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" {
return renderOutput(outputFormat, &tunnel) return renderOutput(outputFormat, &tunnel)
} }
@ -68,6 +102,34 @@ func createTunnel(c *cli.Context) error {
return nil return nil
} }
func tunnelFilePath(tunnelID, originCertPath string) (string, error) {
fileName := fmt.Sprintf("%v.json", tunnelID)
return filepath.Clean(fmt.Sprintf("%v/../%v", originCertPath, fileName)), nil
}
func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
filePath, err := tunnelFilePath(tunnelID, originCertPath)
if err != nil {
return err
}
logger.Infof("Writing tunnel credentials to %v. cloudflared chose this file based on where your origin certificate was found.", filePath)
logger.Infof("Keep this file secret. To revoke these credentials, delete the tunnel.")
file, err := os.Create(filePath)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Unable to write to %s", filePath))
}
defer file.Close()
body, err := json.Marshal(pogs.TunnelAuth{
AccountTag: accountID,
TunnelSecret: tunnelSecret,
})
if err != nil {
return errors.Wrap(err, "Unable to marshal tunnel credentials to JSON")
}
fmt.Fprintf(file, "%d", body)
return nil
}
func buildListCommand() *cli.Command { func buildListCommand() *cli.Command {
return &cli.Command{ return &cli.Command{
Name: "list", Name: "list",
@ -85,10 +147,15 @@ func listTunnels(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger") return errors.Wrap(err, "error setting up logger")
} }
client, err := newTunnelstoreClient(c, logger) originCertPath, err := findOriginCert(c, logger)
if err != nil {
return errors.Wrap(err, "Error locating origin cert")
}
cert, err := getOriginCertFromContext(originCertPath, logger)
if err != nil { if err != nil {
return err return err
} }
client := newTunnelstoreClient(c, cert, logger)
tunnels, err := client.ListTunnels() tunnels, err := client.ListTunnels()
if err != nil { if err != nil {
@ -155,10 +222,15 @@ func deleteTunnel(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger") return errors.Wrap(err, "error setting up logger")
} }
client, err := newTunnelstoreClient(c, logger) originCertPath, err := findOriginCert(c, logger)
if err != nil {
return errors.Wrap(err, "Error locating origin cert")
}
cert, err := getOriginCertFromContext(originCertPath, logger)
if err != nil { if err != nil {
return err return err
} }
client := newTunnelstoreClient(c, cert, logger)
if err := client.DeleteTunnel(id); err != nil { if err := client.DeleteTunnel(id); err != nil {
return errors.Wrapf(err, "Error deleting tunnel %s", id) return errors.Wrapf(err, "Error deleting tunnel %s", id)
@ -180,11 +252,12 @@ func renderOutput(format string, v interface{}) error {
} }
} }
func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Client, error) { func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client {
originCertPath, err := findOriginCert(c, logger) client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger)
if err != nil { return client
return nil, errors.Wrap(err, "Error locating origin cert") }
}
func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) {
blocks, err := readOriginCert(originCertPath, logger) blocks, err := readOriginCert(originCertPath, logger)
if err != nil { if err != nil {
@ -199,8 +272,5 @@ func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Cl
if cert.AccountID == "" { if cert.AccountID == "" {
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
} }
return cert, nil
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger)
return client, nil
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/cloudflare/cloudflared/tunnelstore" "github.com/cloudflare/cloudflared/tunnelstore"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert"
) )
func Test_fmtConnections(t *testing.T) { func Test_fmtConnections(t *testing.T) {
@ -69,3 +70,10 @@ func Test_fmtConnections(t *testing.T) {
}) })
} }
} }
func TestTunnelfilePath(t *testing.T) {
actual, err := tunnelFilePath("tunnel", "~/.cloudflared/cert.pem")
assert.NoError(t, err)
expected := "~/.cloudflared/tunnel.json"
assert.Equal(t, expected, actual)
}

View File

@ -39,7 +39,7 @@ type Connection struct {
} }
type Client interface { type Client interface {
CreateTunnel(name string) (*Tunnel, error) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
GetTunnel(id string) (*Tunnel, error) GetTunnel(id string) (*Tunnel, error)
DeleteTunnel(id string) error DeleteTunnel(id string) error
ListTunnels() ([]Tunnel, error) ListTunnels() ([]Tunnel, error)
@ -75,14 +75,16 @@ func NewRESTClient(baseURL string, accountTag string, authToken string, logger l
type newTunnel struct { type newTunnel struct {
Name string `json:"name"` Name string `json:"name"`
TunnelSecret []byte `json:"tunnel_secret"`
} }
func (r *RESTClient) CreateTunnel(name string) (*Tunnel, error) { func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
if name == "" { if name == "" {
return nil, errors.New("tunnel name required") return nil, errors.New("tunnel name required")
} }
body, err := json.Marshal(&newTunnel{ body, err := json.Marshal(&newTunnel{
Name: name, Name: name,
TunnelSecret: tunnelSecret,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Failed to serialize new tunnel request") return nil, errors.Wrap(err, "Failed to serialize new tunnel request")