diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 157ed814..e0573d01 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -224,7 +224,7 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error { return fmt.Errorf("Tunnel %s has already been deleted", tunnel.ID) } if forceFlagSet { - if err := client.CleanupConnections(tunnel.ID); err != nil { + if err := client.CleanupConnections(tunnel.ID, tunnelstore.NewCleanupParams()); err != nil { return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnel.ID) } } @@ -276,13 +276,24 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { } func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error { + params := tunnelstore.NewCleanupParams() + extraLog := "" + if connector := sc.c.String("connector-id"); connector != "" { + connectorID, err := uuid.Parse(connector) + if err != nil { + return errors.Wrapf(err, "%s is not a valid client ID (must be a UUID)", connector) + } + params.ForClient(connectorID) + extraLog = fmt.Sprintf(" for connector-id %s", connectorID.String()) + } + client, err := sc.client() if err != nil { return err } for _, tunnelID := range tunnelIDs { - sc.log.Info().Msgf("Cleanup connection for tunnel %s", tunnelID) - if err := client.CleanupConnections(tunnelID); err != nil { + sc.log.Info().Msgf("Cleanup connection for tunnel %s%s", tunnelID, extraLog) + if err := client.CleanupConnections(tunnelID, params); err != nil { sc.log.Error().Msgf("Error cleaning up connections for tunnel %v, error :%v", tunnelID, err) } } diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go index d2744366..e2c129bf 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_test.go +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -4,11 +4,12 @@ import ( "encoding/base64" "flag" "fmt" - "github.com/rs/zerolog" "reflect" "testing" "time" + "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/tunnelstore" "github.com/google/uuid" @@ -260,7 +261,7 @@ func (d *deleteMockTunnelStore) DeleteTunnel(tunnelID uuid.UUID) error { return nil } -func (d *deleteMockTunnelStore) CleanupConnections(tunnelID uuid.UUID) error { +func (d *deleteMockTunnelStore) CleanupConnections(tunnelID uuid.UUID, _ *tunnelstore.CleanupParams) error { tunnel, ok := d.mockTunnels[tunnelID] if !ok { return fmt.Errorf("Couldn't find tunnel: %v", tunnelID) diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 9c51cf9d..73b5dc32 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -126,6 +126,12 @@ var ( Usage: "Inverts the sort order of the tunnel info.", EnvVars: []string{"TUNNEL_INFO_INVERT_SORT"}, } + cleanupClientFlag = &cli.StringFlag{ + Name: "connector-id", + Aliases: []string{"c"}, + Usage: `Constraints the cleanup to stop the connections of a single Connector (by its ID). You can find the various Connectors (and their IDs) currently connected to your tunnel via 'cloudflared tunnel info '.`, + EnvVars: []string{"TUNNEL_CLEANUP_CONNECTOR"}, + } ) func buildCreateCommand() *cli.Command { @@ -600,6 +606,7 @@ func buildCleanupCommand() *cli.Command { Usage: "Cleanup tunnel connections", UsageText: "cloudflared tunnel [tunnel command options] cleanup [subcommand options] TUNNEL", Description: "Delete connections for tunnels with the given UUIDs or names.", + Flags: []cli.Flag{cleanupClientFlag}, CustomHelpTemplate: commandHelpTemplate(), } } diff --git a/tunnelstore/cleanup_params.go b/tunnelstore/cleanup_params.go new file mode 100644 index 00000000..4b24ba35 --- /dev/null +++ b/tunnelstore/cleanup_params.go @@ -0,0 +1,25 @@ +package tunnelstore + +import ( + "net/url" + + "github.com/google/uuid" +) + +type CleanupParams struct { + queryParams url.Values +} + +func NewCleanupParams() *CleanupParams { + return &CleanupParams{ + queryParams: url.Values{}, + } +} + +func (cp *CleanupParams) ForClient(clientID uuid.UUID) { + cp.queryParams.Set("client_id", clientID.String()) +} + +func (cp CleanupParams) encode() string { + return cp.queryParams.Encode() +} diff --git a/tunnelstore/client.go b/tunnelstore/client.go index dfed022c..db8330fe 100644 --- a/tunnelstore/client.go +++ b/tunnelstore/client.go @@ -204,7 +204,7 @@ type Client interface { DeleteTunnel(tunnelID uuid.UUID) error ListTunnels(filter *Filter) ([]*Tunnel, error) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) - CleanupConnections(tunnelID uuid.UUID) error + CleanupConnections(tunnelID uuid.UUID, params *CleanupParams) error RouteTunnel(tunnelID uuid.UUID, route Route) (RouteResult, error) // Teamnet endpoints @@ -370,8 +370,9 @@ func parseConnectionsDetails(reader io.Reader) ([]*ActiveClient, error) { return clients, err } -func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error { +func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID, params *CleanupParams) error { endpoint := r.baseEndpoints.accountLevel + endpoint.RawQuery = params.encode() endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID)) resp, err := r.sendRequest("DELETE", endpoint, nil) if err != nil {