From 348c48c02d0c01503bf7cf22a60bcf906bef60ce Mon Sep 17 00:00:00 2001 From: Travis Crumb Date: Fri, 13 Feb 2026 13:35:07 -0500 Subject: [PATCH] Add --pidfile support with signal-aware cleanup to cloudflared access tcp Add --pidfile flag to the access tcp subcommand (also applies to rdp, ssh, and smb aliases). This writes the process ID to a file after startup and removes it on exit. Signal handling (SIGTERM/SIGINT) ensures the PID file is cleaned up when the process is killed, not just on graceful shutdown. Includes unit tests for writePidFile and removePidFile. Closes #723 --- cmd/cloudflared/access/carrier.go | 53 ++++++++++++++++++++++++++ cmd/cloudflared/access/carrier_test.go | 53 ++++++++++++++++++++++++++ cmd/cloudflared/access/cmd.go | 6 +++ 3 files changed, 112 insertions(+) create mode 100644 cmd/cloudflared/access/carrier_test.go diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index 52aacc56..b3e2a768 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -5,8 +5,12 @@ import ( "fmt" "io" "net/http" + "os" + "os/signal" "strings" + "syscall" + "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" @@ -69,6 +73,26 @@ func ssh(c *cli.Context) error { } log := logger.CreateSSHLoggerFromContext(c, outputTerminal) + if c.IsSet(sshPidFileFlag) { + pidFile := c.String(sshPidFileFlag) + writePidFile(pidFile, log) + defer removePidFile(pidFile, log) + + // Trap SIGTERM/SIGINT to clean up the PID file before exiting. + // Without this, signals kill the process before defers can run. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + go func() { + <-sigCh + removePidFile(pidFile, log) + signal.Reset(syscall.SIGTERM, syscall.SIGINT) + // Re-raise so the process exits with the default signal behavior + if p, err := os.FindProcess(os.Getpid()); err == nil { + _ = p.Signal(syscall.SIGTERM) + } + }() + } + // get the hostname from the cmdline and error out if its not provided rawHostName := c.String(sshHostnameFlag) url, err := parseURL(rawHostName) @@ -145,3 +169,32 @@ func ssh(c *cli.Context) error { } return carrier.StartClient(wsConn, s, options) } + +// writePidFile writes the current process ID to a given file path. +// It expands ~ in paths using go-homedir. +func writePidFile(pidPathname string, log *zerolog.Logger) { + expandedPath, err := homedir.Expand(pidPathname) + if err != nil { + log.Err(err).Str("pidPath", pidPathname).Msg("Unable to expand the path, try to use absolute path in --pidfile") + return + } + file, err := os.Create(expandedPath) + if err != nil { + log.Err(err).Str("pidPath", expandedPath).Msg("Unable to write pid") + return + } + defer file.Close() + fmt.Fprintf(file, "%d", os.Getpid()) +} + +// removePidFile removes the PID file at the given path. +// Errors are logged but do not cause a failure. +func removePidFile(pidPathname string, log *zerolog.Logger) { + expandedPath, err := homedir.Expand(pidPathname) + if err != nil { + return + } + if err := os.Remove(expandedPath); err != nil && !os.IsNotExist(err) { + log.Err(err).Str("pidPath", expandedPath).Msg("Unable to remove pid file") + } +} diff --git a/cmd/cloudflared/access/carrier_test.go b/cmd/cloudflared/access/carrier_test.go new file mode 100644 index 00000000..f7176a62 --- /dev/null +++ b/cmd/cloudflared/access/carrier_test.go @@ -0,0 +1,53 @@ +package access + +import ( + "os" + "path/filepath" + "strconv" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWritePidFile(t *testing.T) { + log := zerolog.Nop() + + t.Run("writes current PID to file", func(t *testing.T) { + pidFile := filepath.Join(t.TempDir(), "test.pid") + + writePidFile(pidFile, &log) + + content, err := os.ReadFile(pidFile) + require.NoError(t, err) + + pid, err := strconv.Atoi(string(content)) + require.NoError(t, err) + assert.Equal(t, os.Getpid(), pid) + }) + + t.Run("handles invalid path gracefully", func(t *testing.T) { + // Should not panic on a path that can't be created + writePidFile("/nonexistent/directory/test.pid", &log) + }) +} + +func TestRemovePidFile(t *testing.T) { + log := zerolog.Nop() + + t.Run("removes existing pid file", func(t *testing.T) { + pidFile := filepath.Join(t.TempDir(), "test.pid") + + writePidFile(pidFile, &log) + assert.FileExists(t, pidFile) + + removePidFile(pidFile, &log) + assert.NoFileExists(t, pidFile) + }) + + t.Run("handles missing file gracefully", func(t *testing.T) { + // Should not panic when removing a file that doesn't exist + removePidFile("/tmp/nonexistent-cloudflared-test.pid", &log) + }) +} diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 636b9288..7ad87c03 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -38,6 +38,7 @@ const ( sshGenCertFlag = "short-lived-cert" sshConnectTo = "connect-to" sshDebugStream = "debug-stream" + sshPidFileFlag = "pidfile" sshConfigTemplate = ` Add to your {{.Home}}/.ssh/config: @@ -204,6 +205,11 @@ func Commands() []*cli.Command { Hidden: true, Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.", }, + &cli.StringFlag{ + Name: sshPidFileFlag, + Usage: "Write the application's PID to this file after startup.", + EnvVars: []string{"TUNNEL_ACCESS_PIDFILE"}, + }, }, }, {