diff --git a/cmd/cloudflared/flags/flags.go b/cmd/cloudflared/flags/flags.go index 5350d22f..7c919f05 100644 --- a/cmd/cloudflared/flags/flags.go +++ b/cmd/cloudflared/flags/flags.go @@ -149,4 +149,7 @@ const ( // MetricsUpdateFreq is the command line flag to define how frequently tunnel metrics are updated MetricsUpdateFreq = "metrics-update-freq" + + // ApiURL is the command line flag used to define the base URL of the API + ApiURL = "api-url" ) diff --git a/cmd/cloudflared/tail/cmd.go b/cmd/cloudflared/tail/cmd.go index 4715a191..d7f5a429 100644 --- a/cmd/cloudflared/tail/cmd.go +++ b/cmd/cloudflared/tail/cmd.go @@ -23,9 +23,7 @@ import ( "github.com/cloudflare/cloudflared/management" ) -var ( - buildInfo *cliutil.BuildInfo -) +var buildInfo *cliutil.BuildInfo func Init(bi *cliutil.BuildInfo) { buildInfo = bi @@ -56,7 +54,7 @@ func managementTokenCommand(c *cli.Context) error { if err != nil { return err } - var tokenResponse = struct { + tokenResponse := struct { Token string `json:"token"` }{Token: token} @@ -231,7 +229,7 @@ func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) { return "", err } - client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log) + client, err := userCreds.Client(c.String(cfdflags.ApiURL), buildInfo.UserAgent(), log) if err != nil { return "", err } diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 19afccbf..535c8bea 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -131,7 +131,7 @@ var ( "hostname", "id", cfdflags.LBPool, - "api-url", + cfdflags.ApiURL, cfdflags.MetricsUpdateFreq, cfdflags.Tag, "heartbeat-interval", @@ -716,7 +716,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Hidden: true, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "api-url", + Name: cfdflags.ApiURL, Usage: "Base URL for Cloudflare API v4", EnvVars: []string{"TUNNEL_API_URL"}, Value: "https://api.cloudflare.com/client/v4", diff --git a/cmd/cloudflared/tunnel/login.go b/cmd/cloudflared/tunnel/login.go index 632e622a..a5cf7813 100644 --- a/cmd/cloudflared/tunnel/login.go +++ b/cmd/cloudflared/tunnel/login.go @@ -67,7 +67,7 @@ func login(c *cli.Context) error { path, ok, err := checkForExistingCert() if ok { - fmt.Fprintf(os.Stdout, "You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path) + log.Error().Err(err).Msgf("You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path) return nil } else if err != nil { return err @@ -78,7 +78,8 @@ func login(c *cli.Context) error { callbackStoreURL = c.String(callbackURLParamName) ) - if c.Bool(fedRAMPParamName) { + isFEDRamp := c.Bool(fedRAMPParamName) + if isFEDRamp { baseloginURL = fedBaseLoginURL callbackStoreURL = fedCallbackStoreURL } @@ -99,7 +100,23 @@ func login(c *cli.Context) error { log, ) if err != nil { - fmt.Fprintf(os.Stderr, "Failed to write the certificate due to the following error:\n%v\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", err, path) + log.Error().Err(err).Msgf("Failed to write the certificate.\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", path) + return err + } + + cert, err := credentials.DecodeOriginCert(resourceData) + if err != nil { + log.Error().Err(err).Msg("failed to decode origin certificate") + return err + } + + if isFEDRamp { + cert.Endpoint = credentials.FedEndpoint + } + + resourceData, err = cert.EncodeOriginCert() + if err != nil { + log.Error().Err(err).Msg("failed to encode origin certificate") return err } @@ -107,7 +124,7 @@ func login(c *cli.Context) error { return errors.Wrap(err, fmt.Sprintf("error writing cert to %s", path)) } - fmt.Fprintf(os.Stdout, "You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path) + log.Info().Msgf("You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path) return nil } diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 63ee6532..553cb83b 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -20,6 +20,8 @@ import ( "github.com/cloudflare/cloudflared/logger" ) +const fedRampBaseApiURL = "https://api.fed.cloudflare.com/client/v4" + type invalidJSONCredentialError struct { err error path string @@ -65,7 +67,16 @@ func (sc *subcommandContext) client() (cfapi.Client, error) { if err != nil { return nil, err } - sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log) + + var apiURL string + if cred.IsFEDEndpoint() { + sc.log.Info().Str("api-url", fedRampBaseApiURL).Msg("using fedramp base api") + apiURL = fedRampBaseApiURL + } else { + apiURL = sc.c.String(cfdflags.ApiURL) + } + + sc.tunnelstoreClient, err = cred.Client(apiURL, buildInfo.UserAgent(), sc.log) if err != nil { return nil, err } diff --git a/credentials/credentials.go b/credentials/credentials.go index 8d1d8908..f5679b25 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -9,6 +9,7 @@ import ( const ( logFieldOriginCertPath = "originCertPath" + FedEndpoint = "fed" ) type User struct { @@ -32,6 +33,10 @@ func (c User) CertPath() string { return c.certPath } +func (c User) IsFEDEndpoint() bool { + return c.cert.Endpoint == FedEndpoint +} + // Client uses the user credentials to create a Cloudflare API client func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) { if apiURL == "" { @@ -45,7 +50,6 @@ func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfa userAgent, log, ) - if err != nil { return nil, err } diff --git a/credentials/origin_cert.go b/credentials/origin_cert.go index f22ad01d..e8181b36 100644 --- a/credentials/origin_cert.go +++ b/credentials/origin_cert.go @@ -1,11 +1,13 @@ package credentials import ( + "bytes" "encoding/json" "encoding/pem" "fmt" "os" "path/filepath" + "strings" "github.com/mitchellh/go-homedir" "github.com/rs/zerolog" @@ -17,16 +19,28 @@ const ( DefaultCredentialFile = "cert.pem" ) -type namedTunnelToken struct { +type OriginCert struct { ZoneID string `json:"zoneID"` AccountID string `json:"accountID"` APIToken string `json:"apiToken"` + Endpoint string `json:"endpoint,omitempty"` } -type OriginCert struct { - ZoneID string - APIToken string - AccountID string +func (oc *OriginCert) UnmarshalJSON(data []byte) error { + var aux struct { + ZoneID string `json:"zoneID"` + AccountID string `json:"accountID"` + APIToken string `json:"apiToken"` + Endpoint string `json:"endpoint,omitempty"` + } + if err := json.Unmarshal(data, &aux); err != nil { + return fmt.Errorf("error parsing OriginCert: %v", err) + } + oc.ZoneID = aux.ZoneID + oc.AccountID = aux.AccountID + oc.APIToken = aux.APIToken + oc.Endpoint = strings.ToLower(aux.Endpoint) + return nil } // FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the @@ -41,40 +55,56 @@ func FindDefaultOriginCertPath() string { return "" } +func DecodeOriginCert(blocks []byte) (*OriginCert, error) { + return decodeOriginCert(blocks) +} + +func (cert *OriginCert) EncodeOriginCert() ([]byte, error) { + if cert == nil { + return nil, fmt.Errorf("originCert cannot be nil") + } + buffer, err := json.Marshal(cert) + if err != nil { + return nil, fmt.Errorf("originCert marshal failed: %v", err) + } + block := pem.Block{ + Type: "ARGO TUNNEL TOKEN", + Headers: map[string]string{}, + Bytes: buffer, + } + var out bytes.Buffer + err = pem.Encode(&out, &block) + if err != nil { + return nil, fmt.Errorf("pem encoding failed: %v", err) + } + return out.Bytes(), nil +} + func decodeOriginCert(blocks []byte) (*OriginCert, error) { if len(blocks) == 0 { - return nil, fmt.Errorf("Cannot decode empty certificate") + return nil, fmt.Errorf("cannot decode empty certificate") } originCert := OriginCert{} block, rest := pem.Decode(blocks) - for { - if block == nil { - break - } + for block != nil { switch block.Type { case "PRIVATE KEY", "CERTIFICATE": // this is for legacy purposes. - break case "ARGO TUNNEL TOKEN": if originCert.ZoneID != "" || originCert.APIToken != "" { - return nil, fmt.Errorf("Found multiple tokens in the certificate") + return nil, fmt.Errorf("found multiple tokens in the certificate") } // The token is a string, // Try the newer JSON format - ntt := namedTunnelToken{} - if err := json.Unmarshal(block.Bytes, &ntt); err == nil { - originCert.ZoneID = ntt.ZoneID - originCert.APIToken = ntt.APIToken - originCert.AccountID = ntt.AccountID - } + _ = json.Unmarshal(block.Bytes, &originCert) default: - return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type) + return nil, fmt.Errorf("unknown block %s in the certificate", block.Type) } block, rest = pem.Decode(rest) } if originCert.ZoneID == "" || originCert.APIToken == "" { - return nil, fmt.Errorf("Missing token in the certificate") + return nil, fmt.Errorf("missing token in the certificate") } return &originCert, nil diff --git a/credentials/origin_cert_test.go b/credentials/origin_cert_test.go index 77a473e4..7e2a90a0 100644 --- a/credentials/origin_cert_test.go +++ b/credentials/origin_cert_test.go @@ -16,27 +16,25 @@ const ( originCertFile = "cert.pem" ) -var ( - nopLog = zerolog.Nop().With().Logger() -) +var nopLog = zerolog.Nop().With().Logger() func TestLoadOriginCert(t *testing.T) { cert, err := decodeOriginCert([]byte{}) - assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err) + assert.Equal(t, fmt.Errorf("cannot decode empty certificate"), err) assert.Nil(t, cert) blocks, err := os.ReadFile("test-cert-unknown-block.pem") - assert.NoError(t, err) + require.NoError(t, err) cert, err = decodeOriginCert(blocks) - assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err) + assert.Equal(t, fmt.Errorf("unknown block RSA PRIVATE KEY in the certificate"), err) assert.Nil(t, cert) } func TestJSONArgoTunnelTokenEmpty(t *testing.T) { blocks, err := os.ReadFile("test-cert-no-token.pem") - assert.NoError(t, err) + require.NoError(t, err) cert, err := decodeOriginCert(blocks) - assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err) + assert.Equal(t, fmt.Errorf("missing token in the certificate"), err) assert.Nil(t, cert) } @@ -52,51 +50,21 @@ func TestJSONArgoTunnelToken(t *testing.T) { func CloudflareTunnelTokenTest(t *testing.T, path string) { blocks, err := os.ReadFile(path) - assert.NoError(t, err) + require.NoError(t, err) cert, err := decodeOriginCert(blocks) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, cert) assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID) key := "test-service-key" assert.Equal(t, key, cert.APIToken) } -type mockFile struct { - path string - data []byte - err error -} - -type mockFileSystem struct { - files map[string]mockFile -} - -func newMockFileSystem(files ...mockFile) *mockFileSystem { - fs := mockFileSystem{map[string]mockFile{}} - for _, f := range files { - fs.files[f.path] = f - } - return &fs -} - -func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) { - if f, ok := fs.files[path]; ok { - return f.data, f.err - } - return nil, os.ErrNotExist -} - -func (fs *mockFileSystem) ValidFilePath(path string) bool { - _, exists := fs.files[path] - return exists -} - func TestFindOriginCert_Valid(t *testing.T) { file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") require.NoError(t, err) dir := t.TempDir() certPath := path.Join(dir, originCertFile) - os.WriteFile(certPath, file, fs.ModePerm) + _ = os.WriteFile(certPath, file, fs.ModePerm) path, err := FindOriginCert(certPath, &nopLog) require.NoError(t, err) require.Equal(t, certPath, path) @@ -108,3 +76,28 @@ func TestFindOriginCert_Missing(t *testing.T) { _, err := FindOriginCert(certPath, &nopLog) require.Error(t, err) } + +func TestEncodeDecodeOriginCert(t *testing.T) { + cert := OriginCert{ + ZoneID: "zone", + AccountID: "account", + APIToken: "token", + Endpoint: "FED", + } + blocks, err := cert.EncodeOriginCert() + require.NoError(t, err) + decodedCert, err := DecodeOriginCert(blocks) + require.NoError(t, err) + assert.NotNil(t, cert) + assert.Equal(t, "zone", decodedCert.ZoneID) + assert.Equal(t, "account", decodedCert.AccountID) + assert.Equal(t, "token", decodedCert.APIToken) + assert.Equal(t, FedEndpoint, decodedCert.Endpoint) +} + +func TestEncodeDecodeNilOriginCert(t *testing.T) { + var cert *OriginCert + blocks, err := cert.EncodeOriginCert() + assert.Equal(t, fmt.Errorf("originCert cannot be nil"), err) + require.Nil(t, blocks) +} diff --git a/credentials/test-cert-unknown-block.pem b/credentials/test-cert-unknown-block.pem index 4a847eb0..86fd4a40 100644 --- a/credentials/test-cert-unknown-block.pem +++ b/credentials/test-cert-unknown-block.pem @@ -87,3 +87,4 @@ M2i4QoOFcSKIG+v4SuvgEJHgG8vGvxh2qlSxnMWuPV+7/1P5ATLqDj1PlKms+BNR y7sc5AT9PclkL3Y9MNzOu0LXyBkGYcl8M0EQfLv9VPbWT+NXiMg/O2CHiT02pAAz uQicoQq3yzeQh20wtrtaXzTNmA== -----END RSA PRIVATE KEY----- + diff --git a/token/transfer.go b/token/transfer.go index 9b035537..fd5d80ed 100644 --- a/token/transfer.go +++ b/token/transfer.go @@ -70,7 +70,6 @@ func RunTransfer(transferURL *url.URL, appAUD, resourceName, key, value string, } return resourceData, nil - } // BuildRequestURL creates a request suitable for a resource transfer.