TUN-7134: Acquire token for cloudflared tail
cloudflared tail will now fetch the management token from by making a request to the Cloudflare API using the cert.pem (acquired from cloudflared login). Refactored some of the credentials code into it's own package as to allow for easier use between subcommands outside of `cloudflared tunnel`.
This commit is contained in:
parent
8dc0697a8f
commit
b89c092c1b
|
@ -1,58 +0,0 @@
|
||||||
package certutil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type namedTunnelToken struct {
|
|
||||||
ZoneID string `json:"zoneID"`
|
|
||||||
AccountID string `json:"accountID"`
|
|
||||||
APIToken string `json:"apiToken"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OriginCert struct {
|
|
||||||
ZoneID string
|
|
||||||
APIToken string
|
|
||||||
AccountID string
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
|
|
||||||
if len(blocks) == 0 {
|
|
||||||
return nil, fmt.Errorf("Cannot decode empty certificate")
|
|
||||||
}
|
|
||||||
originCert := OriginCert{}
|
|
||||||
block, rest := pem.Decode(blocks)
|
|
||||||
for {
|
|
||||||
if block == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
switch block.Type {
|
|
||||||
case "PRIVATE KEY", "CERTIFICATE":
|
|
||||||
// this is for legacy purposes.
|
|
||||||
break
|
|
||||||
case "ARGO TUNNEL TOKEN":
|
|
||||||
if originCert.ZoneID != "" || originCert.APIToken != "" {
|
|
||||||
return nil, fmt.Errorf("Found multiple tokens in the certificate")
|
|
||||||
}
|
|
||||||
// The token is a string,
|
|
||||||
// Try the newer JSON format
|
|
||||||
ntt := namedTunnelToken{}
|
|
||||||
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
|
|
||||||
originCert.ZoneID = ntt.ZoneID
|
|
||||||
originCert.APIToken = ntt.APIToken
|
|
||||||
originCert.AccountID = ntt.AccountID
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
|
|
||||||
}
|
|
||||||
block, rest = pem.Decode(rest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if originCert.ZoneID == "" || originCert.APIToken == "" {
|
|
||||||
return nil, fmt.Errorf("Missing token in the certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &originCert, nil
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
package certutil
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadOriginCert(t *testing.T) {
|
|
||||||
cert, err := DecodeOriginCert([]byte{})
|
|
||||||
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
|
|
||||||
assert.Nil(t, cert)
|
|
||||||
|
|
||||||
blocks, err := ioutil.ReadFile("test-cert-unknown-block.pem")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
cert, err = DecodeOriginCert(blocks)
|
|
||||||
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
|
|
||||||
assert.Nil(t, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
|
|
||||||
cert, err := DecodeOriginCert([]byte{})
|
|
||||||
blocks, err := ioutil.ReadFile("test-cert-no-token.pem")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
cert, err = DecodeOriginCert(blocks)
|
|
||||||
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
|
|
||||||
assert.Nil(t, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSONArgoTunnelToken(t *testing.T) {
|
|
||||||
// The given cert's Argo Tunnel Token was generated by base64 encoding this JSON:
|
|
||||||
// {
|
|
||||||
// "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19",
|
|
||||||
// "apiToken": "test-service-key",
|
|
||||||
// "accountID": "abcdabcdabcdabcd1234567890abcdef"
|
|
||||||
// }
|
|
||||||
CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem")
|
|
||||||
}
|
|
||||||
|
|
||||||
func CloudflareTunnelTokenTest(t *testing.T, path string) {
|
|
||||||
blocks, err := ioutil.ReadFile(path)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
cert, err := DecodeOriginCert(blocks)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.NotNil(t, cert)
|
|
||||||
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
|
|
||||||
key := "test-service-key"
|
|
||||||
assert.Equal(t, key, cert.APIToken)
|
|
||||||
}
|
|
|
@ -8,6 +8,7 @@ type TunnelClient interface {
|
||||||
CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error)
|
CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error)
|
||||||
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
|
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
|
||||||
GetTunnelToken(tunnelID uuid.UUID) (string, error)
|
GetTunnelToken(tunnelID uuid.UUID) (string, error)
|
||||||
|
GetManagementToken(tunnelID uuid.UUID) (string, error)
|
||||||
DeleteTunnel(tunnelID uuid.UUID) error
|
DeleteTunnel(tunnelID uuid.UUID) error
|
||||||
ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)
|
ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)
|
||||||
ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error)
|
ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error)
|
||||||
|
|
|
@ -50,6 +50,10 @@ type newTunnel struct {
|
||||||
TunnelSecret []byte `json:"tunnel_secret"`
|
TunnelSecret []byte `json:"tunnel_secret"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type managementRequest struct {
|
||||||
|
Resources []string `json:"resources"`
|
||||||
|
}
|
||||||
|
|
||||||
type CleanupParams struct {
|
type CleanupParams struct {
|
||||||
queryParams url.Values
|
queryParams url.Values
|
||||||
}
|
}
|
||||||
|
@ -133,6 +137,28 @@ func (r *RESTClient) GetTunnelToken(tunnelID uuid.UUID) (token string, err error
|
||||||
return "", r.statusCodeToError("get tunnel token", resp)
|
return "", r.statusCodeToError("get tunnel token", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *RESTClient) GetManagementToken(tunnelID uuid.UUID) (token string, err error) {
|
||||||
|
endpoint := r.baseEndpoints.accountLevel
|
||||||
|
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/management", tunnelID))
|
||||||
|
|
||||||
|
body := &managementRequest{
|
||||||
|
Resources: []string{"logs"},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := r.sendRequest("POST", endpoint, body)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.Wrap(err, "REST request failed")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
err = parseResponse(resp.Body, &token)
|
||||||
|
return token, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", r.statusCodeToError("get tunnel token", resp)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
|
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
|
||||||
endpoint := r.baseEndpoints.accountLevel
|
endpoint := r.baseEndpoints.accountLevel
|
||||||
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
|
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
|
||||||
|
|
|
@ -47,3 +47,7 @@ func (bi *BuildInfo) GetBuildTypeMsg() string {
|
||||||
}
|
}
|
||||||
return fmt.Sprintf(" with %s", bi.BuildType)
|
return fmt.Sprintf(" with %s", bi.BuildType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bi *BuildInfo) UserAgent() string {
|
||||||
|
return fmt.Sprintf("cloudflared/%s", bi.CloudflaredVersion)
|
||||||
|
}
|
||||||
|
|
|
@ -90,7 +90,7 @@ func main() {
|
||||||
updater.Init(Version)
|
updater.Init(Version)
|
||||||
tracing.Init(Version)
|
tracing.Init(Version)
|
||||||
token.Init(Version)
|
token.Init(Version)
|
||||||
tail.Init(Version)
|
tail.Init(bInfo)
|
||||||
runApp(app, graceShutdownC)
|
runApp(app, graceShutdownC)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package tail
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -10,28 +11,32 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/mattn/go-colorable"
|
"github.com/mattn/go-colorable"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"nhooyr.io/websocket"
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
|
"github.com/cloudflare/cloudflared/credentials"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/management"
|
"github.com/cloudflare/cloudflared/management"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
version string
|
buildInfo *cliutil.BuildInfo
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(v string) {
|
func Init(bi *cliutil.BuildInfo) {
|
||||||
version = v
|
buildInfo = bi
|
||||||
}
|
}
|
||||||
|
|
||||||
func Command() *cli.Command {
|
func Command() *cli.Command {
|
||||||
return &cli.Command{
|
return &cli.Command{
|
||||||
Name: "tail",
|
Name: "tail",
|
||||||
Action: Run,
|
Action: Run,
|
||||||
Usage: "Stream logs from a remote cloudflared",
|
Usage: "Stream logs from a remote cloudflared",
|
||||||
|
UsageText: "cloudflared tail [tail command options] [TUNNEL-ID]",
|
||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
Name: "connector-id",
|
Name: "connector-id",
|
||||||
|
@ -75,6 +80,12 @@ func Command() *cli.Command {
|
||||||
Usage: "Application logging level {debug, info, warn, error, fatal}",
|
Usage: "Application logging level {debug, info, warn, error, fatal}",
|
||||||
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: credentials.OriginCertFlag,
|
||||||
|
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
|
||||||
|
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||||
|
Value: credentials.FindDefaultOriginCertPath(),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -159,6 +170,59 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel.
|
||||||
|
func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) {
|
||||||
|
userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelIDString := c.Args().First()
|
||||||
|
if tunnelIDString == "" {
|
||||||
|
return "", errors.New("no tunnel ID provided")
|
||||||
|
}
|
||||||
|
tunnelID, err := uuid.Parse(tunnelIDString)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.New("unable to parse provided tunnel id as a valid UUID")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := client.GetManagementToken(tunnelID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildURL will build the management url to contain the required query parameters to authenticate the request.
|
||||||
|
func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) {
|
||||||
|
var err error
|
||||||
|
managementHostname := c.String("management-hostname")
|
||||||
|
token := c.String("token")
|
||||||
|
if token == "" {
|
||||||
|
token, err = getManagementToken(c, log)
|
||||||
|
if err != nil {
|
||||||
|
return url.URL{}, fmt.Errorf("unable to acquire management token for requested tunnel id: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
query := url.Values{}
|
||||||
|
query.Add("access_token", token)
|
||||||
|
connector := c.String("connector-id")
|
||||||
|
if connector != "" {
|
||||||
|
connectorID, err := uuid.Parse(connector)
|
||||||
|
if err != nil {
|
||||||
|
return url.URL{}, fmt.Errorf("unabled to parse 'connector-id' flag into a valid UUID: %w", err)
|
||||||
|
}
|
||||||
|
query.Add("connector_id", connectorID.String())
|
||||||
|
}
|
||||||
|
return url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: query.Encode()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Run implements a foreground runner
|
// Run implements a foreground runner
|
||||||
func Run(c *cli.Context) error {
|
func Run(c *cli.Context) error {
|
||||||
log := createLogger(c)
|
log := createLogger(c)
|
||||||
|
@ -173,12 +237,14 @@ func Run(c *cli.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
managementHostname := c.String("management-hostname")
|
u, err := buildURL(c, log)
|
||||||
token := c.String("token")
|
if err != nil {
|
||||||
u := url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: "access_token=" + token}
|
log.Err(err).Msg("unable to construct management request URL")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
header := make(http.Header)
|
header := make(http.Header)
|
||||||
header.Add("User-Agent", "cloudflared/"+version)
|
header.Add("User-Agent", buildInfo.UserAgent())
|
||||||
trace := c.String("trace")
|
trace := c.String("trace")
|
||||||
if trace != "" {
|
if trace != "" {
|
||||||
header["cf-trace-id"] = []string{trace}
|
header["cf-trace-id"] = []string{trace}
|
||||||
|
@ -206,6 +272,11 @@ func Run(c *cli.Context) error {
|
||||||
log.Error().Err(err).Msg("unable to request logs from management tunnel")
|
log.Error().Err(err).Msg("unable to request logs from management tunnel")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("tunnel-id", c.Args().First()).
|
||||||
|
Str("connector-id", c.String("connector-id")).
|
||||||
|
Interface("filters", filters).
|
||||||
|
Msg("connected")
|
||||||
|
|
||||||
readerDone := make(chan struct{})
|
readerDone := make(chan struct{})
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/credentials"
|
||||||
"github.com/cloudflare/cloudflared/features"
|
"github.com/cloudflare/cloudflared/features"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
@ -751,10 +752,10 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
},
|
},
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "origincert",
|
Name: credentials.OriginCertFlag,
|
||||||
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
|
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
|
||||||
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||||
Value: findDefaultOriginCertPath(),
|
Value: credentials.FindDefaultOriginCertPath(),
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
|
|
@ -3,17 +3,14 @@ package tunnel
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
mathRand "math/rand"
|
mathRand "math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
homedir "github.com/mitchellh/go-homedir"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
@ -33,7 +30,6 @@ import (
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const LogFieldOriginCertPath = "originCertPath"
|
|
||||||
const secretValue = "*****"
|
const secretValue = "*****"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -46,18 +42,6 @@ var (
|
||||||
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
|
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories
|
|
||||||
// contains a cert.pem file, return empty string
|
|
||||||
func findDefaultOriginCertPath() string {
|
|
||||||
for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
|
|
||||||
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, config.DefaultCredentialFile))
|
|
||||||
if ok, _ := config.FileExists(originCertPath); ok {
|
|
||||||
return originCertPath
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateRandomClientID(log *zerolog.Logger) (string, error) {
|
func generateRandomClientID(log *zerolog.Logger) (string, error) {
|
||||||
u, err := uuid.NewRandom()
|
u, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -128,62 +112,6 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope
|
||||||
namedTunnel != nil) // named tunnel
|
namedTunnel != nil) // named tunnel
|
||||||
}
|
}
|
||||||
|
|
||||||
func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
|
|
||||||
if originCertPath == "" {
|
|
||||||
log.Info().Msgf("Cannot determine default origin certificate path. No file %s in %v", config.DefaultCredentialFile, config.DefaultConfigSearchDirectories())
|
|
||||||
if isRunningFromTerminal() {
|
|
||||||
log.Error().Msgf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl)
|
|
||||||
return "", fmt.Errorf("client didn't specify origincert path when running from terminal")
|
|
||||||
} else {
|
|
||||||
log.Error().Msgf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl)
|
|
||||||
return "", fmt.Errorf("client didn't specify origincert path")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
originCertPath, err = homedir.Expand(originCertPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Err(err).Msgf("Cannot resolve origin certificate path")
|
|
||||||
return "", fmt.Errorf("cannot resolve path %s", originCertPath)
|
|
||||||
}
|
|
||||||
// Check that the user has acquired a certificate using the login command
|
|
||||||
ok, err := config.FileExists(originCertPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msgf("Cannot check if origin cert exists at path %s", originCertPath)
|
|
||||||
return "", fmt.Errorf("cannot check if origin cert exists at path %s", originCertPath)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
|
|
||||||
|
|
||||||
%s
|
|
||||||
|
|
||||||
If the path above is wrong, specify the path with the -origincert option.
|
|
||||||
If you don't have a certificate signed by Cloudflare, run the command:
|
|
||||||
|
|
||||||
%s login
|
|
||||||
`, originCertPath, os.Args[0])
|
|
||||||
return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
return originCertPath, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readOriginCert(originCertPath string) ([]byte, error) {
|
|
||||||
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
|
||||||
originCert, err := ioutil.ReadFile(originCertPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
|
|
||||||
}
|
|
||||||
return originCert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) {
|
|
||||||
if originCertPath, err := findOriginCert(originCertPath, log); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else {
|
|
||||||
return readOriginCert(originCertPath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareTunnelConfig(
|
func prepareTunnelConfig(
|
||||||
c *cli.Context,
|
c *cli.Context,
|
||||||
info *cliutil.BuildInfo,
|
info *cliutil.BuildInfo,
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/credentials"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -56,13 +57,13 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s searchByID) Path() (string, error) {
|
func (s searchByID) Path() (string, error) {
|
||||||
originCertPath := s.c.String("origincert")
|
originCertPath := s.c.String(credentials.OriginCertFlag)
|
||||||
originCertLog := s.log.With().
|
originCertLog := s.log.With().
|
||||||
Str(LogFieldOriginCertPath, originCertPath).
|
Str("originCertPath", originCertPath).
|
||||||
Logger()
|
Logger()
|
||||||
|
|
||||||
// Fallback to look for tunnel credentials in the origin cert directory
|
// Fallback to look for tunnel credentials in the origin cert directory
|
||||||
if originCertPath, err := findOriginCert(originCertPath, &originCertLog); err == nil {
|
if originCertPath, err := credentials.FindOriginCert(originCertPath, &originCertLog); err == nil {
|
||||||
originCertDir := filepath.Dir(originCertPath)
|
originCertDir := filepath.Dir(originCertPath)
|
||||||
if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil {
|
if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil {
|
||||||
if s.fs.validFilePath(filePath) {
|
if s.fs.validFilePath(filePath) {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
|
"github.com/cloudflare/cloudflared/credentials"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/token"
|
"github.com/cloudflare/cloudflared/token"
|
||||||
)
|
)
|
||||||
|
@ -85,7 +86,7 @@ func checkForExistingCert() (string, bool, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", false, err
|
return "", false, err
|
||||||
}
|
}
|
||||||
path := filepath.Join(configPath, config.DefaultCredentialFile)
|
path := filepath.Join(configPath, credentials.DefaultCredentialFile)
|
||||||
fileInfo, err := os.Stat(path)
|
fileInfo, err := os.Stat(path)
|
||||||
if err == nil && fileInfo.Size() > 0 {
|
if err == nil && fileInfo.Size() > 0 {
|
||||||
return path, true, nil
|
return path, true, nil
|
||||||
|
|
|
@ -13,9 +13,9 @@ import (
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/certutil"
|
|
||||||
"github.com/cloudflare/cloudflared/cfapi"
|
"github.com/cloudflare/cloudflared/cfapi"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/credentials"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ type subcommandContext struct {
|
||||||
|
|
||||||
// These fields should be accessed using their respective Getter
|
// These fields should be accessed using their respective Getter
|
||||||
tunnelstoreClient cfapi.Client
|
tunnelstoreClient cfapi.Client
|
||||||
userCredential *userCredential
|
userCredential *credentials.User
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
|
func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
|
||||||
|
@ -56,65 +56,28 @@ func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
|
||||||
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
|
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
|
||||||
}
|
}
|
||||||
|
|
||||||
type userCredential struct {
|
|
||||||
cert *certutil.OriginCert
|
|
||||||
certPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sc *subcommandContext) client() (cfapi.Client, error) {
|
func (sc *subcommandContext) client() (cfapi.Client, error) {
|
||||||
if sc.tunnelstoreClient != nil {
|
if sc.tunnelstoreClient != nil {
|
||||||
return sc.tunnelstoreClient, nil
|
return sc.tunnelstoreClient, nil
|
||||||
}
|
}
|
||||||
credential, err := sc.credential()
|
cred, err := sc.credential()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version())
|
sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log)
|
||||||
client, err := cfapi.NewRESTClient(
|
|
||||||
sc.c.String("api-url"),
|
|
||||||
credential.cert.AccountID,
|
|
||||||
credential.cert.ZoneID,
|
|
||||||
credential.cert.APIToken,
|
|
||||||
userAgent,
|
|
||||||
sc.log,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sc.tunnelstoreClient = client
|
return sc.tunnelstoreClient, nil
|
||||||
return client, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *subcommandContext) credential() (*userCredential, error) {
|
func (sc *subcommandContext) credential() (*credentials.User, error) {
|
||||||
if sc.userCredential == nil {
|
if sc.userCredential == nil {
|
||||||
originCertPath := sc.c.String("origincert")
|
uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log)
|
||||||
originCertLog := sc.log.With().
|
|
||||||
Str(LogFieldOriginCertPath, originCertPath).
|
|
||||||
Logger()
|
|
||||||
|
|
||||||
originCertPath, err := findOriginCert(originCertPath, &originCertLog)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error locating origin cert")
|
return nil, err
|
||||||
}
|
|
||||||
blocks, err := readOriginCert(originCertPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := certutil.DecodeOriginCert(blocks)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "Error decoding origin cert")
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert.AccountID == "" {
|
|
||||||
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.userCredential = &userCredential{
|
|
||||||
cert: cert,
|
|
||||||
certPath: originCertPath,
|
|
||||||
}
|
}
|
||||||
|
sc.userCredential = uc
|
||||||
}
|
}
|
||||||
return sc.userCredential, nil
|
return sc.userCredential, nil
|
||||||
}
|
}
|
||||||
|
@ -175,13 +138,13 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tunnelCredentials := connection.Credentials{
|
tunnelCredentials := connection.Credentials{
|
||||||
AccountTag: credential.cert.AccountID,
|
AccountTag: credential.AccountID(),
|
||||||
TunnelSecret: tunnelSecret,
|
TunnelSecret: tunnelSecret,
|
||||||
TunnelID: tunnel.ID,
|
TunnelID: tunnel.ID,
|
||||||
}
|
}
|
||||||
usedCertPath := false
|
usedCertPath := false
|
||||||
if credentialsFilePath == "" {
|
if credentialsFilePath == "" {
|
||||||
originCertDir := filepath.Dir(credential.certPath)
|
originCertDir := filepath.Dir(credential.CertPath())
|
||||||
credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir)
|
credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cfapi"
|
"github.com/cloudflare/cloudflared/cfapi"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockFileSystem struct {
|
type mockFileSystem struct {
|
||||||
|
@ -37,7 +38,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) {
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
fs fileSystem
|
fs fileSystem
|
||||||
tunnelstoreClient cfapi.Client
|
tunnelstoreClient cfapi.Client
|
||||||
userCredential *userCredential
|
userCredential *credentials.User
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
tunnelID uuid.UUID
|
tunnelID uuid.UUID
|
||||||
|
@ -249,7 +250,7 @@ func Test_subcommandContext_Delete(t *testing.T) {
|
||||||
isUIEnabled bool
|
isUIEnabled bool
|
||||||
fs fileSystem
|
fs fileSystem
|
||||||
tunnelstoreClient *deleteMockTunnelStore
|
tunnelstoreClient *deleteMockTunnelStore
|
||||||
userCredential *userCredential
|
userCredential *credentials.User
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
tunnelIDs []uuid.UUID
|
tunnelIDs []uuid.UUID
|
||||||
|
|
|
@ -39,8 +39,6 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultCredentialFile = "cert.pem"
|
|
||||||
|
|
||||||
// BastionFlag is to enable bastion, or jump host, operation
|
// BastionFlag is to enable bastion, or jump host, operation
|
||||||
BastionFlag = "bastion"
|
BastionFlag = "bastion"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cfapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
logFieldOriginCertPath = "originCertPath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
cert *OriginCert
|
||||||
|
certPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c User) AccountID() string {
|
||||||
|
return c.cert.AccountID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c User) ZoneID() string {
|
||||||
|
return c.cert.ZoneID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c User) APIToken() string {
|
||||||
|
return c.cert.APIToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c User) CertPath() string {
|
||||||
|
return c.certPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client uses the user credentials to create a Cloudflare API client
|
||||||
|
func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) {
|
||||||
|
if apiURL == "" {
|
||||||
|
return nil, errors.New("An api-url was not provided for the Cloudflare API client")
|
||||||
|
}
|
||||||
|
client, err := cfapi.NewRESTClient(
|
||||||
|
apiURL,
|
||||||
|
c.cert.AccountID,
|
||||||
|
c.cert.ZoneID,
|
||||||
|
c.cert.APIToken,
|
||||||
|
userAgent,
|
||||||
|
log,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read will load and read the origin cert.pem to load the user credentials
|
||||||
|
func Read(originCertPath string, log *zerolog.Logger) (*User, error) {
|
||||||
|
originCertLog := log.With().
|
||||||
|
Str(logFieldOriginCertPath, originCertPath).
|
||||||
|
Logger()
|
||||||
|
|
||||||
|
originCertPath, err := FindOriginCert(originCertPath, &originCertLog)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error locating origin cert")
|
||||||
|
}
|
||||||
|
blocks, err := readOriginCert(originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := decodeOriginCert(blocks)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error decoding origin cert")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cert.AccountID == "" {
|
||||||
|
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &User{
|
||||||
|
cert: cert,
|
||||||
|
certPath: originCertPath,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCredentialsRead(t *testing.T) {
|
||||||
|
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
|
||||||
|
require.NoError(t, err)
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath := path.Join(dir, originCertFile)
|
||||||
|
os.WriteFile(certPath, file, fs.ModePerm)
|
||||||
|
user, err := Read(certPath, &nopLog)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, certPath, user.CertPath())
|
||||||
|
require.Equal(t, "test-service-key", user.APIToken())
|
||||||
|
require.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", user.ZoneID())
|
||||||
|
require.Equal(t, "abcdabcdabcdabcd1234567890abcdef", user.AccountID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsClient(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
certPath: "/tmp/cert.pem",
|
||||||
|
cert: &OriginCert{
|
||||||
|
ZoneID: "7b0a4d77dfb881c1a3b7d61ea9443e19",
|
||||||
|
AccountID: "abcdabcdabcdabcd1234567890abcdef",
|
||||||
|
APIToken: "test-service-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client, err := user.Client("example.com", "cloudflared/test", &nopLog)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, client)
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/mitchellh/go-homedir"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultCredentialFile = "cert.pem"
|
||||||
|
OriginCertFlag = "origincert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type namedTunnelToken struct {
|
||||||
|
ZoneID string `json:"zoneID"`
|
||||||
|
AccountID string `json:"accountID"`
|
||||||
|
APIToken string `json:"apiToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OriginCert struct {
|
||||||
|
ZoneID string
|
||||||
|
APIToken string
|
||||||
|
AccountID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
|
||||||
|
// DefaultConfigSearchDirectories contains a cert.pem file, return empty string
|
||||||
|
func FindDefaultOriginCertPath() string {
|
||||||
|
for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
|
||||||
|
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, DefaultCredentialFile))
|
||||||
|
if ok := fileExists(originCertPath); ok {
|
||||||
|
return originCertPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeOriginCert(blocks []byte) (*OriginCert, error) {
|
||||||
|
if len(blocks) == 0 {
|
||||||
|
return nil, fmt.Errorf("Cannot decode empty certificate")
|
||||||
|
}
|
||||||
|
originCert := OriginCert{}
|
||||||
|
block, rest := pem.Decode(blocks)
|
||||||
|
for {
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
switch block.Type {
|
||||||
|
case "PRIVATE KEY", "CERTIFICATE":
|
||||||
|
// this is for legacy purposes.
|
||||||
|
break
|
||||||
|
case "ARGO TUNNEL TOKEN":
|
||||||
|
if originCert.ZoneID != "" || originCert.APIToken != "" {
|
||||||
|
return nil, fmt.Errorf("Found multiple tokens in the certificate")
|
||||||
|
}
|
||||||
|
// The token is a string,
|
||||||
|
// Try the newer JSON format
|
||||||
|
ntt := namedTunnelToken{}
|
||||||
|
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
|
||||||
|
originCert.ZoneID = ntt.ZoneID
|
||||||
|
originCert.APIToken = ntt.APIToken
|
||||||
|
originCert.AccountID = ntt.AccountID
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
|
||||||
|
}
|
||||||
|
block, rest = pem.Decode(rest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if originCert.ZoneID == "" || originCert.APIToken == "" {
|
||||||
|
return nil, fmt.Errorf("Missing token in the certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &originCert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readOriginCert(originCertPath string) ([]byte, error) {
|
||||||
|
originCert, err := os.ReadFile(originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return originCert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindOriginCert will check to make sure that the certificate exists at the specified file path.
|
||||||
|
func FindOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
|
||||||
|
if originCertPath == "" {
|
||||||
|
log.Error().Msgf("Cannot determine default origin certificate path. No file %s in %v. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable", DefaultCredentialFile, config.DefaultConfigSearchDirectories())
|
||||||
|
return "", fmt.Errorf("client didn't specify origincert path")
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
originCertPath, err = homedir.Expand(originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Msgf("Cannot resolve origin certificate path")
|
||||||
|
return "", fmt.Errorf("cannot resolve path %s", originCertPath)
|
||||||
|
}
|
||||||
|
// Check that the user has acquired a certificate using the login command
|
||||||
|
ok := fileExists(originCertPath)
|
||||||
|
if !ok {
|
||||||
|
log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
If the path above is wrong, specify the path with the -origincert option.
|
||||||
|
If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
|
|
||||||
|
cloudflared login
|
||||||
|
`, originCertPath)
|
||||||
|
return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return originCertPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileExists checks to see if a file exist at the provided path.
|
||||||
|
func fileExists(path string) bool {
|
||||||
|
fileStat, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !fileStat.IsDir()
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
originCertFile = "cert.pem"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
nopLog = zerolog.Nop().With().Logger()
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadOriginCert(t *testing.T) {
|
||||||
|
cert, err := decodeOriginCert([]byte{})
|
||||||
|
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
|
||||||
|
blocks, err := os.ReadFile("test-cert-unknown-block.pem")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cert, err = decodeOriginCert(blocks)
|
||||||
|
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
|
||||||
|
blocks, err := os.ReadFile("test-cert-no-token.pem")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cert, err := decodeOriginCert(blocks)
|
||||||
|
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONArgoTunnelToken(t *testing.T) {
|
||||||
|
// The given cert's Argo Tunnel Token was generated by base64 encoding this JSON:
|
||||||
|
// {
|
||||||
|
// "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19",
|
||||||
|
// "apiToken": "test-service-key",
|
||||||
|
// "accountID": "abcdabcdabcdabcd1234567890abcdef"
|
||||||
|
// }
|
||||||
|
CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem")
|
||||||
|
}
|
||||||
|
|
||||||
|
func CloudflareTunnelTokenTest(t *testing.T, path string) {
|
||||||
|
blocks, err := os.ReadFile(path)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
cert, err := decodeOriginCert(blocks)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, cert)
|
||||||
|
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
|
||||||
|
key := "test-service-key"
|
||||||
|
assert.Equal(t, key, cert.APIToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockFile struct {
|
||||||
|
path string
|
||||||
|
data []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockFileSystem struct {
|
||||||
|
files map[string]mockFile
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockFileSystem(files ...mockFile) *mockFileSystem {
|
||||||
|
fs := mockFileSystem{map[string]mockFile{}}
|
||||||
|
for _, f := range files {
|
||||||
|
fs.files[f.path] = f
|
||||||
|
}
|
||||||
|
return &fs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) {
|
||||||
|
if f, ok := fs.files[path]; ok {
|
||||||
|
return f.data, f.err
|
||||||
|
}
|
||||||
|
return nil, os.ErrNotExist
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *mockFileSystem) ValidFilePath(path string) bool {
|
||||||
|
_, exists := fs.files[path]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOriginCert_Valid(t *testing.T) {
|
||||||
|
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
|
||||||
|
require.NoError(t, err)
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath := path.Join(dir, originCertFile)
|
||||||
|
os.WriteFile(certPath, file, fs.ModePerm)
|
||||||
|
path, err := FindOriginCert(certPath, &nopLog)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, certPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOriginCert_Missing(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath := path.Join(dir, originCertFile)
|
||||||
|
_, err := FindOriginCert(certPath, &nopLog)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
Loading…
Reference in New Issue