TUN-3156: Add route subcommand under tunnel
This commit is contained in:
parent
7afde79600
commit
8836ee1dda
|
@ -174,6 +174,7 @@ func Commands() []*cli.Command {
|
|||
subcommands = append(subcommands, buildDeleteCommand())
|
||||
subcommands = append(subcommands, buildRunCommand())
|
||||
subcommands = append(subcommands, buildCleanupCommand())
|
||||
subcommands = append(subcommands, buildRouteCommand())
|
||||
|
||||
cmds = append(cmds, &cli.Command{
|
||||
Name: "tunnel",
|
||||
|
|
|
@ -58,7 +58,7 @@ var (
|
|||
forceDeleteFlag = &cli.BoolFlag{
|
||||
Name: "force",
|
||||
Aliases: []string{"f"},
|
||||
Usage: "Allows you to delete a tunnel, even if it has active connections.",
|
||||
Usage: "Allows you to delete a tunnel, even if it has active connections.",
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -131,13 +131,13 @@ func createTunnel(c *cli.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func tunnelFilePath(tunnelID, directory string) (string, error) {
|
||||
func tunnelFilePath(tunnelID uuid.UUID, directory string) (string, error) {
|
||||
fileName := fmt.Sprintf("%v.json", tunnelID)
|
||||
filePath := filepath.Clean(fmt.Sprintf("%s/%s", directory, fileName))
|
||||
return homedir.Expand(filePath)
|
||||
}
|
||||
|
||||
func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
|
||||
func writeTunnelCredentials(tunnelID uuid.UUID, accountID, originCertPath string, tunnelSecret []byte, logger logger.Service) error {
|
||||
originCertDir := filepath.Dir(originCertPath)
|
||||
filePath, err := tunnelFilePath(tunnelID, originCertDir)
|
||||
if err != nil {
|
||||
|
@ -155,7 +155,7 @@ func writeTunnelCredentials(tunnelID, accountID, originCertPath string, tunnelSe
|
|||
return ioutil.WriteFile(filePath, body, 400)
|
||||
}
|
||||
|
||||
func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Service) (*pogs.TunnelAuth, error) {
|
||||
func readTunnelCredentials(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (*pogs.TunnelAuth, error) {
|
||||
filePath, err := tunnelCredentialsPath(c, tunnelID, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -172,7 +172,7 @@ func readTunnelCredentials(c *cli.Context, tunnelID string, logger logger.Servic
|
|||
return &auth, nil
|
||||
}
|
||||
|
||||
func tunnelCredentialsPath(c *cli.Context, tunnelID string, logger logger.Service) (string, error) {
|
||||
func tunnelCredentialsPath(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (string, error) {
|
||||
if filePath := c.String("credentials-file"); filePath != "" {
|
||||
if validFilePath(filePath) {
|
||||
return filePath, nil
|
||||
|
@ -322,7 +322,10 @@ func deleteTunnel(c *cli.Context) error {
|
|||
if c.NArg() != 1 {
|
||||
return cliutil.UsageError(`"cloudflared tunnel delete" requires exactly 1 argument, the ID of the tunnel to delete.`)
|
||||
}
|
||||
id := c.Args().First()
|
||||
tunnelID, err := uuid.Parse(c.Args().First())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing tunnel ID")
|
||||
}
|
||||
|
||||
logger, err := logger.New()
|
||||
if err != nil {
|
||||
|
@ -337,9 +340,9 @@ func deleteTunnel(c *cli.Context) error {
|
|||
|
||||
forceFlagSet := c.Bool("force")
|
||||
|
||||
tunnel, err := client.GetTunnel(id)
|
||||
tunnel, err := client.GetTunnel(tunnelID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", id)
|
||||
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnelID)
|
||||
}
|
||||
|
||||
// Check if tunnel DeletedAt field has already been set
|
||||
|
@ -351,17 +354,17 @@ func deleteTunnel(c *cli.Context) error {
|
|||
if !forceFlagSet {
|
||||
return errors.New("You can not delete this tunnel because it has active connections. To see connections run the 'list' command. If you believe the tunnel is not active, you can use a -f / --force flag with this command.")
|
||||
}
|
||||
|
||||
if err := client.CleanupConnections(id); err != nil {
|
||||
return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", id)
|
||||
|
||||
if err := client.CleanupConnections(tunnelID); err != nil {
|
||||
return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnelID)
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.DeleteTunnel(id); err != nil {
|
||||
return errors.Wrapf(err, "Error deleting tunnel %s", id)
|
||||
if err := client.DeleteTunnel(tunnelID); err != nil {
|
||||
return errors.Wrapf(err, "Error deleting tunnel %s", tunnelID)
|
||||
}
|
||||
|
||||
tunnelCredentialsPath, err := tunnelCredentialsPath(c, id, logger)
|
||||
tunnelCredentialsPath, err := tunnelCredentialsPath(c, tunnelID, logger)
|
||||
if err != nil {
|
||||
logger.Infof("Cannot locate tunnel credentials to delete, error: %v. Please delete the file manually", err)
|
||||
return nil
|
||||
|
@ -388,7 +391,7 @@ func renderOutput(format string, v interface{}) error {
|
|||
}
|
||||
|
||||
func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logger.Service) tunnelstore.Client {
|
||||
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ServiceKey, logger)
|
||||
client := tunnelstore.NewRESTClient(c.String("api-url"), cert.AccountID, cert.ZoneID, cert.ServiceKey, logger)
|
||||
return client
|
||||
}
|
||||
|
||||
|
@ -428,8 +431,8 @@ func runTunnel(c *cli.Context) error {
|
|||
if c.NArg() != 1 {
|
||||
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
|
||||
}
|
||||
id := c.Args().First()
|
||||
tunnelID, err := uuid.Parse(id)
|
||||
|
||||
tunnelID, err := uuid.Parse(c.Args().First())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing tunnel ID")
|
||||
}
|
||||
|
@ -439,7 +442,7 @@ func runTunnel(c *cli.Context) error {
|
|||
return errors.Wrap(err, "error setting up logger")
|
||||
}
|
||||
|
||||
credentials, err := readTunnelCredentials(c, id, logger)
|
||||
credentials, err := readTunnelCredentials(c, tunnelID, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -474,12 +477,108 @@ func cleanupConnections(c *cli.Context) error {
|
|||
client := newTunnelstoreClient(c, cert, logger)
|
||||
|
||||
for i := 0; i < c.NArg(); i++ {
|
||||
id := c.Args().Get(i)
|
||||
logger.Infof("Cleanup connection for tunnel %s", id)
|
||||
if err := client.CleanupConnections(id); err != nil {
|
||||
logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err)
|
||||
tunnelID, err := uuid.Parse(c.Args().Get(i))
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to parse argument %d as tunnelID, error :%v", i, err)
|
||||
continue
|
||||
}
|
||||
logger.Infof("Cleanup connection for tunnel %s", tunnelID)
|
||||
if err := client.CleanupConnections(tunnelID); err != nil {
|
||||
logger.Errorf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRouteCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "route",
|
||||
Action: cliutil.ErrorHandler(routeTunnel),
|
||||
Usage: "Define what hostname or load balancer can route to this tunnel",
|
||||
Description: `The route defines what hostname or load balancer can route to this tunnel.
|
||||
To route a hostname: cloudflared tunnel route dns <tunnel ID> <hostname>
|
||||
To route a load balancer: cloudflared tunnel route lb <tunnel ID> <load balancer name> <load balancer pool>
|
||||
If you don't specify a load balancer pool, we will create a new pool called tunnel:<tunnel ID>`,
|
||||
ArgsUsage: "dns|lb TUNNEL-ID HOSTNAME [LB-POOL]",
|
||||
Hidden: hideSubcommands,
|
||||
}
|
||||
}
|
||||
|
||||
func routeTunnel(c *cli.Context) error {
|
||||
if c.NArg() < 2 {
|
||||
return cliutil.UsageError(`"cloudflared tunnel route" requires the first argument to be the route type(dns or lb), followed by the ID of the tunnel`)
|
||||
}
|
||||
const tunnelIDIndex = 1
|
||||
tunnelID, err := uuid.Parse(c.Args().Get(tunnelIDIndex))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing tunnel ID")
|
||||
}
|
||||
|
||||
logger, err := logger.New()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error setting up logger")
|
||||
}
|
||||
|
||||
routeType := c.Args().First()
|
||||
var route tunnelstore.Route
|
||||
switch routeType {
|
||||
case "dns":
|
||||
route, err = dnsRouteFromArg(c, tunnelID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "lb":
|
||||
route, err = lbRouteFromArg(c, tunnelID, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return cliutil.UsageError("%s is not a recognized route type. Supported route types are dns and lb", routeType)
|
||||
}
|
||||
|
||||
cert, _, err := getOriginCertFromContext(c, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := newTunnelstoreClient(c, cert, logger)
|
||||
return client.RouteTunnel(tunnelID, route)
|
||||
}
|
||||
|
||||
func dnsRouteFromArg(c *cli.Context, tunnelID uuid.UUID) (tunnelstore.Route, error) {
|
||||
const (
|
||||
userHostnameIndex = 2
|
||||
expectArgs = 3
|
||||
)
|
||||
if c.NArg() != expectArgs {
|
||||
return nil, cliutil.UsageError("Expect %d arguments, got %d", expectArgs, c.NArg())
|
||||
}
|
||||
userHostname := c.Args().Get(userHostnameIndex)
|
||||
if userHostname == "" {
|
||||
return nil, cliutil.UsageError("The third argument should be the hostname")
|
||||
}
|
||||
return tunnelstore.NewDNSRoute(userHostname), nil
|
||||
}
|
||||
|
||||
func lbRouteFromArg(c *cli.Context, tunnelID uuid.UUID, logger logger.Service) (tunnelstore.Route, error) {
|
||||
const (
|
||||
lbNameIndex = 2
|
||||
lbPoolIndex = 3
|
||||
expectMinArgs = 3
|
||||
)
|
||||
if c.NArg() < expectMinArgs {
|
||||
return nil, cliutil.UsageError("Expect at least %d arguments, got %d", expectMinArgs, c.NArg())
|
||||
}
|
||||
lbName := c.Args().Get(lbNameIndex)
|
||||
if lbName == "" {
|
||||
return nil, cliutil.UsageError("The third argument should be the load balancer name")
|
||||
}
|
||||
lbPool := c.Args().Get(lbPoolIndex)
|
||||
if lbPool == "" {
|
||||
lbPool = fmt.Sprintf("tunnel:%v", tunnelID)
|
||||
logger.Infof("Generate pool name %s", lbPool)
|
||||
}
|
||||
|
||||
return tunnelstore.NewLBRoute(lbName, lbPool), nil
|
||||
}
|
||||
|
|
|
@ -75,11 +75,13 @@ func Test_fmtConnections(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTunnelfilePath(t *testing.T) {
|
||||
tunnelID, err := uuid.Parse("f48d8918-bc23-4647-9d48-082c5b76de65")
|
||||
assert.NoError(t, err)
|
||||
originCertDir := filepath.Dir("~/.cloudflared/cert.pem")
|
||||
actual, err := tunnelFilePath("tunnel", originCertDir)
|
||||
actual, err := tunnelFilePath(tunnelID, originCertDir)
|
||||
assert.NoError(t, err)
|
||||
homeDir, err := homedir.Dir()
|
||||
assert.NoError(t, err)
|
||||
expected := fmt.Sprintf("%s/.cloudflared/tunnel.json", homeDir)
|
||||
expected := fmt.Sprintf("%s/.cloudflared/%v.json", homeDir, tunnelID)
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ var (
|
|||
)
|
||||
|
||||
type Tunnel struct {
|
||||
ID string `json:"id"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DeletedAt time.Time `json:"deleted_at"`
|
||||
|
@ -39,30 +39,98 @@ type Connection struct {
|
|||
ID uuid.UUID `json:"uuid"`
|
||||
}
|
||||
|
||||
// Route represents a record type that can route to a tunnel
|
||||
type Route interface {
|
||||
json.Marshaler
|
||||
RecordType() string
|
||||
}
|
||||
|
||||
type DNSRoute struct {
|
||||
userHostname string
|
||||
}
|
||||
|
||||
func NewDNSRoute(userHostname string) Route {
|
||||
return &DNSRoute{
|
||||
userHostname: userHostname,
|
||||
}
|
||||
}
|
||||
|
||||
func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
|
||||
s := struct {
|
||||
Type string `json:"type"`
|
||||
UserHostname string `json:"user_hostname"`
|
||||
}{
|
||||
Type: dr.RecordType(),
|
||||
UserHostname: dr.userHostname,
|
||||
}
|
||||
return json.Marshal(&s)
|
||||
}
|
||||
|
||||
func (dr *DNSRoute) RecordType() string {
|
||||
return "dns"
|
||||
}
|
||||
|
||||
type LBRoute struct {
|
||||
lbName string
|
||||
lbPool string
|
||||
}
|
||||
|
||||
func NewLBRoute(lbName, lbPool string) Route {
|
||||
return &LBRoute{
|
||||
lbName: lbName,
|
||||
lbPool: lbPool,
|
||||
}
|
||||
}
|
||||
|
||||
func (lr *LBRoute) MarshalJSON() ([]byte, error) {
|
||||
s := struct {
|
||||
Type string `json:"type"`
|
||||
LBName string `json:"lb_name"`
|
||||
LBPool string `json:"lb_pool"`
|
||||
}{
|
||||
Type: lr.RecordType(),
|
||||
LBName: lr.lbName,
|
||||
LBPool: lr.lbPool,
|
||||
}
|
||||
return json.Marshal(&s)
|
||||
}
|
||||
|
||||
func (lr *LBRoute) RecordType() string {
|
||||
return "lb"
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
|
||||
GetTunnel(tunnelID string) (*Tunnel, error)
|
||||
DeleteTunnel(tunnelID string) error
|
||||
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
|
||||
DeleteTunnel(tunnelID uuid.UUID) error
|
||||
ListTunnels() ([]Tunnel, error)
|
||||
CleanupConnections(tunnelID string) error
|
||||
CleanupConnections(tunnelID uuid.UUID) error
|
||||
RouteTunnel(tunnelID uuid.UUID, route Route) error
|
||||
}
|
||||
|
||||
type RESTClient struct {
|
||||
baseURL string
|
||||
authToken string
|
||||
client http.Client
|
||||
logger logger.Service
|
||||
baseEndpoints *baseEndpoints
|
||||
authToken string
|
||||
client http.Client
|
||||
logger logger.Service
|
||||
}
|
||||
|
||||
type baseEndpoints struct {
|
||||
accountLevel string
|
||||
zoneLevel string
|
||||
}
|
||||
|
||||
var _ Client = (*RESTClient)(nil)
|
||||
|
||||
func NewRESTClient(baseURL string, accountTag string, authToken string, logger logger.Service) *RESTClient {
|
||||
func NewRESTClient(baseURL string, accountTag, zoneTag string, authToken string, logger logger.Service) *RESTClient {
|
||||
if strings.HasSuffix(baseURL, "/") {
|
||||
baseURL = baseURL[:len(baseURL)-1]
|
||||
}
|
||||
url := fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag)
|
||||
return &RESTClient{
|
||||
baseURL: url,
|
||||
baseEndpoints: &baseEndpoints{
|
||||
accountLevel: fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag),
|
||||
zoneLevel: fmt.Sprintf("%s/zones/%s/tunnels", baseURL, accountTag),
|
||||
},
|
||||
authToken: authToken,
|
||||
client: http.Client{
|
||||
Transport: &http.Transport{
|
||||
|
@ -92,7 +160,7 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
|
|||
return nil, errors.Wrap(err, "Failed to serialize new tunnel request")
|
||||
}
|
||||
|
||||
resp, err := r.sendRequest("POST", "", bytes.NewBuffer(body))
|
||||
resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "REST request failed")
|
||||
}
|
||||
|
@ -108,8 +176,9 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
|
|||
return nil, statusCodeToError("create tunnel", resp)
|
||||
}
|
||||
|
||||
func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
|
||||
resp, err := r.sendRequest("GET", tunnelID, nil)
|
||||
func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
|
||||
endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID)
|
||||
resp, err := r.sendRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "REST request failed")
|
||||
}
|
||||
|
@ -122,8 +191,9 @@ func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
|
|||
return nil, statusCodeToError("get tunnel", resp)
|
||||
}
|
||||
|
||||
func (r *RESTClient) DeleteTunnel(tunnelID string) error {
|
||||
resp, err := r.sendRequest("DELETE", tunnelID, nil)
|
||||
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
|
||||
endpoint := fmt.Sprintf("%s/%v", r.baseEndpoints.accountLevel, tunnelID)
|
||||
resp, err := r.sendRequest("DELETE", endpoint, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "REST request failed")
|
||||
}
|
||||
|
@ -133,7 +203,7 @@ func (r *RESTClient) DeleteTunnel(tunnelID string) error {
|
|||
}
|
||||
|
||||
func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
|
||||
resp, err := r.sendRequest("GET", "", nil)
|
||||
resp, err := r.sendRequest("GET", r.baseEndpoints.accountLevel, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "REST request failed")
|
||||
}
|
||||
|
@ -150,8 +220,9 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
|
|||
return nil, statusCodeToError("list tunnels", resp)
|
||||
}
|
||||
|
||||
func (r *RESTClient) CleanupConnections(tunnelID string) error {
|
||||
resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil)
|
||||
func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
|
||||
endpoint := fmt.Sprintf("%s/%v/connections", r.baseEndpoints.accountLevel, tunnelID)
|
||||
resp, err := r.sendRequest("DELETE", endpoint, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "REST request failed")
|
||||
}
|
||||
|
@ -160,15 +231,23 @@ func (r *RESTClient) CleanupConnections(tunnelID string) error {
|
|||
return statusCodeToError("cleanup connections", resp)
|
||||
}
|
||||
|
||||
func (r *RESTClient) resolve(target string) string {
|
||||
if target != "" {
|
||||
return r.baseURL + "/" + target
|
||||
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error {
|
||||
body, err := json.Marshal(route)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to serialize Route")
|
||||
}
|
||||
return r.baseURL
|
||||
|
||||
endpoint := fmt.Sprintf("%s/%v/routes", r.baseEndpoints.zoneLevel, tunnelID)
|
||||
resp, err := r.sendRequest("PUT", endpoint, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "REST request failed")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return statusCodeToError("add route", resp)
|
||||
}
|
||||
|
||||
func (r *RESTClient) sendRequest(method string, target string, body io.Reader) (*http.Response, error) {
|
||||
url := r.resolve(target)
|
||||
func (r *RESTClient) sendRequest(method string, url string, body io.Reader) (*http.Response, error) {
|
||||
r.logger.Debugf("%s %s", method, url)
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue