diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index ead0dedf..f328ba84 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -3,6 +3,7 @@ package access import ( "crypto/tls" "fmt" + "io" "net/http" "strings" @@ -13,6 +14,7 @@ import ( "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/validation" ) @@ -38,6 +40,7 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z if forwarder.TokenSecret != "" { headers.Set(cfAccessClientSecretHeader, forwarder.TokenSecret) } + headers.Set("User-Agent", userAgent) carrier.SetBastionDest(headers, forwarder.Destination) @@ -58,7 +61,12 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z // useful for proxying other protocols (like ssh) over websockets // (which you can put Access in front of) func ssh(c *cli.Context) error { - log := logger.CreateSSHLoggerFromContext(c, logger.EnableTerminalLog) + // If not running as a forwarder, disable terminal logs as it collides with the stdin/stdout of the parent process + outputTerminal := logger.DisableTerminalLog + if c.IsSet(sshURLFlag) { + outputTerminal = logger.EnableTerminalLog + } + log := logger.CreateSSHLoggerFromContext(c, outputTerminal) // get the hostname from the cmdline and error out if its not provided rawHostName := c.String(sshHostnameFlag) @@ -76,6 +84,7 @@ func ssh(c *cli.Context) error { if c.IsSet(sshTokenSecretFlag) { headers.Set(cfAccessClientSecretHeader, c.String(sshTokenSecretFlag)) } + headers.Set("User-Agent", userAgent) carrier.SetBastionDest(headers, c.String(sshDestinationFlag)) @@ -121,7 +130,19 @@ func ssh(c *cli.Context) error { return err } - return carrier.StartClient(wsConn, &carrier.StdinoutStream{}, options) + var s io.ReadWriter + s = &carrier.StdinoutStream{} + if c.IsSet(sshDebugStream) { + maxMessages := c.Uint64(sshDebugStream) + if maxMessages == 0 { + // default to 10 if provided but unset + maxMessages = 10 + } + logger := log.With().Str("host", hostname).Logger() + s = stream.NewDebugStream(s, &logger, maxMessages) + } + carrier.StartClient(wsConn, s, options) + return nil } func buildRequestHeaders(values []string) http.Header { diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 07b6dc2f..9d81e375 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -34,6 +34,7 @@ const ( sshTokenSecretFlag = "service-token-secret" sshGenCertFlag = "short-lived-cert" sshConnectTo = "connect-to" + sshDebugStream = "debug-stream" sshConfigTemplate = ` Add to your {{.Home}}/.ssh/config: @@ -151,9 +152,12 @@ func Commands() []*cli.Command { EnvVars: []string{"TUNNEL_SERVICE_TOKEN_SECRET"}, }, &cli.StringFlag{ - Name: logger.LogSSHDirectoryFlag, - Aliases: []string{"logfile"}, //added to match the tunnel side - Usage: "Save application log to this directory for reporting issues.", + Name: logger.LogFileFlag, + Usage: "Save application log to this file for reporting issues.", + }, + &cli.StringFlag{ + Name: logger.LogSSHDirectoryFlag, + Usage: "Save application log to this directory for reporting issues.", }, &cli.StringFlag{ Name: logger.LogSSHLevelFlag, @@ -165,6 +169,11 @@ func Commands() []*cli.Command { Hidden: true, Usage: "Connect to alternate location for testing, value is host, host:port, or sni:port:host", }, + &cli.Uint64Flag{ + Name: sshDebugStream, + Hidden: true, + Usage: "Writes up-to the max provided stream payloads to the logger as debug statements.", + }, }, }, { diff --git a/logger/create.go b/logger/create.go index d7a987e1..83048e59 100644 --- a/logger/create.go +++ b/logger/create.go @@ -175,7 +175,7 @@ func createFromContext( log := newZerolog(loggerConfig) if incompatibleFlagsSet := logFile != "" && logDirectory != ""; incompatibleFlagsSet { - log.Error().Msgf("Your config includes values for both %s and %s, but they are incompatible. %s takes precedence.", LogFileFlag, logDirectoryFlagName, LogFileFlag) + log.Error().Msgf("Your config includes values for both %s (%s) and %s (%s), but they are incompatible. %s takes precedence.", LogFileFlag, logFile, logDirectoryFlagName, logDirectory, LogFileFlag) } return log } diff --git a/stream/debug.go b/stream/debug.go new file mode 100644 index 00000000..0c175e4f --- /dev/null +++ b/stream/debug.go @@ -0,0 +1,64 @@ +package stream + +import ( + "io" + "sync/atomic" + + "github.com/rs/zerolog" +) + +// DebugStream will tee each read and write to the output logger as a debug message +type DebugStream struct { + reader io.Reader + writer io.Writer + log *zerolog.Logger + max uint64 + count atomic.Uint64 +} + +func NewDebugStream(stream io.ReadWriter, logger *zerolog.Logger, max uint64) *DebugStream { + return &DebugStream{ + reader: stream, + writer: stream, + log: logger, + max: max, + } +} + +func (d *DebugStream) Read(p []byte) (n int, err error) { + n, err = d.reader.Read(p) + if n > 0 && d.max > d.count.Load() { + d.count.Add(1) + if err != nil { + d.log.Err(err). + Str("dir", "r"). + Int("count", n). + Msgf("%+q", p[:n]) + } else { + d.log.Debug(). + Str("dir", "r"). + Int("count", n). + Msgf("%+q", p[:n]) + } + } + return +} + +func (d *DebugStream) Write(p []byte) (n int, err error) { + n, err = d.writer.Write(p) + if n > 0 && d.max > d.count.Load() { + d.count.Add(1) + if err != nil { + d.log.Err(err). + Str("dir", "w"). + Int("count", n). + Msgf("%+q", p[:n]) + } else { + d.log.Debug(). + Str("dir", "w"). + Int("count", n). + Msgf("%+q", p[:n]) + } + } + return +}