TUN-3156: Add route subcommand under tunnel

This commit is contained in:
cthuang 2020-07-06 16:01:48 +08:00
parent 7afde79600
commit 8836ee1dda
4 changed files with 230 additions and 49 deletions

View File

@ -174,6 +174,7 @@ func Commands() []*cli.Command {
subcommands = append(subcommands, buildDeleteCommand())
subcommands = append(subcommands, buildRunCommand())
subcommands = append(subcommands, buildCleanupCommand())
subcommands = append(subcommands, buildRouteCommand())
cmds = append(cmds, &cli.Command{
Name: "tunnel",

View File

@ -58,7 +58,7 @@ var (
forceDeleteFlag = &cli.BoolFlag{
Name: "force",
Aliases: []string{"f"},
Usage: "Allows you to delete a tunnel, even if it has active connections.",
Usage: "Allows you to delete a tunnel, even if it has active connections.",
}
)
@ -131,13 +131,13 @@ func createTunnel(c *cli.Context) error {
return nil
}
func tunnelFilePath(tunnelID, directory string) (string, error) {
func tunnelFilePath(tunnelID uuid.UUID, directory string) (string, error) {
fileName := fmt.Sprintf("%v.json", tunnelID)
filePath := filepath.Clean(fmt.Sprintf("%s/%s", directory, fileName))
return homedir.Expand(filePath)
}
func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
func writeTunnelCredentials(tunnelID uuid.UUID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
originCertDir := filepath.Dir(originCertPath)
filePath, err := tunnelFilePath(tunnelID, originCertDir)
if err != nil {
@ -155,7 +155,7 @@ func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSe
return ioutil.WriteFile(filePath, body, 400)
}
func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Service) (*pogs.TunnelAuth, error) {
func readTunnelCredentials(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (*pogs.TunnelAuth, error) {
filePath, err := tunnelCredentialsPath(c, tunnelID, logger)
if err != nil {
return nil, err
@ -172,7 +172,7 @@ func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Servic
return &auth, nil
}
func tunnelCredentialsPath(c *cli.Context, tunnelID string, logger logger.Service) (string, error) {
func tunnelCredentialsPath(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (string, error) {
if filePath := c.String("credentials-file"); filePath != "" {
if validFilePath(filePath) {
return filePath, nil
@ -322,7 +322,10 @@ func deleteTunnel(c *cli.Context) error {
if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`)
}
id := c.Args().First()
tunnelID, err := uuid.Parse(c.Args().First())
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New()
if err != nil {
@ -337,9 +340,9 @@ func deleteTunnel(c *cli.Context) error {
forceFlagSet := c.Bool("force")
tunnel, err := client.GetTunnel(id)
tunnel, err := client.GetTunnel(tunnelID)
if err != nil {
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", id)
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnelID)
}
// Check if tunnel DeletedAt field has already been set
@ -351,17 +354,17 @@ func deleteTunnel(c *cli.Context) error {
if !forceFlagSet {
return errors.New("You can not delete this tunnel because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.")
}
if err := client.CleanupConnections(id); err != nil {
return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", id)
if err := client.CleanupConnections(tunnelID); err != nil {
return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnelID)
}
}
if err := client.DeleteTunnel(id); err != nil {
return errors.Wrapf(err, "Error deleting tunnel %s", id)
if err := client.DeleteTunnel(tunnelID); err != nil {
return errors.Wrapf(err, "Error deleting tunnel %s", tunnelID)
}
tunnelCredentialsPath, err := tunnelCredentialsPath(c, id, logger)
tunnelCredentialsPath, err := tunnelCredentialsPath(c, tunnelID, logger)
if err != nil {
logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err)
return nil
@ -388,7 +391,7 @@ func renderOutput(format string, v interface{}) error {
}
func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client {
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger)
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ZoneID, cert.ServiceKey, logger)
return client
}
@ -428,8 +431,8 @@ func runTunnel(c *cli.Context) error {
if c.NArg() != 1 {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
}
id := c.Args().First()
tunnelID, err := uuid.Parse(id)
tunnelID, err := uuid.Parse(c.Args().First())
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
@ -439,7 +442,7 @@ func runTunnel(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger")
}
credentials, err := readTunnelCredentials(c, id, logger)
credentials, err := readTunnelCredentials(c, tunnelID, logger)
if err != nil {
return err
}
@ -474,12 +477,108 @@ func cleanupConnections(c *cli.Context) error {
client := newTunnelstoreClient(c, cert, logger)
for i := 0; i < c.NArg(); i++ {
id := c.Args().Get(i)
logger.Infof("Cleanup connection for tunnel %s", id)
if err := client.CleanupConnections(id); err != nil {
logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err)
tunnelID, err := uuid.Parse(c.Args().Get(i))
if err != nil {
logger.Errorf("Failed to parse argument %d as tunnelID, error :%v", i, err)
continue
}
logger.Infof("Cleanup connection for tunnel %s", tunnelID)
if err := client.CleanupConnections(tunnelID); err != nil {
logger.Errorf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err)
}
}
return nil
}
func buildRouteCommand() *cli.Command {
return &cli.Command{
Name: "route",
Action: cliutil.ErrorHandler(routeTunnel),
Usage: "Define what hostname or load balancer can route to this tunnel",
Description: `The route defines what hostname or load balancer can route to this tunnel.
To route a hostname: cloudflared tunnel route dns <tunnel ID> <hostname>
To route a load balancer: cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>
If you don't specify a load balancer pool, we will create a new pool called tunnel:<tunnel ID>`,
ArgsUsage: "dns|lb TUNNEL-ID HOSTNAME [LB-POOL]",
Hidden: hideSubcommands,
}
}
func routeTunnel(c *cli.Context) error {
if c.NArg() < 2 {
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`)
}
const tunnelIDIndex = 1
tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex))
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New()
if err != nil {
return errors.Wrap(err, "error setting up logger")
}
routeType := c.Args().First()
var route tunnelstore.Route
switch routeType {
case "dns":
route, err = dnsRouteFromArg(c, tunnelID)
if err != nil {
return err
}
case "lb":
route, err = lbRouteFromArg(c, tunnelID, logger)
if err != nil {
return err
}
default:
return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType)
}
cert, _, err := getOriginCertFromContext(c, logger)
if err != nil {
return err
}
client := newTunnelstoreClient(c, cert, logger)
return client.RouteTunnel(tunnelID, route)
}
func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
const (
userHostnameIndex = 2
expectArgs = 3
)
if c.NArg() != expectArgs {
return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg())
}
userHostname := c.Args().Get(userHostnameIndex)
if userHostname == "" {
return nil, cliutil.UsageError("The third argument should be the hostname")
}
return tunnelstore.NewDNSRoute(userHostname), nil
}
func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (tunnelstore.Route, error) {
const (
lbNameIndex = 2
lbPoolIndex = 3
expectMinArgs = 3
)
if c.NArg() < expectMinArgs {
return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg())
}
lbName := c.Args().Get(lbNameIndex)
if lbName == "" {
return nil, cliutil.UsageError("The third argument should be the load balancer name")
}
lbPool := c.Args().Get(lbPoolIndex)
if lbPool == "" {
lbPool = fmt.Sprintf("tunnel:%v", tunnelID)
logger.Infof("Generate pool name %s", lbPool)
}
return tunnelstore.NewLBRoute(lbName, lbPool), nil
}

View File

@ -75,11 +75,13 @@ func Test_fmtConnections(t *testing.T) {
}
func TestTunnelfilePath(t *testing.T) {
tunnelID, err := uuid.Parse("f48d8918-bc23-4647-9d48-082c5b76de65")
assert.NoError(t, err)
originCertDir := filepath.Dir("~/.cloudflared/cert.pem")
actual, err := tunnelFilePath("tunnel", originCertDir)
actual, err := tunnelFilePath(tunnelID, originCertDir)
assert.NoError(t, err)
homeDir, err := homedir.Dir()
assert.NoError(t, err)
expected := fmt.Sprintf("%s/.cloudflared/tunnel.json", homeDir)
expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
assert.Equal(t, expected, actual)
}

View File

@ -27,7 +27,7 @@ var (
)
type Tunnel struct {
ID string `json:"id"`
ID uuid.UUID `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
DeletedAt time.Time `json:"deleted_at"`
@ -39,30 +39,98 @@ type Connection struct {
ID uuid.UUID `json:"uuid"`
}
// Route represents a record type that can route to a tunnel
type Route interface {
json.Marshaler
RecordType() string
}
type DNSRoute struct {
userHostname string
}
func NewDNSRoute(userHostname string) Route {
return &DNSRoute{
userHostname: userHostname,
}
}
func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
s := struct {
Type string `json:"type"`
UserHostname string `json:"user_hostname"`
}{
Type: dr.RecordType(),
UserHostname: dr.userHostname,
}
return json.Marshal(&s)
}
func (dr *DNSRoute) RecordType() string {
return "dns"
}
type LBRoute struct {
lbName string
lbPool string
}
func NewLBRoute(lbName, lbPool string) Route {
return &LBRoute{
lbName: lbName,
lbPool: lbPool,
}
}
func (lr *LBRoute) MarshalJSON() ([]byte, error) {
s := struct {
Type string `json:"type"`
LBName string `json:"lb_name"`
LBPool string `json:"lb_pool"`
}{
Type: lr.RecordType(),
LBName: lr.lbName,
LBPool: lr.lbPool,
}
return json.Marshal(&s)
}
func (lr *LBRoute) RecordType() string {
return "lb"
}
type Client interface {
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
GetTunnel(tunnelID string) (*Tunnel, error)
DeleteTunnel(tunnelID string) error
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
DeleteTunnel(tunnelID uuid.UUID) error
ListTunnels() ([]Tunnel, error)
CleanupConnections(tunnelID string) error
CleanupConnections(tunnelID uuid.UUID) error
RouteTunnel(tunnelID uuid.UUID, route Route) error
}
type RESTClient struct {
baseURL string
authToken string
client http.Client
logger logger.Service
baseEndpoints *baseEndpoints
authToken string
client http.Client
logger logger.Service
}
type baseEndpoints struct {
accountLevel string
zoneLevel string
}
var _ Client = (*RESTClient)(nil)
func NewRESTClient(baseURL string, accountTag string, authToken string, logger logger.Service) *RESTClient {
func NewRESTClient(baseURL string, accountTag, zoneTag string, authToken string, logger logger.Service) *RESTClient {
if strings.HasSuffix(baseURL, "/") {
baseURL = baseURL[:len(baseURL)-1]
}
url := fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag)
return &RESTClient{
baseURL: url,
baseEndpoints: &baseEndpoints{
accountLevel: fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag),
zoneLevel: fmt.Sprintf("%s/zones/%s/tunnels", baseURL, accountTag),
},
authToken: authToken,
client: http.Client{
Transport: &http.Transport{
@ -92,7 +160,7 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
return nil, errors.Wrap(err, "Failed to serialize new tunnel request")
}
resp, err := r.sendRequest("POST", "", bytes.NewBuffer(body))
resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, bytes.NewBuffer(body))
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
@ -108,8 +176,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
return nil, statusCodeToError("create tunnel", resp)
}
func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
resp, err := r.sendRequest("GET", tunnelID, nil)
func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID)
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
@ -122,8 +191,9 @@ func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
return nil, statusCodeToError("get tunnel", resp)
}
func (r *RESTClient) DeleteTunnel(tunnelID string) error {
resp, err := r.sendRequest("DELETE", tunnelID, nil)
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID)
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
@ -133,7 +203,7 @@ func (r *RESTClient) DeleteTunnel(tunnelID string) error {
}
func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
resp, err := r.sendRequest("GET", "", nil)
resp, err := r.sendRequest("GET", r.baseEndpoints.accountLevel, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
@ -150,8 +220,9 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
return nil, statusCodeToError("list tunnels", resp)
}
func (r *RESTClient) CleanupConnections(tunnelID string) error {
resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil)
func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
endpoint := fmt.Sprintf("%s/%v/connections", r.baseEndpoints.accountLevel, tunnelID)
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
@ -160,15 +231,23 @@ func (r *RESTClient) CleanupConnections(tunnelID string) error {
return statusCodeToError("cleanup connections", resp)
}
func (r *RESTClient) resolve(target string) string {
if target != "" {
return r.baseURL + "/" + target
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error {
body, err := json.Marshal(route)
if err != nil {
return errors.Wrap(err, "Failed to serialize Route")
}
return r.baseURL
endpoint := fmt.Sprintf("%s/%v/routes", r.baseEndpoints.zoneLevel, tunnelID)
resp, err := r.sendRequest("PUT", endpoint, bytes.NewBuffer(body))
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
return statusCodeToError("add route", resp)
}
func (r *RESTClient) sendRequest(method string, target string, body io.Reader) (*http.Response, error) {
url := r.resolve(target)
func (r *RESTClient) sendRequest(method string, url string, body io.Reader) (*http.Response, error) {
r.logger.Debugf("%s %s", method, url)
req, err := http.NewRequest(method, url, body)
if err != nil {