TUN-3008: Implement cloudflared tunnel cleanup command
This commit is contained in:
parent
87e06100df
commit
f5c8ff77e9
|
@ -173,6 +173,7 @@ func Commands() []*cli.Command {
|
||||||
subcommands = append(subcommands, buildListCommand())
|
subcommands = append(subcommands, buildListCommand())
|
||||||
subcommands = append(subcommands, buildDeleteCommand())
|
subcommands = append(subcommands, buildDeleteCommand())
|
||||||
subcommands = append(subcommands, buildRunCommand())
|
subcommands = append(subcommands, buildRunCommand())
|
||||||
|
subcommands = append(subcommands, buildCleanupCommand())
|
||||||
|
|
||||||
cmds = append(cmds, &cli.Command{
|
cmds = append(cmds, &cli.Command{
|
||||||
Name: "tunnel",
|
Name: "tunnel",
|
||||||
|
|
|
@ -92,11 +92,7 @@ func createTunnel(c *cli.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
originCertPath, err := findOriginCert(c, logger)
|
cert, originCertPath, err := getOriginCertFromContext(c, logger)
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Error locating origin cert")
|
|
||||||
}
|
|
||||||
cert, err := getOriginCertFromContext(originCertPath, logger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -223,11 +219,7 @@ func listTunnels(c *cli.Context) error {
|
||||||
return errors.Wrap(err, "error setting up logger")
|
return errors.Wrap(err, "error setting up logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
originCertPath, err := findOriginCert(c, logger)
|
cert, _, err := getOriginCertFromContext(c, logger)
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Error locating origin cert")
|
|
||||||
}
|
|
||||||
cert, err := getOriginCertFromContext(originCertPath, logger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -310,11 +302,7 @@ func deleteTunnel(c *cli.Context) error {
|
||||||
return errors.Wrap(err, "error setting up logger")
|
return errors.Wrap(err, "error setting up logger")
|
||||||
}
|
}
|
||||||
|
|
||||||
originCertPath, err := findOriginCert(c, logger)
|
cert, _, err := getOriginCertFromContext(c, logger)
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Error locating origin cert")
|
|
||||||
}
|
|
||||||
cert, err := getOriginCertFromContext(originCertPath, logger)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -355,22 +343,25 @@ func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logg
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) {
|
func getOriginCertFromContext(c *cli.Context, logger logger.Service) (cert *certutil.OriginCert, originCertPath string, err error) {
|
||||||
|
originCertPath, err = findOriginCert(c, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", errors.Wrap(err, "Error locating origin cert")
|
||||||
|
}
|
||||||
blocks, err := readOriginCert(originCertPath, logger)
|
blocks, err := readOriginCert(originCertPath, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
|
return nil, "", errors.Wrapf(err, "Can't read origin cert from %s", originCertPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := certutil.DecodeOriginCert(blocks)
|
cert, err = certutil.DecodeOriginCert(blocks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error decoding origin cert")
|
return nil, "", errors.Wrap(err, "Error decoding origin cert")
|
||||||
}
|
}
|
||||||
|
|
||||||
if cert.AccountID == "" {
|
if cert.AccountID == "" {
|
||||||
return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
|
return nil, "", errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath)
|
||||||
}
|
}
|
||||||
return cert, nil
|
return cert, originCertPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildRunCommand() *cli.Command {
|
func buildRunCommand() *cli.Command {
|
||||||
|
@ -406,3 +397,40 @@ func runTunnel(c *cli.Context) error {
|
||||||
logger.Debugf("Read credentials for %v", credentials.AccountTag)
|
logger.Debugf("Read credentials for %v", credentials.AccountTag)
|
||||||
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
|
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildCleanupCommand() *cli.Command {
|
||||||
|
return &cli.Command{
|
||||||
|
Name: "cleanup",
|
||||||
|
Action: cliutil.ErrorHandler(cleanupConnections),
|
||||||
|
Usage: "Cleanup connections for the tunnel with given IDs",
|
||||||
|
ArgsUsage: "TUNNEL-IDS",
|
||||||
|
Hidden: hideSubcommands,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupConnections(c *cli.Context) error {
|
||||||
|
if c.NArg() < 1 {
|
||||||
|
return cliutil.UsageError(`"cloudflared tunnel cleanup" requires at least 1 argument, the IDs of the tunnels to cleanup connections.`)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err := logger.New()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error setting up logger")
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, _, err := getOriginCertFromContext(c, logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
client := newTunnelstoreClient(c, cert, logger)
|
||||||
|
|
||||||
|
for i := 0; i < c.NArg(); i++ {
|
||||||
|
id := c.Args().Get(i)
|
||||||
|
logger.Infof("Cleanup connection for tunnel %s", id)
|
||||||
|
if err := client.CleanupConnections(id); err != nil {
|
||||||
|
logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -41,9 +41,10 @@ type Connection struct {
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
|
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
|
||||||
GetTunnel(id string) (*Tunnel, error)
|
GetTunnel(tunnelID string) (*Tunnel, error)
|
||||||
DeleteTunnel(id string) error
|
DeleteTunnel(tunnelID string) error
|
||||||
ListTunnels() ([]Tunnel, error)
|
ListTunnels() ([]Tunnel, error)
|
||||||
|
CleanupConnections(tunnelID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type RESTClient struct {
|
type RESTClient struct {
|
||||||
|
@ -104,11 +105,11 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
|
||||||
return nil, ErrTunnelNameConflict
|
return nil, ErrTunnelNameConflict
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, statusCodeToError("create", resp)
|
return nil, statusCodeToError("create tunnel", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) {
|
func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) {
|
||||||
resp, err := r.sendRequest("GET", id, nil)
|
resp, err := r.sendRequest("GET", tunnelID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "REST request failed")
|
return nil, errors.Wrap(err, "REST request failed")
|
||||||
}
|
}
|
||||||
|
@ -118,17 +119,17 @@ func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) {
|
||||||
return unmarshalTunnel(resp.Body)
|
return unmarshalTunnel(resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, statusCodeToError("read", resp)
|
return nil, statusCodeToError("get tunnel", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RESTClient) DeleteTunnel(id string) error {
|
func (r *RESTClient) DeleteTunnel(tunnelID string) error {
|
||||||
resp, err := r.sendRequest("DELETE", id, nil)
|
resp, err := r.sendRequest("DELETE", tunnelID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "REST request failed")
|
return errors.Wrap(err, "REST request failed")
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
return statusCodeToError("delete", resp)
|
return statusCodeToError("delete tunnel", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
|
func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
|
||||||
|
@ -146,7 +147,17 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) {
|
||||||
return tunnels, nil
|
return tunnels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, statusCodeToError("list", resp)
|
return nil, statusCodeToError("list tunnels", resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RESTClient) CleanupConnections(tunnelID string) error {
|
||||||
|
resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "REST request failed")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return statusCodeToError("cleanup connections", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RESTClient) resolve(target string) string {
|
func (r *RESTClient) resolve(target string) string {
|
||||||
|
@ -189,6 +200,6 @@ func statusCodeToError(op string, resp *http.Response) error {
|
||||||
case http.StatusNotFound:
|
case http.StatusNotFound:
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
return errors.Errorf("API call to %s tunnel failed with status %d: %s", op,
|
return errors.Errorf("API call to %s failed with status %d: %s", op,
|
||||||
resp.StatusCode, http.StatusText(resp.StatusCode))
|
resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue