From 979e5be8ab4756ca54e653d239f195b38cef1375 Mon Sep 17 00:00:00 2001 From: Michael Borkenstein Date: Wed, 18 Sep 2019 14:11:12 -0500 Subject: [PATCH] AUTH-2067: Log commands correctly --- cmd/cloudflared/tunnel/cmd.go | 2 +- sshlog/manager.go | 2 +- sshserver/sshserver_unix.go | 79 +++++++++++++++++++++-------------- 3 files changed, 50 insertions(+), 33 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index a9830c54..7d79ad5e 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -386,7 +386,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan return errors.Wrap(err, msg) } - if err := os.Mkdir(sshLogFileDirectory, 0600); err != nil { + if err := os.MkdirAll(sshLogFileDirectory, 0600); err != nil { msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory) logger.WithError(err).Errorf(msg) return errors.Wrap(err, msg) diff --git a/sshlog/manager.go b/sshlog/manager.go index aa90fb1d..3178eb03 100644 --- a/sshlog/manager.go +++ b/sshlog/manager.go @@ -30,5 +30,5 @@ func (m *manager) NewLogger(name string, logger *logrus.Logger) (io.WriteCloser, } func (m *manager) NewSessionLogger(name string, logger *logrus.Logger) (io.WriteCloser, error) { - return NewSessionLogger(name, logger, time.Second, defaultFileSizeLimit) + return NewSessionLogger(filepath.Join(m.baseDirectory, name), logger, time.Second, defaultFileSizeLimit) } diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index c6b6cd4a..90eb8c19 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -33,6 +33,8 @@ const ( auditEventExec = "exec" auditEventScp = "scp" auditEventResize = "resize" + sshContextSessionID = "sessionID" + sshContextEventLogger = "eventLogger" ) type auditEvent struct { @@ -144,14 +146,42 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { return } + sshContext, ok := session.Context().(ssh.Context) + if !ok { + s.logger.Error("Could not retrieve session context") + s.errorAndExit(session, "", nil) + } + + sshContext.SetValue(sshContextSessionID, sessionID) + sshContext.SetValue(sshContextEventLogger, eventLogger) + // Get uid and gid of user attempting to login - sshUser, uidInt, gidInt, success := s.getSSHUser(session, sessionID, eventLogger) + sshUser, uidInt, gidInt, success := s.getSSHUser(session, eventLogger) if !success { return } // Spawn shell under user - cmd := s.spawnCmd(session, sshUser, uidInt, gidInt) + var cmd *exec.Cmd + if session.RawCommand() != "" { + cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand()) + + event := auditEventExec + if strings.HasPrefix(session.RawCommand(), "scp") { + event = auditEventScp + } + s.logAuditEvent(session, event) + } else { + cmd = exec.Command(sshUser.Shell) + s.logAuditEvent(session, auditEventStart) + defer s.logAuditEvent(session, auditEventStop) + } + // Supplementary groups are not explicitly specified. They seem to be inherited by default. + cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true} + cmd.Env = append(cmd.Env, session.Environ()...) + cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username)) + cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir)) + cmd.Dir = sshUser.HomeDir var shellInput io.WriteCloser var shellOutput io.ReadCloser @@ -163,7 +193,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { if isPty { cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) tty, err := s.startPtySession(cmd, winCh, func() { - s.logAuditEvent(eventLogger, session, sessionID, auditEventResize) + s.logAuditEvent(session, auditEventResize) }) shellInput = tty shellOutput = tty @@ -172,8 +202,6 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { close(s.shutdownC) return } - s.logAuditEvent(eventLogger, session, sessionID, auditEventStart) - defer s.logAuditEvent(eventLogger, session, sessionID, auditEventStop) } else { var shellError io.ReadCloser shellInput, shellOutput, shellError, err = s.startNonPtySession(cmd) @@ -182,11 +210,6 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { close(s.shutdownC) return } - event := auditEventExec - if strings.HasPrefix(session.RawCommand(), "scp") { - event = auditEventScp - } - s.logAuditEvent(eventLogger, session, sessionID, event) // Write stderr to both the command recorder, and remote user go func() { @@ -205,8 +228,8 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { s.errorAndExit(session, "", nil) return } - defer sessionLogger.Close() go func() { + defer sessionLogger.Close() defer pr.Close() _, err := io.Copy(sessionLogger, pr) if err != nil { @@ -244,32 +267,15 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { } } -// spawnCmd spawns a shell under the user -func (s *SSHServer) spawnCmd(session ssh.Session, sshUser *User, uidInt, gidInt uint32) *exec.Cmd { - var cmd *exec.Cmd - if session.RawCommand() != "" { - cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand()) - } else { - cmd = exec.Command(sshUser.Shell) - } - // Supplementary groups are not explicitly specified. They seem to be inherited by default. - cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true} - cmd.Env = append(cmd.Env, session.Environ()...) - cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username)) - cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir)) - cmd.Dir = sshUser.HomeDir - return cmd -} - // getSSHUser gets the ssh user, uid, and gid of the user attempting to login -func (s *SSHServer) getSSHUser(session ssh.Session, sessionID string, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) { +func (s *SSHServer) getSSHUser(session ssh.Session, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) { // Get uid and gid of user attempting to login sshUser, ok := session.Context().Value("sshUser").(*User) if !ok || sshUser == nil { s.errorAndExit(session, "Error retrieving credentials from session", nil) return nil, 0, 0, false } - s.logAuditEvent(eventLogger, session, sessionID, auditEventAuth) + s.logAuditEvent(session, auditEventAuth) uidInt, err := stringToUint32(sshUser.Uid) if err != nil { @@ -337,13 +343,24 @@ func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logC return tty, nil } -func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, sessionID string, eventType string) { +func (s *SSHServer) logAuditEvent(session ssh.Session, eventType string) { username := "unknown" sshUser, ok := session.Context().Value("sshUser").(*User) if ok && sshUser != nil { username = sshUser.Username } + sessionID, ok := session.Context().Value(sshContextSessionID).(string) + if !ok { + s.logger.Error("Failed to retrieve sessionID from context") + return + } + writer, ok := session.Context().Value(sshContextEventLogger).(io.WriteCloser) + if !ok { + s.logger.Error("Failed to retrieve eventLogger from context") + return + } + event := auditEvent{ Event: session.RawCommand(), EventType: eventType,