AUTH-2030: Support both authorized_key and short lived cert authentication simultaniously without specifiying at start time
This commit is contained in:
parent
cf314ddb58
commit
7abbe91d41
|
@ -49,9 +49,6 @@ const (
|
||||||
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
|
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
|
||||||
sshPortFlag = "local-ssh-port"
|
sshPortFlag = "local-ssh-port"
|
||||||
|
|
||||||
// shortLivedCertFlag enables short lived cert authentication
|
|
||||||
shortLivedCertFlag = "short-lived-certs"
|
|
||||||
|
|
||||||
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
|
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
|
||||||
sshIdleTimeoutFlag = "ssh-idle-timeout"
|
sshIdleTimeoutFlag = "ssh-idle-timeout"
|
||||||
|
|
||||||
|
@ -387,8 +384,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
||||||
}
|
}
|
||||||
|
|
||||||
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
|
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
|
||||||
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Cannot create new SSH Server")
|
logger.WithError(err).Error("Cannot create new SSH Server")
|
||||||
return errors.Wrap(err, "Cannot create new SSH Server")
|
return errors.Wrap(err, "Cannot create new SSH Server")
|
||||||
|
@ -971,12 +967,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
EnvVars: []string{"LOCAL_SSH_PORT"},
|
EnvVars: []string{"LOCAL_SSH_PORT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
|
||||||
Name: shortLivedCertFlag,
|
|
||||||
Usage: "Enable short lived cert authentication for SSH server",
|
|
||||||
EnvVars: []string{"SHORT_LIVED_CERTS"},
|
|
||||||
Hidden: true,
|
|
||||||
}),
|
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
Name: sshIdleTimeoutFlag,
|
Name: sshIdleTimeoutFlag,
|
||||||
Usage: "Connection timeout after no activity",
|
Usage: "Connection timeout after no activity",
|
||||||
|
|
|
@ -18,6 +18,23 @@ var (
|
||||||
authorizedKeysDir = ".cloudflared/authorized_keys"
|
authorizedKeysDir = ".cloudflared/authorized_keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *SSHServer) configureAuthentication() {
|
||||||
|
caCert, err := getCACert()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Info(err)
|
||||||
|
}
|
||||||
|
s.caCert = caCert
|
||||||
|
s.PublicKeyHandler = s.authenticationHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
|
cert, ok := key.(*gossh.Certificate)
|
||||||
|
if !ok {
|
||||||
|
return s.authorizedKeyHandler(ctx, key)
|
||||||
|
}
|
||||||
|
return s.shortLivedCertHandler(ctx, cert)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
sshUser, err := s.getUserFunc(ctx.User())
|
sshUser, err := s.getUserFunc(ctx.User())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,20 +73,14 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certificate) bool {
|
||||||
userCert, ok := key.(*gossh.Certificate)
|
if !ssh.KeysEqual(s.caCert, cert.SignatureKey) {
|
||||||
if !ok {
|
|
||||||
s.logger.Debug("Received key is not an SSH certificate")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ssh.KeysEqual(s.caCert, userCert.SignatureKey) {
|
|
||||||
s.logger.Debug("CA certificate does not match user certificate signer")
|
s.logger.Debug("CA certificate does not match user certificate signer")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
checker := gossh.CertChecker{}
|
checker := gossh.CertChecker{}
|
||||||
if err := checker.CheckCert(ctx.User(), userCert); err != nil {
|
if err := checker.CheckCert(ctx.User(), cert); err != nil {
|
||||||
s.logger.Debug(err)
|
s.logger.Debug(err)
|
||||||
return false
|
return false
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/sirupsen/logrus/hooks/test"
|
"github.com/sirupsen/logrus/hooks/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
gossh "golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -90,7 +91,8 @@ func TestShortLivedCerts_Success(t *testing.T) {
|
||||||
caCert := getKey(t, testCAFilename)
|
caCert := getKey(t, testCAFilename)
|
||||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
|
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
|
||||||
|
|
||||||
userCert := getKey(t, testUserCertFilename)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
|
require.True(t, ok)
|
||||||
assert.True(t, sshServer.shortLivedCertHandler(context, userCert))
|
assert.True(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +103,8 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||||
caCert := getKey(t, testOtherCAFilename)
|
caCert := getKey(t, testOtherCAFilename)
|
||||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
|
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
|
||||||
|
|
||||||
userCert := getKey(t, testUserCertFilename)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
|
require.True(t, ok)
|
||||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||||
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
||||||
}
|
}
|
||||||
|
@ -113,7 +116,8 @@ func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
|
||||||
caCert := getKey(t, testCAFilename)
|
caCert := getKey(t, testCAFilename)
|
||||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||||
|
|
||||||
userCert := getKey(t, testUserCertFilename)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
|
require.True(t, ok)
|
||||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||||
}
|
}
|
||||||
|
@ -125,7 +129,8 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||||
caCert := getKey(t, testCAFilename)
|
caCert := getKey(t, testCAFilename)
|
||||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||||
|
|
||||||
userCert := getKey(t, testUserCertFilename)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
|
require.True(t, ok)
|
||||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||||
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ type SSHServer struct {
|
||||||
getUserFunc func(string) (*User, error)
|
getUserFunc func(string) (*User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -48,17 +48,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if shortLivedCertAuth {
|
sshServer.configureAuthentication()
|
||||||
caCert, err := getCACert()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sshServer.caCert = caCert
|
|
||||||
sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler
|
|
||||||
} else {
|
|
||||||
sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
return &sshServer, nil
|
return &sshServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,6 +101,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
|
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
|
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
type SSHServer struct{}
|
type SSHServer struct{}
|
||||||
|
|
||||||
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) {
|
func New(_ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
|
||||||
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue