AUTH-2030: Support both authorized_key and short lived cert authentication simultaniously without specifiying at start time
This commit is contained in:
parent
cf314ddb58
commit
7abbe91d41
|
@ -49,9 +49,6 @@ const (
|
|||
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
|
||||
sshPortFlag = "local-ssh-port"
|
||||
|
||||
// shortLivedCertFlag enables short lived cert authentication
|
||||
shortLivedCertFlag = "short-lived-certs"
|
||||
|
||||
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
|
||||
sshIdleTimeoutFlag = "ssh-idle-timeout"
|
||||
|
||||
|
@ -387,8 +384,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
|||
}
|
||||
|
||||
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
|
||||
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
||||
|
||||
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Cannot create new SSH Server")
|
||||
return errors.Wrap(err, "Cannot create new SSH Server")
|
||||
|
@ -971,12 +967,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
EnvVars: []string{"LOCAL_SSH_PORT"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: shortLivedCertFlag,
|
||||
Usage: "Enable short lived cert authentication for SSH server",
|
||||
EnvVars: []string{"SHORT_LIVED_CERTS"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: sshIdleTimeoutFlag,
|
||||
Usage: "Connection timeout after no activity",
|
||||
|
|
|
@ -18,6 +18,23 @@ var (
|
|||
authorizedKeysDir = ".cloudflared/authorized_keys"
|
||||
)
|
||||
|
||||
func (s *SSHServer) configureAuthentication() {
|
||||
caCert, err := getCACert()
|
||||
if err != nil {
|
||||
s.logger.Info(err)
|
||||
}
|
||||
s.caCert = caCert
|
||||
s.PublicKeyHandler = s.authenticationHandler
|
||||
}
|
||||
|
||||
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
cert, ok := key.(*gossh.Certificate)
|
||||
if !ok {
|
||||
return s.authorizedKeyHandler(ctx, key)
|
||||
}
|
||||
return s.shortLivedCertHandler(ctx, cert)
|
||||
}
|
||||
|
||||
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
sshUser, err := s.getUserFunc(ctx.User())
|
||||
if err != nil {
|
||||
|
@ -56,20 +73,14 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
userCert, ok := key.(*gossh.Certificate)
|
||||
if !ok {
|
||||
s.logger.Debug("Received key is not an SSH certificate")
|
||||
return false
|
||||
}
|
||||
|
||||
if !ssh.KeysEqual(s.caCert, userCert.SignatureKey) {
|
||||
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certificate) bool {
|
||||
if !ssh.KeysEqual(s.caCert, cert.SignatureKey) {
|
||||
s.logger.Debug("CA certificate does not match user certificate signer")
|
||||
return false
|
||||
}
|
||||
|
||||
checker := gossh.CertChecker{}
|
||||
if err := checker.CheckCert(ctx.User(), userCert); err != nil {
|
||||
if err := checker.CheckCert(ctx.User(), cert); err != nil {
|
||||
s.logger.Debug(err)
|
||||
return false
|
||||
} else {
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -90,7 +91,8 @@ func TestShortLivedCerts_Success(t *testing.T) {
|
|||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.True(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
}
|
||||
|
||||
|
@ -101,7 +103,8 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
|||
caCert := getKey(t, testOtherCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
||||
}
|
||||
|
@ -113,7 +116,8 @@ func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
|
|||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||
}
|
||||
|
@ -125,7 +129,8 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
|||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ type SSHServer struct {
|
|||
getUserFunc func(string) (*User, error)
|
||||
}
|
||||
|
||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -48,17 +48,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if shortLivedCertAuth {
|
||||
caCert, err := getCACert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sshServer.caCert = caCert
|
||||
sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler
|
||||
} else {
|
||||
sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler
|
||||
}
|
||||
|
||||
sshServer.configureAuthentication()
|
||||
return &sshServer, nil
|
||||
}
|
||||
|
||||
|
@ -111,6 +101,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
|||
return
|
||||
}
|
||||
|
||||
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
|
||||
type SSHServer struct{}
|
||||
|
||||
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) {
|
||||
func New(_ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
|
||||
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue