Add --pidfile support with signal-aware cleanup to cloudflared access tcp
Add --pidfile flag to the access tcp subcommand (also applies to rdp, ssh, and smb aliases). This writes the process ID to a file after startup and removes it on exit. Signal handling (SIGTERM/SIGINT) ensures the PID file is cleaned up when the process is killed, not just on graceful shutdown. Includes unit tests for writePidFile and removePidFile. Closes #723
This commit is contained in:
parent
a0bcbf6a44
commit
348c48c02d
|
|
@ -5,8 +5,12 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/mitchellh/go-homedir"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
@ -69,6 +73,26 @@ func ssh(c *cli.Context) error {
|
||||||
}
|
}
|
||||||
log := logger.CreateSSHLoggerFromContext(c, outputTerminal)
|
log := logger.CreateSSHLoggerFromContext(c, outputTerminal)
|
||||||
|
|
||||||
|
if c.IsSet(sshPidFileFlag) {
|
||||||
|
pidFile := c.String(sshPidFileFlag)
|
||||||
|
writePidFile(pidFile, log)
|
||||||
|
defer removePidFile(pidFile, log)
|
||||||
|
|
||||||
|
// Trap SIGTERM/SIGINT to clean up the PID file before exiting.
|
||||||
|
// Without this, signals kill the process before defers can run.
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
go func() {
|
||||||
|
<-sigCh
|
||||||
|
removePidFile(pidFile, log)
|
||||||
|
signal.Reset(syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
// Re-raise so the process exits with the default signal behavior
|
||||||
|
if p, err := os.FindProcess(os.Getpid()); err == nil {
|
||||||
|
_ = p.Signal(syscall.SIGTERM)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// get the hostname from the cmdline and error out if its not provided
|
// get the hostname from the cmdline and error out if its not provided
|
||||||
rawHostName := c.String(sshHostnameFlag)
|
rawHostName := c.String(sshHostnameFlag)
|
||||||
url, err := parseURL(rawHostName)
|
url, err := parseURL(rawHostName)
|
||||||
|
|
@ -145,3 +169,32 @@ func ssh(c *cli.Context) error {
|
||||||
}
|
}
|
||||||
return carrier.StartClient(wsConn, s, options)
|
return carrier.StartClient(wsConn, s, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writePidFile writes the current process ID to a given file path.
|
||||||
|
// It expands ~ in paths using go-homedir.
|
||||||
|
func writePidFile(pidPathname string, log *zerolog.Logger) {
|
||||||
|
expandedPath, err := homedir.Expand(pidPathname)
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Str("pidPath", pidPathname).Msg("Unable to expand the path, try to use absolute path in --pidfile")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, err := os.Create(expandedPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Str("pidPath", expandedPath).Msg("Unable to write pid")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
fmt.Fprintf(file, "%d", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
|
// removePidFile removes the PID file at the given path.
|
||||||
|
// Errors are logged but do not cause a failure.
|
||||||
|
func removePidFile(pidPathname string, log *zerolog.Logger) {
|
||||||
|
expandedPath, err := homedir.Expand(pidPathname)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := os.Remove(expandedPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
log.Err(err).Str("pidPath", expandedPath).Msg("Unable to remove pid file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
package access
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWritePidFile(t *testing.T) {
|
||||||
|
log := zerolog.Nop()
|
||||||
|
|
||||||
|
t.Run("writes current PID to file", func(t *testing.T) {
|
||||||
|
pidFile := filepath.Join(t.TempDir(), "test.pid")
|
||||||
|
|
||||||
|
writePidFile(pidFile, &log)
|
||||||
|
|
||||||
|
content, err := os.ReadFile(pidFile)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pid, err := strconv.Atoi(string(content))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, os.Getpid(), pid)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles invalid path gracefully", func(t *testing.T) {
|
||||||
|
// Should not panic on a path that can't be created
|
||||||
|
writePidFile("/nonexistent/directory/test.pid", &log)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemovePidFile(t *testing.T) {
|
||||||
|
log := zerolog.Nop()
|
||||||
|
|
||||||
|
t.Run("removes existing pid file", func(t *testing.T) {
|
||||||
|
pidFile := filepath.Join(t.TempDir(), "test.pid")
|
||||||
|
|
||||||
|
writePidFile(pidFile, &log)
|
||||||
|
assert.FileExists(t, pidFile)
|
||||||
|
|
||||||
|
removePidFile(pidFile, &log)
|
||||||
|
assert.NoFileExists(t, pidFile)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles missing file gracefully", func(t *testing.T) {
|
||||||
|
// Should not panic when removing a file that doesn't exist
|
||||||
|
removePidFile("/tmp/nonexistent-cloudflared-test.pid", &log)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -38,6 +38,7 @@ const (
|
||||||
sshGenCertFlag = "short-lived-cert"
|
sshGenCertFlag = "short-lived-cert"
|
||||||
sshConnectTo = "connect-to"
|
sshConnectTo = "connect-to"
|
||||||
sshDebugStream = "debug-stream"
|
sshDebugStream = "debug-stream"
|
||||||
|
sshPidFileFlag = "pidfile"
|
||||||
sshConfigTemplate = `
|
sshConfigTemplate = `
|
||||||
Add to your {{.Home}}/.ssh/config:
|
Add to your {{.Home}}/.ssh/config:
|
||||||
|
|
||||||
|
|
@ -204,6 +205,11 @@ func Commands() []*cli.Command {
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.",
|
Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.",
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: sshPidFileFlag,
|
||||||
|
Usage: "Write the application's PID to this file after startup.",
|
||||||
|
EnvVars: []string{"TUNNEL_ACCESS_PIDFILE"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue