diff --git a/cmd/cloudflared/cliutil/errors.go b/cmd/cloudflared/cliutil/errors.go index c0938349..e85e396b 100644 --- a/cmd/cloudflared/cliutil/errors.go +++ b/cmd/cloudflared/cliutil/errors.go @@ -1,12 +1,35 @@ package cliutil -import "gopkg.in/urfave/cli.v2" +import ( + "fmt" + + "gopkg.in/urfave/cli.v2" +) + +type usageError string + +func (ue usageError) Error() string { + return string(ue) +} + +func UsageError(format string, args ...interface{}) error { + if len(args) == 0 { + return usageError(format) + } else { + msg := fmt.Sprintf(format, args...) + return usageError(msg) + } +} // Ensures exit with error code if actionFunc returns an error func ErrorHandler(actionFunc cli.ActionFunc) cli.ActionFunc { return func(ctx *cli.Context) error { err := actionFunc(ctx) if err != nil { + if _, ok := err.(usageError); ok { + msg := fmt.Sprintf("%s\nSee 'cloudflared %s --help'.", err.Error(), ctx.Command.FullName()) + return cli.Exit(msg, -1) + } // os.Exits with error code if err is cli.ExitCoder or cli.MultiError cli.HandleExitCoder(err) err = cli.Exit(err.Error(), 1) @@ -14,4 +37,3 @@ func ErrorHandler(actionFunc cli.ActionFunc) cli.ActionFunc { return err } } - diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 843462ba..5a9bdb85 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -168,6 +168,9 @@ func Commands() []*cli.Command { c.Hidden = false subcommands = append(subcommands, &c) } + subcommands = append(subcommands, buildCreateCommand()) + subcommands = append(subcommands, buildListCommand()) + subcommands = append(subcommands, buildDeleteCommand()) cmds = append(cmds, &cli.Command{ Name: "tunnel", @@ -175,7 +178,7 @@ func Commands() []*cli.Command { Before: Before, Category: "Tunnel", Usage: "Make a locally-running web service accessible over the internet using Argo Tunnel.", - ArgsUsage: "[origin-url]", + ArgsUsage: " ", Description: `Argo Tunnel asks you to specify a hostname on a Cloudflare-powered domain you control and a local address. Traffic from that hostname is routed (optionally via a Cloudflare Load Balancer) to this machine and appears on the @@ -843,6 +846,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_API_CA_KEY"}, Hidden: true, }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-url", + Usage: "Base URL for Cloudflare API v4", + EnvVars: []string{"TUNNEL_API_URL"}, + Value: "https://api.cloudflare.com/client/v4", + Hidden: true, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "metrics", Value: "localhost:", diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 39772845..883d3542 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -102,27 +102,29 @@ func dnsProxyStandAlone(c *cli.Context) bool { return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world")) } -func getOriginCert(c *cli.Context) ([]byte, error) { - if c.String("origincert") == "" { +func findOriginCert(c *cli.Context) (string, error) { + originCertPath := c.String("origincert") + if originCertPath == "" { logger.Warnf("Cannot determine default origin certificate path. No file %s in %v", config.DefaultCredentialFile, config.DefaultConfigDirs) if isRunningFromTerminal() { logger.Errorf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl) - return nil, fmt.Errorf("Client didn't specify origincert path when running from terminal") + return "", fmt.Errorf("Client didn't specify origincert path when running from terminal") } else { logger.Errorf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl) - return nil, fmt.Errorf("Client didn't specify origincert path") + return "", fmt.Errorf("Client didn't specify origincert path") } } - // Check that the user has acquired a certificate using the login command - originCertPath, err := homedir.Expand(c.String("origincert")) + var err error + originCertPath, err = homedir.Expand(originCertPath) if err != nil { - logger.WithError(err).Errorf("Cannot resolve path %s", c.String("origincert")) - return nil, fmt.Errorf("Cannot resolve path %s", c.String("origincert")) + logger.WithError(err).Errorf("Cannot resolve path %s", originCertPath) + return "", fmt.Errorf("Cannot resolve path %s", originCertPath) } + // Check that the user has acquired a certificate using the login command ok, err := config.FileExists(originCertPath) if err != nil { - logger.Errorf("Cannot check if origin cert exists at path %s", c.String("origincert")) - return nil, fmt.Errorf("Cannot check if origin cert exists at path %s", c.String("origincert")) + logger.Errorf("Cannot check if origin cert exists at path %s", originCertPath) + return "", fmt.Errorf("Cannot check if origin cert exists at path %s", originCertPath) } if !ok { logger.Errorf(`Cannot find a valid certificate for your origin at the path: @@ -134,8 +136,15 @@ If you don't have a certificate signed by Cloudflare, run the command: %s login `, originCertPath, os.Args[0]) - return nil, fmt.Errorf("Cannot find a valid certificate at the path %s", originCertPath) + return "", fmt.Errorf("Cannot find a valid certificate at the path %s", originCertPath) } + + return originCertPath, nil +} + +func readOriginCert(originCertPath string) ([]byte, error) { + logger.Debugf("Reading origin cert from %s", originCertPath) + // Easier to send the certificate as []byte via RPC than decoding it at this point originCert, err := ioutil.ReadFile(originCertPath) if err != nil { @@ -145,6 +154,14 @@ If you don't have a certificate signed by Cloudflare, run the command: return originCert, nil } +func getOriginCert(c *cli.Context) ([]byte, error) { + if originCertPath, err := findOriginCert(c); err != nil { + return nil, err + } else { + return readOriginCert(originCertPath) + } +} + func prepareTunnelConfig( c *cli.Context, buildInfo *buildinfo.BuildInfo, diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go new file mode 100644 index 00000000..f9493f57 --- /dev/null +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -0,0 +1,165 @@ +package tunnel + +import ( + "encoding/json" + "fmt" + "os" + "time" + + "github.com/pkg/errors" + "gopkg.in/urfave/cli.v2" + "gopkg.in/yaml.v2" + + "github.com/cloudflare/cloudflared/certutil" + "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/tunnelstore" +) + +var ( + outputFormatFlag = &cli.StringFlag{ + Name: "output", + Aliases: []string{"o"}, + Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'", + } +) + +const hideSubcommands = true + +func buildCreateCommand() *cli.Command { + return &cli.Command{ + Name: "create", + Action: cliutil.ErrorHandler(createTunnel), + Usage: "Create a new tunnel with given name", + ArgsUsage: "TUNNEL-NAME", + Hidden: hideSubcommands, + Flags: []cli.Flag{outputFormatFlag}, + } +} + +func createTunnel(c *cli.Context) error { + if c.NArg() != 1 { + return cliutil.UsageError(`"cloudflared tunnel create" requires exactly 1 argument, the name of tunnel to create.`) + } + name := c.Args().First() + + client, err := newTunnelstoreClient(c) + if err != nil { + return err + } + + tunnel, err := client.CreateTunnel(name) + if err != nil { + return errors.Wrap(err, "Error creating a new tunnel") + } + + if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { + return renderOutput(outputFormat, &tunnel) + } + + logger.Infof("Created tunnel %s with id %s", tunnel.Name, tunnel.ID) + return nil +} + +func buildListCommand() *cli.Command { + return &cli.Command{ + Name: "list", + Action: cliutil.ErrorHandler(listTunnels), + Usage: "List existing tunnels", + ArgsUsage: " ", + Hidden: hideSubcommands, + Flags: []cli.Flag{outputFormatFlag}, + } +} + +func listTunnels(c *cli.Context) error { + client, err := newTunnelstoreClient(c) + if err != nil { + return err + } + + tunnels, err := client.ListTunnels() + if err != nil { + return errors.Wrap(err, "Error listing tunnels") + } + + if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { + return renderOutput(outputFormat, tunnels) + } + if len(tunnels) > 0 { + const listFormat = "%-40s%-40s%s\n" + fmt.Printf(listFormat, "ID", "NAME", "CREATED") + for _, t := range tunnels { + fmt.Printf(listFormat, t.ID, t.Name, t.CreatedAt.Format(time.RFC3339)) + } + } else { + fmt.Println("You have no tunnels, use 'cloudflared tunnel create' to define a new tunnel") + } + + return nil +} + +func buildDeleteCommand() *cli.Command { + return &cli.Command{ + Name: "delete", + Action: cliutil.ErrorHandler(deleteTunnel), + Usage: "Delete existing tunnel with given ID", + ArgsUsage: "TUNNEL-ID", + Hidden: hideSubcommands, + } +} + +func deleteTunnel(c *cli.Context) error { + if c.NArg() != 1 { + return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`) + } + id := c.Args().First() + + client, err := newTunnelstoreClient(c) + if err != nil { + return err + } + + if err := client.DeleteTunnel(id); err != nil { + return errors.Wrapf(err, "Error deleting tunnel %s", id) + } + + return nil +} + +func renderOutput(format string, v interface{}) error { + switch format { + case "json": + encoder := json.NewEncoder(os.Stdout) + encoder.SetIndent("", " ") + return encoder.Encode(v) + case "yaml": + return yaml.NewEncoder(os.Stdout).Encode(v) + default: + return errors.Errorf("Unknown output format '%s'", format) + } +} + +func newTunnelstoreClient(c *cli.Context) (tunnelstore.Client, error) { + originCertPath, err := findOriginCert(c) + if err != nil { + return nil, errors.Wrap(err, "Error locating origin cert") + } + + blocks, err := readOriginCert(originCertPath) + if err != nil { + return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) + } + + cert, err := certutil.DecodeOriginCert(blocks) + if err != nil { + return nil, errors.Wrap(err, "Error decoding origin cert") + } + + if cert.AccountID == "" { + return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) + } + + client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey) + + return client, nil +} diff --git a/tunnelstore/client.go b/tunnelstore/client.go new file mode 100644 index 00000000..fcd5f92c --- /dev/null +++ b/tunnelstore/client.go @@ -0,0 +1,184 @@ +package tunnelstore + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + defaultTimeout = 15 * time.Second + jsonContentType = "application/json" +) + +var ( + ErrTunnelNameConflict = errors.New("tunnel with name already exists") + ErrUnauthorized = errors.New("unauthorized") + ErrBadRequest = errors.New("incorrect request parameters") + ErrNotFound = errors.New("not found") +) + +type Tunnel struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"created_at"` +} + +type Client interface { + CreateTunnel(name string) (*Tunnel, error) + GetTunnel(id string) (*Tunnel, error) + DeleteTunnel(id string) error + ListTunnels() ([]Tunnel, error) +} + +type RESTClient struct { + baseURL string + authToken string + client http.Client +} + +var _ Client = (*RESTClient)(nil) + +func NewRESTClient(baseURL string, accountTag string, authToken string) *RESTClient { + if strings.HasSuffix(baseURL, "/") { + baseURL = baseURL[:len(baseURL)-1] + } + url := fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag) + return &RESTClient{ + baseURL: url, + authToken: authToken, + client: http.Client{ + Transport: &http.Transport{ + TLSHandshakeTimeout: defaultTimeout, + ResponseHeaderTimeout: defaultTimeout, + }, + Timeout: defaultTimeout, + }, + } +} + +type newTunnel struct { + Name string `json:"name"` +} + +func (r *RESTClient) CreateTunnel(name string) (*Tunnel, error) { + if name == "" { + return nil, errors.New("tunnel name required") + } + body, err := json.Marshal(&newTunnel{ + Name: name, + }) + if err != nil { + return nil, errors.Wrap(err, "Failed to serialize new tunnel request") + } + + resp, err := r.sendRequest("POST", "", bytes.NewBuffer(body)) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + return unmarshalTunnel(resp.Body) + case http.StatusConflict: + return nil, ErrTunnelNameConflict + } + + return nil, statusCodeToError("create", resp) +} + +func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) { + resp, err := r.sendRequest("GET", id, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return unmarshalTunnel(resp.Body) + } + + return nil, statusCodeToError("read", resp) +} + +func (r *RESTClient) DeleteTunnel(id string) error { + resp, err := r.sendRequest("DELETE", id, nil) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + return statusCodeToError("delete", resp) +} + +func (r *RESTClient) ListTunnels() ([]Tunnel, error) { + resp, err := r.sendRequest("GET", "", nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + var tunnels []Tunnel + if err := json.NewDecoder(resp.Body).Decode(&tunnels); err != nil { + return nil, errors.Wrap(err, "failed to decode response") + } + return tunnels, nil + } + + return nil, statusCodeToError("list", resp) +} + +func (r *RESTClient) resolve(target string) string { + if target != "" { + return r.baseURL + "/" + target + } + return r.baseURL +} + +func (r *RESTClient) sendRequest(method string, target string, body io.Reader) (*http.Response, error) { + url := r.resolve(target) + logrus.Debugf("%s %s", method, url) + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, errors.Wrapf(err, "can't create %s request", method) + } + if body != nil { + req.Header.Set("Content-Type", jsonContentType) + } + req.Header.Add("X-Auth-User-Service-Key", r.authToken) + return r.client.Do(req) +} + +func unmarshalTunnel(reader io.Reader) (*Tunnel, error) { + var tunnel Tunnel + if err := json.NewDecoder(reader).Decode(&tunnel); err != nil { + return nil, errors.Wrap(err, "failed to decode response") + } + return &tunnel, nil +} + +func statusCodeToError(op string, resp *http.Response) error { + switch resp.StatusCode { + case http.StatusOK: + return nil + case http.StatusBadRequest: + return ErrBadRequest + case http.StatusUnauthorized, http.StatusForbidden: + return ErrUnauthorized + case http.StatusNotFound: + return ErrNotFound + } + return errors.Errorf("API call to %s tunnel failed with status %d: %s", op, + resp.StatusCode, http.StatusText(resp.StatusCode)) +} + +