TUN-3581: Tunnels can be run by name using only --credentials-file, no
origin cert necessary.
This commit is contained in:
parent
fcc393e2f0
commit
69fd502db3
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -250,7 +251,7 @@ func installLinuxService(c *cli.Context) error {
|
||||||
val, err := src.String(s)
|
val, err := src.String(s)
|
||||||
return err == nil && val != ""
|
return err == nil && val != ""
|
||||||
}
|
}
|
||||||
if src.TunnelID == "" || !configPresent("credentials-file") {
|
if src.TunnelID == "" || !configPresent(tunnel.CredFileFlag) {
|
||||||
return fmt.Errorf(`Configuration file %s must contain entries for the tunnel to run and its associated credentials:
|
return fmt.Errorf(`Configuration file %s must contain entries for the tunnel to run and its associated credentials:
|
||||||
tunnel: TUNNEL-UUID
|
tunnel: TUNNEL-UUID
|
||||||
credentials-file: CREDENTIALS-FILE
|
credentials-file: CREDENTIALS-FILE
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredFinder can find the tunnel credentials file.
|
||||||
|
type CredFinder interface {
|
||||||
|
Path() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements CredFinder and looks for the credentials file at the given
|
||||||
|
// filepath.
|
||||||
|
type staticPath struct {
|
||||||
|
filePath string
|
||||||
|
fs fileSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStaticPath(filePath string, fs fileSystem) CredFinder {
|
||||||
|
return staticPath{
|
||||||
|
filePath: filePath,
|
||||||
|
fs: fs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a staticPath) Path() (string, error) {
|
||||||
|
if a.filePath != "" && a.fs.validFilePath(a.filePath) {
|
||||||
|
return a.filePath, nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("Tunnel credentials file '%s' doesn't exist or is not a file", a.filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements CredFinder and looks for the credentials file in several directories
|
||||||
|
// searching for a file named <id>.json
|
||||||
|
type searchByID struct {
|
||||||
|
id uuid.UUID
|
||||||
|
c *cli.Context
|
||||||
|
logger logger.Service
|
||||||
|
fs fileSystem
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSearchByID(id uuid.UUID, c *cli.Context, logger logger.Service, fs fileSystem) CredFinder {
|
||||||
|
return searchByID{
|
||||||
|
id: id,
|
||||||
|
c: c,
|
||||||
|
logger: logger,
|
||||||
|
fs: fs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s searchByID) Path() (string, error) {
|
||||||
|
|
||||||
|
// Fallback to look for tunnel credentials in the origin cert directory
|
||||||
|
if originCertPath, err := findOriginCert(s.c, s.logger); err == nil {
|
||||||
|
originCertDir := filepath.Dir(originCertPath)
|
||||||
|
if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil {
|
||||||
|
if s.fs.validFilePath(filePath) {
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last resort look under default config directories
|
||||||
|
for _, configDir := range config.DefaultConfigSearchDirectories() {
|
||||||
|
if filePath, err := tunnelFilePath(s.id, configDir); err == nil {
|
||||||
|
if s.fs.validFilePath(filePath) {
|
||||||
|
return filePath, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("Tunnel credentials file not found")
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Abstract away details of reading files, so that SubcommandContext can read
|
||||||
|
// from either the real filesystem, or a mock (when running unit tests).
|
||||||
|
type fileSystem interface {
|
||||||
|
readFile(filePath string) ([]byte, error)
|
||||||
|
validFilePath(path string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type realFileSystem struct{}
|
||||||
|
|
||||||
|
func (fs realFileSystem) validFilePath(path string) bool {
|
||||||
|
fileStat, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !fileStat.IsDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs realFileSystem) readFile(filePath string) ([]byte, error) {
|
||||||
|
return ioutil.ReadFile(filePath)
|
||||||
|
}
|
|
@ -3,9 +3,7 @@ package tunnel
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -13,10 +11,8 @@ import (
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/certutil"
|
"github.com/cloudflare/cloudflared/certutil"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,14 +28,14 @@ func (e errInvalidJSONCredential) Error() string {
|
||||||
// subcommandContext carries structs shared between subcommands, to reduce number of arguments needed to
|
// subcommandContext carries structs shared between subcommands, to reduce number of arguments needed to
|
||||||
// pass between subcommands, and make sure they are only initialized once
|
// pass between subcommands, and make sure they are only initialized once
|
||||||
type subcommandContext struct {
|
type subcommandContext struct {
|
||||||
c *cli.Context
|
c *cli.Context
|
||||||
logger logger.Service
|
logger logger.Service
|
||||||
|
isUIEnabled bool
|
||||||
|
fs fileSystem
|
||||||
|
|
||||||
// These fields should be accessed using their respective Getter
|
// These fields should be accessed using their respective Getter
|
||||||
tunnelstoreClient tunnelstore.Client
|
tunnelstoreClient tunnelstore.Client
|
||||||
userCredential *userCredential
|
userCredential *userCredential
|
||||||
|
|
||||||
isUIEnabled bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
|
func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
|
||||||
|
@ -55,9 +51,18 @@ func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
|
||||||
c: c,
|
c: c,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
isUIEnabled: isUIEnabled,
|
isUIEnabled: isUIEnabled,
|
||||||
|
fs: realFileSystem{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns something that can find the given tunnel's credentials file.
|
||||||
|
func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
|
||||||
|
if path := sc.c.String(CredFileFlag); path != "" {
|
||||||
|
return newStaticPath(path, sc.fs)
|
||||||
|
}
|
||||||
|
return newSearchByID(tunnelID, sc.c, sc.logger, sc.fs)
|
||||||
|
}
|
||||||
|
|
||||||
type userCredential struct {
|
type userCredential struct {
|
||||||
cert *certutil.OriginCert
|
cert *certutil.OriginCert
|
||||||
certPath string
|
certPath string
|
||||||
|
@ -108,56 +113,27 @@ func (sc *subcommandContext) credential() (*userCredential, error) {
|
||||||
return sc.userCredential, nil
|
return sc.userCredential, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *subcommandContext) readTunnelCredentials(tunnelID uuid.UUID) (*pogs.TunnelAuth, error) {
|
func (sc *subcommandContext) readTunnelCredentials(credFinder CredFinder) (connection.Credentials, error) {
|
||||||
filePath, err := sc.tunnelCredentialsPath(tunnelID)
|
filePath, err := credFinder.Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return connection.Credentials{}, err
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadFile(filePath)
|
body, err := sc.fs.readFile(filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
|
return connection.Credentials{}, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
var auth pogs.TunnelAuth
|
var credentials connection.Credentials
|
||||||
if err = json.Unmarshal(body, &auth); err != nil {
|
if err = json.Unmarshal(body, &credentials); err != nil {
|
||||||
if strings.HasSuffix(filePath, ".pem") {
|
if strings.HasSuffix(filePath, ".pem") {
|
||||||
return nil, fmt.Errorf("The tunnel credentials file should be .json but you gave a .pem. "+
|
return connection.Credentials{}, fmt.Errorf("The tunnel credentials file should be .json but you gave a .pem. " +
|
||||||
"The tunnel credentials file was originally created by `cloudflared tunnel create` and named %s.json."+
|
"The tunnel credentials file was originally created by `cloudflared tunnel create`. " +
|
||||||
"You may have accidentally used the filepath to cert.pem, which is generated by `cloudflared tunnel "+
|
"You may have accidentally used the filepath to cert.pem, which is generated by `cloudflared tunnel " +
|
||||||
"login`.", tunnelID)
|
"login`.")
|
||||||
}
|
}
|
||||||
return nil, errInvalidJSONCredential{path: filePath, err: err}
|
return connection.Credentials{}, errInvalidJSONCredential{path: filePath, err: err}
|
||||||
}
|
}
|
||||||
return &auth, nil
|
return credentials, nil
|
||||||
}
|
|
||||||
|
|
||||||
func (sc *subcommandContext) tunnelCredentialsPath(tunnelID uuid.UUID) (string, error) {
|
|
||||||
if filePath := sc.c.String("credentials-file"); filePath != "" {
|
|
||||||
if validFilePath(filePath) {
|
|
||||||
return filePath, nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("Tunnel credentials file %s doesn't exist or is not a file", filePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to look for tunnel credentials in the origin cert directory
|
|
||||||
if originCertPath, err := findOriginCert(sc.c, sc.logger); err == nil {
|
|
||||||
originCertDir := filepath.Dir(originCertPath)
|
|
||||||
if filePath, err := tunnelFilePath(tunnelID, originCertDir); err == nil {
|
|
||||||
if validFilePath(filePath) {
|
|
||||||
return filePath, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Last resort look under default config directories
|
|
||||||
for _, configDir := range config.DefaultConfigSearchDirectories() {
|
|
||||||
if filePath, err := tunnelFilePath(tunnelID, configDir); err == nil {
|
|
||||||
if validFilePath(filePath) {
|
|
||||||
return filePath, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("Tunnel credentials file not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) {
|
func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) {
|
||||||
|
@ -180,7 +156,14 @@ func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if writeFileErr := writeTunnelCredentials(tunnel.ID, credential.cert.AccountID, credential.certPath, tunnelSecret, sc.logger); err != nil {
|
tunnelCredentials := connection.Credentials{
|
||||||
|
AccountTag: credential.cert.AccountID,
|
||||||
|
TunnelSecret: tunnelSecret,
|
||||||
|
TunnelID: tunnel.ID,
|
||||||
|
TunnelName: name,
|
||||||
|
}
|
||||||
|
filePath, writeFileErr := writeTunnelCredentials(credential.certPath, &tunnelCredentials)
|
||||||
|
if err != nil {
|
||||||
var errorLines []string
|
var errorLines []string
|
||||||
errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID))
|
errorLines = append(errorLines, fmt.Sprintf("Your tunnel '%v' was created with ID %v. However, cloudflared couldn't write to the tunnel credentials file at %v.json.", tunnel.Name, tunnel.ID, tunnel.ID))
|
||||||
errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr))
|
errorLines = append(errorLines, fmt.Sprintf("The file-writing error is: %v", writeFileErr))
|
||||||
|
@ -193,6 +176,7 @@ func (sc *subcommandContext) create(name string) (*tunnelstore.Tunnel, error) {
|
||||||
errorMsg := strings.Join(errorLines, "\n")
|
errorMsg := strings.Join(errorLines, "\n")
|
||||||
return nil, errors.New(errorMsg)
|
return nil, errors.New(errorMsg)
|
||||||
}
|
}
|
||||||
|
sc.logger.Infof("Tunnel credentials written to %v. cloudflared chose this file based on where your origin certificate was found. Keep this file secret. To revoke these credentials, delete the tunnel.", filePath)
|
||||||
|
|
||||||
if outputFormat := sc.c.String(outputFormatFlag.Name); outputFormat != "" {
|
if outputFormat := sc.c.String(outputFormatFlag.Name); outputFormat != "" {
|
||||||
return nil, renderOutput(outputFormat, &tunnel)
|
return nil, renderOutput(outputFormat, &tunnel)
|
||||||
|
@ -243,7 +227,8 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
|
||||||
return errors.Wrapf(err, "Error deleting tunnel %s", tunnel.ID)
|
return errors.Wrapf(err, "Error deleting tunnel %s", tunnel.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelCredentialsPath, err := sc.tunnelCredentialsPath(tunnel.ID)
|
credFinder := sc.credentialFinder(id)
|
||||||
|
tunnelCredentialsPath, err := credFinder.Path()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sc.logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err)
|
sc.logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err)
|
||||||
return nil
|
return nil
|
||||||
|
@ -256,8 +241,21 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findCredentials will choose the right way to find the credentials file, find it,
|
||||||
|
// and add the TunnelID into any old credentials (generated before TUN-3581 added the `TunnelID`
|
||||||
|
// field to credentials files)
|
||||||
|
func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Credentials, error) {
|
||||||
|
credFinder := sc.credentialFinder(tunnelID)
|
||||||
|
credentials, err := sc.readTunnelCredentials(credFinder)
|
||||||
|
// This line ensures backwards compatibility with credentials files generated before
|
||||||
|
// TUN-3581. Those old credentials files don't have a TunnelID field, so we enrich the struct
|
||||||
|
// with the ID, which we have already resolved from the user input.
|
||||||
|
credentials.TunnelID = tunnelID
|
||||||
|
return credentials, err
|
||||||
|
}
|
||||||
|
|
||||||
func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
||||||
credentials, err := sc.readTunnelCredentials(tunnelID)
|
credentials, err := sc.findCredentials(tunnelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := err.(errInvalidJSONCredential); ok {
|
if e, ok := err.(errInvalidJSONCredential); ok {
|
||||||
sc.logger.Errorf("The credentials file at %s contained invalid JSON. This is probably caused by passing the wrong filepath. Reminder: the credentials file is a .json file created via `cloudflared tunnel create`.", e.path)
|
sc.logger.Errorf("The credentials file at %s contained invalid JSON. This is probably caused by passing the wrong filepath. Reminder: the credentials file is a .json file created via `cloudflared tunnel create`.", e.path)
|
||||||
|
@ -265,13 +263,12 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return StartServer(
|
return StartServer(
|
||||||
sc.c,
|
sc.c,
|
||||||
version,
|
version,
|
||||||
shutdownC,
|
shutdownC,
|
||||||
graceShutdownC,
|
graceShutdownC,
|
||||||
&connection.NamedTunnelConfig{Auth: *credentials, ID: tunnelID},
|
&connection.NamedTunnelConfig{Credentials: credentials},
|
||||||
sc.logger,
|
sc.logger,
|
||||||
sc.isUIEnabled,
|
sc.isUIEnabled,
|
||||||
)
|
)
|
||||||
|
@ -300,6 +297,7 @@ func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) (tun
|
||||||
return client.RouteTunnel(tunnelID, r)
|
return client.RouteTunnel(tunnelID, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Query Tunnelstore to find the active tunnel with the given name.
|
||||||
func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, bool, error) {
|
func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, bool, error) {
|
||||||
filter := tunnelstore.NewFilter()
|
filter := tunnelstore.NewFilter()
|
||||||
filter.NoDeleted()
|
filter.NoDeleted()
|
||||||
|
@ -322,6 +320,15 @@ func (sc *subcommandContext) findID(input string) (uuid.UUID, error) {
|
||||||
return u, nil
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Look up name in the credentials file.
|
||||||
|
credFinder := newStaticPath(sc.c.String(CredFileFlag), sc.fs)
|
||||||
|
if credentials, err := sc.readTunnelCredentials(credFinder); err == nil {
|
||||||
|
if credentials.TunnelID != uuid.Nil && input == credentials.TunnelName {
|
||||||
|
return credentials.TunnelID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to querying Tunnelstore.
|
||||||
if tunnel, found, err := sc.tunnelActive(input); err != nil {
|
if tunnel, found, err := sc.tunnelActive(input); err != nil {
|
||||||
return uuid.Nil, err
|
return uuid.Nil, err
|
||||||
} else if found {
|
} else if found {
|
||||||
|
|
|
@ -1,11 +1,19 @@
|
||||||
package tunnel
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_findIDs(t *testing.T) {
|
func Test_findIDs(t *testing.T) {
|
||||||
|
@ -80,3 +88,128 @@ func Test_findIDs(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockFileSystem struct {
|
||||||
|
rf func(string) ([]byte, error)
|
||||||
|
vfp func(string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs mockFileSystem) validFilePath(path string) bool {
|
||||||
|
return fs.vfp(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs mockFileSystem) readFile(filePath string) ([]byte, error) {
|
||||||
|
return fs.rf(filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_subcommandContext_findCredentials(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
c *cli.Context
|
||||||
|
logger logger.Service
|
||||||
|
isUIEnabled bool
|
||||||
|
fs fileSystem
|
||||||
|
tunnelstoreClient tunnelstore.Client
|
||||||
|
userCredential *userCredential
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
tunnelID uuid.UUID
|
||||||
|
}
|
||||||
|
oldCertPath := "old_cert.json"
|
||||||
|
newCertPath := "new_cert.json"
|
||||||
|
accountTag := "0000d4d14e84bd4ae5a6a02e0000ac63"
|
||||||
|
secret := []byte{211, 79, 177, 245, 179, 194, 152, 127, 140, 71, 18, 46, 183, 209, 10, 24, 192, 150, 55, 249, 211, 16, 167, 30, 113, 51, 152, 168, 72, 100, 205, 144}
|
||||||
|
secretB64 := base64.StdEncoding.EncodeToString(secret)
|
||||||
|
tunnelID := uuid.MustParse("df5ed608-b8b4-4109-89f3-9f2cf199df64")
|
||||||
|
name := "mytunnel"
|
||||||
|
|
||||||
|
fs := mockFileSystem{
|
||||||
|
rf: func(filePath string) ([]byte, error) {
|
||||||
|
if filePath == oldCertPath {
|
||||||
|
// An old credentials file created before TUN-3581 added the new fields
|
||||||
|
return []byte(fmt.Sprintf(`{"AccountTag":"%s","TunnelSecret":"%s"}`, accountTag, secretB64)), nil
|
||||||
|
}
|
||||||
|
if filePath == newCertPath {
|
||||||
|
// A new credentials file created after TUN-3581 with its new fields.
|
||||||
|
return []byte(fmt.Sprintf(`{"AccountTag":"%s","TunnelSecret":"%s","TunnelID":"%s","TunnelName":"%s"}`, accountTag, secretB64, tunnelID, name)), nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("file not found")
|
||||||
|
},
|
||||||
|
vfp: func(string) bool { return true },
|
||||||
|
}
|
||||||
|
logger, err := logger.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want connection.Credentials
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Filepath given leads to old credentials file",
|
||||||
|
fields: fields{
|
||||||
|
logger: logger,
|
||||||
|
fs: fs,
|
||||||
|
c: func() *cli.Context {
|
||||||
|
flagSet := flag.NewFlagSet("test0", flag.PanicOnError)
|
||||||
|
flagSet.String(CredFileFlag, oldCertPath, "")
|
||||||
|
c := cli.NewContext(cli.NewApp(), flagSet, nil)
|
||||||
|
err = c.Set(CredFileFlag, oldCertPath)
|
||||||
|
return c
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
tunnelID: tunnelID,
|
||||||
|
},
|
||||||
|
want: connection.Credentials{
|
||||||
|
AccountTag: accountTag,
|
||||||
|
TunnelID: tunnelID,
|
||||||
|
TunnelSecret: secret,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Filepath given leads to new credentials file",
|
||||||
|
fields: fields{
|
||||||
|
logger: logger,
|
||||||
|
fs: fs,
|
||||||
|
c: func() *cli.Context {
|
||||||
|
flagSet := flag.NewFlagSet("test0", flag.PanicOnError)
|
||||||
|
flagSet.String(CredFileFlag, newCertPath, "")
|
||||||
|
c := cli.NewContext(cli.NewApp(), flagSet, nil)
|
||||||
|
err = c.Set(CredFileFlag, newCertPath)
|
||||||
|
return c
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
tunnelID: tunnelID,
|
||||||
|
},
|
||||||
|
want: connection.Credentials{
|
||||||
|
AccountTag: accountTag,
|
||||||
|
TunnelID: tunnelID,
|
||||||
|
TunnelSecret: secret,
|
||||||
|
TunnelName: name,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sc := &subcommandContext{
|
||||||
|
c: tt.fields.c,
|
||||||
|
logger: tt.fields.logger,
|
||||||
|
isUIEnabled: tt.fields.isUIEnabled,
|
||||||
|
fs: tt.fields.fs,
|
||||||
|
tunnelstoreClient: tt.fields.tunnelstoreClient,
|
||||||
|
userCredential: tt.fields.userCredential,
|
||||||
|
}
|
||||||
|
got, err := sc.findCredentials(tt.args.tunnelID)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("subcommandContext.findCredentials() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("subcommandContext.findCredentials() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -24,13 +24,12 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
credFileFlagAlias = "cred-file"
|
CredFileFlagAlias = "cred-file"
|
||||||
|
CredFileFlag = "credentials-file"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -75,8 +74,8 @@ var (
|
||||||
"tunnels, you can do so with Cloudflare's Load Balancer product.",
|
"tunnels, you can do so with Cloudflare's Load Balancer product.",
|
||||||
})
|
})
|
||||||
credentialsFileFlag = altsrc.NewStringFlag(&cli.StringFlag{
|
credentialsFileFlag = altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "credentials-file",
|
Name: CredFileFlag,
|
||||||
Aliases: []string{credFileFlagAlias},
|
Aliases: []string{CredFileFlagAlias},
|
||||||
Usage: "File path of tunnel credentials",
|
Usage: "File path of tunnel credentials",
|
||||||
EnvVars: []string{"TUNNEL_CRED_FILE"},
|
EnvVars: []string{"TUNNEL_CRED_FILE"},
|
||||||
})
|
})
|
||||||
|
@ -141,30 +140,21 @@ func tunnelFilePath(tunnelID uuid.UUID, directory string) (string, error) {
|
||||||
return homedir.Expand(filePath)
|
return homedir.Expand(filePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeTunnelCredentials(tunnelID uuid.UUID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
|
func writeTunnelCredentials(
|
||||||
|
originCertPath string,
|
||||||
|
credentials *connection.Credentials,
|
||||||
|
) (filePath string, err error) {
|
||||||
originCertDir := filepath.Dir(originCertPath)
|
originCertDir := filepath.Dir(originCertPath)
|
||||||
filePath, err := tunnelFilePath(tunnelID, originCertDir)
|
filePath, err = tunnelFilePath(credentials.TunnelID, originCertDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
body, err := json.Marshal(pogs.TunnelAuth{
|
// Write the name and ID to the file too
|
||||||
AccountTag: accountID,
|
body, err := json.Marshal(credentials)
|
||||||
TunnelSecret: tunnelSecret,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Unable to marshal tunnel credentials to JSON")
|
return "", errors.Wrap(err, "Unable to marshal tunnel credentials to JSON")
|
||||||
}
|
}
|
||||||
logger.Infof("Writing tunnel credentials to %v. cloudflared chose this file based on where your origin certificate was found.", filePath)
|
return filePath, ioutil.WriteFile(filePath, body, 400)
|
||||||
logger.Infof("Keep this file secret. To revoke these credentials, delete the tunnel.")
|
|
||||||
return ioutil.WriteFile(filePath, body, 400)
|
|
||||||
}
|
|
||||||
|
|
||||||
func validFilePath(path string) bool {
|
|
||||||
fileStat, err := os.Stat(path)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return !fileStat.IsDir()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildListCommand() *cli.Command {
|
func buildListCommand() *cli.Command {
|
||||||
|
|
|
@ -22,9 +22,23 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type NamedTunnelConfig struct {
|
type NamedTunnelConfig struct {
|
||||||
Auth pogs.TunnelAuth
|
Credentials Credentials
|
||||||
ID uuid.UUID
|
Client pogs.ClientInfo
|
||||||
Client pogs.ClientInfo
|
}
|
||||||
|
|
||||||
|
// Credentials are stored in the credentials file and contain all info needed to run a tunnel.
|
||||||
|
type Credentials struct {
|
||||||
|
AccountTag string
|
||||||
|
TunnelSecret []byte
|
||||||
|
TunnelID uuid.UUID
|
||||||
|
TunnelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Credentials) Auth() pogs.TunnelAuth {
|
||||||
|
return pogs.TunnelAuth{
|
||||||
|
AccountTag: c.AccountTag,
|
||||||
|
TunnelSecret: c.TunnelSecret,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClassicTunnelConfig struct {
|
type ClassicTunnelConfig struct {
|
||||||
|
|
|
@ -165,7 +165,7 @@ func NewProtocolSelector(protocolFlag string, namedTunnel *NamedTunnelConfig, fe
|
||||||
if protocolFlag != autoSelectFlag {
|
if protocolFlag != autoSelectFlag {
|
||||||
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||||
}
|
}
|
||||||
threshold := switchThreshold(namedTunnel.Auth.AccountTag)
|
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
|
||||||
if threshold < http2Percentage {
|
if threshold < http2Percentage {
|
||||||
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil
|
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,7 +15,7 @@ const (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
testNamedTunnelConfig = &NamedTunnelConfig{
|
testNamedTunnelConfig = &NamedTunnelConfig{
|
||||||
Auth: pogs.TunnelAuth{
|
Credentials: Credentials{
|
||||||
AccountTag: "testAccountTag",
|
AccountTag: "testAccountTag",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,8 +92,8 @@ func (rsc *registrationServerClient) RegisterConnection(
|
||||||
) error {
|
) error {
|
||||||
conn, err := rsc.client.RegisterConnection(
|
conn, err := rsc.client.RegisterConnection(
|
||||||
ctx,
|
ctx,
|
||||||
config.Auth,
|
config.Credentials.Auth(),
|
||||||
config.ID,
|
config.Credentials.TunnelID,
|
||||||
connIndex,
|
connIndex,
|
||||||
options,
|
options,
|
||||||
)
|
)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,7 +35,7 @@ func TestWaitForBackoffFallback(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
resolveTTL := time.Duration(0)
|
resolveTTL := time.Duration(0)
|
||||||
namedTunnel := &connection.NamedTunnelConfig{
|
namedTunnel := &connection.NamedTunnelConfig{
|
||||||
Auth: pogs.TunnelAuth{
|
Credentials: connection.Credentials{
|
||||||
AccountTag: "test-account",
|
AccountTag: "test-account",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue