AUTH-2056: Writes stderr to its own stream for non-pty connections

This commit is contained in:
Michael Borkenstein 2019-09-10 18:50:04 -05:00
parent 40d9370bb6
commit ff795a7beb
4 changed files with 72 additions and 41 deletions

2
.gitignore vendored
View File

@ -9,5 +9,7 @@ guide/public
\#*\#
cscope.*
cloudflared
cloudflared.exe
!cmd/cloudflared/
.DS_Store
*-session.log

View File

@ -366,8 +366,9 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
if c.IsSet("ssh-server") {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
logger.Errorf("--ssh-server is not supported on %s", runtime.GOOS)
return errors.New(fmt.Sprintf("--ssh-server is not supported on %s", runtime.GOOS))
msg := fmt.Sprintf("--ssh-server is not supported on %s", runtime.GOOS)
logger.Error(msg)
return errors.New(msg)
}
@ -378,11 +379,17 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
uploader, err := awsuploader.NewFileUploader(c.String(bucketNameFlag), c.String(regionNameFlag),
c.String(accessKeyIDFlag), c.String(secretIDFlag), c.String(sessionTokenIDFlag), c.String(s3URLFlag))
if err != nil {
logger.WithError(err).Error("Cannot create uploader for SSH Server")
return errors.Wrap(err, "Cannot create uploader for SSH Server")
msg := "Cannot create uploader for SSH Server"
logger.WithError(err).Error(msg)
return errors.Wrap(err, msg)
}
if err := os.Mkdir(sshLogFileDirectory, 0600); err != nil {
msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory)
logger.WithError(err).Errorf(msg)
return errors.Wrap(err, msg)
}
os.Mkdir(sshLogFileDirectory, 0600)
logManager = sshlog.New(sshLogFileDirectory)
uploadManager := awsuploader.NewDirectoryUploadManager(logger, uploader, sshLogFileDirectory, 30*time.Minute, shutdownC)
@ -392,8 +399,9 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag), c.Bool(enablePortForwardingFlag))
if err != nil {
logger.WithError(err).Error("Cannot create new SSH Server")
return errors.Wrap(err, "Cannot create new SSH Server")
msg := "Cannot create new SSH Server"
logger.WithError(err).Error(msg)
return errors.Wrap(err, msg)
}
wg.Add(1)
go func() {

View File

@ -29,7 +29,7 @@ func (m *emptyManager) NewSessionLogger(name string, logger *logrus.Logger) (io.
// emptyWriteCloser
func (w *emptyWriteCloser) Write(p []byte) (n int, err error) {
return 0, nil
return len(p), nil
}
func (w *emptyWriteCloser) Close() error {

View File

@ -33,7 +33,6 @@ const (
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
auditEventTamper = "tamper"
)
type auditEvent struct {
@ -156,14 +155,18 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
var shellInput io.WriteCloser
var shellOutput io.ReadCloser
pr, pw := io.Pipe()
defer pw.Close()
ptyReq, winCh, isPty := session.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
shellInput, shellOutput, err = s.startPtySession(cmd, winCh, func() {
tty, err := s.startPtySession(cmd, winCh, func() {
s.logAuditEvent(eventLogger, session, sessionID, auditEventResize)
})
shellInput = tty
shellOutput = tty
if err != nil {
s.logger.WithError(err).Error("Failed to start pty session")
close(s.shutdownC)
@ -172,7 +175,8 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
s.logAuditEvent(eventLogger, session, sessionID, auditEventStart)
defer s.logAuditEvent(eventLogger, session, sessionID, auditEventStop)
} else {
shellInput, shellOutput, err = s.startNonPtySession(cmd)
var shellError io.ReadCloser
shellInput, shellOutput, shellError, err = s.startNonPtySession(cmd)
if err != nil {
s.logger.WithError(err).Error("Failed to start non-pty session")
close(s.shutdownC)
@ -183,18 +187,15 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
event = auditEventScp
}
s.logAuditEvent(eventLogger, session, sessionID, event)
}
// Write incoming commands to shell
// Write stderr to both the command recorder, and remote user
go func() {
if _, err := io.Copy(shellInput, session); err != nil {
s.logger.WithError(err).Error("Failed to write incoming command to pty")
mw := io.MultiWriter(pw, session.Stderr())
if _, err := io.Copy(mw, shellError); err != nil {
s.logger.WithError(err).Error("Failed to write stderr to user")
}
}()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
}
sessionLogger, err := s.logManager.NewSessionLogger(fmt.Sprintf("%s-session.log", sessionID), s.logger)
if err != nil {
@ -206,13 +207,25 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
}
defer sessionLogger.Close()
go func() {
io.Copy(sessionLogger, pr)
defer pr.Close()
_, err := io.Copy(sessionLogger, pr)
if err != nil {
s.logger.WithError(err).Error("Failed to write session log")
}
}()
// Write outgoing command output to both the command recorder, and remote user
// Write stdin to shell
go func() {
defer shellInput.Close()
if _, err := io.Copy(shellInput, session); err != nil {
s.logger.WithError(err).Error("Failed to write incoming command to pty")
}
}()
// Write stdout to both the command recorder, and remote user
mw := io.MultiWriter(pw, session)
if _, err := io.Copy(mw, shellOutput); err != nil {
s.logger.WithError(err).Error("Failed to write command output to user")
s.logger.WithError(err).Error("Failed to write stdout to user")
}
// Wait for all resources associated with cmd to be released
@ -264,8 +277,8 @@ func (s *SSHServer) getSSHUser(session ssh.Session, sessionID string, eventLogge
// errorAndExit reports an error with the session and exits
func (s *SSHServer) errorAndExit(session ssh.Session, errText string, err error) {
if err := session.Exit(1); err != nil {
s.logger.WithError(err).Error("Failed to close SSH session")
if exitError := session.Exit(1); exitError != nil {
s.logger.WithError(exitError).Error("Failed to close SSH session")
} else if err != nil {
s.logger.WithError(err).Error(errText)
} else if errText != "" {
@ -273,26 +286,31 @@ func (s *SSHServer) errorAndExit(session ssh.Session, errText string, err error)
}
}
func (s *SSHServer) startNonPtySession(cmd *exec.Cmd) (io.WriteCloser, io.ReadCloser, error) {
in, err := cmd.StdinPipe()
func (s *SSHServer) startNonPtySession(cmd *exec.Cmd) (stdin io.WriteCloser, stdout io.ReadCloser, stderr io.ReadCloser, err error) {
stdin, err = cmd.StdinPipe()
if err != nil {
return nil, nil, err
return
}
out, err := cmd.StdoutPipe()
stdout, err = cmd.StdoutPipe()
if err != nil {
return nil, nil, err
return
}
cmd.Stderr = cmd.Stdout
stderr, err = cmd.StderrPipe()
if err != nil {
return
}
if err = cmd.Start(); err != nil {
return nil, nil, err
return
}
return in, out, nil
return
}
func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logCallback func()) (io.WriteCloser, io.ReadCloser, error) {
func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logCallback func()) (io.ReadWriteCloser, error) {
tty, err := pty.Start(cmd)
if err != nil {
return nil, nil, err
return nil, err
}
// Handle terminal window size changes
@ -307,7 +325,7 @@ func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logC
}
}()
return tty, tty, nil
return tty, nil
}
func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, sessionID string, eventType string) {
@ -328,11 +346,14 @@ func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, se
}
data, err := json.Marshal(&event)
if err != nil {
s.logger.WithError(err).Error("Failed to log audit event. malformed audit object")
s.logger.WithError(err).Error("Failed to marshal audit event. malformed audit object")
return
}
line := string(data) + "\n"
writer.Write([]byte(line))
if _, err := writer.Write([]byte(line)); err != nil {
s.logger.WithError(err).Error("Failed to write audit event.")
}
}
// Sets PTY window size for terminal