TUN-3084: Generate and store tunnel_secret value during tunnel creation
This commit is contained in:
parent
8f75feac94
commit
3ec500bdbb
|
@ -1,9 +1,11 @@
|
||||||
package tunnel
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -15,6 +17,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/certutil"
|
"github.com/cloudflare/cloudflared/certutil"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,6 +42,13 @@ func buildCreateCommand() *cli.Command {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateTunnelSecret as an array of 32 bytes using secure random number generator
|
||||||
|
func generateTunnelSecret() ([]byte, error) {
|
||||||
|
randomBytes := make([]byte, 32)
|
||||||
|
_, err := rand.Read(randomBytes)
|
||||||
|
return randomBytes, err
|
||||||
|
}
|
||||||
|
|
||||||
func createTunnel(c *cli.Context) error {
|
func createTunnel(c *cli.Context) error {
|
||||||
if c.NArg() != 1 {
|
if c.NArg() != 1 {
|
||||||
return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`)
|
return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`)
|
||||||
|
@ -50,16 +60,40 @@ func createTunnel(c *cli.Context) error {
|
||||||
return errors.Wrap(err, "error setting up logger")
|
return errors.Wrap(err, "error setting up logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := newTunnelstoreClient(c, logger)
|
tunnelSecret, err := generateTunnelSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnel, err := client.CreateTunnel(name)
|
originCertPath, err := findOriginCert(c, logger)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Error locating origin cert")
|
||||||
|
}
|
||||||
|
cert, err := getOriginCertFromContext(originCertPath, logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
client := newTunnelstoreClient(c, cert, logger)
|
||||||
|
|
||||||
|
tunnel, err := client.CreateTunnel(name, tunnelSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Error creating a new tunnel")
|
return errors.Wrap(err, "Error creating a new tunnel")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if writeFileErr := writeTunnelCredentials(tunnel.ID, cert.AccountID, originCertPath, tunnelSecret, logger); err != nil {
|
||||||
|
var errorLines []string
|
||||||
|
errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID))
|
||||||
|
errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr))
|
||||||
|
if deleteErr := client.DeleteTunnel(tunnel.ID); deleteErr != nil {
|
||||||
|
errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID))
|
||||||
|
errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr))
|
||||||
|
} else {
|
||||||
|
errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the tunnelfile"))
|
||||||
|
}
|
||||||
|
errorMsg := strings.Join(errorLines, "\n")
|
||||||
|
return errors.New(errorMsg)
|
||||||
|
}
|
||||||
|
|
||||||
if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" {
|
if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" {
|
||||||
return renderOutput(outputFormat, &tunnel)
|
return renderOutput(outputFormat, &tunnel)
|
||||||
}
|
}
|
||||||
|
@ -68,6 +102,34 @@ func createTunnel(c *cli.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tunnelFilePath(tunnelID, originCertPath string) (string, error) {
|
||||||
|
fileName := fmt.Sprintf("%v.json", tunnelID)
|
||||||
|
return filepath.Clean(fmt.Sprintf("%v/../%v", originCertPath, fileName)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
|
||||||
|
filePath, err := tunnelFilePath(tunnelID, originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Infof("Writing tunnel credentials to %v. cloudflared chose this file based on where your origin certificate was found.", filePath)
|
||||||
|
logger.Infof("Keep this file secret. To revoke these credentials, delete the tunnel.")
|
||||||
|
file, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, fmt.Sprintf("Unable to write to %s", filePath))
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
body, err := json.Marshal(pogs.TunnelAuth{
|
||||||
|
AccountTag: accountID,
|
||||||
|
TunnelSecret: tunnelSecret,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Unable to marshal tunnel credentials to JSON")
|
||||||
|
}
|
||||||
|
fmt.Fprintf(file, "%d", body)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func buildListCommand() *cli.Command {
|
func buildListCommand() *cli.Command {
|
||||||
return &cli.Command{
|
return &cli.Command{
|
||||||
Name: "list",
|
Name: "list",
|
||||||
|
@ -85,10 +147,15 @@ func listTunnels(c *cli.Context) error {
|
||||||
return errors.Wrap(err, "error setting up logger")
|
return errors.Wrap(err, "error setting up logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := newTunnelstoreClient(c, logger)
|
originCertPath, err := findOriginCert(c, logger)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Error locating origin cert")
|
||||||
|
}
|
||||||
|
cert, err := getOriginCertFromContext(originCertPath, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
client := newTunnelstoreClient(c, cert, logger)
|
||||||
|
|
||||||
tunnels, err := client.ListTunnels()
|
tunnels, err := client.ListTunnels()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -155,10 +222,15 @@ func deleteTunnel(c *cli.Context) error {
|
||||||
return errors.Wrap(err, "error setting up logger")
|
return errors.Wrap(err, "error setting up logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := newTunnelstoreClient(c, logger)
|
originCertPath, err := findOriginCert(c, logger)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Error locating origin cert")
|
||||||
|
}
|
||||||
|
cert, err := getOriginCertFromContext(originCertPath, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
client := newTunnelstoreClient(c, cert, logger)
|
||||||
|
|
||||||
if err := client.DeleteTunnel(id); err != nil {
|
if err := client.DeleteTunnel(id); err != nil {
|
||||||
return errors.Wrapf(err, "Error deleting tunnel %s", id)
|
return errors.Wrapf(err, "Error deleting tunnel %s", id)
|
||||||
|
@ -180,11 +252,12 @@ func renderOutput(format string, v interface{}) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Client, error) {
|
func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client {
|
||||||
originCertPath, err := findOriginCert(c, logger)
|
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger)
|
||||||
if err != nil {
|
return client
|
||||||
return nil, errors.Wrap(err, "Error locating origin cert")
|
}
|
||||||
}
|
|
||||||
|
func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) {
|
||||||
|
|
||||||
blocks, err := readOriginCert(originCertPath, logger)
|
blocks, err := readOriginCert(originCertPath, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -199,8 +272,5 @@ func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Cl
|
||||||
if cert.AccountID == "" {
|
if cert.AccountID == "" {
|
||||||
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
|
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
|
||||||
}
|
}
|
||||||
|
return cert, nil
|
||||||
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger)
|
|
||||||
|
|
||||||
return client, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_fmtConnections(t *testing.T) {
|
func Test_fmtConnections(t *testing.T) {
|
||||||
|
@ -69,3 +70,10 @@ func Test_fmtConnections(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTunnelfilePath(t *testing.T) {
|
||||||
|
actual, err := tunnelFilePath("tunnel", "~/.cloudflared/cert.pem")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
expected := "~/.cloudflared/tunnel.json"
|
||||||
|
assert.Equal(t, expected, actual)
|
||||||
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ type Connection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
CreateTunnel(name string) (*Tunnel, error)
|
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
|
||||||
GetTunnel(id string) (*Tunnel, error)
|
GetTunnel(id string) (*Tunnel, error)
|
||||||
DeleteTunnel(id string) error
|
DeleteTunnel(id string) error
|
||||||
ListTunnels() ([]Tunnel, error)
|
ListTunnels() ([]Tunnel, error)
|
||||||
|
@ -75,14 +75,16 @@ func NewRESTClient(baseURL string, accountTag string, authToken string, logger l
|
||||||
|
|
||||||
type newTunnel struct {
|
type newTunnel struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
TunnelSecret []byte `json:"tunnel_secret"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RESTClient) CreateTunnel(name string) (*Tunnel, error) {
|
func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return nil, errors.New("tunnel name required")
|
return nil, errors.New("tunnel name required")
|
||||||
}
|
}
|
||||||
body, err := json.Marshal(&newTunnel{
|
body, err := json.Marshal(&newTunnel{
|
||||||
Name: name,
|
Name: name,
|
||||||
|
TunnelSecret: tunnelSecret,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Failed to serialize new tunnel request")
|
return nil, errors.Wrap(err, "Failed to serialize new tunnel request")
|
||||||
|
|
Loading…
Reference in New Issue