cloudflared-mirror/sshserver/sshserver_unix.go

400 lines
11 KiB
Go

//+build !windows
package sshserver
import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
const (
auditEventAuth = "auth"
auditEventStart = "session_start"
auditEventStop = "session_stop"
auditEventExec = "exec"
auditEventScp = "scp"
auditEventResize = "resize"
sshContextSessionID = "sessionID"
sshContextEventLogger = "eventLogger"
)
type auditEvent struct {
Event string `json:"event,omitempty"`
EventType string `json:"event_type,omitempty"`
SessionID string `json:"session_id,omitempty"`
User string `json:"user,omitempty"`
Login string `json:"login,omitempty"`
Datetime string `json:"datetime,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
}
// SSHServer adds on to the ssh.Server of the gliderlabs package
type SSHServer struct {
ssh.Server
logger *logrus.Logger
shutdownC chan struct{}
caCert ssh.PublicKey
logManager sshlog.Manager
}
// New creates a new SSHServer and configures its host keys and authenication by the data provided
func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration, enablePortForwarding bool) (*SSHServer, error) {
currentUser, err := user.Current()
if err != nil {
return nil, err
}
if currentUser.Uid != "0" {
return nil, errors.New("cloudflared SSH server needs to run as root")
}
forwardHandler := &ssh.ForwardedTCPHandler{}
sshServer := SSHServer{
Server: ssh.Server{
Addr: address,
MaxTimeout: maxTimeout,
IdleTimeout: idleTimeout,
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
// Register SSH global Request handlers to respond to tcpip forwarding
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
},
// Register SSH channel types
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": ssh.DirectTCPIPHandler,
},
},
logger: logger,
shutdownC: shutdownC,
logManager: logManager,
}
// AUTH-2050: This is a temporary workaround of a timing issue in the tunnel muxer to allow further testing.
// TODO: Remove this
sshServer.ConnCallback = func(conn net.Conn) net.Conn {
time.Sleep(10 * time.Millisecond)
return conn
}
if enablePortForwarding {
sshServer.LocalPortForwardingCallback = allowForward
sshServer.ReversePortForwardingCallback = allowForward
}
if err := sshServer.configureHostKeys(); err != nil {
return nil, err
}
sshServer.configureAuthentication()
return &sshServer, nil
}
// Start the SSH server listener to start handling SSH connections from clients
func (s *SSHServer) Start() error {
s.logger.Infof("Starting SSH server at %s", s.Addr)
go func() {
<-s.shutdownC
if err := s.Close(); err != nil {
s.logger.WithError(err).Error("Cannot close SSH server")
}
}()
s.Handle(s.connectionHandler)
return s.ListenAndServe()
}
func (s *SSHServer) connectionHandler(session ssh.Session) {
sessionUUID, err := uuid.NewRandom()
if err != nil {
if _, err := io.WriteString(session, "Failed to generate session ID\n"); err != nil {
s.logger.WithError(err).Error("Failed to generate session ID: Failed to write to SSH session")
}
s.errorAndExit(session, "", nil)
return
}
sessionID := sessionUUID.String()
eventLogger, err := s.logManager.NewLogger(fmt.Sprintf("%s-event.log", sessionID), s.logger)
if err != nil {
if _, err := io.WriteString(session, "Failed to create event log\n"); err != nil {
s.logger.WithError(err).Error("Failed to create event log: Failed to write to create event logger")
}
s.errorAndExit(session, "", nil)
return
}
sshContext, ok := session.Context().(ssh.Context)
if !ok {
s.logger.Error("Could not retrieve session context")
s.errorAndExit(session, "", nil)
}
sshContext.SetValue(sshContextSessionID, sessionID)
sshContext.SetValue(sshContextEventLogger, eventLogger)
// Get uid and gid of user attempting to login
sshUser, uidInt, gidInt, success := s.getSSHUser(session, eventLogger)
if !success {
return
}
// Spawn shell under user
var cmd *exec.Cmd
if session.RawCommand() != "" {
cmd = exec.Command(sshUser.Shell, "-c", session.RawCommand())
event := auditEventExec
if strings.HasPrefix(session.RawCommand(), "scp") {
event = auditEventScp
}
s.logAuditEvent(session, event)
} else {
cmd = exec.Command(sshUser.Shell)
s.logAuditEvent(session, auditEventStart)
defer s.logAuditEvent(session, auditEventStop)
}
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}, Setsid: true}
cmd.Env = append(cmd.Env, session.Environ()...)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=%s", sshUser.HomeDir))
cmd.Dir = sshUser.HomeDir
var shellInput io.WriteCloser
var shellOutput io.ReadCloser
pr, pw := io.Pipe()
defer pw.Close()
ptyReq, winCh, isPty := session.Pty()
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
tty, err := s.startPtySession(cmd, winCh, func() {
s.logAuditEvent(session, auditEventResize)
})
shellInput = tty
shellOutput = tty
if err != nil {
s.logger.WithError(err).Error("Failed to start pty session")
close(s.shutdownC)
return
}
} else {
var shellError io.ReadCloser
shellInput, shellOutput, shellError, err = s.startNonPtySession(cmd)
if err != nil {
s.logger.WithError(err).Error("Failed to start non-pty session")
close(s.shutdownC)
return
}
// Write stderr to both the command recorder, and remote user
go func() {
mw := io.MultiWriter(pw, session.Stderr())
if _, err := io.Copy(mw, shellError); err != nil {
s.logger.WithError(err).Error("Failed to write stderr to user")
}
}()
}
sessionLogger, err := s.logManager.NewSessionLogger(fmt.Sprintf("%s-session.log", sessionID), s.logger)
if err != nil {
if _, err := io.WriteString(session, "Failed to create log\n"); err != nil {
s.logger.WithError(err).Error("Failed to create log: Failed to write to SSH session")
}
s.errorAndExit(session, "", nil)
return
}
go func() {
defer sessionLogger.Close()
defer pr.Close()
_, err := io.Copy(sessionLogger, pr)
if err != nil {
s.logger.WithError(err).Error("Failed to write session log")
}
}()
// Write stdin to shell
go func() {
/*
Only close shell stdin for non-pty sessions because they have distinct stdin, stdout, and stderr.
This is done to prevent commands like SCP from hanging after all data has been sent.
PTY sessions share one file for all three streams and the shell process closes it.
Closing it here also closes shellOutput and causes an error on copy().
*/
if !isPty {
defer shellInput.Close()
}
if _, err := io.Copy(shellInput, session); err != nil {
s.logger.WithError(err).Error("Failed to write incoming command to pty")
}
}()
// Write stdout to both the command recorder, and remote user
mw := io.MultiWriter(pw, session)
if _, err := io.Copy(mw, shellOutput); err != nil {
s.logger.WithError(err).Error("Failed to write stdout to user")
}
// Wait for all resources associated with cmd to be released
// Returns error if shell exited with a non-zero status or received a signal
if err := cmd.Wait(); err != nil {
s.logger.WithError(err).Debug("Shell did not close correctly")
}
}
// getSSHUser gets the ssh user, uid, and gid of the user attempting to login
func (s *SSHServer) getSSHUser(session ssh.Session, eventLogger io.WriteCloser) (*User, uint32, uint32, bool) {
// Get uid and gid of user attempting to login
sshUser, ok := session.Context().Value("sshUser").(*User)
if !ok || sshUser == nil {
s.errorAndExit(session, "Error retrieving credentials from session", nil)
return nil, 0, 0, false
}
s.logAuditEvent(session, auditEventAuth)
uidInt, err := stringToUint32(sshUser.Uid)
if err != nil {
s.errorAndExit(session, "Invalid user", err)
return sshUser, 0, 0, false
}
gidInt, err := stringToUint32(sshUser.Gid)
if err != nil {
s.errorAndExit(session, "Invalid user group", err)
return sshUser, 0, 0, false
}
return sshUser, uidInt, gidInt, true
}
// errorAndExit reports an error with the session and exits
func (s *SSHServer) errorAndExit(session ssh.Session, errText string, err error) {
if exitError := session.Exit(1); exitError != nil {
s.logger.WithError(exitError).Error("Failed to close SSH session")
} else if err != nil {
s.logger.WithError(err).Error(errText)
} else if errText != "" {
s.logger.Error(errText)
}
}
func (s *SSHServer) startNonPtySession(cmd *exec.Cmd) (stdin io.WriteCloser, stdout io.ReadCloser, stderr io.ReadCloser, err error) {
stdin, err = cmd.StdinPipe()
if err != nil {
return
}
stdout, err = cmd.StdoutPipe()
if err != nil {
return
}
stderr, err = cmd.StderrPipe()
if err != nil {
return
}
if err = cmd.Start(); err != nil {
return
}
return
}
func (s *SSHServer) startPtySession(cmd *exec.Cmd, winCh <-chan ssh.Window, logCallback func()) (io.ReadWriteCloser, error) {
tty, err := pty.Start(cmd)
if err != nil {
return nil, err
}
// Handle terminal window size changes
go func() {
for win := range winCh {
if errNo := setWinsize(tty, win.Width, win.Height); errNo != 0 {
s.logger.WithError(err).Error("Failed to set pty window size")
close(s.shutdownC)
return
}
logCallback()
}
}()
return tty, nil
}
func (s *SSHServer) logAuditEvent(session ssh.Session, eventType string) {
username := "unknown"
sshUser, ok := session.Context().Value("sshUser").(*User)
if ok && sshUser != nil {
username = sshUser.Username
}
sessionID, ok := session.Context().Value(sshContextSessionID).(string)
if !ok {
s.logger.Error("Failed to retrieve sessionID from context")
return
}
writer, ok := session.Context().Value(sshContextEventLogger).(io.WriteCloser)
if !ok {
s.logger.Error("Failed to retrieve eventLogger from context")
return
}
event := auditEvent{
Event: session.RawCommand(),
EventType: eventType,
SessionID: sessionID,
User: username,
Login: username,
Datetime: time.Now().UTC().Format(time.RFC3339),
IPAddress: session.RemoteAddr().String(),
}
data, err := json.Marshal(&event)
if err != nil {
s.logger.WithError(err).Error("Failed to marshal audit event. malformed audit object")
return
}
line := string(data) + "\n"
if _, err := writer.Write([]byte(line)); err != nil {
s.logger.WithError(err).Error("Failed to write audit event.")
}
}
// Sets PTY window size for terminal
func setWinsize(f *os.File, w, h int) syscall.Errno {
_, _, errNo := syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
return errNo
}
func stringToUint32(str string) (uint32, error) {
uid, err := strconv.ParseUint(str, 10, 32)
return uint32(uid), err
}
func allowForward(_ ssh.Context, _ string, _ uint32) bool {
return true
}