AUTH-2067: Log commands correctly

This commit is contained in:
Michael Borkenstein 2019-09-18 14:11:12 -05:00
parent 2789d0cf36
commit 979e5be8ab
3 changed files with 50 additions and 33 deletions

View File

@ -386,7 +386,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return errors.Wrap(err, msg)
}
if err := os.Mkdir(sshLogFileDirectory, 0600); err != nil {
if err := os.MkdirAll(sshLogFileDirectory, 0600); err != nil {
msg := fmt.Sprintf("Cannot create SSH log file directory %s", sshLogFileDirectory)
logger.WithError(err).Errorf(msg)
return errors.Wrap(err, msg)

View File

@ -30,5 +30,5 @@ func (m *manager) NewLogger(name string, logger *logrus.Logger) (io.WriteCloser,
}
func (m *manager) NewSessionLogger(name string, logger *logrus.Logger) (io.WriteCloser, error) {
return NewSessionLogger(name, logger, time.Second, defaultFileSizeLimit)
return NewSessionLogger(filepath.Join(m.baseDirectory, name), logger, time.Second, defaultFileSizeLimit)
}

View File

@ -33,6 +33,8 @@ const (
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
sshContextSessionID = "sessionID"
sshContextEventLogger = "eventLogger"
)
type auditEvent struct {
@ -144,14 +146,42 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
return
}
sshContext, ok := session.Context().(ssh.Context)
if !ok {
s.logger.Error("Could not retrieve session context")
s.errorAndExit(session, "", nil)
}
sshContext.SetValue(sshContextSessionID, sessionID)
sshContext.SetValue(sshContextEventLogger, eventLogger)
// Get uid and gid of user attempting to login
sshUser, uidInt, gidInt, success := s.getSSHUser(session, sessionID, eventLogger)
sshUser, uidInt, gidInt, success := s.getSSHUser(session, eventLogger)
if !success {
return
}
// Spawn shell under user
cmd := s.spawnCmd(session, sshUser, uidInt, gidInt)
var cmd *exec.Cmd
if session.RawCommand() != "" {
cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand())
event := auditEventExec
if strings.HasPrefix(session.RawCommand(), "scp") {
event = auditEventScp
}
s.logAuditEvent(session, event)
} else {
cmd = exec.Command(sshUser.Shell)
s.logAuditEvent(session, auditEventStart)
defer s.logAuditEvent(session, auditEventStop)
}
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true}
cmd.Env = append(cmd.Env, session.Environ()...)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
cmd.Dir = sshUser.HomeDir
var shellInput io.WriteCloser
var shellOutput io.ReadCloser
@ -163,7 +193,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
tty, err := s.startPtySession(cmd, winCh, func() {
s.logAuditEvent(eventLogger, session, sessionID, auditEventResize)
s.logAuditEvent(session, auditEventResize)
})
shellInput = tty
shellOutput = tty
@ -172,8 +202,6 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
close(s.shutdownC)
return
}
s.logAuditEvent(eventLogger, session, sessionID, auditEventStart)
defer s.logAuditEvent(eventLogger, session, sessionID, auditEventStop)
} else {
var shellError io.ReadCloser
shellInput, shellOutput, shellError, err = s.startNonPtySession(cmd)
@ -182,11 +210,6 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
close(s.shutdownC)
return
}
event := auditEventExec
if strings.HasPrefix(session.RawCommand(), "scp") {
event = auditEventScp
}
s.logAuditEvent(eventLogger, session, sessionID, event)
// Write stderr to both the command recorder, and remote user
go func() {
@ -205,8 +228,8 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
s.errorAndExit(session, "", nil)
return
}
defer sessionLogger.Close()
go func() {
defer sessionLogger.Close()
defer pr.Close()
_, err := io.Copy(sessionLogger, pr)
if err != nil {
@ -244,32 +267,15 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
}
}
// spawnCmd spawns a shell under the user
func (s *SSHServer) spawnCmd(session ssh.Session, sshUser *User, uidInt, gidInt uint32) *exec.Cmd {
var cmd *exec.Cmd
if session.RawCommand() != "" {
cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand())
} else {
cmd = exec.Command(sshUser.Shell)
}
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true}
cmd.Env = append(cmd.Env, session.Environ()...)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
cmd.Dir = sshUser.HomeDir
return cmd
}
// getSSHUser gets the ssh user, uid, and gid of the user attempting to login
func (s *SSHServer) getSSHUser(session ssh.Session, sessionID string, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) {
func (s *SSHServer) getSSHUser(session ssh.Session, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) {
// Get uid and gid of user attempting to login
sshUser, ok := session.Context().Value("sshUser").(*User)
if !ok || sshUser == nil {
s.errorAndExit(session, "Error retrieving credentials from session", nil)
return nil, 0, 0, false
}
s.logAuditEvent(eventLogger, session, sessionID, auditEventAuth)
s.logAuditEvent(session, auditEventAuth)
uidInt, err := stringToUint32(sshUser.Uid)
if err != nil {
@ -337,13 +343,24 @@ func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logC
return tty, nil
}
func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, sessionID string, eventType string) {
func (s *SSHServer) logAuditEvent(session ssh.Session, eventType string) {
username := "unknown"
sshUser, ok := session.Context().Value("sshUser").(*User)
if ok && sshUser != nil {
username = sshUser.Username
}
sessionID, ok := session.Context().Value(sshContextSessionID).(string)
if !ok {
s.logger.Error("Failed to retrieve sessionID from context")
return
}
writer, ok := session.Context().Value(sshContextEventLogger).(io.WriteCloser)
if !ok {
s.logger.Error("Failed to retrieve eventLogger from context")
return
}
event := auditEvent{
Event: session.RawCommand(),
EventType: eventType,