diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 347e243d..b05b0c80 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -173,6 +173,7 @@ func Commands() []*cli.Command { subcommands = append(subcommands, buildListCommand()) subcommands = append(subcommands, buildDeleteCommand()) subcommands = append(subcommands, buildRunCommand()) + subcommands = append(subcommands, buildCleanupCommand()) cmds = append(cmds, &cli.Command{ Name: "tunnel", diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 1a85f177..51bd54d9 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -92,11 +92,7 @@ func createTunnel(c *cli.Context) error { return err } - originCertPath, err := findOriginCert(c, logger) - if err != nil { - return errors.Wrap(err, "Error locating origin cert") - } - cert, err := getOriginCertFromContext(originCertPath, logger) + cert, originCertPath, err := getOriginCertFromContext(c, logger) if err != nil { return err } @@ -223,11 +219,7 @@ func listTunnels(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - originCertPath, err := findOriginCert(c, logger) - if err != nil { - return errors.Wrap(err, "Error locating origin cert") - } - cert, err := getOriginCertFromContext(originCertPath, logger) + cert, _, err := getOriginCertFromContext(c, logger) if err != nil { return err } @@ -310,11 +302,7 @@ func deleteTunnel(c *cli.Context) error { return errors.Wrap(err, "error setting up logger") } - originCertPath, err := findOriginCert(c, logger) - if err != nil { - return errors.Wrap(err, "Error locating origin cert") - } - cert, err := getOriginCertFromContext(originCertPath, logger) + cert, _, err := getOriginCertFromContext(c, logger) if err != nil { return err } @@ -355,22 +343,25 @@ func newTunnelstoreClient(c *cli.Context, cert *certutil.OriginCert, logger logg return client } -func getOriginCertFromContext(originCertPath string, logger logger.Service) (*certutil.OriginCert, error) { - +func getOriginCertFromContext(c *cli.Context, logger logger.Service) (cert *certutil.OriginCert, originCertPath string, err error) { + originCertPath, err = findOriginCert(c, logger) + if err != nil { + return nil, "", errors.Wrap(err, "Error locating origin cert") + } blocks, err := readOriginCert(originCertPath, logger) if err != nil { - return nil, errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) + return nil, "", errors.Wrapf(err, "Can't read origin cert from %s", originCertPath) } - cert, err := certutil.DecodeOriginCert(blocks) + cert, err = certutil.DecodeOriginCert(blocks) if err != nil { - return nil, errors.Wrap(err, "Error decoding origin cert") + return nil, "", errors.Wrap(err, "Error decoding origin cert") } if cert.AccountID == "" { - return nil, errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) + return nil, "", errors.Errorf(`Origin certificate needs to be refreshed before creating new tunnels.\nDelete %s and run "cloudflared login" to obtain a new cert.`, originCertPath) } - return cert, nil + return cert, originCertPath, nil } func buildRunCommand() *cli.Command { @@ -406,3 +397,40 @@ func runTunnel(c *cli.Context) error { logger.Debugf("Read credentials for %v", credentials.AccountTag) return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID}) } + +func buildCleanupCommand() *cli.Command { + return &cli.Command{ + Name: "cleanup", + Action: cliutil.ErrorHandler(cleanupConnections), + Usage: "Cleanup connections for the tunnel with given IDs", + ArgsUsage: "TUNNEL-IDS", + Hidden: hideSubcommands, + } +} + +func cleanupConnections(c *cli.Context) error { + if c.NArg() < 1 { + return cliutil.UsageError(`"cloudflared tunnel cleanup" requires at least 1 argument, the IDs of the tunnels to cleanup connections.`) + } + + logger, err := logger.New() + if err != nil { + return errors.Wrap(err, "error setting up logger") + } + + cert, _, err := getOriginCertFromContext(c, logger) + if err != nil { + return err + } + client := newTunnelstoreClient(c, cert, logger) + + for i := 0; i < c.NArg(); i++ { + id := c.Args().Get(i) + logger.Infof("Cleanup connection for tunnel %s", id) + if err := client.CleanupConnections(id); err != nil { + logger.Errorf("Error cleaning up connections for tunnel %s, error :%v", id, err) + } + } + + return nil +} diff --git a/tunnelstore/client.go b/tunnelstore/client.go index 3f772852..56077699 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -41,9 +41,10 @@ type Connection struct { type Client interface { CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) - GetTunnel(id string) (*Tunnel, error) - DeleteTunnel(id string) error + GetTunnel(tunnelID string) (*Tunnel, error) + DeleteTunnel(tunnelID string) error ListTunnels() ([]Tunnel, error) + CleanupConnections(tunnelID string) error } type RESTClient struct { @@ -104,11 +105,11 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er return nil, ErrTunnelNameConflict } - return nil, statusCodeToError("create", resp) + return nil, statusCodeToError("create tunnel", resp) } -func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) { - resp, err := r.sendRequest("GET", id, nil) +func (r *RESTClient) GetTunnel(tunnelID string) (*Tunnel, error) { + resp, err := r.sendRequest("GET", tunnelID, nil) if err != nil { return nil, errors.Wrap(err, "REST request failed") } @@ -118,17 +119,17 @@ func (r *RESTClient) GetTunnel(id string) (*Tunnel, error) { return unmarshalTunnel(resp.Body) } - return nil, statusCodeToError("read", resp) + return nil, statusCodeToError("get tunnel", resp) } -func (r *RESTClient) DeleteTunnel(id string) error { - resp, err := r.sendRequest("DELETE", id, nil) +func (r *RESTClient) DeleteTunnel(tunnelID string) error { + resp, err := r.sendRequest("DELETE", tunnelID, nil) if err != nil { return errors.Wrap(err, "REST request failed") } defer resp.Body.Close() - return statusCodeToError("delete", resp) + return statusCodeToError("delete tunnel", resp) } func (r *RESTClient) ListTunnels() ([]Tunnel, error) { @@ -146,7 +147,17 @@ func (r *RESTClient) ListTunnels() ([]Tunnel, error) { return tunnels, nil } - return nil, statusCodeToError("list", resp) + return nil, statusCodeToError("list tunnels", resp) +} + +func (r *RESTClient) CleanupConnections(tunnelID string) error { + resp, err := r.sendRequest("DELETE", fmt.Sprintf("%s/connections", tunnelID), nil) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + return statusCodeToError("cleanup connections", resp) } func (r *RESTClient) resolve(target string) string { @@ -189,6 +200,6 @@ func statusCodeToError(op string, resp *http.Response) error { case http.StatusNotFound: return ErrNotFound } - return errors.Errorf("API call to %s tunnel failed with status %d: %s", op, + return errors.Errorf("API call to %s failed with status %d: %s", op, resp.StatusCode, http.StatusText(resp.StatusCode)) }