TUN-3084: Generate and store tunnel_secret value during tunnel creation
This commit is contained in:
		
							parent
							
								
									8f75feac94
								
							
						
					
					
						commit
						3ec500bdbb
					
				|  | @ -1,9 +1,11 @@ | |||
| package tunnel | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | @ -15,6 +17,7 @@ import ( | |||
| 	"github.com/cloudflare/cloudflared/certutil" | ||||
| 	"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" | ||||
| 	"github.com/cloudflare/cloudflared/logger" | ||||
| 	"github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||
| 	"github.com/cloudflare/cloudflared/tunnelstore" | ||||
| ) | ||||
| 
 | ||||
|  | @ -39,6 +42,13 @@ func buildCreateCommand() *cli.Command { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // generateTunnelSecret as an array of 32 bytes using secure random number generator
 | ||||
| func generateTunnelSecret() ([]byte, error) { | ||||
| 	randomBytes := make([]byte, 32) | ||||
| 	_, err := rand.Read(randomBytes) | ||||
| 	return randomBytes, err | ||||
| } | ||||
| 
 | ||||
| func createTunnel(c *cli.Context) error { | ||||
| 	if c.NArg() != 1 { | ||||
| 		return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`) | ||||
|  | @ -50,16 +60,40 @@ func createTunnel(c *cli.Context) error { | |||
| 		return errors.Wrap(err, "error setting up logger") | ||||
| 	} | ||||
| 
 | ||||
| 	client, err := newTunnelstoreClient(c, logger) | ||||
| 	tunnelSecret, err := generateTunnelSecret() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	tunnel, err := client.CreateTunnel(name) | ||||
| 	originCertPath, err := findOriginCert(c, logger) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Error locating origin cert") | ||||
| 	} | ||||
| 	cert, err := getOriginCertFromContext(originCertPath, logger) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	client := newTunnelstoreClient(c, cert, logger) | ||||
| 
 | ||||
| 	tunnel, err := client.CreateTunnel(name, tunnelSecret) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Error creating a new tunnel") | ||||
| 	} | ||||
| 
 | ||||
| 	if writeFileErr := writeTunnelCredentials(tunnel.ID, cert.AccountID, originCertPath, tunnelSecret, logger); err != nil { | ||||
| 		var errorLines []string | ||||
| 		errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID)) | ||||
| 		errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr)) | ||||
| 		if deleteErr := client.DeleteTunnel(tunnel.ID); deleteErr != nil { | ||||
| 			errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID)) | ||||
| 			errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr)) | ||||
| 		} else { | ||||
| 			errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the tunnelfile")) | ||||
| 		} | ||||
| 		errorMsg := strings.Join(errorLines, "\n") | ||||
| 		return errors.New(errorMsg) | ||||
| 	} | ||||
| 
 | ||||
| 	if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { | ||||
| 		return renderOutput(outputFormat, &tunnel) | ||||
| 	} | ||||
|  | @ -68,6 +102,34 @@ func createTunnel(c *cli.Context) error { | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func tunnelFilePath(tunnelID, originCertPath string) (string, error) { | ||||
| 	fileName := fmt.Sprintf("%v.json", tunnelID) | ||||
| 	return filepath.Clean(fmt.Sprintf("%v/../%v", originCertPath, fileName)), nil | ||||
| } | ||||
| 
 | ||||
| func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error { | ||||
| 	filePath, err := tunnelFilePath(tunnelID, originCertPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	logger.Infof("Writing tunnel credentials to %v. cloudflared chose this file based on where your origin certificate was found.", filePath) | ||||
| 	logger.Infof("Keep this file secret. To revoke these credentials, delete the tunnel.") | ||||
| 	file, err := os.Create(filePath) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, fmt.Sprintf("Unable to write to %s", filePath)) | ||||
| 	} | ||||
| 	defer file.Close() | ||||
| 	body, err := json.Marshal(pogs.TunnelAuth{ | ||||
| 		AccountTag:   accountID, | ||||
| 		TunnelSecret: tunnelSecret, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Unable to marshal tunnel credentials to JSON") | ||||
| 	} | ||||
| 	fmt.Fprintf(file, "%d", body) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func buildListCommand() *cli.Command { | ||||
| 	return &cli.Command{ | ||||
| 		Name:      "list", | ||||
|  | @ -85,10 +147,15 @@ func listTunnels(c *cli.Context) error { | |||
| 		return errors.Wrap(err, "error setting up logger") | ||||
| 	} | ||||
| 
 | ||||
| 	client, err := newTunnelstoreClient(c, logger) | ||||
| 	originCertPath, err := findOriginCert(c, logger) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Error locating origin cert") | ||||
| 	} | ||||
| 	cert, err := getOriginCertFromContext(originCertPath, logger) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	client := newTunnelstoreClient(c, cert, logger) | ||||
| 
 | ||||
| 	tunnels, err := client.ListTunnels() | ||||
| 	if err != nil { | ||||
|  | @ -155,10 +222,15 @@ func deleteTunnel(c *cli.Context) error { | |||
| 		return errors.Wrap(err, "error setting up logger") | ||||
| 	} | ||||
| 
 | ||||
| 	client, err := newTunnelstoreClient(c, logger) | ||||
| 	originCertPath, err := findOriginCert(c, logger) | ||||
| 	if err != nil { | ||||
| 		return errors.Wrap(err, "Error locating origin cert") | ||||
| 	} | ||||
| 	cert, err := getOriginCertFromContext(originCertPath, logger) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	client := newTunnelstoreClient(c, cert, logger) | ||||
| 
 | ||||
| 	if err := client.DeleteTunnel(id); err != nil { | ||||
| 		return errors.Wrapf(err, "Error deleting tunnel %s", id) | ||||
|  | @ -180,11 +252,12 @@ func renderOutput(format string, v interface{}) error { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Client, error) { | ||||
| 	originCertPath, err := findOriginCert(c, logger) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error locating origin cert") | ||||
| 	} | ||||
| func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client { | ||||
| 	client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger) | ||||
| 	return client | ||||
| } | ||||
| 
 | ||||
| func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) { | ||||
| 
 | ||||
| 	blocks, err := readOriginCert(originCertPath, logger) | ||||
| 	if err != nil { | ||||
|  | @ -199,8 +272,5 @@ func newTunnelstoreClient(c *cli.Context, logger logger.Service) (tunnelstore.Cl | |||
| 	if cert.AccountID == "" { | ||||
| 		return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) | ||||
| 	} | ||||
| 
 | ||||
| 	client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger) | ||||
| 
 | ||||
| 	return client, nil | ||||
| 	return cert, nil | ||||
| } | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"github.com/cloudflare/cloudflared/tunnelstore" | ||||
| 
 | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func Test_fmtConnections(t *testing.T) { | ||||
|  | @ -69,3 +70,10 @@ func Test_fmtConnections(t *testing.T) { | |||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestTunnelfilePath(t *testing.T) { | ||||
| 	actual, err := tunnelFilePath("tunnel", "~/.cloudflared/cert.pem") | ||||
| 	assert.NoError(t, err) | ||||
| 	expected := "~/.cloudflared/tunnel.json" | ||||
| 	assert.Equal(t, expected, actual) | ||||
| } | ||||
|  |  | |||
|  | @ -39,7 +39,7 @@ type Connection struct { | |||
| } | ||||
| 
 | ||||
| type Client interface { | ||||
| 	CreateTunnel(name string) (*Tunnel, error) | ||||
| 	CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) | ||||
| 	GetTunnel(id string) (*Tunnel, error) | ||||
| 	DeleteTunnel(id string) error | ||||
| 	ListTunnels() ([]Tunnel, error) | ||||
|  | @ -74,15 +74,17 @@ func NewRESTClient(baseURL string, accountTag string, authToken string, logger l | |||
| } | ||||
| 
 | ||||
| type newTunnel struct { | ||||
| 	Name string `json:"name"` | ||||
| 	Name         string `json:"name"` | ||||
| 	TunnelSecret []byte `json:"tunnel_secret"` | ||||
| } | ||||
| 
 | ||||
| func (r *RESTClient) CreateTunnel(name string) (*Tunnel, error) { | ||||
| func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) { | ||||
| 	if name == "" { | ||||
| 		return nil, errors.New("tunnel name required") | ||||
| 	} | ||||
| 	body, err := json.Marshal(&newTunnel{ | ||||
| 		Name: name, | ||||
| 		Name:         name, | ||||
| 		TunnelSecret: tunnelSecret, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Failed to serialize new tunnel request") | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue