From 188f4667cb940fb00e3978737ca9d4162970f4a1 Mon Sep 17 00:00:00 2001 From: Michael Borkenstein Date: Mon, 19 Aug 2019 13:51:59 -0500 Subject: [PATCH] AUTH-2004: Adds static host key support --- sshserver/host_keys.go | 110 ++++++++++++++++++++++++++++++++++++ sshserver/sshserver_unix.go | 15 +++-- 2 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 sshserver/host_keys.go diff --git a/sshserver/host_keys.go b/sshserver/host_keys.go new file mode 100644 index 00000000..a2d22bc4 --- /dev/null +++ b/sshserver/host_keys.go @@ -0,0 +1,110 @@ +//+build !windows + +package sshserver + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + "path/filepath" + + "github.com/gliderlabs/ssh" + "github.com/pkg/errors" +) + +const ( + rsaFilename = "ssh_host_rsa_key" + ecdsaFilename = "ssh_host_ecdsa_key" +) + +func (s *SSHServer) configureHostKeys() error { + if _, err := os.Stat(configDir); os.IsNotExist(err) { + if err := os.MkdirAll(configDir, 0755); err != nil { + return errors.Wrap(err, fmt.Sprintf("Error creating %s directory", configDir)) + } + } + + if err := s.configureHostKey(s.ensureECDSAKeyExists); err != nil { + return err + } + + if err := s.configureHostKey(s.ensureRSAKeyExists); err != nil { + return err + } + + return nil +} + +func (s *SSHServer) configureHostKey(keyFunc func() (string, error)) error { + path, err := keyFunc() + if err != nil { + return err + } + + if err := s.SetOption(ssh.HostKeyFile(path)); err != nil { + return errors.Wrap(err, "Could not set SSH host key") + } + return nil +} + +func (s *SSHServer) ensureRSAKeyExists() (string, error) { + keyPath := filepath.Join(configDir, rsaFilename) + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", errors.Wrap(err, "Error generating RSA host key") + } + + privateKey := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + if err = writePrivateKey(keyPath, privateKey); err != nil { + return "", err + } + + s.logger.Debug("Created new RSA SSH host key: ", keyPath) + } + return keyPath, nil +} + +func (s *SSHServer) ensureECDSAKeyExists() (string, error) { + keyPath := filepath.Join(configDir, ecdsaFilename) + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return "", errors.Wrap(err, "Error generating ECDSA host key") + } + + keyBytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return "", errors.Wrap(err, "Error marshalling ECDSA key") + } + + privateKey := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + } + + if err = writePrivateKey(keyPath, privateKey); err != nil { + return "", err + } + + s.logger.Debug("Created new ECDSA SSH host key: ", keyPath) + } + return keyPath, nil +} + +func writePrivateKey(keyPath string, privateKey *pem.Block) error { + if err := ioutil.WriteFile(keyPath, pem.EncodeToMemory(privateKey), 0600); err != nil { + return errors.Wrap(err, fmt.Sprintf("Error writing host key to %s", keyPath)) + } + return nil +} diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index 91d68ebe..3fb760b4 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -19,24 +19,31 @@ import ( "github.com/sirupsen/logrus" ) +const ( + defaultShellPrompt = `\e[0;31m[\u@\h \W]\$ \e[m ` + configDir = "/etc/cloudflared/" +) + type SSHServer struct { ssh.Server logger *logrus.Logger shutdownC chan struct{} } -const DefaultShellPrompt = `\e[0;31m[\u@\h \W]\$ \e[m ` - func New(logger *logrus.Logger, address string, shutdownC chan struct{}) (*SSHServer, error) { currentUser, err := user.Current() if err != nil { return nil, err } if currentUser.Uid != "0" { - return nil, errors.New("cloudflared ssh server needs to run as root") + return nil, errors.New("cloudflared SSH server needs to run as root") } sshServer := SSHServer{ssh.Server{Addr: address}, logger, shutdownC} + if err := sshServer.configureHostKeys(); err != nil { + return nil, err + } + return &sshServer, nil } @@ -85,7 +92,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { } cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) - cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", DefaultShellPrompt)) + cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", defaultShellPrompt)) psuedoTTY, err := pty.Start(cmd) if err != nil { s.logger.WithError(err).Error("Failed to start pty session")