diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index b05b0c80..dbb949b2 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -174,6 +174,7 @@ func Commands() []*cli.Command { subcommands = append(subcommands, buildDeleteCommand()) subcommands = append(subcommands, buildRunCommand()) subcommands = append(subcommands, buildCleanupCommand()) + subcommands = append(subcommands, buildRouteCommand()) cmds = append(cmds, &cli.Command{ Name: "tunnel", diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 75bbac31..51cde9ea 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -58,7 +58,7 @@ var ( forceDeleteFlag = &cli.BoolFlag{ Name: "force", Aliases: []string{"f"}, - Usage: "Allows you to delete a tunnel, even if it has active connections.", + Usage: "Allows you to delete a tunnel, even if it has active connections.", } ) @@ -131,13 +131,13 @@ func createTunnel(c *cli.Context) error { return nil } -func tunnelFilePath(tunnelID, directory string) (string, error) { +func tunnelFilePath(tunnelID uuid.UUID, directory string) (string, error) { fileName := fmt.Sprintf("%v.json", tunnelID) filePath := filepath.Clean(fmt.Sprintf("%s/%s", directory, fileName)) return homedir.Expand(filePath) } -func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error { +func writeTunnelCredentials(tunnelID uuid.UUID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error { originCertDir := filepath.Dir(originCertPath) filePath, err := tunnelFilePath(tunnelID, originCertDir) if err != nil { @@ -155,7 +155,7 @@ func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSe return ioutil.WriteFile(filePath, body, 400) } -func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Service) (*pogs.TunnelAuth, error) { +func readTunnelCredentials(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (*pogs.TunnelAuth, error) { filePath, err := tunnelCredentialsPath(c, tunnelID, logger) if err != nil { return nil, err @@ -172,7 +172,7 @@ func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Servic return &auth, nil } -func tunnelCredentialsPath(c *cli.Context, tunnelID string, logger logger.Service) (string, error) { +func tunnelCredentialsPath(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (string, error) { if filePath := c.String("credentials-file"); filePath != "" { if validFilePath(filePath) { return filePath, nil @@ -322,7 +322,10 @@ func deleteTunnel(c *cli.Context) error { if c.NArg() != 1 { return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`) } - id := c.Args().First() + tunnelID, err := uuid.Parse(c.Args().First()) + if err != nil { + return errors.Wrap(err, "error parsing tunnel ID") + } logger, err := logger.New() if err != nil { @@ -337,9 +340,9 @@ func deleteTunnel(c *cli.Context) error { forceFlagSet := c.Bool("force") - tunnel, err := client.GetTunnel(id) + tunnel, err := client.GetTunnel(tunnelID) if err != nil { - return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", id) + return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnelID) } // Check if tunnel DeletedAt field has already been set @@ -351,17 +354,17 @@ func deleteTunnel(c *cli.Context) error { if !forceFlagSet { return errors.New("You can not delete this tunnel because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.") } - - if err := client.CleanupConnections(id); err != nil { - return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", id) + + if err := client.CleanupConnections(tunnelID); err != nil { + return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnelID) } } - if err := client.DeleteTunnel(id); err != nil { - return errors.Wrapf(err, "Error deleting tunnel %s", id) + if err := client.DeleteTunnel(tunnelID); err != nil { + return errors.Wrapf(err, "Error deleting tunnel %s", tunnelID) } - tunnelCredentialsPath, err := tunnelCredentialsPath(c, id, logger) + tunnelCredentialsPath, err := tunnelCredentialsPath(c, tunnelID, logger) if err != nil { logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err) return nil @@ -388,7 +391,7 @@ func renderOutput(format string, v interface{}) error { } func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client { - client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger) + client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ZoneID, cert.ServiceKey, logger) return client } @@ -428,8 +431,8 @@ func runTunnel(c *cli.Context) error { if c.NArg() != 1 { return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) } - id := c.Args().First() - tunnelID, err := uuid.Parse(id) + + tunnelID, err := uuid.Parse(c.Args().First()) if err != nil { return errors.Wrap(err, "error parsing tunnel ID") } @@ -439,7 +442,7 @@ func runTunnel(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - credentials, err := readTunnelCredentials(c, id, logger) + credentials, err := readTunnelCredentials(c, tunnelID, logger) if err != nil { return err } @@ -474,12 +477,108 @@ func cleanupConnections(c *cli.Context) error { client := newTunnelstoreClient(c, cert, logger) for i := 0; i < c.NArg(); i++ { - id := c.Args().Get(i) - logger.Infof("Cleanup connection for tunnel %s", id) - if err := client.CleanupConnections(id); err != nil { - logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err) + tunnelID, err := uuid.Parse(c.Args().Get(i)) + if err != nil { + logger.Errorf("Failed to parse argument %d as tunnelID, error :%v", i, err) + continue + } + logger.Infof("Cleanup connection for tunnel %s", tunnelID) + if err := client.CleanupConnections(tunnelID); err != nil { + logger.Errorf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err) } } return nil } + +func buildRouteCommand() *cli.Command { + return &cli.Command{ + Name: "route", + Action: cliutil.ErrorHandler(routeTunnel), + Usage: "Define what hostname or load balancer can route to this tunnel", + Description: `The route defines what hostname or load balancer can route to this tunnel. + To route a hostname: cloudflared tunnel route dns + To route a load balancer: cloudflared tunnel route lb + If you don't specify a load balancer pool, we will create a new pool called tunnel:`, + ArgsUsage: "dns|lb TUNNEL-ID HOSTNAME [LB-POOL]", + Hidden: hideSubcommands, + } +} + +func routeTunnel(c *cli.Context) error { + if c.NArg() < 2 { + return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`) + } + const tunnelIDIndex = 1 + tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex)) + if err != nil { + return errors.Wrap(err, "error parsing tunnel ID") + } + + logger, err := logger.New() + if err != nil { + return errors.Wrap(err, "error setting up logger") + } + + routeType := c.Args().First() + var route tunnelstore.Route + switch routeType { + case "dns": + route, err = dnsRouteFromArg(c, tunnelID) + if err != nil { + return err + } + case "lb": + route, err = lbRouteFromArg(c, tunnelID, logger) + if err != nil { + return err + } + default: + return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType) + } + + cert, _, err := getOriginCertFromContext(c, logger) + if err != nil { + return err + } + + client := newTunnelstoreClient(c, cert, logger) + return client.RouteTunnel(tunnelID, route) +} + +func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) { + const ( + userHostnameIndex = 2 + expectArgs = 3 + ) + if c.NArg() != expectArgs { + return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg()) + } + userHostname := c.Args().Get(userHostnameIndex) + if userHostname == "" { + return nil, cliutil.UsageError("The third argument should be the hostname") + } + return tunnelstore.NewDNSRoute(userHostname), nil +} + +func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (tunnelstore.Route, error) { + const ( + lbNameIndex = 2 + lbPoolIndex = 3 + expectMinArgs = 3 + ) + if c.NArg() < expectMinArgs { + return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg()) + } + lbName := c.Args().Get(lbNameIndex) + if lbName == "" { + return nil, cliutil.UsageError("The third argument should be the load balancer name") + } + lbPool := c.Args().Get(lbPoolIndex) + if lbPool == "" { + lbPool = fmt.Sprintf("tunnel:%v", tunnelID) + logger.Infof("Generate pool name %s", lbPool) + } + + return tunnelstore.NewLBRoute(lbName, lbPool), nil +} diff --git a/cmd/cloudflared/tunnel/subcommands_test.go b/cmd/cloudflared/tunnel/subcommands_test.go index a8b4ebe3..fb5ea914 100644 --- a/cmd/cloudflared/tunnel/subcommands_test.go +++ b/cmd/cloudflared/tunnel/subcommands_test.go @@ -75,11 +75,13 @@ func Test_fmtConnections(t *testing.T) { } func TestTunnelfilePath(t *testing.T) { + tunnelID, err := uuid.Parse("f48d8918-bc23-4647-9d48-082c5b76de65") + assert.NoError(t, err) originCertDir := filepath.Dir("~/.cloudflared/cert.pem") - actual, err := tunnelFilePath("tunnel", originCertDir) + actual, err := tunnelFilePath(tunnelID, originCertDir) assert.NoError(t, err) homeDir, err := homedir.Dir() assert.NoError(t, err) - expected := fmt.Sprintf("%s/.cloudflared/tunnel.json", homeDir) + expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID) assert.Equal(t, expected, actual) } diff --git a/tunnelstore/client.go b/tunnelstore/client.go index 56077699..7cfad0ae 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -27,7 +27,7 @@ var ( ) type Tunnel struct { - ID string `json:"id"` + ID uuid.UUID `json:"id"` Name string `json:"name"` CreatedAt time.Time `json:"created_at"` DeletedAt time.Time `json:"deleted_at"` @@ -39,30 +39,98 @@ type Connection struct { ID uuid.UUID `json:"uuid"` } +// Route represents a record type that can route to a tunnel +type Route interface { + json.Marshaler + RecordType() string +} + +type DNSRoute struct { + userHostname string +} + +func NewDNSRoute(userHostname string) Route { + return &DNSRoute{ + userHostname: userHostname, + } +} + +func (dr *DNSRoute) MarshalJSON() ([]byte, error) { + s := struct { + Type string `json:"type"` + UserHostname string `json:"user_hostname"` + }{ + Type: dr.RecordType(), + UserHostname: dr.userHostname, + } + return json.Marshal(&s) +} + +func (dr *DNSRoute) RecordType() string { + return "dns" +} + +type LBRoute struct { + lbName string + lbPool string +} + +func NewLBRoute(lbName, lbPool string) Route { + return &LBRoute{ + lbName: lbName, + lbPool: lbPool, + } +} + +func (lr *LBRoute) MarshalJSON() ([]byte, error) { + s := struct { + Type string `json:"type"` + LBName string `json:"lb_name"` + LBPool string `json:"lb_pool"` + }{ + Type: lr.RecordType(), + LBName: lr.lbName, + LBPool: lr.lbPool, + } + return json.Marshal(&s) +} + +func (lr *LBRoute) RecordType() string { + return "lb" +} + type Client interface { CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) - GetTunnel(tunnelID string) (*Tunnel, error) - DeleteTunnel(tunnelID string) error + GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) + DeleteTunnel(tunnelID uuid.UUID) error ListTunnels() ([]Tunnel, error) - CleanupConnections(tunnelID string) error + CleanupConnections(tunnelID uuid.UUID) error + RouteTunnel(tunnelID uuid.UUID, route Route) error } type RESTClient struct { - baseURL string - authToken string - client http.Client - logger logger.Service + baseEndpoints *baseEndpoints + authToken string + client http.Client + logger logger.Service +} + +type baseEndpoints struct { + accountLevel string + zoneLevel string } var _ Client = (*RESTClient)(nil) -func NewRESTClient(baseURL string, accountTag string, authToken string, logger logger.Service) *RESTClient { +func NewRESTClient(baseURL string, accountTag, zoneTag string, authToken string, logger logger.Service) *RESTClient { if strings.HasSuffix(baseURL, "/") { baseURL = baseURL[:len(baseURL)-1] } - url := fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag) return &RESTClient{ - baseURL: url, + baseEndpoints: &baseEndpoints{ + accountLevel: fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag), + zoneLevel: fmt.Sprintf("%s/zones/%s/tunnels", baseURL, accountTag), + }, authToken: authToken, client: http.Client{ Transport: &http.Transport{ @@ -92,7 +160,7 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er return nil, errors.Wrap(err, "Failed to serialize new tunnel request") } - resp, err := r.sendRequest("POST", "", bytes.NewBuffer(body)) + resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, bytes.NewBuffer(body)) if err != nil { return nil, errors.Wrap(err, "REST request failed") } @@ -108,8 +176,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er return nil, statusCodeToError("create tunnel", resp) } -func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) { - resp, err := r.sendRequest("GET", tunnelID, nil) +func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) { + endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID) + resp, err := r.sendRequest("GET", endpoint, nil) if err != nil { return nil, errors.Wrap(err, "REST request failed") } @@ -122,8 +191,9 @@ func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) { return nil, statusCodeToError("get tunnel", resp) } -func (r *RESTClient) DeleteTunnel(tunnelID string) error { - resp, err := r.sendRequest("DELETE", tunnelID, nil) +func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error { + endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID) + resp, err := r.sendRequest("DELETE", endpoint, nil) if err != nil { return errors.Wrap(err, "REST request failed") } @@ -133,7 +203,7 @@ func (r *RESTClient) DeleteTunnel(tunnelID string) error { } func (r *RESTClient) ListTunnels() ([]Tunnel, error) { - resp, err := r.sendRequest("GET", "", nil) + resp, err := r.sendRequest("GET", r.baseEndpoints.accountLevel, nil) if err != nil { return nil, errors.Wrap(err, "REST request failed") } @@ -150,8 +220,9 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) { return nil, statusCodeToError("list tunnels", resp) } -func (r *RESTClient) CleanupConnections(tunnelID string) error { - resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil) +func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error { + endpoint := fmt.Sprintf("%s/%v/connections", r.baseEndpoints.accountLevel, tunnelID) + resp, err := r.sendRequest("DELETE", endpoint, nil) if err != nil { return errors.Wrap(err, "REST request failed") } @@ -160,15 +231,23 @@ func (r *RESTClient) CleanupConnections(tunnelID string) error { return statusCodeToError("cleanup connections", resp) } -func (r *RESTClient) resolve(target string) string { - if target != "" { - return r.baseURL + "/" + target +func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error { + body, err := json.Marshal(route) + if err != nil { + return errors.Wrap(err, "Failed to serialize Route") } - return r.baseURL + + endpoint := fmt.Sprintf("%s/%v/routes", r.baseEndpoints.zoneLevel, tunnelID) + resp, err := r.sendRequest("PUT", endpoint, bytes.NewBuffer(body)) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + return statusCodeToError("add route", resp) } -func (r *RESTClient) sendRequest(method string, target string, body io.Reader) (*http.Response, error) { - url := r.resolve(target) +func (r *RESTClient) sendRequest(method string, url string, body io.Reader) (*http.Response, error) { r.logger.Debugf("%s %s", method, url) req, err := http.NewRequest(method, url, body) if err != nil {