TUN-3156: Add route subcommand under tunnel

This commit is contained in:
cthuang 2020-07-06 16:01:48 +08:00
parent 7afde79600
commit 8836ee1dda
4 changed files with 230 additions and 49 deletions

View File

@ -174,6 +174,7 @@ func Commands() []*cli.Command {
subcommands = append(subcommands, buildDeleteCommand()) subcommands = append(subcommands, buildDeleteCommand())
subcommands = append(subcommands, buildRunCommand()) subcommands = append(subcommands, buildRunCommand())
subcommands = append(subcommands, buildCleanupCommand()) subcommands = append(subcommands, buildCleanupCommand())
subcommands = append(subcommands, buildRouteCommand())
cmds = append(cmds, &cli.Command{ cmds = append(cmds, &cli.Command{
Name: "tunnel", Name: "tunnel",

View File

@ -58,7 +58,7 @@ var (
forceDeleteFlag = &cli.BoolFlag{ forceDeleteFlag = &cli.BoolFlag{
Name: "force", Name: "force",
Aliases: []string{"f"}, Aliases: []string{"f"},
Usage: "Allows you to delete a tunnel, even if it has active connections.", Usage: "Allows you to delete a tunnel, even if it has active connections.",
} }
) )
@ -131,13 +131,13 @@ func createTunnel(c *cli.Context) error {
return nil return nil
} }
func tunnelFilePath(tunnelID, directory string) (string, error) { func tunnelFilePath(tunnelID uuid.UUID, directory string) (string, error) {
fileName := fmt.Sprintf("%v.json", tunnelID) fileName := fmt.Sprintf("%v.json", tunnelID)
filePath := filepath.Clean(fmt.Sprintf("%s/%s", directory, fileName)) filePath := filepath.Clean(fmt.Sprintf("%s/%s", directory, fileName))
return homedir.Expand(filePath) return homedir.Expand(filePath)
} }
func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error { func writeTunnelCredentials(tunnelID uuid.UUID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
originCertDir := filepath.Dir(originCertPath) originCertDir := filepath.Dir(originCertPath)
filePath, err := tunnelFilePath(tunnelID, originCertDir) filePath, err := tunnelFilePath(tunnelID, originCertDir)
if err != nil { if err != nil {
@ -155,7 +155,7 @@ func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSe
return ioutil.WriteFile(filePath, body, 400) return ioutil.WriteFile(filePath, body, 400)
} }
func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Service) (*pogs.TunnelAuth, error) { func readTunnelCredentials(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (*pogs.TunnelAuth, error) {
filePath, err := tunnelCredentialsPath(c, tunnelID, logger) filePath, err := tunnelCredentialsPath(c, tunnelID, logger)
if err != nil { if err != nil {
return nil, err return nil, err
@ -172,7 +172,7 @@ func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Servic
return &auth, nil return &auth, nil
} }
func tunnelCredentialsPath(c *cli.Context, tunnelID string, logger logger.Service) (string, error) { func tunnelCredentialsPath(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (string, error) {
if filePath := c.String("credentials-file"); filePath != "" { if filePath := c.String("credentials-file"); filePath != "" {
if validFilePath(filePath) { if validFilePath(filePath) {
return filePath, nil return filePath, nil
@ -322,7 +322,10 @@ func deleteTunnel(c *cli.Context) error {
if c.NArg() != 1 { if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`) return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`)
} }
id := c.Args().First() tunnelID, err := uuid.Parse(c.Args().First())
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New() logger, err := logger.New()
if err != nil { if err != nil {
@ -337,9 +340,9 @@ func deleteTunnel(c *cli.Context) error {
forceFlagSet := c.Bool("force") forceFlagSet := c.Bool("force")
tunnel, err := client.GetTunnel(id) tunnel, err := client.GetTunnel(tunnelID)
if err != nil { if err != nil {
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", id) return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnelID)
} }
// Check if tunnel DeletedAt field has already been set // Check if tunnel DeletedAt field has already been set
@ -351,17 +354,17 @@ func deleteTunnel(c *cli.Context) error {
if !forceFlagSet { if !forceFlagSet {
return errors.New("You can not delete this tunnel because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.") return errors.New("You can not delete this tunnel because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.")
} }
if err := client.CleanupConnections(id); err != nil { if err := client.CleanupConnections(tunnelID); err != nil {
return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", id) return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnelID)
} }
} }
if err := client.DeleteTunnel(id); err != nil { if err := client.DeleteTunnel(tunnelID); err != nil {
return errors.Wrapf(err, "Error deleting tunnel %s", id) return errors.Wrapf(err, "Error deleting tunnel %s", tunnelID)
} }
tunnelCredentialsPath, err := tunnelCredentialsPath(c, id, logger) tunnelCredentialsPath, err := tunnelCredentialsPath(c, tunnelID, logger)
if err != nil { if err != nil {
logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err) logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err)
return nil return nil
@ -388,7 +391,7 @@ func renderOutput(format string, v interface{}) error {
} }
func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client { func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client {
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger) client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ZoneID, cert.ServiceKey, logger)
return client return client
} }
@ -428,8 +431,8 @@ func runTunnel(c *cli.Context) error {
if c.NArg() != 1 { if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`) return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
} }
id := c.Args().First()
tunnelID, err := uuid.Parse(id) tunnelID, err := uuid.Parse(c.Args().First())
if err != nil { if err != nil {
return errors.Wrap(err, "error parsing tunnel ID") return errors.Wrap(err, "error parsing tunnel ID")
} }
@ -439,7 +442,7 @@ func runTunnel(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger") return errors.Wrap(err, "error setting up logger")
} }
credentials, err := readTunnelCredentials(c, id, logger) credentials, err := readTunnelCredentials(c, tunnelID, logger)
if err != nil { if err != nil {
return err return err
} }
@ -474,12 +477,108 @@ func cleanupConnections(c *cli.Context) error {
client := newTunnelstoreClient(c, cert, logger) client := newTunnelstoreClient(c, cert, logger)
for i := 0; i < c.NArg(); i++ { for i := 0; i < c.NArg(); i++ {
id := c.Args().Get(i) tunnelID, err := uuid.Parse(c.Args().Get(i))
logger.Infof("Cleanup connection for tunnel %s", id) if err != nil {
if err := client.CleanupConnections(id); err != nil { logger.Errorf("Failed to parse argument %d as tunnelID, error :%v", i, err)
logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err) continue
}
logger.Infof("Cleanup connection for tunnel %s", tunnelID)
if err := client.CleanupConnections(tunnelID); err != nil {
logger.Errorf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err)
} }
} }
return nil return nil
} }
func buildRouteCommand() *cli.Command {
return &cli.Command{
Name: "route",
Action: cliutil.ErrorHandler(routeTunnel),
Usage: "Define what hostname or load balancer can route to this tunnel",
Description: `The route defines what hostname or load balancer can route to this tunnel.
To route a hostname: cloudflared tunnel route dns <tunnel ID> <hostname>
To route a load balancer: cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>
If you don't specify a load balancer pool, we will create a new pool called tunnel:<tunnel ID>`,
ArgsUsage: "dns|lb TUNNEL-ID HOSTNAME [LB-POOL]",
Hidden: hideSubcommands,
}
}
func routeTunnel(c *cli.Context) error {
if c.NArg() < 2 {
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`)
}
const tunnelIDIndex = 1
tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex))
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New()
if err != nil {
return errors.Wrap(err, "error setting up logger")
}
routeType := c.Args().First()
var route tunnelstore.Route
switch routeType {
case "dns":
route, err = dnsRouteFromArg(c, tunnelID)
if err != nil {
return err
}
case "lb":
route, err = lbRouteFromArg(c, tunnelID, logger)
if err != nil {
return err
}
default:
return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType)
}
cert, _, err := getOriginCertFromContext(c, logger)
if err != nil {
return err
}
client := newTunnelstoreClient(c, cert, logger)
return client.RouteTunnel(tunnelID, route)
}
func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
const (
userHostnameIndex = 2
expectArgs = 3
)
if c.NArg() != expectArgs {
return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg())
}
userHostname := c.Args().Get(userHostnameIndex)
if userHostname == "" {
return nil, cliutil.UsageError("The third argument should be the hostname")
}
return tunnelstore.NewDNSRoute(userHostname), nil
}
func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (tunnelstore.Route, error) {
const (
lbNameIndex = 2
lbPoolIndex = 3
expectMinArgs = 3
)
if c.NArg() < expectMinArgs {
return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg())
}
lbName := c.Args().Get(lbNameIndex)
if lbName == "" {
return nil, cliutil.UsageError("The third argument should be the load balancer name")
}
lbPool := c.Args().Get(lbPoolIndex)
if lbPool == "" {
lbPool = fmt.Sprintf("tunnel:%v", tunnelID)
logger.Infof("Generate pool name %s", lbPool)
}
return tunnelstore.NewLBRoute(lbName, lbPool), nil
}

View File

@ -75,11 +75,13 @@ func Test_fmtConnections(t *testing.T) {
} }
func TestTunnelfilePath(t *testing.T) { func TestTunnelfilePath(t *testing.T) {
tunnelID, err := uuid.Parse("f48d8918-bc23-4647-9d48-082c5b76de65")
assert.NoError(t, err)
originCertDir := filepath.Dir("~/.cloudflared/cert.pem") originCertDir := filepath.Dir("~/.cloudflared/cert.pem")
actual, err := tunnelFilePath("tunnel", originCertDir) actual, err := tunnelFilePath(tunnelID, originCertDir)
assert.NoError(t, err) assert.NoError(t, err)
homeDir, err := homedir.Dir() homeDir, err := homedir.Dir()
assert.NoError(t, err) assert.NoError(t, err)
expected := fmt.Sprintf("%s/.cloudflared/tunnel.json", homeDir) expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
} }

View File

@ -27,7 +27,7 @@ var (
) )
type Tunnel struct { type Tunnel struct {
ID string `json:"id"` ID uuid.UUID `json:"id"`
Name string `json:"name"` Name string `json:"name"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
DeletedAt time.Time `json:"deleted_at"` DeletedAt time.Time `json:"deleted_at"`
@ -39,30 +39,98 @@ type Connection struct {
ID uuid.UUID `json:"uuid"` ID uuid.UUID `json:"uuid"`
} }
// Route represents a record type that can route to a tunnel
type Route interface {
json.Marshaler
RecordType() string
}
type DNSRoute struct {
userHostname string
}
func NewDNSRoute(userHostname string) Route {
return &DNSRoute{
userHostname: userHostname,
}
}
func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
s := struct {
Type string `json:"type"`
UserHostname string `json:"user_hostname"`
}{
Type: dr.RecordType(),
UserHostname: dr.userHostname,
}
return json.Marshal(&s)
}
func (dr *DNSRoute) RecordType() string {
return "dns"
}
type LBRoute struct {
lbName string
lbPool string
}
func NewLBRoute(lbName, lbPool string) Route {
return &LBRoute{
lbName: lbName,
lbPool: lbPool,
}
}
func (lr *LBRoute) MarshalJSON() ([]byte, error) {
s := struct {
Type string `json:"type"`
LBName string `json:"lb_name"`
LBPool string `json:"lb_pool"`
}{
Type: lr.RecordType(),
LBName: lr.lbName,
LBPool: lr.lbPool,
}
return json.Marshal(&s)
}
func (lr *LBRoute) RecordType() string {
return "lb"
}
type Client interface { type Client interface {
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
GetTunnel(tunnelID string) (*Tunnel, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
DeleteTunnel(tunnelID string) error DeleteTunnel(tunnelID uuid.UUID) error
ListTunnels() ([]Tunnel, error) ListTunnels() ([]Tunnel, error)
CleanupConnections(tunnelID string) error CleanupConnections(tunnelID uuid.UUID) error
RouteTunnel(tunnelID uuid.UUID, route Route) error
} }
type RESTClient struct { type RESTClient struct {
baseURL string baseEndpoints *baseEndpoints
authToken string authToken string
client http.Client client http.Client
logger logger.Service logger logger.Service
}
type baseEndpoints struct {
accountLevel string
zoneLevel string
} }
var _ Client = (*RESTClient)(nil) var _ Client = (*RESTClient)(nil)
func NewRESTClient(baseURL string, accountTag string, authToken string, logger logger.Service) *RESTClient { func NewRESTClient(baseURL string, accountTag, zoneTag string, authToken string, logger logger.Service) *RESTClient {
if strings.HasSuffix(baseURL, "/") { if strings.HasSuffix(baseURL, "/") {
baseURL = baseURL[:len(baseURL)-1] baseURL = baseURL[:len(baseURL)-1]
} }
url := fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag)
return &RESTClient{ return &RESTClient{
baseURL: url, baseEndpoints: &baseEndpoints{
accountLevel: fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag),
zoneLevel: fmt.Sprintf("%s/zones/%s/tunnels", baseURL, accountTag),
},
authToken: authToken, authToken: authToken,
client: http.Client{ client: http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
@ -92,7 +160,7 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
return nil, errors.Wrap(err, "Failed to serialize new tunnel request") return nil, errors.Wrap(err, "Failed to serialize new tunnel request")
} }
resp, err := r.sendRequest("POST", "", bytes.NewBuffer(body)) resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, bytes.NewBuffer(body))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "REST request failed") return nil, errors.Wrap(err, "REST request failed")
} }
@ -108,8 +176,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
return nil, statusCodeToError("create tunnel", resp) return nil, statusCodeToError("create tunnel", resp)
} }
func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) { func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
resp, err := r.sendRequest("GET", tunnelID, nil) endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID)
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "REST request failed") return nil, errors.Wrap(err, "REST request failed")
} }
@ -122,8 +191,9 @@ func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
return nil, statusCodeToError("get tunnel", resp) return nil, statusCodeToError("get tunnel", resp)
} }
func (r *RESTClient) DeleteTunnel(tunnelID string) error { func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
resp, err := r.sendRequest("DELETE", tunnelID, nil) endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID)
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil { if err != nil {
return errors.Wrap(err, "REST request failed") return errors.Wrap(err, "REST request failed")
} }
@ -133,7 +203,7 @@ func (r *RESTClient) DeleteTunnel(tunnelID string) error {
} }
func (r *RESTClient) ListTunnels() ([]Tunnel, error) { func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
resp, err := r.sendRequest("GET", "", nil) resp, err := r.sendRequest("GET", r.baseEndpoints.accountLevel, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "REST request failed") return nil, errors.Wrap(err, "REST request failed")
} }
@ -150,8 +220,9 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
return nil, statusCodeToError("list tunnels", resp) return nil, statusCodeToError("list tunnels", resp)
} }
func (r *RESTClient) CleanupConnections(tunnelID string) error { func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil) endpoint := fmt.Sprintf("%s/%v/connections", r.baseEndpoints.accountLevel, tunnelID)
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil { if err != nil {
return errors.Wrap(err, "REST request failed") return errors.Wrap(err, "REST request failed")
} }
@ -160,15 +231,23 @@ func (r *RESTClient) CleanupConnections(tunnelID string) error {
return statusCodeToError("cleanup connections", resp) return statusCodeToError("cleanup connections", resp)
} }
func (r *RESTClient) resolve(target string) string { func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error {
if target != "" { body, err := json.Marshal(route)
return r.baseURL + "/" + target if err != nil {
return errors.Wrap(err, "Failed to serialize Route")
} }
return r.baseURL
endpoint := fmt.Sprintf("%s/%v/routes", r.baseEndpoints.zoneLevel, tunnelID)
resp, err := r.sendRequest("PUT", endpoint, bytes.NewBuffer(body))
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
return statusCodeToError("add route", resp)
} }
func (r *RESTClient) sendRequest(method string, target string, body io.Reader) (*http.Response, error) { func (r *RESTClient) sendRequest(method string, url string, body io.Reader) (*http.Response, error) {
url := r.resolve(target)
r.logger.Debugf("%s %s", method, url) r.logger.Debugf("%s %s", method, url)
req, err := http.NewRequest(method, url, body) req, err := http.NewRequest(method, url, body)
if err != nil { if err != nil {