TUN-3008: Implement cloudflared tunnel cleanup command

This commit is contained in:
cthuang 2020-07-03 16:55:11 +08:00
parent 87e06100df
commit f5c8ff77e9
3 changed files with 73 additions and 33 deletions

View File

@ -173,6 +173,7 @@ func Commands() []*cli.Command {
subcommands = append(subcommands, buildListCommand()) subcommands = append(subcommands, buildListCommand())
subcommands = append(subcommands, buildDeleteCommand()) subcommands = append(subcommands, buildDeleteCommand())
subcommands = append(subcommands, buildRunCommand()) subcommands = append(subcommands, buildRunCommand())
subcommands = append(subcommands, buildCleanupCommand())
cmds = append(cmds, &cli.Command{ cmds = append(cmds, &cli.Command{
Name: "tunnel", Name: "tunnel",

View File

@ -92,11 +92,7 @@ func createTunnel(c *cli.Context) error {
return err return err
} }
originCertPath, err := findOriginCert(c, logger) cert, originCertPath, err := getOriginCertFromContext(c, logger)
if err != nil {
return errors.Wrap(err, "Error locating origin cert")
}
cert, err := getOriginCertFromContext(originCertPath, logger)
if err != nil { if err != nil {
return err return err
} }
@ -223,11 +219,7 @@ func listTunnels(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger") return errors.Wrap(err, "error setting up logger")
} }
originCertPath, err := findOriginCert(c, logger) cert, _, err := getOriginCertFromContext(c, logger)
if err != nil {
return errors.Wrap(err, "Error locating origin cert")
}
cert, err := getOriginCertFromContext(originCertPath, logger)
if err != nil { if err != nil {
return err return err
} }
@ -310,11 +302,7 @@ func deleteTunnel(c *cli.Context) error {
return errors.Wrap(err, "error setting up logger") return errors.Wrap(err, "error setting up logger")
} }
originCertPath, err := findOriginCert(c, logger) cert, _, err := getOriginCertFromContext(c, logger)
if err != nil {
return errors.Wrap(err, "Error locating origin cert")
}
cert, err := getOriginCertFromContext(originCertPath, logger)
if err != nil { if err != nil {
return err return err
} }
@ -355,22 +343,25 @@ func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logg
return client return client
} }
func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) { func getOriginCertFromContext(c *cli.Context, logger logger.Service) (cert *certutil.OriginCert, originCertPath string, err error) {
originCertPath, err = findOriginCert(c, logger)
if err != nil {
return nil, "", errors.Wrap(err, "Error locating origin cert")
}
blocks, err := readOriginCert(originCertPath, logger) blocks, err := readOriginCert(originCertPath, logger)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) return nil, "", errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
} }
cert, err := certutil.DecodeOriginCert(blocks) cert, err = certutil.DecodeOriginCert(blocks)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Error decoding origin cert") return nil, "", errors.Wrap(err, "Error decoding origin cert")
} }
if cert.AccountID == "" { if cert.AccountID == "" {
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) return nil, "", errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
} }
return cert, nil return cert, originCertPath, nil
} }
func buildRunCommand() *cli.Command { func buildRunCommand() *cli.Command {
@ -406,3 +397,40 @@ func runTunnel(c *cli.Context) error {
logger.Debugf("Read credentials for %v", credentials.AccountTag) logger.Debugf("Read credentials for %v", credentials.AccountTag)
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID}) return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
} }
func buildCleanupCommand() *cli.Command {
return &cli.Command{
Name: "cleanup",
Action: cliutil.ErrorHandler(cleanupConnections),
Usage: "Cleanup connections for the tunnel with given IDs",
ArgsUsage: "TUNNEL-IDS",
Hidden: hideSubcommands,
}
}
func cleanupConnections(c *cli.Context) error {
if c.NArg() < 1 {
return cliutil.UsageError(`"cloudflared tunnel cleanup" requires at least 1 argument, the IDs of the tunnels to cleanup connections.`)
}
logger, err := logger.New()
if err != nil {
return errors.Wrap(err, "error setting up logger")
}
cert, _, err := getOriginCertFromContext(c, logger)
if err != nil {
return err
}
client := newTunnelstoreClient(c, cert, logger)
for i := 0; i < c.NArg(); i++ {
id := c.Args().Get(i)
logger.Infof("Cleanup connection for tunnel %s", id)
if err := client.CleanupConnections(id); err != nil {
logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err)
}
}
return nil
}

View File

@ -41,9 +41,10 @@ type Connection struct {
type Client interface { type Client interface {
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
GetTunnel(id string) (*Tunnel, error) GetTunnel(tunnelID string) (*Tunnel, error)
DeleteTunnel(id string) error DeleteTunnel(tunnelID string) error
ListTunnels() ([]Tunnel, error) ListTunnels() ([]Tunnel, error)
CleanupConnections(tunnelID string) error
} }
type RESTClient struct { type RESTClient struct {
@ -104,11 +105,11 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
return nil, ErrTunnelNameConflict return nil, ErrTunnelNameConflict
} }
return nil, statusCodeToError("create", resp) return nil, statusCodeToError("create tunnel", resp)
} }
func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) { func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
resp, err := r.sendRequest("GET", id, nil) resp, err := r.sendRequest("GET", tunnelID, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "REST request failed") return nil, errors.Wrap(err, "REST request failed")
} }
@ -118,17 +119,17 @@ func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) {
return unmarshalTunnel(resp.Body) return unmarshalTunnel(resp.Body)
} }
return nil, statusCodeToError("read", resp) return nil, statusCodeToError("get tunnel", resp)
} }
func (r *RESTClient) DeleteTunnel(id string) error { func (r *RESTClient) DeleteTunnel(tunnelID string) error {
resp, err := r.sendRequest("DELETE", id, nil) resp, err := r.sendRequest("DELETE", tunnelID, nil)
if err != nil { if err != nil {
return errors.Wrap(err, "REST request failed") return errors.Wrap(err, "REST request failed")
} }
defer resp.Body.Close() defer resp.Body.Close()
return statusCodeToError("delete", resp) return statusCodeToError("delete tunnel", resp)
} }
func (r *RESTClient) ListTunnels() ([]Tunnel, error) { func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
@ -146,7 +147,17 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
return tunnels, nil return tunnels, nil
} }
return nil, statusCodeToError("list", resp) return nil, statusCodeToError("list tunnels", resp)
}
func (r *RESTClient) CleanupConnections(tunnelID string) error {
resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
return statusCodeToError("cleanup connections", resp)
} }
func (r *RESTClient) resolve(target string) string { func (r *RESTClient) resolve(target string) string {
@ -189,6 +200,6 @@ func statusCodeToError(op string, resp *http.Response) error {
case http.StatusNotFound: case http.StatusNotFound:
return ErrNotFound return ErrNotFound
} }
return errors.Errorf("API call to %s tunnel failed with status %d: %s", op, return errors.Errorf("API call to %s failed with status %d: %s", op,
resp.StatusCode, http.StatusText(resp.StatusCode)) resp.StatusCode, http.StatusText(resp.StatusCode))
} }