Add --pidfile support with signal-aware cleanup to cloudflared access tcp
Add --pidfile flag to the access tcp subcommand (also applies to rdp, ssh, and smb aliases). This writes the process ID to a file after startup and removes it on exit. Signal handling (SIGTERM/SIGINT) ensures the PID file is cleaned up when the process is killed, not just on graceful shutdown. Includes unit tests for writePidFile and removePidFile. Closes #723
This commit is contained in:
parent
a0bcbf6a44
commit
348c48c02d
|
|
@ -5,8 +5,12 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
|
@ -69,6 +73,26 @@ func ssh(c *cli.Context) error {
|
|||
}
|
||||
log := logger.CreateSSHLoggerFromContext(c, outputTerminal)
|
||||
|
||||
if c.IsSet(sshPidFileFlag) {
|
||||
pidFile := c.String(sshPidFileFlag)
|
||||
writePidFile(pidFile, log)
|
||||
defer removePidFile(pidFile, log)
|
||||
|
||||
// Trap SIGTERM/SIGINT to clean up the PID file before exiting.
|
||||
// Without this, signals kill the process before defers can run.
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||
go func() {
|
||||
<-sigCh
|
||||
removePidFile(pidFile, log)
|
||||
signal.Reset(syscall.SIGTERM, syscall.SIGINT)
|
||||
// Re-raise so the process exits with the default signal behavior
|
||||
if p, err := os.FindProcess(os.Getpid()); err == nil {
|
||||
_ = p.Signal(syscall.SIGTERM)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// get the hostname from the cmdline and error out if its not provided
|
||||
rawHostName := c.String(sshHostnameFlag)
|
||||
url, err := parseURL(rawHostName)
|
||||
|
|
@ -145,3 +169,32 @@ func ssh(c *cli.Context) error {
|
|||
}
|
||||
return carrier.StartClient(wsConn, s, options)
|
||||
}
|
||||
|
||||
// writePidFile writes the current process ID to a given file path.
|
||||
// It expands ~ in paths using go-homedir.
|
||||
func writePidFile(pidPathname string, log *zerolog.Logger) {
|
||||
expandedPath, err := homedir.Expand(pidPathname)
|
||||
if err != nil {
|
||||
log.Err(err).Str("pidPath", pidPathname).Msg("Unable to expand the path, try to use absolute path in --pidfile")
|
||||
return
|
||||
}
|
||||
file, err := os.Create(expandedPath)
|
||||
if err != nil {
|
||||
log.Err(err).Str("pidPath", expandedPath).Msg("Unable to write pid")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
fmt.Fprintf(file, "%d", os.Getpid())
|
||||
}
|
||||
|
||||
// removePidFile removes the PID file at the given path.
|
||||
// Errors are logged but do not cause a failure.
|
||||
func removePidFile(pidPathname string, log *zerolog.Logger) {
|
||||
expandedPath, err := homedir.Expand(pidPathname)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := os.Remove(expandedPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Err(err).Str("pidPath", expandedPath).Msg("Unable to remove pid file")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
package access
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWritePidFile(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
|
||||
t.Run("writes current PID to file", func(t *testing.T) {
|
||||
pidFile := filepath.Join(t.TempDir(), "test.pid")
|
||||
|
||||
writePidFile(pidFile, &log)
|
||||
|
||||
content, err := os.ReadFile(pidFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
pid, err := strconv.Atoi(string(content))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, os.Getpid(), pid)
|
||||
})
|
||||
|
||||
t.Run("handles invalid path gracefully", func(t *testing.T) {
|
||||
// Should not panic on a path that can't be created
|
||||
writePidFile("/nonexistent/directory/test.pid", &log)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemovePidFile(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
|
||||
t.Run("removes existing pid file", func(t *testing.T) {
|
||||
pidFile := filepath.Join(t.TempDir(), "test.pid")
|
||||
|
||||
writePidFile(pidFile, &log)
|
||||
assert.FileExists(t, pidFile)
|
||||
|
||||
removePidFile(pidFile, &log)
|
||||
assert.NoFileExists(t, pidFile)
|
||||
})
|
||||
|
||||
t.Run("handles missing file gracefully", func(t *testing.T) {
|
||||
// Should not panic when removing a file that doesn't exist
|
||||
removePidFile("/tmp/nonexistent-cloudflared-test.pid", &log)
|
||||
})
|
||||
}
|
||||
|
|
@ -38,6 +38,7 @@ const (
|
|||
sshGenCertFlag = "short-lived-cert"
|
||||
sshConnectTo = "connect-to"
|
||||
sshDebugStream = "debug-stream"
|
||||
sshPidFileFlag = "pidfile"
|
||||
sshConfigTemplate = `
|
||||
Add to your {{.Home}}/.ssh/config:
|
||||
|
||||
|
|
@ -204,6 +205,11 @@ func Commands() []*cli.Command {
|
|||
Hidden: true,
|
||||
Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: sshPidFileFlag,
|
||||
Usage: "Write the application's PID to this file after startup.",
|
||||
EnvVars: []string{"TUNNEL_ACCESS_PIDFILE"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue