diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index 9b879f40..965104c9 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -75,19 +75,6 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { return } - // Spawn shell under user - cmd := exec.Command(sshUser.Shell) - - ptyReq, winCh, isPty := session.Pty() - if !isPty { - if _, err := io.WriteString(session, "No PTY requested.\n"); err != nil { - s.logger.WithError(err).Error("No PTY requested: Failed to write to SSH session") - } - - s.CloseSession(session) - return - } - uidInt, err := stringToUint32(sshUser.Uid) if err != nil { s.logger.WithError(err).Error("Invalid user") @@ -101,39 +88,52 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { return } + // Spawn shell under user + var cmd *exec.Cmd + if session.RawCommand() != "" { + cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand()) + } else { + cmd = exec.Command(sshUser.Shell) + } // Supplementary groups are not explicitly specified. They seem to be inherited by default. - cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}} - cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) + cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true} + cmd.Env = append(cmd.Env, session.Environ()...) cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username)) cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir)) cmd.Dir = sshUser.HomeDir - psuedoTTY, err := pty.Start(cmd) - if err != nil { - s.logger.WithError(err).Error("Failed to start pty session") - s.CloseSession(session) - close(s.shutdownC) - return + + ptyReq, winCh, isPty := session.Pty() + var shellInput io.WriteCloser + var shellOutput io.ReadCloser + + if isPty { + cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) + shellInput, shellOutput, err = s.startPtySession(cmd, winCh) + if err != nil { + s.logger.WithError(err).Error("Failed to start pty session") + close(s.shutdownC) + return + } + } else { + shellInput, shellOutput, err = s.startNonPtySession(cmd) + if err != nil { + s.logger.WithError(err).Error("Failed to start non-pty session") + close(s.shutdownC) + return + } } - // Handle terminal window size changes + // Write incoming commands to shell go func() { - for win := range winCh { - if errNo := setWinsize(psuedoTTY, win.Width, win.Height); errNo != 0 { - s.logger.WithError(err).Error("Failed to set pty window size: ", err.Error()) - s.CloseSession(session) - close(s.shutdownC) - return - } - } - }() - - // Write incoming commands to PTY - go func() { - if _, err := io.Copy(psuedoTTY, session); err != nil { + if _, err := io.Copy(shellInput, session); err != nil { s.logger.WithError(err).Error("Failed to write incoming command to pty") } }() + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + scanner := bufio.NewScanner(pr) go func() { for scanner.Scan() { @@ -143,18 +143,10 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { // Write outgoing command output to both the command recorder, and remote user mw := io.MultiWriter(pw, session) - if _, err := io.Copy(mw, psuedoTTY); err != nil { + if _, err := io.Copy(mw, shellOutput); err != nil { s.logger.WithError(err).Error("Failed to write command output to user") } - if err := pw.Close(); err != nil { - s.logger.WithError(err).Error("Failed to close pipe writer") - } - - if err := pr.Close(); err != nil { - s.logger.WithError(err).Error("Failed to close pipe reader") - } - // Wait for all resources associated with cmd to be released // Returns error if shell exited with a non-zero status or received a signal if err := cmd.Wait(); err != nil { @@ -168,6 +160,42 @@ func (s *SSHServer) CloseSession(session ssh.Session) { } } +func (s *SSHServer) startNonPtySession(cmd *exec.Cmd) (io.WriteCloser, io.ReadCloser, error) { + in, err := cmd.StdinPipe() + if err != nil { + return nil, nil, err + } + out, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, err + } + cmd.Stderr = cmd.Stdout + if err = cmd.Start(); err != nil { + return nil, nil, err + } + return in, out, nil +} + +func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window) (io.WriteCloser, io.ReadCloser, error) { + tty, err := pty.Start(cmd) + if err != nil { + return nil, nil, err + } + + // Handle terminal window size changes + go func() { + for win := range winCh { + if errNo := setWinsize(tty, win.Width, win.Height); errNo != 0 { + s.logger.WithError(err).Error("Failed to set pty window size") + close(s.shutdownC) + return + } + } + }() + + return tty, tty, nil +} + // Sets PTY window size for terminal func setWinsize(f *os.File, w, h int) syscall.Errno { _, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),