TUN-7134: Acquire token for cloudflared tail

cloudflared tail will now fetch the management token from by making
a request to the Cloudflare API using the cert.pem (acquired from
cloudflared login).

Refactored some of the credentials code into it's own package as
to allow for easier use between subcommands outside of
`cloudflared tunnel`.
This commit is contained in:
Devin Carr 2023-04-12 09:43:38 -07:00
parent 8dc0697a8f
commit b89c092c1b
21 changed files with 497 additions and 250 deletions

View File

@ -1,58 +0,0 @@
package certutil
import (
"encoding/json"
"encoding/pem"
"fmt"
)
type namedTunnelToken struct {
ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"`
APIToken string `json:"apiToken"`
}
type OriginCert struct {
ZoneID string
APIToken string
AccountID string
}
func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
if len(blocks) == 0 {
return nil, fmt.Errorf("Cannot decode empty certificate")
}
originCert := OriginCert{}
block, rest := pem.Decode(blocks)
for {
if block == nil {
break
}
switch block.Type {
case "PRIVATE KEY", "CERTIFICATE":
// this is for legacy purposes.
break
case "ARGO TUNNEL TOKEN":
if originCert.ZoneID != "" || originCert.APIToken != "" {
return nil, fmt.Errorf("Found multiple tokens in the certificate")
}
// The token is a string,
// Try the newer JSON format
ntt := namedTunnelToken{}
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
originCert.ZoneID = ntt.ZoneID
originCert.APIToken = ntt.APIToken
originCert.AccountID = ntt.AccountID
}
default:
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
}
block, rest = pem.Decode(rest)
}
if originCert.ZoneID == "" || originCert.APIToken == "" {
return nil, fmt.Errorf("Missing token in the certificate")
}
return &originCert, nil
}

View File

@ -1,51 +0,0 @@
package certutil
import (
"fmt"
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLoadOriginCert(t *testing.T) {
cert, err := DecodeOriginCert([]byte{})
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
assert.Nil(t, cert)
blocks, err := ioutil.ReadFile("test-cert-unknown-block.pem")
assert.Nil(t, err)
cert, err = DecodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
cert, err := DecodeOriginCert([]byte{})
blocks, err := ioutil.ReadFile("test-cert-no-token.pem")
assert.Nil(t, err)
cert, err = DecodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelToken(t *testing.T) {
// The given cert's Argo Tunnel Token was generated by base64 encoding this JSON:
// {
// "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19",
// "apiToken": "test-service-key",
// "accountID": "abcdabcdabcdabcd1234567890abcdef"
// }
CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem")
}
func CloudflareTunnelTokenTest(t *testing.T, path string) {
blocks, err := ioutil.ReadFile(path)
assert.Nil(t, err)
cert, err := DecodeOriginCert(blocks)
assert.Nil(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
key := "test-service-key"
assert.Equal(t, key, cert.APIToken)
}

View File

@ -8,6 +8,7 @@ type TunnelClient interface {
CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error)
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
GetTunnelToken(tunnelID uuid.UUID) (string, error) GetTunnelToken(tunnelID uuid.UUID) (string, error)
GetManagementToken(tunnelID uuid.UUID) (string, error)
DeleteTunnel(tunnelID uuid.UUID) error DeleteTunnel(tunnelID uuid.UUID) error
ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)
ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error)

View File

@ -50,6 +50,10 @@ type newTunnel struct {
TunnelSecret []byte `json:"tunnel_secret"` TunnelSecret []byte `json:"tunnel_secret"`
} }
type managementRequest struct {
Resources []string `json:"resources"`
}
type CleanupParams struct { type CleanupParams struct {
queryParams url.Values queryParams url.Values
} }
@ -133,6 +137,28 @@ func (r *RESTClient) GetTunnelToken(tunnelID uuid.UUID) (token string, err error
return "", r.statusCodeToError("get tunnel token", resp) return "", r.statusCodeToError("get tunnel token", resp)
} }
func (r *RESTClient) GetManagementToken(tunnelID uuid.UUID) (token string, err error) {
endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/management", tunnelID))
body := &managementRequest{
Resources: []string{"logs"},
}
resp, err := r.sendRequest("POST", endpoint, body)
if err != nil {
return "", errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
err = parseResponse(resp.Body, &token)
return token, err
}
return "", r.statusCodeToError("get tunnel token", resp)
}
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error { func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
endpoint := r.baseEndpoints.accountLevel endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID)) endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))

View File

@ -47,3 +47,7 @@ func (bi *BuildInfo) GetBuildTypeMsg() string {
} }
return fmt.Sprintf(" with %s", bi.BuildType) return fmt.Sprintf(" with %s", bi.BuildType)
} }
func (bi *BuildInfo) UserAgent() string {
return fmt.Sprintf("cloudflared/%s", bi.CloudflaredVersion)
}

View File

@ -90,7 +90,7 @@ func main() {
updater.Init(Version) updater.Init(Version)
tracing.Init(Version) tracing.Init(Version)
token.Init(Version) token.Init(Version)
tail.Init(Version) tail.Init(bInfo)
runApp(app, graceShutdownC) runApp(app, graceShutdownC)
} }

View File

@ -2,6 +2,7 @@ package tail
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -10,21 +11,24 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/google/uuid"
"github.com/mattn/go-colorable" "github.com/mattn/go-colorable"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"nhooyr.io/websocket" "nhooyr.io/websocket"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
) )
var ( var (
version string buildInfo *cliutil.BuildInfo
) )
func Init(v string) { func Init(bi *cliutil.BuildInfo) {
version = v buildInfo = bi
} }
func Command() *cli.Command { func Command() *cli.Command {
@ -32,6 +36,7 @@ func Command() *cli.Command {
Name: "tail", Name: "tail",
Action: Run, Action: Run,
Usage: "Stream logs from a remote cloudflared", Usage: "Stream logs from a remote cloudflared",
UsageText: "cloudflared tail [tail command options] [TUNNEL-ID]",
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.StringFlag{ &cli.StringFlag{
Name: "connector-id", Name: "connector-id",
@ -75,6 +80,12 @@ func Command() *cli.Command {
Usage: "Application logging level {debug, info, warn, error, fatal}", Usage: "Application logging level {debug, info, warn, error, fatal}",
EnvVars: []string{"TUNNEL_LOGLEVEL"}, EnvVars: []string{"TUNNEL_LOGLEVEL"},
}, },
&cli.StringFlag{
Name: credentials.OriginCertFlag,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: credentials.FindDefaultOriginCertPath(),
},
}, },
} }
} }
@ -159,6 +170,59 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
}, nil }, nil
} }
// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel.
func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) {
userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log)
if err != nil {
return "", err
}
client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log)
if err != nil {
return "", err
}
tunnelIDString := c.Args().First()
if tunnelIDString == "" {
return "", errors.New("no tunnel ID provided")
}
tunnelID, err := uuid.Parse(tunnelIDString)
if err != nil {
return "", errors.New("unable to parse provided tunnel id as a valid UUID")
}
token, err := client.GetManagementToken(tunnelID)
if err != nil {
return "", err
}
return token, nil
}
// buildURL will build the management url to contain the required query parameters to authenticate the request.
func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) {
var err error
managementHostname := c.String("management-hostname")
token := c.String("token")
if token == "" {
token, err = getManagementToken(c, log)
if err != nil {
return url.URL{}, fmt.Errorf("unable to acquire management token for requested tunnel id: %w", err)
}
}
query := url.Values{}
query.Add("access_token", token)
connector := c.String("connector-id")
if connector != "" {
connectorID, err := uuid.Parse(connector)
if err != nil {
return url.URL{}, fmt.Errorf("unabled to parse 'connector-id' flag into a valid UUID: %w", err)
}
query.Add("connector_id", connectorID.String())
}
return url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: query.Encode()}, nil
}
// Run implements a foreground runner // Run implements a foreground runner
func Run(c *cli.Context) error { func Run(c *cli.Context) error {
log := createLogger(c) log := createLogger(c)
@ -173,12 +237,14 @@ func Run(c *cli.Context) error {
return nil return nil
} }
managementHostname := c.String("management-hostname") u, err := buildURL(c, log)
token := c.String("token") if err != nil {
u := url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: "access_token=" + token} log.Err(err).Msg("unable to construct management request URL")
return nil
}
header := make(http.Header) header := make(http.Header)
header.Add("User-Agent", "cloudflared/"+version) header.Add("User-Agent", buildInfo.UserAgent())
trace := c.String("trace") trace := c.String("trace")
if trace != "" { if trace != "" {
header["cf-trace-id"] = []string{trace} header["cf-trace-id"] = []string{trace}
@ -206,6 +272,11 @@ func Run(c *cli.Context) error {
log.Error().Err(err).Msg("unable to request logs from management tunnel") log.Error().Err(err).Msg("unable to request logs from management tunnel")
return nil return nil
} }
log.Debug().
Str("tunnel-id", c.Args().First()).
Str("connector-id", c.String("connector-id")).
Interface("filters", filters).
Msg("connected")
readerDone := make(chan struct{}) readerDone := make(chan struct{})

View File

@ -28,6 +28,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/features"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
@ -751,10 +752,10 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}, },
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "origincert", Name: credentials.OriginCertFlag,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.", Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: findDefaultOriginCertPath(), Value: credentials.FindDefaultOriginCertPath(),
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{

View File

@ -3,17 +3,14 @@ package tunnel
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io/ioutil"
mathRand "math/rand" mathRand "math/rand"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -33,7 +30,6 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
const LogFieldOriginCertPath = "originCertPath"
const secretValue = "*****" const secretValue = "*****"
var ( var (
@ -46,18 +42,6 @@ var (
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"} configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
) )
// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories
// contains a cert.pem file, return empty string
func findDefaultOriginCertPath() string {
for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, config.DefaultCredentialFile))
if ok, _ := config.FileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
func generateRandomClientID(log *zerolog.Logger) (string, error) { func generateRandomClientID(log *zerolog.Logger) (string, error) {
u, err := uuid.NewRandom() u, err := uuid.NewRandom()
if err != nil { if err != nil {
@ -128,62 +112,6 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope
namedTunnel != nil) // named tunnel namedTunnel != nil) // named tunnel
} }
func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
if originCertPath == "" {
log.Info().Msgf("Cannot determine default origin certificate path. No file %s in %v", config.DefaultCredentialFile, config.DefaultConfigSearchDirectories())
if isRunningFromTerminal() {
log.Error().Msgf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl)
return "", fmt.Errorf("client didn't specify origincert path when running from terminal")
} else {
log.Error().Msgf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl)
return "", fmt.Errorf("client didn't specify origincert path")
}
}
var err error
originCertPath, err = homedir.Expand(originCertPath)
if err != nil {
log.Err(err).Msgf("Cannot resolve origin certificate path")
return "", fmt.Errorf("cannot resolve path %s", originCertPath)
}
// Check that the user has acquired a certificate using the login command
ok, err := config.FileExists(originCertPath)
if err != nil {
log.Error().Err(err).Msgf("Cannot check if origin cert exists at path %s", originCertPath)
return "", fmt.Errorf("cannot check if origin cert exists at path %s", originCertPath)
}
if !ok {
log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
%s login
`, originCertPath, os.Args[0])
return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
}
return originCertPath, nil
}
func readOriginCert(originCertPath string) ([]byte, error) {
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
}
return originCert, nil
}
func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) {
if originCertPath, err := findOriginCert(originCertPath, log); err != nil {
return nil, err
} else {
return readOriginCert(originCertPath)
}
}
func prepareTunnelConfig( func prepareTunnelConfig(
c *cli.Context, c *cli.Context,
info *cliutil.BuildInfo, info *cliutil.BuildInfo,

View File

@ -5,6 +5,7 @@ import (
"path/filepath" "path/filepath"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/credentials"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -56,13 +57,13 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys
} }
func (s searchByID) Path() (string, error) { func (s searchByID) Path() (string, error) {
originCertPath := s.c.String("origincert") originCertPath := s.c.String(credentials.OriginCertFlag)
originCertLog := s.log.With(). originCertLog := s.log.With().
Str(LogFieldOriginCertPath, originCertPath). Str("originCertPath", originCertPath).
Logger() Logger()
// Fallback to look for tunnel credentials in the origin cert directory // Fallback to look for tunnel credentials in the origin cert directory
if originCertPath, err := findOriginCert(originCertPath, &originCertLog); err == nil { if originCertPath, err := credentials.FindOriginCert(originCertPath, &originCertLog); err == nil {
originCertDir := filepath.Dir(originCertPath) originCertDir := filepath.Dir(originCertPath)
if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil { if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil {
if s.fs.validFilePath(filePath) { if s.fs.validFilePath(filePath) {

View File

@ -14,6 +14,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/token" "github.com/cloudflare/cloudflared/token"
) )
@ -85,7 +86,7 @@ func checkForExistingCert() (string, bool, error) {
if err != nil { if err != nil {
return "", false, err return "", false, err
} }
path := filepath.Join(configPath, config.DefaultCredentialFile) path := filepath.Join(configPath, credentials.DefaultCredentialFile)
fileInfo, err := os.Stat(path) fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 { if err == nil && fileInfo.Size() > 0 {
return path, true, nil return path, true, nil

View File

@ -13,9 +13,9 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
) )
@ -37,7 +37,7 @@ type subcommandContext struct {
// These fields should be accessed using their respective Getter // These fields should be accessed using their respective Getter
tunnelstoreClient cfapi.Client tunnelstoreClient cfapi.Client
userCredential *userCredential userCredential *credentials.User
} }
func newSubcommandContext(c *cli.Context) (*subcommandContext, error) { func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
@ -56,65 +56,28 @@ func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs) return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
} }
type userCredential struct {
cert *certutil.OriginCert
certPath string
}
func (sc *subcommandContext) client() (cfapi.Client, error) { func (sc *subcommandContext) client() (cfapi.Client, error) {
if sc.tunnelstoreClient != nil { if sc.tunnelstoreClient != nil {
return sc.tunnelstoreClient, nil return sc.tunnelstoreClient, nil
} }
credential, err := sc.credential() cred, err := sc.credential()
if err != nil { if err != nil {
return nil, err return nil, err
} }
userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version()) sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log)
client, err := cfapi.NewRESTClient(
sc.c.String("api-url"),
credential.cert.AccountID,
credential.cert.ZoneID,
credential.cert.APIToken,
userAgent,
sc.log,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sc.tunnelstoreClient = client return sc.tunnelstoreClient, nil
return client, nil
} }
func (sc *subcommandContext) credential() (*userCredential, error) { func (sc *subcommandContext) credential() (*credentials.User, error) {
if sc.userCredential == nil { if sc.userCredential == nil {
originCertPath := sc.c.String("origincert") uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log)
originCertLog := sc.log.With().
Str(LogFieldOriginCertPath, originCertPath).
Logger()
originCertPath, err := findOriginCert(originCertPath, &originCertLog)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error locating origin cert") return nil, err
}
blocks, err := readOriginCert(originCertPath)
if err != nil {
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
}
cert, err := certutil.DecodeOriginCert(blocks)
if err != nil {
return nil, errors.Wrap(err, "Error decoding origin cert")
}
if cert.AccountID == "" {
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
}
sc.userCredential = &userCredential{
cert: cert,
certPath: originCertPath,
} }
sc.userCredential = uc
} }
return sc.userCredential, nil return sc.userCredential, nil
} }
@ -175,13 +138,13 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
return nil, err return nil, err
} }
tunnelCredentials := connection.Credentials{ tunnelCredentials := connection.Credentials{
AccountTag: credential.cert.AccountID, AccountTag: credential.AccountID(),
TunnelSecret: tunnelSecret, TunnelSecret: tunnelSecret,
TunnelID: tunnel.ID, TunnelID: tunnel.ID,
} }
usedCertPath := false usedCertPath := false
if credentialsFilePath == "" { if credentialsFilePath == "" {
originCertDir := filepath.Dir(credential.certPath) originCertDir := filepath.Dir(credential.CertPath())
credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir) credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -16,6 +16,7 @@ import (
"github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
) )
type mockFileSystem struct { type mockFileSystem struct {
@ -37,7 +38,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) {
log *zerolog.Logger log *zerolog.Logger
fs fileSystem fs fileSystem
tunnelstoreClient cfapi.Client tunnelstoreClient cfapi.Client
userCredential *userCredential userCredential *credentials.User
} }
type args struct { type args struct {
tunnelID uuid.UUID tunnelID uuid.UUID
@ -249,7 +250,7 @@ func Test_subcommandContext_Delete(t *testing.T) {
isUIEnabled bool isUIEnabled bool
fs fileSystem fs fileSystem
tunnelstoreClient *deleteMockTunnelStore tunnelstoreClient *deleteMockTunnelStore
userCredential *userCredential userCredential *credentials.User
} }
type args struct { type args struct {
tunnelIDs []uuid.UUID tunnelIDs []uuid.UUID

View File

@ -39,8 +39,6 @@ var (
) )
const ( const (
DefaultCredentialFile = "cert.pem"
// BastionFlag is to enable bastion, or jump host, operation // BastionFlag is to enable bastion, or jump host, operation
BastionFlag = "bastion" BastionFlag = "bastion"
) )

View File

@ -0,0 +1,83 @@
package credentials
import (
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/cfapi"
)
const (
logFieldOriginCertPath = "originCertPath"
)
type User struct {
cert *OriginCert
certPath string
}
func (c User) AccountID() string {
return c.cert.AccountID
}
func (c User) ZoneID() string {
return c.cert.ZoneID
}
func (c User) APIToken() string {
return c.cert.APIToken
}
func (c User) CertPath() string {
return c.certPath
}
// Client uses the user credentials to create a Cloudflare API client
func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) {
if apiURL == "" {
return nil, errors.New("An api-url was not provided for the Cloudflare API client")
}
client, err := cfapi.NewRESTClient(
apiURL,
c.cert.AccountID,
c.cert.ZoneID,
c.cert.APIToken,
userAgent,
log,
)
if err != nil {
return nil, err
}
return client, nil
}
// Read will load and read the origin cert.pem to load the user credentials
func Read(originCertPath string, log *zerolog.Logger) (*User, error) {
originCertLog := log.With().
Str(logFieldOriginCertPath, originCertPath).
Logger()
originCertPath, err := FindOriginCert(originCertPath, &originCertLog)
if err != nil {
return nil, errors.Wrap(err, "Error locating origin cert")
}
blocks, err := readOriginCert(originCertPath)
if err != nil {
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
}
cert, err := decodeOriginCert(blocks)
if err != nil {
return nil, errors.Wrap(err, "Error decoding origin cert")
}
if cert.AccountID == "" {
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
}
return &User{
cert: cert,
certPath: originCertPath,
}, nil
}

View File

@ -0,0 +1,38 @@
package credentials
import (
"io/fs"
"os"
"path"
"testing"
"github.com/stretchr/testify/require"
)
func TestCredentialsRead(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err)
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
os.WriteFile(certPath, file, fs.ModePerm)
user, err := Read(certPath, &nopLog)
require.NoError(t, err)
require.Equal(t, certPath, user.CertPath())
require.Equal(t, "test-service-key", user.APIToken())
require.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", user.ZoneID())
require.Equal(t, "abcdabcdabcdabcd1234567890abcdef", user.AccountID())
}
func TestCredentialsClient(t *testing.T) {
user := User{
certPath: "/tmp/cert.pem",
cert: &OriginCert{
ZoneID: "7b0a4d77dfb881c1a3b7d61ea9443e19",
AccountID: "abcdabcdabcdabcd1234567890abcdef",
APIToken: "test-service-key",
},
}
client, err := user.Client("example.com", "cloudflared/test", &nopLog)
require.NoError(t, err)
require.NotNil(t, client)
}

130
credentials/origin_cert.go Normal file
View File

@ -0,0 +1,130 @@
package credentials
import (
"encoding/json"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/config"
)
const (
DefaultCredentialFile = "cert.pem"
OriginCertFlag = "origincert"
)
type namedTunnelToken struct {
ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"`
APIToken string `json:"apiToken"`
}
type OriginCert struct {
ZoneID string
APIToken string
AccountID string
}
// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
// DefaultConfigSearchDirectories contains a cert.pem file, return empty string
func FindDefaultOriginCertPath() string {
for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, DefaultCredentialFile))
if ok := fileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
func decodeOriginCert(blocks []byte) (*OriginCert, error) {
if len(blocks) == 0 {
return nil, fmt.Errorf("Cannot decode empty certificate")
}
originCert := OriginCert{}
block, rest := pem.Decode(blocks)
for {
if block == nil {
break
}
switch block.Type {
case "PRIVATE KEY", "CERTIFICATE":
// this is for legacy purposes.
break
case "ARGO TUNNEL TOKEN":
if originCert.ZoneID != "" || originCert.APIToken != "" {
return nil, fmt.Errorf("Found multiple tokens in the certificate")
}
// The token is a string,
// Try the newer JSON format
ntt := namedTunnelToken{}
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
originCert.ZoneID = ntt.ZoneID
originCert.APIToken = ntt.APIToken
originCert.AccountID = ntt.AccountID
}
default:
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
}
block, rest = pem.Decode(rest)
}
if originCert.ZoneID == "" || originCert.APIToken == "" {
return nil, fmt.Errorf("Missing token in the certificate")
}
return &originCert, nil
}
func readOriginCert(originCertPath string) ([]byte, error) {
originCert, err := os.ReadFile(originCertPath)
if err != nil {
return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
}
return originCert, nil
}
// FindOriginCert will check to make sure that the certificate exists at the specified file path.
func FindOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
if originCertPath == "" {
log.Error().Msgf("Cannot determine default origin certificate path. No file %s in %v. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable", DefaultCredentialFile, config.DefaultConfigSearchDirectories())
return "", fmt.Errorf("client didn't specify origincert path")
}
var err error
originCertPath, err = homedir.Expand(originCertPath)
if err != nil {
log.Err(err).Msgf("Cannot resolve origin certificate path")
return "", fmt.Errorf("cannot resolve path %s", originCertPath)
}
// Check that the user has acquired a certificate using the login command
ok := fileExists(originCertPath)
if !ok {
log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
cloudflared login
`, originCertPath)
return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
}
return originCertPath, nil
}
// FileExists checks to see if a file exist at the provided path.
func fileExists(path string) bool {
fileStat, err := os.Stat(path)
if err != nil {
return false
}
return !fileStat.IsDir()
}

View File

@ -0,0 +1,110 @@
package credentials
import (
"fmt"
"io/fs"
"os"
"path"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
originCertFile = "cert.pem"
)
var (
nopLog = zerolog.Nop().With().Logger()
)
func TestLoadOriginCert(t *testing.T) {
cert, err := decodeOriginCert([]byte{})
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
assert.Nil(t, cert)
blocks, err := os.ReadFile("test-cert-unknown-block.pem")
assert.NoError(t, err)
cert, err = decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
blocks, err := os.ReadFile("test-cert-no-token.pem")
assert.NoError(t, err)
cert, err := decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelToken(t *testing.T) {
// The given cert's Argo Tunnel Token was generated by base64 encoding this JSON:
// {
// "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19",
// "apiToken": "test-service-key",
// "accountID": "abcdabcdabcdabcd1234567890abcdef"
// }
CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem")
}
func CloudflareTunnelTokenTest(t *testing.T, path string) {
blocks, err := os.ReadFile(path)
assert.NoError(t, err)
cert, err := decodeOriginCert(blocks)
assert.NoError(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
key := "test-service-key"
assert.Equal(t, key, cert.APIToken)
}
type mockFile struct {
path string
data []byte
err error
}
type mockFileSystem struct {
files map[string]mockFile
}
func newMockFileSystem(files ...mockFile) *mockFileSystem {
fs := mockFileSystem{map[string]mockFile{}}
for _, f := range files {
fs.files[f.path] = f
}
return &fs
}
func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) {
if f, ok := fs.files[path]; ok {
return f.data, f.err
}
return nil, os.ErrNotExist
}
func (fs *mockFileSystem) ValidFilePath(path string) bool {
_, exists := fs.files[path]
return exists
}
func TestFindOriginCert_Valid(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err)
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
os.WriteFile(certPath, file, fs.ModePerm)
path, err := FindOriginCert(certPath, &nopLog)
require.NoError(t, err)
require.Equal(t, certPath, path)
}
func TestFindOriginCert_Missing(t *testing.T) {
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
_, err := FindOriginCert(certPath, &nopLog)
require.Error(t, err)
}