AUTH-2030: Support both authorized_key and short lived cert authentication simultaniously without specifiying at start time

This commit is contained in:
Michael Borkenstein 2019-08-29 15:36:45 -05:00
parent cf314ddb58
commit 7abbe91d41
5 changed files with 34 additions and 37 deletions

View File

@ -49,9 +49,6 @@ const (
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
sshPortFlag = "local-ssh-port"
// shortLivedCertFlag enables short lived cert authentication
shortLivedCertFlag = "short-lived-certs"
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
sshIdleTimeoutFlag = "ssh-idle-timeout"
@ -387,8 +384,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
}
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
if err != nil {
logger.WithError(err).Error("Cannot create new SSH Server")
return errors.Wrap(err, "Cannot create new SSH Server")
@ -971,12 +967,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"LOCAL_SSH_PORT"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: shortLivedCertFlag,
Usage: "Enable short lived cert authentication for SSH server",
EnvVars: []string{"SHORT_LIVED_CERTS"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshIdleTimeoutFlag,
Usage: "Connection timeout after no activity",

View File

@ -18,6 +18,23 @@ var (
authorizedKeysDir = ".cloudflared/authorized_keys"
)
func (s *SSHServer) configureAuthentication() {
caCert, err := getCACert()
if err != nil {
s.logger.Info(err)
}
s.caCert = caCert
s.PublicKeyHandler = s.authenticationHandler
}
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
cert, ok := key.(*gossh.Certificate)
if !ok {
return s.authorizedKeyHandler(ctx, key)
}
return s.shortLivedCertHandler(ctx, cert)
}
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
sshUser, err := s.getUserFunc(ctx.User())
if err != nil {
@ -56,20 +73,14 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
return false
}
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, key ssh.PublicKey) bool {
userCert, ok := key.(*gossh.Certificate)
if !ok {
s.logger.Debug("Received key is not an SSH certificate")
return false
}
if !ssh.KeysEqual(s.caCert, userCert.SignatureKey) {
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certificate) bool {
if !ssh.KeysEqual(s.caCert, cert.SignatureKey) {
s.logger.Debug("CA certificate does not match user certificate signer")
return false
}
checker := gossh.CertChecker{}
if err := checker.CheckCert(ctx.User(), userCert); err != nil {
if err := checker.CheckCert(ctx.User(), cert); err != nil {
s.logger.Debug(err)
return false
} else {

View File

@ -18,6 +18,7 @@ import (
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
gossh "golang.org/x/crypto/ssh"
)
const (
@ -90,7 +91,8 @@ func TestShortLivedCerts_Success(t *testing.T) {
caCert := getKey(t, testCAFilename)
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
userCert := getKey(t, testUserCertFilename)
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok)
assert.True(t, sshServer.shortLivedCertHandler(context, userCert))
}
@ -101,7 +103,8 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
caCert := getKey(t, testOtherCAFilename)
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
userCert := getKey(t, testUserCertFilename)
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok)
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
}
@ -113,7 +116,8 @@ func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
caCert := getKey(t, testCAFilename)
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
userCert := getKey(t, testUserCertFilename)
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok)
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
}
@ -125,7 +129,8 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
caCert := getKey(t, testCAFilename)
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
userCert := getKey(t, testUserCertFilename)
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok)
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
}

View File

@ -28,7 +28,7 @@ type SSHServer struct {
getUserFunc func(string) (*User, error)
}
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
currentUser, err := user.Current()
if err != nil {
return nil, err
@ -48,17 +48,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi
return nil, err
}
if shortLivedCertAuth {
caCert, err := getCACert()
if err != nil {
return nil, err
}
sshServer.caCert = caCert
sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler
} else {
sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler
}
sshServer.configureAuthentication()
return &sshServer, nil
}
@ -111,6 +101,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
return
}
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))

View File

@ -11,7 +11,7 @@ import (
type SSHServer struct{}
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) {
func New(_ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
return nil, errors.New("cloudflared ssh server is not supported on windows")
}