TUN-7134: Acquire token for cloudflared tail

cloudflared tail will now fetch the management token from by making
a request to the Cloudflare API using the cert.pem (acquired from
cloudflared login).

Refactored some of the credentials code into it's own package as
to allow for easier use between subcommands outside of
`cloudflared tunnel`.
This commit is contained in:
Devin Carr 2023-04-12 09:43:38 -07:00
parent 8dc0697a8f
commit b89c092c1b
21 changed files with 497 additions and 250 deletions

View File

@ -1,58 +0,0 @@
package certutil
import (
"encoding/json"
"encoding/pem"
"fmt"
)
type namedTunnelToken struct {
ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"`
APIToken string `json:"apiToken"`
}
type OriginCert struct {
ZoneID string
APIToken string
AccountID string
}
func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
if len(blocks) == 0 {
return nil, fmt.Errorf("Cannot decode empty certificate")
}
originCert := OriginCert{}
block, rest := pem.Decode(blocks)
for {
if block == nil {
break
}
switch block.Type {
case "PRIVATE KEY", "CERTIFICATE":
// this is for legacy purposes.
break
case "ARGO TUNNEL TOKEN":
if originCert.ZoneID != "" || originCert.APIToken != "" {
return nil, fmt.Errorf("Found multiple tokens in the certificate")
}
// The token is a string,
// Try the newer JSON format
ntt := namedTunnelToken{}
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
originCert.ZoneID = ntt.ZoneID
originCert.APIToken = ntt.APIToken
originCert.AccountID = ntt.AccountID
}
default:
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
}
block, rest = pem.Decode(rest)
}
if originCert.ZoneID == "" || originCert.APIToken == "" {
return nil, fmt.Errorf("Missing token in the certificate")
}
return &originCert, nil
}

View File

@ -1,51 +0,0 @@
package certutil
import (
"fmt"
"io/ioutil"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLoadOriginCert(t *testing.T) {
cert, err := DecodeOriginCert([]byte{})
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
assert.Nil(t, cert)
blocks, err := ioutil.ReadFile("test-cert-unknown-block.pem")
assert.Nil(t, err)
cert, err = DecodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
cert, err := DecodeOriginCert([]byte{})
blocks, err := ioutil.ReadFile("test-cert-no-token.pem")
assert.Nil(t, err)
cert, err = DecodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelToken(t *testing.T) {
// The given cert's Argo Tunnel Token was generated by base64 encoding this JSON:
// {
// "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19",
// "apiToken": "test-service-key",
// "accountID": "abcdabcdabcdabcd1234567890abcdef"
// }
CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem")
}
func CloudflareTunnelTokenTest(t *testing.T, path string) {
blocks, err := ioutil.ReadFile(path)
assert.Nil(t, err)
cert, err := DecodeOriginCert(blocks)
assert.Nil(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
key := "test-service-key"
assert.Equal(t, key, cert.APIToken)
}

View File

@ -8,6 +8,7 @@ type TunnelClient interface {
CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error)
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
GetTunnelToken(tunnelID uuid.UUID) (string, error)
GetManagementToken(tunnelID uuid.UUID) (string, error)
DeleteTunnel(tunnelID uuid.UUID) error
ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)
ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error)

View File

@ -50,6 +50,10 @@ type newTunnel struct {
TunnelSecret []byte `json:"tunnel_secret"`
}
type managementRequest struct {
Resources []string `json:"resources"`
}
type CleanupParams struct {
queryParams url.Values
}
@ -133,6 +137,28 @@ func (r *RESTClient) GetTunnelToken(tunnelID uuid.UUID) (token string, err error
return "", r.statusCodeToError("get tunnel token", resp)
}
func (r *RESTClient) GetManagementToken(tunnelID uuid.UUID) (token string, err error) {
endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/management", tunnelID))
body := &managementRequest{
Resources: []string{"logs"},
}
resp, err := r.sendRequest("POST", endpoint, body)
if err != nil {
return "", errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
err = parseResponse(resp.Body, &token)
return token, err
}
return "", r.statusCodeToError("get tunnel token", resp)
}
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))

View File

@ -47,3 +47,7 @@ func (bi *BuildInfo) GetBuildTypeMsg() string {
}
return fmt.Sprintf(" with %s", bi.BuildType)
}
func (bi *BuildInfo) UserAgent() string {
return fmt.Sprintf("cloudflared/%s", bi.CloudflaredVersion)
}

View File

