AUTH-2036: Refactor user retrieval, shutdown after ssh server stops, add custom version string

This commit is contained in:
Michael Borkenstein 2019-09-04 10:37:53 -05:00
parent ee588eeeaa
commit d3b254f9ae
5 changed files with 60 additions and 67 deletions

View File

@ -387,7 +387,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
} }
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logManager, logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
if err != nil { if err != nil {
logger.WithError(err).Error("Cannot create new SSH Server") logger.WithError(err).Error("Cannot create new SSH Server")
return errors.Wrap(err, "Cannot create new SSH Server") return errors.Wrap(err, "Cannot create new SSH Server")
@ -398,6 +398,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
if err = server.Start(); err != nil && err != ssh.ErrServerClosed { if err = server.Start(); err != nil && err != ssh.ErrServerClosed {
logger.WithError(err).Error("SSH server error") logger.WithError(err).Error("SSH server error")
} }
// TODO: remove when declarative tunnels are implemented.
close(shutdownC)
}() }()
c.Set("url", "ssh://"+sshServerAddress) c.Set("url", "ssh://"+sshServerAddress)
} }
@ -966,7 +968,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: sshPortFlag, Name: sshPortFlag,
Usage: "Localhost port that cloudflared SSH server will run on", Usage: "Localhost port that cloudflared SSH server will run on",
Value: "22", Value: "2222",
EnvVars: []string{"LOCAL_SSH_PORT"}, EnvVars: []string{"LOCAL_SSH_PORT"},
Hidden: true, Hidden: true,
}), }),

View File

@ -27,7 +27,15 @@ func (s *SSHServer) configureAuthentication() {
s.PublicKeyHandler = s.authenticationHandler s.PublicKeyHandler = s.authenticationHandler
} }
// authenticationHandler is a callback that returns true if the user attempting to connect is authenticated.
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool { func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
sshUser, err := lookupUser(ctx.User())
if err != nil {
s.logger.Debugf("Invalid user: %s", ctx.User())
return false
}
ctx.SetValue("sshUser", sshUser)
cert, ok := key.(*gossh.Certificate) cert, ok := key.(*gossh.Certificate)
if !ok { if !ok {
return s.authorizedKeyHandler(ctx, key) return s.authorizedKeyHandler(ctx, key)
@ -36,9 +44,9 @@ func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bo
} }
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
sshUser, err := s.getUserFunc(ctx.User()) sshUser, ok := ctx.Value("sshUser").(*User)
if err != nil { if !ok {
s.logger.Debugf("Invalid user: %s", ctx.User()) s.logger.Error("Failed to retrieve user from context")
return false return false
} }
@ -55,7 +63,6 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
} }
for len(authorizedKeysBytes) > 0 { for len(authorizedKeysBytes) > 0 {
// Skips invalid keys. Returns error if no valid keys remain. // Skips invalid keys. Returns error if no valid keys remain.
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
authorizedKeysBytes = rest authorizedKeysBytes = rest
@ -65,7 +72,6 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
} }
if ssh.KeysEqual(pubKey, key) { if ssh.KeysEqual(pubKey, key) {
ctx.SetValue("sshUser", sshUser)
return true return true
} }
} }
@ -83,13 +89,6 @@ func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certifica
if err := checker.CheckCert(ctx.User(), cert); err != nil { if err := checker.CheckCert(ctx.User(), cert); err != nil {
s.logger.Debug(err) s.logger.Debug(err)
return false return false
} else {
sshUser, err := s.getUserFunc(ctx.User())
if err != nil {
s.logger.Debugf("Invalid user: %s", ctx.User())
return false
}
ctx.SetValue("sshUser", sshUser)
} }
return true return true
} }

View File

@ -22,7 +22,6 @@ import (
) )
const ( const (
validPrincipal = "testUser"
testDir = "testdata" testDir = "testdata"
testUserKeyFilename = "id_rsa.pub" testUserKeyFilename = "id_rsa.pub"
testCAFilename = "ca.pub" testCAFilename = "ca.pub"
@ -30,7 +29,10 @@ const (
testUserCertFilename = "id_rsa-cert.pub" testUserCertFilename = "id_rsa-cert.pub"
) )
var logger, hook = test.NewNullLogger() var (
logger, hook = test.NewNullLogger()
mockUser = &User{Username: "testUser", HomeDir: testDir}
)
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
authorizedKeysDir = testUserKeyFilename authorizedKeysDir = testUserKeyFilename
@ -40,20 +42,20 @@ func TestMain(m *testing.M) {
} }
func TestPublicKeyAuth_Success(t *testing.T) { func TestPublicKeyAuth_Success(t *testing.T) {
context, cancel := newMockContext(validPrincipal) context, cancel := newMockContext(mockUser)
defer cancel() defer cancel()
sshServer := SSHServer{getUserFunc: getMockUser} sshServer := SSHServer{logger: logger}
pubKey := getKey(t, testUserKeyFilename) pubKey := getKey(t, testUserKeyFilename)
assert.True(t, sshServer.authorizedKeyHandler(context, pubKey)) assert.True(t, sshServer.authorizedKeyHandler(context, pubKey))
} }
func TestPublicKeyAuth_MissingKey(t *testing.T) { func TestPublicKeyAuth_MissingKey(t *testing.T) {
context, cancel := newMockContext(validPrincipal) context, cancel := newMockContext(mockUser)
defer cancel() defer cancel()
sshServer := SSHServer{logger: logger, getUserFunc: getMockUser} sshServer := SSHServer{logger: logger}
pubKey := getKey(t, testOtherCAFilename) pubKey := getKey(t, testOtherCAFilename)
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey)) assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
@ -61,23 +63,27 @@ func TestPublicKeyAuth_MissingKey(t *testing.T) {
} }
func TestPublicKeyAuth_InvalidUser(t *testing.T) { func TestPublicKeyAuth_InvalidUser(t *testing.T) {
context, cancel := newMockContext("notAUser") context, cancel := newMockContext(&User{Username: "notAUser"})
defer cancel() defer cancel()
sshServer := SSHServer{logger: logger, getUserFunc: lookupUser} sshServer := SSHServer{logger: logger}
pubKey := getKey(t, testUserKeyFilename) pubKey := getKey(t, testUserKeyFilename)
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey)) assert.False(t, sshServer.authenticationHandler(context, pubKey))
assert.Contains(t, hook.LastEntry().Message, "Invalid user") assert.Contains(t, hook.LastEntry().Message, "Invalid user")
} }
func TestPublicKeyAuth_MissingFile(t *testing.T) { func TestPublicKeyAuth_MissingFile(t *testing.T) {
currentUser, err := user.Current() tempUser, err := user.Current()
require.Nil(t, err) require.Nil(t, err)
context, cancel := newMockContext(currentUser.Username) currentUser, err := lookupUser(tempUser.Username)
require.Nil(t, err)
require.Nil(t, err)
context, cancel := newMockContext(currentUser)
defer cancel() defer cancel()
sshServer := SSHServer{Server: ssh.Server{}, logger: logger, getUserFunc: lookupUser} sshServer := SSHServer{Server: ssh.Server{}, logger: logger}
pubKey := getKey(t, testUserKeyFilename) pubKey := getKey(t, testUserKeyFilename)
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey)) assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
@ -85,11 +91,11 @@ func TestPublicKeyAuth_MissingFile(t *testing.T) {
} }
func TestShortLivedCerts_Success(t *testing.T) { func TestShortLivedCerts_Success(t *testing.T) {
context, cancel := newMockContext(validPrincipal) context, cancel := newMockContext(mockUser)
defer cancel() defer cancel()
caCert := getKey(t, testCAFilename) caCert := getKey(t, testCAFilename)
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser} sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert}
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok) require.True(t, ok)
@ -97,11 +103,11 @@ func TestShortLivedCerts_Success(t *testing.T) {
} }
func TestShortLivedCerts_CAsDontMatch(t *testing.T) { func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
context, cancel := newMockContext(validPrincipal) context, cancel := newMockContext(mockUser)
defer cancel() defer cancel()
caCert := getKey(t, testOtherCAFilename) caCert := getKey(t, testOtherCAFilename)
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser} sshServer := SSHServer{logger: logger, caCert: caCert}
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok) require.True(t, ok)
@ -109,25 +115,12 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message) assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
} }
func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
context, cancel := newMockContext(validPrincipal)
defer cancel()
caCert := getKey(t, testCAFilename)
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok)
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
}
func TestShortLivedCerts_InvalidPrincipal(t *testing.T) { func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
context, cancel := newMockContext("notAUser") context, cancel := newMockContext(&User{Username: "NotAUser"})
defer cancel() defer cancel()
caCert := getKey(t, testCAFilename) caCert := getKey(t, testCAFilename)
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser} sshServer := SSHServer{logger: logger, caCert: caCert}
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
require.True(t, ok) require.True(t, ok)
@ -135,14 +128,6 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate") assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
} }
func getMockUser(_ string) (*User, error) {
return &User{
Username: validPrincipal,
HomeDir: testDir,
}, nil
}
func getKey(t *testing.T, filename string) ssh.PublicKey { func getKey(t *testing.T, filename string) ssh.PublicKey {
path := path.Join(testDir, filename) path := path.Join(testDir, filename)
bytes, err := ioutil.ReadFile(path) bytes, err := ioutil.ReadFile(path)
@ -157,10 +142,13 @@ type mockSSHContext struct {
*sync.Mutex *sync.Mutex
} }
func newMockContext(user string) (*mockSSHContext, context.CancelFunc) { func newMockContext(user *User) (*mockSSHContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background()) innerCtx, cancel := context.WithCancel(context.Background())
mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}} mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}}
mockCtx.SetValue("user", user) mockCtx.SetValue("sshUser", user)
// This naming is confusing but we cant change it because this mocks the SSHContext struct in gliderlabs/ssh
mockCtx.SetValue("user", user.Username)
return mockCtx, cancel return mockCtx, cancel
} }

