diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index 52aacc56..97206960 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -5,8 +5,10 @@ import ( "fmt" "io" "net/http" + "os" "strings" + "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" @@ -69,6 +71,10 @@ func ssh(c *cli.Context) error { } log := logger.CreateSSHLoggerFromContext(c, outputTerminal) + if c.IsSet(sshPidFileFlag) { + writePidFile(c.String(sshPidFileFlag), log) + } + // get the hostname from the cmdline and error out if its not provided rawHostName := c.String(sshHostnameFlag) url, err := parseURL(rawHostName) @@ -145,3 +151,18 @@ func ssh(c *cli.Context) error { } return carrier.StartClient(wsConn, s, options) } + +func writePidFile(path string, log *zerolog.Logger) { + expandedPath, err := homedir.Expand(path) + if err != nil { + log.Err(err).Str("path", path).Msg("Unable to expand pidfile path") + return + } + file, err := os.Create(expandedPath) + if err != nil { + log.Err(err).Str("path", expandedPath).Msg("Unable to write pidfile") + return + } + defer file.Close() + fmt.Fprintf(file, "%d", os.Getpid()) +} diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 636b9288..07c66dbc 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -38,6 +38,7 @@ const ( sshGenCertFlag = "short-lived-cert" sshConnectTo = "connect-to" sshDebugStream = "debug-stream" + sshPidFileFlag = "pidfile" sshConfigTemplate = ` Add to your {{.Home}}/.ssh/config: @@ -204,6 +205,11 @@ func Commands() []*cli.Command { Hidden: true, Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.", }, + &cli.StringFlag{ + Name: sshPidFileFlag, + Usage: "Write the application's PID to this file", + EnvVars: []string{"TUNNEL_SERVICE_PIDFILE"}, + }, }, }, {