@ -90,7 +90,7 @@ func main() {
updater.Init(Version)
tracing.Init(Version)
token.Init(Version)
tail.Init(Version)
tail.Init(bInfo)
runApp(app, graceShutdownC)
}

View File

@ -2,6 +2,7 @@ package tail
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@ -10,28 +11,32 @@ import (
"syscall"
"time"
"github.com/google/uuid"
"github.com/mattn/go-colorable"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"nhooyr.io/websocket"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/management"
)
var (
version string
buildInfo *cliutil.BuildInfo
)
func Init(v string) {
version = v
func Init(bi *cliutil.BuildInfo) {
buildInfo = bi
}
func Command() *cli.Command {
return &cli.Command{
Name: "tail",
Action: Run,
Usage: "Stream logs from a remote cloudflared",
Name: "tail",
Action: Run,
Usage: "Stream logs from a remote cloudflared",
UsageText: "cloudflared tail [tail command options] [TUNNEL-ID]",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "connector-id",
@ -75,6 +80,12 @@ func Command() *cli.Command {
Usage: "Application logging level {debug, info, warn, error, fatal}",
EnvVars: []string{"TUNNEL_LOGLEVEL"},
},
&cli.StringFlag{
Name: credentials.OriginCertFlag,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: credentials.FindDefaultOriginCertPath(),
},
},
}
}
@ -159,6 +170,59 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
}, nil
}
// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel.
func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) {
userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log)
if err != nil {
return "", err
}
client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log)
if err != nil {
return "", err
}
tunnelIDString := c.Args().First()
if tunnelIDString == "" {
return "", errors.New("no tunnel ID provided")
}
tunnelID, err := uuid.Parse(tunnelIDString)
if err != nil {
return "", errors.New("unable to parse provided tunnel id as a valid UUID")
}
token, err := client.GetManagementToken(tunnelID)
if err != nil {
return "", err
}
return token, nil
}
// buildURL will build the management url to contain the required query parameters to authenticate the request.
func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) {
var err error
managementHostname := c.String("management-hostname")
token := c.String("token")
if token == "" {
token, err = getManagementToken(c, log)
if err != nil {
return url.URL{}, fmt.Errorf("unable to acquire management token for requested tunnel id: %w", err)
}
}
query := url.Values{}
query.Add("access_token", token)
connector := c.String("connector-id")
if connector != "" {
connectorID, err := uuid.Parse(connector)
if err != nil {
return url.URL{}, fmt.Errorf("unabled to parse 'connector-id' flag into a valid UUID: %w", err)
}
query.Add("connector_id", connectorID.String())
}
return url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: query.Encode()}, nil
}
// Run implements a foreground runner
func Run(c *cli.Context) error {
log := createLogger(c)
@ -173,12 +237,14 @@ func Run(c *cli.Context) error {
return nil
}
managementHostname := c.String("management-hostname")
token := c.String("token")
u := url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: "access_token=" + token}
u, err := buildURL(c, log)
if err != nil {
log.Err(err).Msg("unable to construct management request URL")
return nil
}
header := make(http.Header)
header.Add("User-Agent", "cloudflared/"+version)
header.Add("User-Agent", buildInfo.UserAgent())
trace := c.String("trace")
if trace != "" {
header["cf-trace-id"] = []string{trace}
@ -206,6 +272,11 @@ func Run(c *cli.Context) error {
log.Error().Err(err).Msg("unable to request logs from management tunnel")
return nil
}
log.Debug().
Str("tunnel-id", c.Args().First()).
Str("connector-id", c.String("connector-id")).
Interface("filters", filters).
Msg("connected")
readerDone := make(chan struct{})

View File

@ -28,6 +28,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/features"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
@ -751,10 +752,10 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
},
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origincert",
Name: credentials.OriginCertFlag,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: findDefaultOriginCertPath(),
Value: credentials.FindDefaultOriginCertPath(),
Hidden: shouldHide,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{

View File

@ -3,17 +3,14 @@ package tunnel
import (
"crypto/tls"
"fmt"
"io/ioutil"
mathRand "math/rand"
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
@ -33,7 +30,6 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const LogFieldOriginCertPath = "originCertPath"
const secretValue = "*****"
var (
@ -46,18 +42,6 @@ var (
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
)
// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories
// contains a cert.pem file, return empty string
func findDefaultOriginCertPath() string {
for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, config.DefaultCredentialFile))
if ok, _ := config.FileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
func generateRandomClientID(log *zerolog.Logger) (string, error) {
u, err := uuid.NewRandom()
if err != nil {
@ -128,62 +112,6 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope
namedTunnel != nil) // named tunnel
}
func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
if originCertPath == "" {
log.Info().Msgf("Cannot determine default origin certificate path. No file %s in %v", config.DefaultCredentialFile, config.DefaultConfigSearchDirectories())
if isRunningFromTerminal() {
log.Error().Msgf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl)
return "", fmt.Errorf("client didn't specify origincert path when running from terminal")
} else {
log.Error().Msgf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl)
return "", fmt.Errorf("client didn't specify origincert path")
}
}
var err error
originCertPath, err = homedir.Expand(originCertPath)
if err != nil {
log.Err(err).Msgf("Cannot resolve origin certificate path")
return "", fmt.Errorf("cannot resolve path %s", originCertPath)
}
// Check that the user has acquired a certificate using the login command
ok, err := config.FileExists(originCertPath)
if err != nil {
log.Error().Err(err).Msgf("Cannot check if origin cert exists at path %s", originCertPath)
return "", fmt.Errorf("cannot check if origin cert exists at path %s", originCertPath)
}
if !ok {
log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
%s login
`, originCertPath, os.Args[0])
return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
}
return originCertPath, nil
}
func readOriginCert(originCertPath string) ([]byte, error) {
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
}
return originCert, nil
}
func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) {
if originCertPath, err := findOriginCert(originCertPath, log); err != nil {
return nil, err
} else {
return readOriginCert(originCertPath)
}
}
func prepareTunnelConfig(
c *cli.Context,
info *cliutil.BuildInfo,

View File

@ -5,6 +5,7 @@ import (
"path/filepath"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/credentials"
"github.com/google/uuid"
"github.com/rs/zerolog"
@ -56,13 +57,13 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys
}
func (s searchByID) Path() (string, error) {
originCertPath := s.c.String("origincert")
originCertPath := s.c.String(credentials.OriginCertFlag)
originCertLog := s.log.With().
Str(LogFieldOriginCertPath, originCertPath).
Str("originCertPath", originCertPath).
Logger()
// Fallback to look for tunnel credentials in the origin cert directory
if originCertPath, err := findOriginCert(originCertPath, &originCertLog); err == nil {
if originCertPath, err := credentials.FindOriginCert(originCertPath, &originCertLog); err == nil {
originCertDir := filepath.Dir(originCertPath)
if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil {
if s.fs.validFilePath(filePath) {

View File

@ -14,6 +14,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/token"
)
@ -85,7 +86,7 @@ func checkForExistingCert() (string, bool, error) {
if err != nil {
return "", false, err
}
path := filepath.Join(configPath, config.DefaultCredentialFile)
path := filepath.Join(configPath, credentials.DefaultCredentialFile)
fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 {
return path, true, nil

View File

@ -13,9 +13,9 @@ import (
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger"
)
@ -37,7 +37,7 @@ type subcommandContext struct {
// These fields should be accessed using their respective Getter
tunnelstoreClient cfapi.Client
userCredential *userCredential
userCredential *credentials.User
}
func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
@ -56,65 +56,28 @@ func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
}
type userCredential struct {
cert *certutil.OriginCert
certPath string
}
func (sc *subcommandContext) client() (cfapi.Client, error) {
if sc.tunnelstoreClient != nil {
return sc.tunnelstoreClient, nil
}
credential, err := sc.credential()
cred, err := sc.credential()
if err != nil {
return nil, err
}
userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version())
client, err := cfapi.NewRESTClient(
sc.c.String("api-url"),
credential.cert.AccountID,
credential.cert.ZoneID,
credential.cert.APIToken,
userAgent,
sc.log,
)
sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log)
if err != nil {
return nil, err
}
sc.tunnelstoreClient = client
return client, nil
return sc.tunnelstoreClient, nil
}
func (sc *subcommandContext) credential() (*userCredential, error) {
func (sc *subcommandContext) credential() (*credentials.User, error) {
if sc.userCredential == nil {
originCertPath := sc.c.String("origincert")
originCertLog := sc.log.With().
Str(LogFieldOriginCertPath, originCertPath).
Logger()
originCertPath, err := findOriginCert(originCertPath, &originCertLog)
uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log)
if err != nil {
return nil, errors.Wrap(err, "Error locating origin cert")
}
blocks, err := readOriginCert(originCertPath)
if err != nil {
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
}
cert, err := certutil.DecodeOriginCert(blocks)
if err != nil {
return nil, errors.Wrap(err, "Error decoding origin cert")
}
if cert.AccountID == "" {
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
}
sc.userCredential = &userCredential{
cert: cert,
certPath: originCertPath,
return nil, err
}
sc.userCredential = uc
}
return sc.userCredential, nil
}
@ -175,13 +138,13 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
return nil, err
}
tunnelCredentials := connection.Credentials{
AccountTag: credential.cert.AccountID,
AccountTag: credential.AccountID(),
TunnelSecret: tunnelSecret,
TunnelID: tunnel.ID,
}
usedCertPath := false
if credentialsFilePath == "" {
originCertDir := filepath.Dir(credential.certPath)
originCertDir := filepath.Dir(credential.CertPath())
credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir)
if err != nil {
return nil, err

View File

@ -16,6 +16,7 @@ import (
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
)
type mockFileSystem struct {
@ -37,7 +38,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) {
log *zerolog.Logger
fs fileSystem
tunnelstoreClient cfapi.Client
userCredential *userCredential
userCredential *credentials.User
}
type args struct {
tunnelID uuid.UUID
@ -249,7 +250,7 @@ func Test_subcommandContext_Delete(t *testing.T) {
isUIEnabled bool
fs fileSystem
tunnelstoreClient *deleteMockTunnelStore
userCredential *userCredential
userCredential *credentials.User
}
type args struct {
tunnelIDs []uuid.UUID

View File

@ -39,8 +39,6 @@ var (
)
const (
DefaultCredentialFile = "cert.pem"
// BastionFlag is to enable bastion, or jump host, operation
BastionFlag = "bastion"
)

View File

@ -0,0 +1,83 @@
package credentials
import (
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/cfapi"
)
const (
logFieldOriginCertPath = "originCertPath"
)
type User struct {
cert *OriginCert
certPath string
}
func (c User) AccountID() string {
return c.cert.AccountID
}
func (c User) ZoneID() string {
return c.cert.ZoneID
}
func (c User) APIToken() string {
return c.cert.APIToken
}
func (c User) CertPath() string {
return c.certPath
}
// Client uses the user credentials to create a Cloudflare API client
func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) {
if apiURL == "" {
return nil, errors.New("An api-url was not provided for the Cloudflare API client")
}
client, err := cfapi.NewRESTClient(
apiURL,
c.cert.AccountID,
c.cert.ZoneID,
c.cert.APIToken,
userAgent,
log,
)
if err != nil {
return nil, err
}
return client, nil
}
// Read will load and read the origin cert.pem to load the user credentials
func Read(originCertPath string, log *zerolog.Logger) (*User, error) {
originCertLog := log.With().
Str(logFieldOriginCertPath, originCertPath).
Logger()
originCertPath, err := FindOriginCert(originCertPath, &originCertLog)
if err != nil {
return nil, errors.Wrap(err, "Error locating origin cert")
}
blocks, err := readOriginCert(originCertPath)
if err != nil {
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
}
cert, err := decodeOriginCert(blocks)
if err != nil {
return nil, errors.Wrap(err, "Error decoding origin cert")
}
if cert.AccountID == "" {
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
}
return &User{
cert: cert,
certPath: originCertPath,
}, nil
}

View File

@ -0,0 +1,38 @@
package credentials
import (
"io/fs"
"os"
"path"
"testing"
"github.com/stretchr/testify/require"
)
func TestCredentialsRead(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err)
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
os.WriteFile(certPath, file, fs.ModePerm)
user, err := Read(certPath, &nopLog)
require.NoError(t, err)
require.Equal(t, certPath, user.CertPath())
require.Equal(t, "test-service-key", user.APIToken())
require.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", user.ZoneID())
require.Equal(t, "abcdabcdabcdabcd1234567890abcdef", user.AccountID())
}
func TestCredentialsClient(t *testing.T) {
user := User{
certPath: "/tmp/cert.pem",
cert: &OriginCert{
ZoneID: "7b0a4d77dfb881c1a3b7d61ea9443e19",
AccountID: "abcdabcdabcdabcd1234567890abcdef",
APIToken: "test-service-key",
},
}
client, err := user.Client("example.com", "cloudflared/test", &nopLog)
require.NoError(t, err)
require.NotNil(t, client)
}

130
credentials/origin_cert.go Normal file
View File

@ -0,0 +1,130 @@
package credentials
import (
"encoding/json"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/config"
)
const (
DefaultCredentialFile = "cert.pem"
OriginCertFlag = "origincert"
)
type namedTunnelToken struct {
ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"`
APIToken string `json:"apiToken"`
}
type OriginCert struct {
ZoneID string
APIToken string
AccountID string
}
// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
// DefaultConfigSearchDirectories contains a cert.pem file, return empty string
func FindDefaultOriginCertPath() string {
for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, DefaultCredentialFile))
if ok := fileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
func decodeOriginCert(blocks []byte) (*OriginCert, error) {
if len(blocks) == 0 {
return nil, fmt.Errorf("Cannot decode empty certificate")
}
originCert := OriginCert{}
block, rest := pem.Decode(blocks)
for {
if block == nil {
break
}
switch block.Type {
case "PRIVATE KEY", "CERTIFICATE":
// this is for legacy purposes.
break
case "ARGO TUNNEL TOKEN":
if originCert.ZoneID != "" || originCert.APIToken != "" {
return nil, fmt.Errorf("Found multiple tokens in the certificate")
}
// The token is a string,
// Try the newer JSON format
ntt := namedTunnelToken{}
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
originCert.ZoneID = ntt.ZoneID
originCert.APIToken = ntt.APIToken
originCert.AccountID = ntt.AccountID
}
default:
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
}
block, rest = pem.Decode(rest)
}
if originCert.ZoneID == "" || originCert.APIToken == "" {
return nil, fmt.Errorf("Missing token in the certificate")
}
return &originCert, nil
}
func readOriginCert(originCertPath string) ([]byte, error) {
originCert, err := os.ReadFile(originCertPath)
if err != nil {
return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
}
return originCert, nil
}
// FindOriginCert will check to make sure that the certificate exists at the specified file path.
func FindOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
if originCertPath == "" {
log.Error().Msgf("Cannot determine default origin certificate path. No file %s in %v. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable", DefaultCredentialFile, config.DefaultConfigSearchDirectories())
return "", fmt.Errorf("client didn't specify origincert path")
}
var err error
originCertPath, err = homedir.Expand(originCertPath)
if err != nil {
log.Err(err).Msgf("Cannot resolve origin certificate path")
return "", fmt.Errorf("cannot resolve path %s", originCertPath)
}
// Check that the user has acquired a certificate using the login command
ok := fileExists(originCertPath)
if !ok {
log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
cloudflared login
`, originCertPath)
return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
}
return originCertPath, nil
}
// FileExists checks to see if a file exist at the provided path.
func fileExists(path string) bool {
fileStat, err := os.Stat(path)
if err != nil {
return false
}
return !fileStat.IsDir()
}

View File

@ -0,0 +1,110 @@
package credentials
import (
"fmt"
"io/fs"
"os"
"path"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
originCertFile = "cert.pem"
)
var (
nopLog = zerolog.Nop().With().Logger()
)
func TestLoadOriginCert(t *testing.T) {
cert, err := decodeOriginCert([]byte{})
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
assert.Nil(t, cert)
blocks, err := os.ReadFile("test-cert-unknown-block.pem")
assert.NoError(t, err)
cert, err = decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
blocks, err := os.ReadFile("test-cert-no-token.pem")
assert.NoError(t, err)
cert, err := decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelToken(t *testing.T) {
// The given cert's Argo Tunnel Token was generated by base64 encoding this JSON:
// {
// "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19",
// "apiToken": "test-service-key",
// "accountID": "abcdabcdabcdabcd1234567890abcdef"
// }
CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem")
}
func CloudflareTunnelTokenTest(t *testing.T, path string) {
blocks, err := os.ReadFile(path)
assert.NoError(t, err)
cert, err := decodeOriginCert(blocks)
assert.NoError(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
key := "test-service-key"
assert.Equal(t, key, cert.APIToken)
}
type mockFile struct {
path string
data []byte
err error
}
type mockFileSystem struct {
files map[string]mockFile
}
func newMockFileSystem(files ...mockFile) *mockFileSystem {
fs := mockFileSystem{map[string]mockFile{}}
for _, f := range files {
fs.files[f.path] = f
}
return &fs
}
func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) {
if f, ok := fs.files[path]; ok {
return f.data, f.err
}
return nil, os.ErrNotExist
}
func (fs *mockFileSystem) ValidFilePath(path string) bool {
_, exists := fs.files[path]
return exists
}
func TestFindOriginCert_Valid(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err)
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
os.WriteFile(certPath, file, fs.ModePerm)
path, err := FindOriginCert(certPath, &nopLog)
require.NoError(t, err)
require.Equal(t, certPath, path)
}
func TestFindOriginCert_Missing(t *testing.T) {
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
_, err := FindOriginCert(certPath, &nopLog)
require.Error(t, err)
}