View File

@ -9,6 +9,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"runtime"
"strconv" "strconv"
"syscall" "syscall"
"time" "time"
@ -24,14 +25,13 @@ import (
type SSHServer struct { type SSHServer struct {
ssh.Server ssh.Server
logger *logrus.Logger logger *logrus.Logger
shutdownC chan struct{} shutdownC chan struct{}
caCert ssh.PublicKey caCert ssh.PublicKey
getUserFunc func(string) (*User, error) logManager sshlog.Manager
logManager sshlog.Manager
} }
func New(logManager sshlog.Manager, logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
currentUser, err := user.Current() currentUser, err := user.Current()
if err != nil { if err != nil {
return nil, err return nil, err
@ -41,11 +41,15 @@ func New(logManager sshlog.Manager, logger *logrus.Logger, address string, shutd
} }
sshServer := SSHServer{ sshServer := SSHServer{
Server: ssh.Server{Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout}, Server: ssh.Server{
logger: logger, Addr: address,
shutdownC: shutdownC, MaxTimeout: maxTimeout,
getUserFunc: lookupUser, IdleTimeout: idleTimeout,
logManager: logManager, Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
},
logger: logger,
shutdownC: shutdownC,
logManager: logManager,
} }
if err := sshServer.configureHostKeys(); err != nil { if err := sshServer.configureHostKeys(); err != nil {

View File

@ -13,7 +13,7 @@ import (
type SSHServer struct{} type SSHServer struct{}
func New(_ sshlog.Manager, _ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
return nil, errors.New("cloudflared ssh server is not supported on windows") return nil, errors.New("cloudflared ssh server is not supported on windows")
} }