From 7abbe91d41237cb3c01ba2e6949fd260a5774454 Mon Sep 17 00:00:00 2001 From: Michael Borkenstein Date: Thu, 29 Aug 2019 15:36:45 -0500 Subject: [PATCH] AUTH-2030: Support both authorized_key and short lived cert authentication simultaniously without specifiying at start time --- cmd/cloudflared/tunnel/cmd.go | 12 +----------- sshserver/authentication.go | 29 ++++++++++++++++++++--------- sshserver/authentication_test.go | 13 +++++++++---- sshserver/sshserver_unix.go | 15 +++------------ sshserver/sshserver_windows.go | 2 +- 5 files changed, 34 insertions(+), 37 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 96524142..8e5c2400 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -49,9 +49,6 @@ const ( // sshPortFlag is the port on localhost the cloudflared ssh server will run on sshPortFlag = "local-ssh-port" - // shortLivedCertFlag enables short lived cert authentication - shortLivedCertFlag = "short-lived-certs" - // sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed sshIdleTimeoutFlag = "ssh-idle-timeout" @@ -387,8 +384,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan } sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) - server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) - + server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) if err != nil { logger.WithError(err).Error("Cannot create new SSH Server") return errors.Wrap(err, "Cannot create new SSH Server") @@ -971,12 +967,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"LOCAL_SSH_PORT"}, Hidden: true, }), - altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: shortLivedCertFlag, - Usage: "Enable short lived cert authentication for SSH server", - EnvVars: []string{"SHORT_LIVED_CERTS"}, - Hidden: true, - }), altsrc.NewDurationFlag(&cli.DurationFlag{ Name: sshIdleTimeoutFlag, Usage: "Connection timeout after no activity", diff --git a/sshserver/authentication.go b/sshserver/authentication.go index c97264bf..19f143ad 100644 --- a/sshserver/authentication.go +++ b/sshserver/authentication.go @@ -18,6 +18,23 @@ var ( authorizedKeysDir = ".cloudflared/authorized_keys" ) +func (s *SSHServer) configureAuthentication() { + caCert, err := getCACert() + if err != nil { + s.logger.Info(err) + } + s.caCert = caCert + s.PublicKeyHandler = s.authenticationHandler +} + +func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool { + cert, ok := key.(*gossh.Certificate) + if !ok { + return s.authorizedKeyHandler(ctx, key) + } + return s.shortLivedCertHandler(ctx, cert) +} + func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { sshUser, err := s.getUserFunc(ctx.User()) if err != nil { @@ -56,20 +73,14 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo return false } -func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, key ssh.PublicKey) bool { - userCert, ok := key.(*gossh.Certificate) - if !ok { - s.logger.Debug("Received key is not an SSH certificate") - return false - } - - if !ssh.KeysEqual(s.caCert, userCert.SignatureKey) { +func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certificate) bool { + if !ssh.KeysEqual(s.caCert, cert.SignatureKey) { s.logger.Debug("CA certificate does not match user certificate signer") return false } checker := gossh.CertChecker{} - if err := checker.CheckCert(ctx.User(), userCert); err != nil { + if err := checker.CheckCert(ctx.User(), cert); err != nil { s.logger.Debug(err) return false } else { diff --git a/sshserver/authentication_test.go b/sshserver/authentication_test.go index 40a1696f..50661231 100644 --- a/sshserver/authentication_test.go +++ b/sshserver/authentication_test.go @@ -18,6 +18,7 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + gossh "golang.org/x/crypto/ssh" ) const ( @@ -90,7 +91,8 @@ func TestShortLivedCerts_Success(t *testing.T) { caCert := getKey(t, testCAFilename) sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser} - userCert := getKey(t, testUserCertFilename) + userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) + require.True(t, ok) assert.True(t, sshServer.shortLivedCertHandler(context, userCert)) } @@ -101,7 +103,8 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) { caCert := getKey(t, testOtherCAFilename) sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser} - userCert := getKey(t, testUserCertFilename) + userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) + require.True(t, ok) assert.False(t, sshServer.shortLivedCertHandler(context, userCert)) assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message) } @@ -113,7 +116,8 @@ func TestShortLivedCerts_UserDoesNotExist(t *testing.T) { caCert := getKey(t, testCAFilename) sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser} - userCert := getKey(t, testUserCertFilename) + userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) + require.True(t, ok) assert.False(t, sshServer.shortLivedCertHandler(context, userCert)) assert.Contains(t, hook.LastEntry().Message, "Invalid user") } @@ -125,7 +129,8 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) { caCert := getKey(t, testCAFilename) sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser} - userCert := getKey(t, testUserCertFilename) + userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) + require.True(t, ok) assert.False(t, sshServer.shortLivedCertHandler(context, userCert)) assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate") } diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index 34fc417e..9b879f40 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -28,7 +28,7 @@ type SSHServer struct { getUserFunc func(string) (*User, error) } -func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { +func New(logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { currentUser, err := user.Current() if err != nil { return nil, err @@ -48,17 +48,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi return nil, err } - if shortLivedCertAuth { - caCert, err := getCACert() - if err != nil { - return nil, err - } - sshServer.caCert = caCert - sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler - } else { - sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler - } - + sshServer.configureAuthentication() return &sshServer, nil } @@ -111,6 +101,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) { return } + // Supplementary groups are not explicitly specified. They seem to be inherited by default. cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}} cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username)) diff --git a/sshserver/sshserver_windows.go b/sshserver/sshserver_windows.go index 89faba04..ed72025a 100644 --- a/sshserver/sshserver_windows.go +++ b/sshserver/sshserver_windows.go @@ -11,7 +11,7 @@ import ( type SSHServer struct{} -func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) { +func New(_ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { return nil, errors.New("cloudflared ssh server is not supported on windows") }