cloudflared-mirror/sshserver/sshserver_unix.go

345 lines
9.9 KiB
Go

//+build !windows
package sshserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
const (
auditEventAuth = "auth"
auditEventStart = "session_start"
auditEventStop = "session_stop"
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
auditEventTamper = "tamper"
)
type auditEvent struct {
Event string `json:"event,omitempty"`
EventType string `json:"event_type,omitempty"`
SessionID string `json:"session_id,omitempty"`
User string `json:"user,omitempty"`
Login string `json:"login,omitempty"`
Datetime string `json:"datetime,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
}
// SSHServer adds on to the ssh.Server of the gliderlabs package
type SSHServer struct {
ssh.Server
logger *logrus.Logger
shutdownC chan struct{}
caCert ssh.PublicKey
logManager sshlog.Manager
}
// New creates a new SSHServer and configures its host keys and authenication by the data provided
func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration, enablePortForwarding bool) (*SSHServer, error) {
currentUser, err := user.Current()
if err != nil {
return nil, err
}
if currentUser.Uid != "0" {
return nil, errors.New("cloudflared SSH server needs to run as root")
}
forwardHandler := &ssh.ForwardedTCPHandler{}
sshServer := SSHServer{
Server: ssh.Server{
Addr: address,
MaxTimeout: maxTimeout,
IdleTimeout: idleTimeout,
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
// Register SSH global Request handlers to respond to tcpip forwarding
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
},
// Register SSH channel types
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": ssh.DirectTCPIPHandler,
},
},
logger: logger,
shutdownC: shutdownC,
logManager: logManager,
}
if enablePortForwarding {
sshServer.LocalPortForwardingCallback = allowForward
sshServer.ReversePortForwardingCallback = allowForward
}
if err := sshServer.configureHostKeys(); err != nil {
return nil, err
}
sshServer.configureAuthentication()
return &sshServer, nil
}
// Start the SSH server listener to start handling SSH connections from clients
func (s *SSHServer) Start() error {
s.logger.Infof("Starting SSH server at %s", s.Addr)
go func() {
<-s.shutdownC
if err := s.Close(); err != nil {
s.logger.WithError(err).Error("Cannot close SSH server")
}
}()
s.Handle(s.connectionHandler)
return s.ListenAndServe()
}
func (s *SSHServer) connectionHandler(session ssh.Session) {
sessionUUID, err := uuid.NewRandom()
if err != nil {
if _, err := io.WriteString(session, "Failed to generate session ID\n"); err != nil {
s.logger.WithError(err).Error("Failed to generate session ID: Failed to write to SSH session")
}
s.errorAndExit(session, "", nil)
return
}
sessionID := sessionUUID.String()
eventLogger, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger)
if err != nil {
if _, err := io.WriteString(session, "Failed to create event log\n"); err != nil {
s.logger.WithError(err).Error("Failed to create event log: Failed to write to create event logger")
}
s.errorAndExit(session, "", nil)
return
}
// Get uid and gid of user attempting to login
sshUser, uidInt, gidInt, success := s.getSSHUser(session, sessionID, eventLogger)
if !success {
return
}
// Spawn shell under user
cmd := s.spawnCmd(session, sshUser, uidInt, gidInt)
var shellInput io.WriteCloser
var shellOutput io.ReadCloser
ptyReq, winCh, isPty := session.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
shellInput, shellOutput, err = s.startPtySession(cmd, winCh, func() {
s.logAuditEvent(eventLogger, session, sessionID, auditEventResize)
})
if err != nil {
s.logger.WithError(err).Error("Failed to start pty session")
close(s.shutdownC)
return
}
s.logAuditEvent(eventLogger, session, sessionID, auditEventStart)
defer s.logAuditEvent(eventLogger, session, sessionID, auditEventStop)
} else {
shellInput, shellOutput, err = s.startNonPtySession(cmd)
if err != nil {
s.logger.WithError(err).Error("Failed to start non-pty session")
close(s.shutdownC)
return
}
event := auditEventExec
if strings.HasPrefix(session.RawCommand(), "scp") {
event = auditEventScp
}
s.logAuditEvent(eventLogger, session, sessionID, event)
}
// Write incoming commands to shell
go func() {
if _, err := io.Copy(shellInput, session); err != nil {
s.logger.WithError(err).Error("Failed to write incoming command to pty")
}
}()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
sessionLogger, err := s.logManager.NewSessionLogger(fmt.Sprintf("%s-session.log", sessionID), s.logger)
if err != nil {
if _, err := io.WriteString(session, "Failed to create log\n"); err != nil {
s.logger.WithError(err).Error("Failed to create log: Failed to write to SSH session")
}
s.errorAndExit(session, "", nil)
return
}
defer sessionLogger.Close()
go func() {
io.Copy(sessionLogger, pr)
}()
// Write outgoing command output to both the command recorder, and remote user
mw := io.MultiWriter(pw, session)
if _, err := io.Copy(mw, shellOutput); err != nil {
s.logger.WithError(err).Error("Failed to write command output to user")
}
// Wait for all resources associated with cmd to be released
// Returns error if shell exited with a non-zero status or received a signal
if err := cmd.Wait(); err != nil {
s.logger.WithError(err).Debug("Shell did not close correctly")
}
}
// spawnCmd spawns a shell under the user
func (s *SSHServer) spawnCmd(session ssh.Session, sshUser *User, uidInt, gidInt uint32) *exec.Cmd {
var cmd *exec.Cmd
if session.RawCommand() != "" {
cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand())
} else {
cmd = exec.Command(sshUser.Shell)
}
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true}
cmd.Env = append(cmd.Env, session.Environ()...)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
cmd.Dir = sshUser.HomeDir
return cmd
}
// getSSHUser gets the ssh user, uid, and gid of the user attempting to login
func (s *SSHServer) getSSHUser(session ssh.Session, sessionID string, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) {
// Get uid and gid of user attempting to login
sshUser, ok := session.Context().Value("sshUser").(*User)
if !ok || sshUser == nil {
s.errorAndExit(session, "Error retrieving credentials from session", nil)
return nil, 0, 0, false
}
s.logAuditEvent(eventLogger, session, sessionID, auditEventAuth)
uidInt, err := stringToUint32(sshUser.Uid)
if err != nil {
s.errorAndExit(session, "Invalid user", err)
return sshUser, 0, 0, false
}
gidInt, err := stringToUint32(sshUser.Gid)
if err != nil {
s.errorAndExit(session, "Invalid user group", err)
return sshUser, 0, 0, false
}
return sshUser, uidInt, gidInt, true
}
// errorAndExit reports an error with the session and exits
func (s *SSHServer) errorAndExit(session ssh.Session, errText string, err error) {
if err := session.Exit(1); err != nil {
s.logger.WithError(err).Error("Failed to close SSH session")
} else if err != nil {
s.logger.WithError(err).Error(errText)
} else if errText != "" {
s.logger.Error(errText)
}
}
func (s *SSHServer) startNonPtySession(cmd *exec.Cmd) (io.WriteCloser, io.ReadCloser, error) {
in, err := cmd.StdinPipe()
if err != nil {
return nil, nil, err
}
out, err := cmd.StdoutPipe()
if err != nil {
return nil, nil, err
}
cmd.Stderr = cmd.Stdout
if err = cmd.Start(); err != nil {
return nil, nil, err
}
return in, out, nil
}
func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logCallback func()) (io.WriteCloser, io.ReadCloser, error) {
tty, err := pty.Start(cmd)
if err != nil {
return nil, nil, err
}
// Handle terminal window size changes
go func() {
for win := range winCh {
if errNo := setWinsize(tty, win.Width, win.Height); errNo != 0 {
s.logger.WithError(err).Error("Failed to set pty window size")
close(s.shutdownC)
return
}
logCallback()
}
}()
return tty, tty, nil
}
func (s *SSHServer) logAuditEvent(writer io.WriteCloser, session ssh.Session, sessionID string, eventType string) {
username := "unknown"
sshUser, ok := session.Context().Value("sshUser").(*User)
if ok && sshUser != nil {
username = sshUser.Username
}
event := auditEvent{
Event: session.RawCommand(),
EventType: eventType,
SessionID: sessionID,
User: username,
Login: username,
Datetime: time.Now().UTC().Format(time.RFC3339),
IPAddress: session.RemoteAddr().String(),
}
data, err := json.Marshal(&event)
if err != nil {
s.logger.WithError(err).Error("Failed to log audit event. malformed audit object")
return
}
line := string(data) + "\n"
writer.Write([]byte(line))
}
// Sets PTY window size for terminal
func setWinsize(f *os.File, w, h int) syscall.Errno {
_, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
return errNo
}
func stringToUint32(str string) (uint32, error) {
uid, err := strconv.ParseUint(str, 10, 32)
return uint32(uid), err
}
func allowForward(_ ssh.Context, _ string, _ uint32) bool {
return true
}