diff --git a/certutil/certutil.go b/certutil/certutil.go deleted file mode 100644 index 951926bb..00000000 --- a/certutil/certutil.go +++ /dev/null @@ -1,58 +0,0 @@ -package certutil - -import ( - "encoding/json" - "encoding/pem" - "fmt" -) - -type namedTunnelToken struct { - ZoneID string `json:"zoneID"` - AccountID string `json:"accountID"` - APIToken string `json:"apiToken"` -} - -type OriginCert struct { - ZoneID string - APIToken string - AccountID string -} - -func DecodeOriginCert(blocks []byte) (*OriginCert, error) { - if len(blocks) == 0 { - return nil, fmt.Errorf("Cannot decode empty certificate") - } - originCert := OriginCert{} - block, rest := pem.Decode(blocks) - for { - if block == nil { - break - } - switch block.Type { - case "PRIVATE KEY", "CERTIFICATE": - // this is for legacy purposes. - break - case "ARGO TUNNEL TOKEN": - if originCert.ZoneID != "" || originCert.APIToken != "" { - return nil, fmt.Errorf("Found multiple tokens in the certificate") - } - // The token is a string, - // Try the newer JSON format - ntt := namedTunnelToken{} - if err := json.Unmarshal(block.Bytes, &ntt); err == nil { - originCert.ZoneID = ntt.ZoneID - originCert.APIToken = ntt.APIToken - originCert.AccountID = ntt.AccountID - } - default: - return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type) - } - block, rest = pem.Decode(rest) - } - - if originCert.ZoneID == "" || originCert.APIToken == "" { - return nil, fmt.Errorf("Missing token in the certificate") - } - - return &originCert, nil -} diff --git a/certutil/certutil_test.go b/certutil/certutil_test.go deleted file mode 100644 index e48ffcf3..00000000 --- a/certutil/certutil_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package certutil - -import ( - "fmt" - "io/ioutil" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestLoadOriginCert(t *testing.T) { - cert, err := DecodeOriginCert([]byte{}) - assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err) - assert.Nil(t, cert) - - blocks, err := ioutil.ReadFile("test-cert-unknown-block.pem") - assert.Nil(t, err) - cert, err = DecodeOriginCert(blocks) - assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err) - assert.Nil(t, cert) -} - -func TestJSONArgoTunnelTokenEmpty(t *testing.T) { - cert, err := DecodeOriginCert([]byte{}) - blocks, err := ioutil.ReadFile("test-cert-no-token.pem") - assert.Nil(t, err) - cert, err = DecodeOriginCert(blocks) - assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err) - assert.Nil(t, cert) -} - -func TestJSONArgoTunnelToken(t *testing.T) { - // The given cert's Argo Tunnel Token was generated by base64 encoding this JSON: - // { - // "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19", - // "apiToken": "test-service-key", - // "accountID": "abcdabcdabcdabcd1234567890abcdef" - // } - CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem") -} - -func CloudflareTunnelTokenTest(t *testing.T, path string) { - blocks, err := ioutil.ReadFile(path) - assert.Nil(t, err) - cert, err := DecodeOriginCert(blocks) - assert.Nil(t, err) - assert.NotNil(t, cert) - assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID) - key := "test-service-key" - assert.Equal(t, key, cert.APIToken) -} diff --git a/cfapi/client.go b/cfapi/client.go index 192d64ce..f8c2a734 100644 --- a/cfapi/client.go +++ b/cfapi/client.go @@ -8,6 +8,7 @@ type TunnelClient interface { CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) GetTunnelToken(tunnelID uuid.UUID) (string, error) + GetManagementToken(tunnelID uuid.UUID) (string, error) DeleteTunnel(tunnelID uuid.UUID) error ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) diff --git a/cfapi/tunnel.go b/cfapi/tunnel.go index 87caf230..fa6f8f33 100644 --- a/cfapi/tunnel.go +++ b/cfapi/tunnel.go @@ -50,6 +50,10 @@ type newTunnel struct { TunnelSecret []byte `json:"tunnel_secret"` } +type managementRequest struct { + Resources []string `json:"resources"` +} + type CleanupParams struct { queryParams url.Values } @@ -133,6 +137,28 @@ func (r *RESTClient) GetTunnelToken(tunnelID uuid.UUID) (token string, err error return "", r.statusCodeToError("get tunnel token", resp) } +func (r *RESTClient) GetManagementToken(tunnelID uuid.UUID) (token string, err error) { + endpoint := r.baseEndpoints.accountLevel + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/management", tunnelID)) + + body := &managementRequest{ + Resources: []string{"logs"}, + } + + resp, err := r.sendRequest("POST", endpoint, body) + if err != nil { + return "", errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + err = parseResponse(resp.Body, &token) + return token, err + } + + return "", r.statusCodeToError("get tunnel token", resp) +} + func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error { endpoint := r.baseEndpoints.accountLevel endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID)) diff --git a/cmd/cloudflared/cliutil/build_info.go b/cmd/cloudflared/cliutil/build_info.go index 4d73701e..fff4febf 100644 --- a/cmd/cloudflared/cliutil/build_info.go +++ b/cmd/cloudflared/cliutil/build_info.go @@ -47,3 +47,7 @@ func (bi *BuildInfo) GetBuildTypeMsg() string { } return fmt.Sprintf(" with %s", bi.BuildType) } + +func (bi *BuildInfo) UserAgent() string { + return fmt.Sprintf("cloudflared/%s", bi.CloudflaredVersion) +} diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index f729d55a..3ed25a62 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -90,7 +90,7 @@ func main() { updater.Init(Version) tracing.Init(Version) token.Init(Version) - tail.Init(Version) + tail.Init(bInfo) runApp(app, graceShutdownC) } diff --git a/cmd/cloudflared/tail/cmd.go b/cmd/cloudflared/tail/cmd.go index 55864e90..24d7fb65 100644 --- a/cmd/cloudflared/tail/cmd.go +++ b/cmd/cloudflared/tail/cmd.go @@ -2,6 +2,7 @@ package tail import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -10,28 +11,32 @@ import ( "syscall" "time" + "github.com/google/uuid" "github.com/mattn/go-colorable" "github.com/rs/zerolog" "github.com/urfave/cli/v2" "nhooyr.io/websocket" + "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/management" ) var ( - version string + buildInfo *cliutil.BuildInfo ) -func Init(v string) { - version = v +func Init(bi *cliutil.BuildInfo) { + buildInfo = bi } func Command() *cli.Command { return &cli.Command{ - Name: "tail", - Action: Run, - Usage: "Stream logs from a remote cloudflared", + Name: "tail", + Action: Run, + Usage: "Stream logs from a remote cloudflared", + UsageText: "cloudflared tail [tail command options] [TUNNEL-ID]", Flags: []cli.Flag{ &cli.StringFlag{ Name: "connector-id", @@ -75,6 +80,12 @@ func Command() *cli.Command { Usage: "Application logging level {debug, info, warn, error, fatal}", EnvVars: []string{"TUNNEL_LOGLEVEL"}, }, + &cli.StringFlag{ + Name: credentials.OriginCertFlag, + Usage: "Path to the certificate generated for your origin when you run cloudflared login.", + EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, + Value: credentials.FindDefaultOriginCertPath(), + }, }, } } @@ -159,6 +170,59 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) { }, nil } +// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel. +func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) { + userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log) + if err != nil { + return "", err + } + + client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log) + if err != nil { + return "", err + } + + tunnelIDString := c.Args().First() + if tunnelIDString == "" { + return "", errors.New("no tunnel ID provided") + } + tunnelID, err := uuid.Parse(tunnelIDString) + if err != nil { + return "", errors.New("unable to parse provided tunnel id as a valid UUID") + } + + token, err := client.GetManagementToken(tunnelID) + if err != nil { + return "", err + } + + return token, nil +} + +// buildURL will build the management url to contain the required query parameters to authenticate the request. +func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) { + var err error + managementHostname := c.String("management-hostname") + token := c.String("token") + if token == "" { + token, err = getManagementToken(c, log) + if err != nil { + return url.URL{}, fmt.Errorf("unable to acquire management token for requested tunnel id: %w", err) + } + } + query := url.Values{} + query.Add("access_token", token) + connector := c.String("connector-id") + if connector != "" { + connectorID, err := uuid.Parse(connector) + if err != nil { + return url.URL{}, fmt.Errorf("unabled to parse 'connector-id' flag into a valid UUID: %w", err) + } + query.Add("connector_id", connectorID.String()) + } + return url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: query.Encode()}, nil +} + // Run implements a foreground runner func Run(c *cli.Context) error { log := createLogger(c) @@ -173,12 +237,14 @@ func Run(c *cli.Context) error { return nil } - managementHostname := c.String("management-hostname") - token := c.String("token") - u := url.URL{Scheme: "wss", Host: managementHostname, Path: "/logs", RawQuery: "access_token=" + token} + u, err := buildURL(c, log) + if err != nil { + log.Err(err).Msg("unable to construct management request URL") + return nil + } header := make(http.Header) - header.Add("User-Agent", "cloudflared/"+version) + header.Add("User-Agent", buildInfo.UserAgent()) trace := c.String("trace") if trace != "" { header["cf-trace-id"] = []string{trace} @@ -206,6 +272,11 @@ func Run(c *cli.Context) error { log.Error().Err(err).Msg("unable to request logs from management tunnel") return nil } + log.Debug(). + Str("tunnel-id", c.Args().First()). + Str("connector-id", c.String("connector-id")). + Interface("filters", filters). + Msg("connected") readerDone := make(chan struct{}) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 17c177f2..0eb5b328 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -28,6 +28,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" @@ -751,10 +752,10 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag { Hidden: shouldHide, }, altsrc.NewStringFlag(&cli.StringFlag{ - Name: "origincert", + Name: credentials.OriginCertFlag, Usage: "Path to the certificate generated for your origin when you run cloudflared login.", EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, - Value: findDefaultOriginCertPath(), + Value: credentials.FindDefaultOriginCertPath(), Hidden: shouldHide, }), altsrc.NewDurationFlag(&cli.DurationFlag{ diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index adcc82ef..89b76392 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -3,17 +3,14 @@ package tunnel import ( "crypto/tls" "fmt" - "io/ioutil" mathRand "math/rand" "net" "net/netip" "os" - "path/filepath" "strings" "time" "github.com/google/uuid" - homedir "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" @@ -33,7 +30,6 @@ import ( tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -const LogFieldOriginCertPath = "originCertPath" const secretValue = "*****" var ( @@ -46,18 +42,6 @@ var ( configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"} ) -// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories -// contains a cert.pem file, return empty string -func findDefaultOriginCertPath() string { - for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() { - originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, config.DefaultCredentialFile)) - if ok, _ := config.FileExists(originCertPath); ok { - return originCertPath - } - } - return "" -} - func generateRandomClientID(log *zerolog.Logger) (string, error) { u, err := uuid.NewRandom() if err != nil { @@ -128,62 +112,6 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope namedTunnel != nil) // named tunnel } -func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) { - if originCertPath == "" { - log.Info().Msgf("Cannot determine default origin certificate path. No file %s in %v", config.DefaultCredentialFile, config.DefaultConfigSearchDirectories()) - if isRunningFromTerminal() { - log.Error().Msgf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl) - return "", fmt.Errorf("client didn't specify origincert path when running from terminal") - } else { - log.Error().Msgf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl) - return "", fmt.Errorf("client didn't specify origincert path") - } - } - var err error - originCertPath, err = homedir.Expand(originCertPath) - if err != nil { - log.Err(err).Msgf("Cannot resolve origin certificate path") - return "", fmt.Errorf("cannot resolve path %s", originCertPath) - } - // Check that the user has acquired a certificate using the login command - ok, err := config.FileExists(originCertPath) - if err != nil { - log.Error().Err(err).Msgf("Cannot check if origin cert exists at path %s", originCertPath) - return "", fmt.Errorf("cannot check if origin cert exists at path %s", originCertPath) - } - if !ok { - log.Error().Msgf(`Cannot find a valid certificate for your origin at the path: - - %s - -If the path above is wrong, specify the path with the -origincert option. -If you don't have a certificate signed by Cloudflare, run the command: - - %s login -`, originCertPath, os.Args[0]) - return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath) - } - - return originCertPath, nil -} - -func readOriginCert(originCertPath string) ([]byte, error) { - // Easier to send the certificate as []byte via RPC than decoding it at this point - originCert, err := ioutil.ReadFile(originCertPath) - if err != nil { - return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath) - } - return originCert, nil -} - -func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) { - if originCertPath, err := findOriginCert(originCertPath, log); err != nil { - return nil, err - } else { - return readOriginCert(originCertPath) - } -} - func prepareTunnelConfig( c *cli.Context, info *cliutil.BuildInfo, diff --git a/cmd/cloudflared/tunnel/credential_finder.go b/cmd/cloudflared/tunnel/credential_finder.go index a2320af4..92e05495 100644 --- a/cmd/cloudflared/tunnel/credential_finder.go +++ b/cmd/cloudflared/tunnel/credential_finder.go @@ -5,6 +5,7 @@ import ( "path/filepath" "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/credentials" "github.com/google/uuid" "github.com/rs/zerolog" @@ -56,13 +57,13 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys } func (s searchByID) Path() (string, error) { - originCertPath := s.c.String("origincert") + originCertPath := s.c.String(credentials.OriginCertFlag) originCertLog := s.log.With(). - Str(LogFieldOriginCertPath, originCertPath). + Str("originCertPath", originCertPath). Logger() // Fallback to look for tunnel credentials in the origin cert directory - if originCertPath, err := findOriginCert(originCertPath, &originCertLog); err == nil { + if originCertPath, err := credentials.FindOriginCert(originCertPath, &originCertLog); err == nil { originCertDir := filepath.Dir(originCertPath) if filePath, err := tunnelFilePath(s.id, originCertDir); err == nil { if s.fs.validFilePath(filePath) { diff --git a/cmd/cloudflared/tunnel/login.go b/cmd/cloudflared/tunnel/login.go index 8b519147..dd0f8fe9 100644 --- a/cmd/cloudflared/tunnel/login.go +++ b/cmd/cloudflared/tunnel/login.go @@ -14,6 +14,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/token" ) @@ -85,7 +86,7 @@ func checkForExistingCert() (string, bool, error) { if err != nil { return "", false, err } - path := filepath.Join(configPath, config.DefaultCredentialFile) + path := filepath.Join(configPath, credentials.DefaultCredentialFile) fileInfo, err := os.Stat(path) if err == nil && fileInfo.Size() > 0 { return path, true, nil diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index bc65aced..f49c15eb 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -13,9 +13,9 @@ import ( "github.com/rs/zerolog" "github.com/urfave/cli/v2" - "github.com/cloudflare/cloudflared/certutil" "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/logger" ) @@ -37,7 +37,7 @@ type subcommandContext struct { // These fields should be accessed using their respective Getter tunnelstoreClient cfapi.Client - userCredential *userCredential + userCredential *credentials.User } func newSubcommandContext(c *cli.Context) (*subcommandContext, error) { @@ -56,65 +56,28 @@ func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder { return newSearchByID(tunnelID, sc.c, sc.log, sc.fs) } -type userCredential struct { - cert *certutil.OriginCert - certPath string -} - func (sc *subcommandContext) client() (cfapi.Client, error) { if sc.tunnelstoreClient != nil { return sc.tunnelstoreClient, nil } - credential, err := sc.credential() + cred, err := sc.credential() if err != nil { return nil, err } - userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version()) - client, err := cfapi.NewRESTClient( - sc.c.String("api-url"), - credential.cert.AccountID, - credential.cert.ZoneID, - credential.cert.APIToken, - userAgent, - sc.log, - ) - + sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log) if err != nil { return nil, err } - sc.tunnelstoreClient = client - return client, nil + return sc.tunnelstoreClient, nil } -func (sc *subcommandContext) credential() (*userCredential, error) { +func (sc *subcommandContext) credential() (*credentials.User, error) { if sc.userCredential == nil { - originCertPath := sc.c.String("origincert") - originCertLog := sc.log.With(). - Str(LogFieldOriginCertPath, originCertPath). - Logger() - - originCertPath, err := findOriginCert(originCertPath, &originCertLog) + uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log) if err != nil { - return nil, errors.Wrap(err, "Error locating origin cert") - } - blocks, err := readOriginCert(originCertPath) - if err != nil { - return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) - } - - cert, err := certutil.DecodeOriginCert(blocks) - if err != nil { - return nil, errors.Wrap(err, "Error decoding origin cert") - } - - if cert.AccountID == "" { - return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) - } - - sc.userCredential = &userCredential{ - cert: cert, - certPath: originCertPath, + return nil, err } + sc.userCredential = uc } return sc.userCredential, nil } @@ -175,13 +138,13 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec return nil, err } tunnelCredentials := connection.Credentials{ - AccountTag: credential.cert.AccountID, + AccountTag: credential.AccountID(), TunnelSecret: tunnelSecret, TunnelID: tunnel.ID, } usedCertPath := false if credentialsFilePath == "" { - originCertDir := filepath.Dir(credential.certPath) + originCertDir := filepath.Dir(credential.CertPath()) credentialsFilePath, err = tunnelFilePath(tunnelCredentials.TunnelID, originCertDir) if err != nil { return nil, err diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go index 35cc46e7..c2293463 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_test.go +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -16,6 +16,7 @@ import ( "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/credentials" ) type mockFileSystem struct { @@ -37,7 +38,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) { log *zerolog.Logger fs fileSystem tunnelstoreClient cfapi.Client - userCredential *userCredential + userCredential *credentials.User } type args struct { tunnelID uuid.UUID @@ -249,7 +250,7 @@ func Test_subcommandContext_Delete(t *testing.T) { isUIEnabled bool fs fileSystem tunnelstoreClient *deleteMockTunnelStore - userCredential *userCredential + userCredential *credentials.User } type args struct { tunnelIDs []uuid.UUID diff --git a/config/configuration.go b/config/configuration.go index 70bf163a..73d45fbc 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -39,8 +39,6 @@ var ( ) const ( - DefaultCredentialFile = "cert.pem" - // BastionFlag is to enable bastion, or jump host, operation BastionFlag = "bastion" ) diff --git a/credentials/credentials.go b/credentials/credentials.go new file mode 100644 index 00000000..8d1d8908 --- /dev/null +++ b/credentials/credentials.go @@ -0,0 +1,83 @@ +package credentials + +import ( + "github.com/pkg/errors" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/cfapi" +) + +const ( + logFieldOriginCertPath = "originCertPath" +) + +type User struct { + cert *OriginCert + certPath string +} + +func (c User) AccountID() string { + return c.cert.AccountID +} + +func (c User) ZoneID() string { + return c.cert.ZoneID +} + +func (c User) APIToken() string { + return c.cert.APIToken +} + +func (c User) CertPath() string { + return c.certPath +} + +// Client uses the user credentials to create a Cloudflare API client +func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) { + if apiURL == "" { + return nil, errors.New("An api-url was not provided for the Cloudflare API client") + } + client, err := cfapi.NewRESTClient( + apiURL, + c.cert.AccountID, + c.cert.ZoneID, + c.cert.APIToken, + userAgent, + log, + ) + + if err != nil { + return nil, err + } + return client, nil +} + +// Read will load and read the origin cert.pem to load the user credentials +func Read(originCertPath string, log *zerolog.Logger) (*User, error) { + originCertLog := log.With(). + Str(logFieldOriginCertPath, originCertPath). + Logger() + + originCertPath, err := FindOriginCert(originCertPath, &originCertLog) + if err != nil { + return nil, errors.Wrap(err, "Error locating origin cert") + } + blocks, err := readOriginCert(originCertPath) + if err != nil { + return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) + } + + cert, err := decodeOriginCert(blocks) + if err != nil { + return nil, errors.Wrap(err, "Error decoding origin cert") + } + + if cert.AccountID == "" { + return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) + } + + return &User{ + cert: cert, + certPath: originCertPath, + }, nil +} diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go new file mode 100644 index 00000000..d9b2d7b7 --- /dev/null +++ b/credentials/credentials_test.go @@ -0,0 +1,38 @@ +package credentials + +import ( + "io/fs" + "os" + "path" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCredentialsRead(t *testing.T) { + file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") + require.NoError(t, err) + dir := t.TempDir() + certPath := path.Join(dir, originCertFile) + os.WriteFile(certPath, file, fs.ModePerm) + user, err := Read(certPath, &nopLog) + require.NoError(t, err) + require.Equal(t, certPath, user.CertPath()) + require.Equal(t, "test-service-key", user.APIToken()) + require.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", user.ZoneID()) + require.Equal(t, "abcdabcdabcdabcd1234567890abcdef", user.AccountID()) +} + +func TestCredentialsClient(t *testing.T) { + user := User{ + certPath: "/tmp/cert.pem", + cert: &OriginCert{ + ZoneID: "7b0a4d77dfb881c1a3b7d61ea9443e19", + AccountID: "abcdabcdabcdabcd1234567890abcdef", + APIToken: "test-service-key", + }, + } + client, err := user.Client("example.com", "cloudflared/test", &nopLog) + require.NoError(t, err) + require.NotNil(t, client) +} diff --git a/credentials/origin_cert.go b/credentials/origin_cert.go new file mode 100644 index 00000000..73a59fa3 --- /dev/null +++ b/credentials/origin_cert.go @@ -0,0 +1,130 @@ +package credentials + +import ( + "encoding/json" + "encoding/pem" + "fmt" + "os" + "path/filepath" + + "github.com/mitchellh/go-homedir" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/config" +) + +const ( + DefaultCredentialFile = "cert.pem" + OriginCertFlag = "origincert" +) + +type namedTunnelToken struct { + ZoneID string `json:"zoneID"` + AccountID string `json:"accountID"` + APIToken string `json:"apiToken"` +} + +type OriginCert struct { + ZoneID string + APIToken string + AccountID string +} + +// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the +// DefaultConfigSearchDirectories contains a cert.pem file, return empty string +func FindDefaultOriginCertPath() string { + for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() { + originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, DefaultCredentialFile)) + if ok := fileExists(originCertPath); ok { + return originCertPath + } + } + return "" +} + +func decodeOriginCert(blocks []byte) (*OriginCert, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("Cannot decode empty certificate") + } + originCert := OriginCert{} + block, rest := pem.Decode(blocks) + for { + if block == nil { + break + } + switch block.Type { + case "PRIVATE KEY", "CERTIFICATE": + // this is for legacy purposes. + break + case "ARGO TUNNEL TOKEN": + if originCert.ZoneID != "" || originCert.APIToken != "" { + return nil, fmt.Errorf("Found multiple tokens in the certificate") + } + // The token is a string, + // Try the newer JSON format + ntt := namedTunnelToken{} + if err := json.Unmarshal(block.Bytes, &ntt); err == nil { + originCert.ZoneID = ntt.ZoneID + originCert.APIToken = ntt.APIToken + originCert.AccountID = ntt.AccountID + } + default: + return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type) + } + block, rest = pem.Decode(rest) + } + + if originCert.ZoneID == "" || originCert.APIToken == "" { + return nil, fmt.Errorf("Missing token in the certificate") + } + + return &originCert, nil +} + +func readOriginCert(originCertPath string) ([]byte, error) { + originCert, err := os.ReadFile(originCertPath) + if err != nil { + return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath) + } + + return originCert, nil +} + +// FindOriginCert will check to make sure that the certificate exists at the specified file path. +func FindOriginCert(originCertPath string, log *zerolog.Logger) (string, error) { + if originCertPath == "" { + log.Error().Msgf("Cannot determine default origin certificate path. No file %s in %v. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable", DefaultCredentialFile, config.DefaultConfigSearchDirectories()) + return "", fmt.Errorf("client didn't specify origincert path") + } + var err error + originCertPath, err = homedir.Expand(originCertPath) + if err != nil { + log.Err(err).Msgf("Cannot resolve origin certificate path") + return "", fmt.Errorf("cannot resolve path %s", originCertPath) + } + // Check that the user has acquired a certificate using the login command + ok := fileExists(originCertPath) + if !ok { + log.Error().Msgf(`Cannot find a valid certificate for your origin at the path: + + %s + +If the path above is wrong, specify the path with the -origincert option. +If you don't have a certificate signed by Cloudflare, run the command: + + cloudflared login +`, originCertPath) + return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath) + } + + return originCertPath, nil +} + +// FileExists checks to see if a file exist at the provided path. +func fileExists(path string) bool { + fileStat, err := os.Stat(path) + if err != nil { + return false + } + return !fileStat.IsDir() +} diff --git a/credentials/origin_cert_test.go b/credentials/origin_cert_test.go new file mode 100644 index 00000000..77a473e4 --- /dev/null +++ b/credentials/origin_cert_test.go @@ -0,0 +1,110 @@ +package credentials + +import ( + "fmt" + "io/fs" + "os" + "path" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + originCertFile = "cert.pem" +) + +var ( + nopLog = zerolog.Nop().With().Logger() +) + +func TestLoadOriginCert(t *testing.T) { + cert, err := decodeOriginCert([]byte{}) + assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err) + assert.Nil(t, cert) + + blocks, err := os.ReadFile("test-cert-unknown-block.pem") + assert.NoError(t, err) + cert, err = decodeOriginCert(blocks) + assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err) + assert.Nil(t, cert) +} + +func TestJSONArgoTunnelTokenEmpty(t *testing.T) { + blocks, err := os.ReadFile("test-cert-no-token.pem") + assert.NoError(t, err) + cert, err := decodeOriginCert(blocks) + assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err) + assert.Nil(t, cert) +} + +func TestJSONArgoTunnelToken(t *testing.T) { + // The given cert's Argo Tunnel Token was generated by base64 encoding this JSON: + // { + // "zoneID": "7b0a4d77dfb881c1a3b7d61ea9443e19", + // "apiToken": "test-service-key", + // "accountID": "abcdabcdabcdabcd1234567890abcdef" + // } + CloudflareTunnelTokenTest(t, "test-cloudflare-tunnel-cert-json.pem") +} + +func CloudflareTunnelTokenTest(t *testing.T, path string) { + blocks, err := os.ReadFile(path) + assert.NoError(t, err) + cert, err := decodeOriginCert(blocks) + assert.NoError(t, err) + assert.NotNil(t, cert) + assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID) + key := "test-service-key" + assert.Equal(t, key, cert.APIToken) +} + +type mockFile struct { + path string + data []byte + err error +} + +type mockFileSystem struct { + files map[string]mockFile +} + +func newMockFileSystem(files ...mockFile) *mockFileSystem { + fs := mockFileSystem{map[string]mockFile{}} + for _, f := range files { + fs.files[f.path] = f + } + return &fs +} + +func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) { + if f, ok := fs.files[path]; ok { + return f.data, f.err + } + return nil, os.ErrNotExist +} + +func (fs *mockFileSystem) ValidFilePath(path string) bool { + _, exists := fs.files[path] + return exists +} + +func TestFindOriginCert_Valid(t *testing.T) { + file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") + require.NoError(t, err) + dir := t.TempDir() + certPath := path.Join(dir, originCertFile) + os.WriteFile(certPath, file, fs.ModePerm) + path, err := FindOriginCert(certPath, &nopLog) + require.NoError(t, err) + require.Equal(t, certPath, path) +} + +func TestFindOriginCert_Missing(t *testing.T) { + dir := t.TempDir() + certPath := path.Join(dir, originCertFile) + _, err := FindOriginCert(certPath, &nopLog) + require.Error(t, err) +} diff --git a/certutil/test-cert-no-token.pem b/credentials/test-cert-no-token.pem similarity index 100% rename from certutil/test-cert-no-token.pem rename to credentials/test-cert-no-token.pem diff --git a/certutil/test-cert-unknown-block.pem b/credentials/test-cert-unknown-block.pem similarity index 100% rename from certutil/test-cert-unknown-block.pem rename to credentials/test-cert-unknown-block.pem diff --git a/certutil/test-cloudflare-tunnel-cert-json.pem b/credentials/test-cloudflare-tunnel-cert-json.pem similarity index 100% rename from certutil/test-cloudflare-tunnel-cert-json.pem rename to credentials/test-cloudflare-tunnel-cert-json.pem