AUTH-2036: Refactor user retrieval, shutdown after ssh server stops, add custom version string
This commit is contained in:
parent
ee588eeeaa
commit
d3b254f9ae
|
@ -387,7 +387,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
||||||
}
|
}
|
||||||
|
|
||||||
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
|
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
|
||||||
server, err := sshserver.New(logManager, logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Cannot create new SSH Server")
|
logger.WithError(err).Error("Cannot create new SSH Server")
|
||||||
return errors.Wrap(err, "Cannot create new SSH Server")
|
return errors.Wrap(err, "Cannot create new SSH Server")
|
||||||
|
@ -398,6 +398,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
||||||
if err = server.Start(); err != nil && err != ssh.ErrServerClosed {
|
if err = server.Start(); err != nil && err != ssh.ErrServerClosed {
|
||||||
logger.WithError(err).Error("SSH server error")
|
logger.WithError(err).Error("SSH server error")
|
||||||
}
|
}
|
||||||
|
// TODO: remove when declarative tunnels are implemented.
|
||||||
|
close(shutdownC)
|
||||||
}()
|
}()
|
||||||
c.Set("url", "ssh://"+sshServerAddress)
|
c.Set("url", "ssh://"+sshServerAddress)
|
||||||
}
|
}
|
||||||
|
@ -966,7 +968,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: sshPortFlag,
|
Name: sshPortFlag,
|
||||||
Usage: "Localhost port that cloudflared SSH server will run on",
|
Usage: "Localhost port that cloudflared SSH server will run on",
|
||||||
Value: "22",
|
Value: "2222",
|
||||||
EnvVars: []string{"LOCAL_SSH_PORT"},
|
EnvVars: []string{"LOCAL_SSH_PORT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -27,7 +27,15 @@ func (s *SSHServer) configureAuthentication() {
|
||||||
s.PublicKeyHandler = s.authenticationHandler
|
s.PublicKeyHandler = s.authenticationHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authenticationHandler is a callback that returns true if the user attempting to connect is authenticated.
|
||||||
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
|
sshUser, err := lookupUser(ctx.User())
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Debugf("Invalid user: %s", ctx.User())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ctx.SetValue("sshUser", sshUser)
|
||||||
|
|
||||||
cert, ok := key.(*gossh.Certificate)
|
cert, ok := key.(*gossh.Certificate)
|
||||||
if !ok {
|
if !ok {
|
||||||
return s.authorizedKeyHandler(ctx, key)
|
return s.authorizedKeyHandler(ctx, key)
|
||||||
|
@ -36,9 +44,9 @@ func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
sshUser, err := s.getUserFunc(ctx.User())
|
sshUser, ok := ctx.Value("sshUser").(*User)
|
||||||
if err != nil {
|
if !ok {
|
||||||
s.logger.Debugf("Invalid user: %s", ctx.User())
|
s.logger.Error("Failed to retrieve user from context")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +63,6 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||||
}
|
}
|
||||||
|
|
||||||
for len(authorizedKeysBytes) > 0 {
|
for len(authorizedKeysBytes) > 0 {
|
||||||
|
|
||||||
// Skips invalid keys. Returns error if no valid keys remain.
|
// Skips invalid keys. Returns error if no valid keys remain.
|
||||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||||
authorizedKeysBytes = rest
|
authorizedKeysBytes = rest
|
||||||
|
@ -65,7 +72,6 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||||
}
|
}
|
||||||
|
|
||||||
if ssh.KeysEqual(pubKey, key) {
|
if ssh.KeysEqual(pubKey, key) {
|
||||||
ctx.SetValue("sshUser", sshUser)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,13 +89,6 @@ func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certifica
|
||||||
if err := checker.CheckCert(ctx.User(), cert); err != nil {
|
if err := checker.CheckCert(ctx.User(), cert); err != nil {
|
||||||
s.logger.Debug(err)
|
s.logger.Debug(err)
|
||||||
return false
|
return false
|
||||||
} else {
|
|
||||||
sshUser, err := s.getUserFunc(ctx.User())
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Debugf("Invalid user: %s", ctx.User())
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
ctx.SetValue("sshUser", sshUser)
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
validPrincipal = "testUser"
|
|
||||||
testDir = "testdata"
|
testDir = "testdata"
|
||||||
testUserKeyFilename = "id_rsa.pub"
|
testUserKeyFilename = "id_rsa.pub"
|
||||||
testCAFilename = "ca.pub"
|
testCAFilename = "ca.pub"
|
||||||
|
@ -30,7 +29,10 @@ const (
|
||||||
testUserCertFilename = "id_rsa-cert.pub"
|
testUserCertFilename = "id_rsa-cert.pub"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger, hook = test.NewNullLogger()
|
var (
|
||||||
|
logger, hook = test.NewNullLogger()
|
||||||
|
mockUser = &User{Username: "testUser", HomeDir: testDir}
|
||||||
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
authorizedKeysDir = testUserKeyFilename
|
authorizedKeysDir = testUserKeyFilename
|
||||||
|
@ -40,20 +42,20 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPublicKeyAuth_Success(t *testing.T) {
|
func TestPublicKeyAuth_Success(t *testing.T) {
|
||||||
context, cancel := newMockContext(validPrincipal)
|
context, cancel := newMockContext(mockUser)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshServer := SSHServer{getUserFunc: getMockUser}
|
sshServer := SSHServer{logger: logger}
|
||||||
|
|
||||||
pubKey := getKey(t, testUserKeyFilename)
|
pubKey := getKey(t, testUserKeyFilename)
|
||||||
assert.True(t, sshServer.authorizedKeyHandler(context, pubKey))
|
assert.True(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPublicKeyAuth_MissingKey(t *testing.T) {
|
func TestPublicKeyAuth_MissingKey(t *testing.T) {
|
||||||
context, cancel := newMockContext(validPrincipal)
|
context, cancel := newMockContext(mockUser)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshServer := SSHServer{logger: logger, getUserFunc: getMockUser}
|
sshServer := SSHServer{logger: logger}
|
||||||
|
|
||||||
pubKey := getKey(t, testOtherCAFilename)
|
pubKey := getKey(t, testOtherCAFilename)
|
||||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||||
|
@ -61,23 +63,27 @@ func TestPublicKeyAuth_MissingKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPublicKeyAuth_InvalidUser(t *testing.T) {
|
func TestPublicKeyAuth_InvalidUser(t *testing.T) {
|
||||||
context, cancel := newMockContext("notAUser")
|
context, cancel := newMockContext(&User{Username: "notAUser"})
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshServer := SSHServer{logger: logger, getUserFunc: lookupUser}
|
sshServer := SSHServer{logger: logger}
|
||||||
|
|
||||||
pubKey := getKey(t, testUserKeyFilename)
|
pubKey := getKey(t, testUserKeyFilename)
|
||||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
assert.False(t, sshServer.authenticationHandler(context, pubKey))
|
||||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPublicKeyAuth_MissingFile(t *testing.T) {
|
func TestPublicKeyAuth_MissingFile(t *testing.T) {
|
||||||
currentUser, err := user.Current()
|
tempUser, err := user.Current()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
context, cancel := newMockContext(currentUser.Username)
|
currentUser, err := lookupUser(tempUser.Username)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, err)
|
||||||
|
context, cancel := newMockContext(currentUser)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
sshServer := SSHServer{Server: ssh.Server{}, logger: logger, getUserFunc: lookupUser}
|
sshServer := SSHServer{Server: ssh.Server{}, logger: logger}
|
||||||
|
|
||||||
pubKey := getKey(t, testUserKeyFilename)
|
pubKey := getKey(t, testUserKeyFilename)
|
||||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||||
|
@ -85,11 +91,11 @@ func TestPublicKeyAuth_MissingFile(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShortLivedCerts_Success(t *testing.T) {
|
func TestShortLivedCerts_Success(t *testing.T) {
|
||||||
context, cancel := newMockContext(validPrincipal)
|
context, cancel := newMockContext(mockUser)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
caCert := getKey(t, testCAFilename)
|
caCert := getKey(t, testCAFilename)
|
||||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
|
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert}
|
||||||
|
|
||||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
@ -97,11 +103,11 @@ func TestShortLivedCerts_Success(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||||
context, cancel := newMockContext(validPrincipal)
|
context, cancel := newMockContext(mockUser)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
caCert := getKey(t, testOtherCAFilename)
|
caCert := getKey(t, testOtherCAFilename)
|
||||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
|
sshServer := SSHServer{logger: logger, caCert: caCert}
|
||||||
|
|
||||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
@ -109,25 +115,12 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||||
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
|
|
||||||
context, cancel := newMockContext(validPrincipal)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
caCert := getKey(t, testCAFilename)
|
|
||||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
|
||||||
|
|
||||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
|
||||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||||
context, cancel := newMockContext("notAUser")
|
context, cancel := newMockContext(&User{Username: "NotAUser"})
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
caCert := getKey(t, testCAFilename)
|
caCert := getKey(t, testCAFilename)
|
||||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
sshServer := SSHServer{logger: logger, caCert: caCert}
|
||||||
|
|
||||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
@ -135,14 +128,6 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||||
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMockUser(_ string) (*User, error) {
|
|
||||||
return &User{
|
|
||||||
Username: validPrincipal,
|
|
||||||
HomeDir: testDir,
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func getKey(t *testing.T, filename string) ssh.PublicKey {
|
func getKey(t *testing.T, filename string) ssh.PublicKey {
|
||||||
path := path.Join(testDir, filename)
|
path := path.Join(testDir, filename)
|
||||||
bytes, err := ioutil.ReadFile(path)
|
bytes, err := ioutil.ReadFile(path)
|
||||||
|
@ -157,10 +142,13 @@ type mockSSHContext struct {
|
||||||
*sync.Mutex
|
*sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockContext(user string) (*mockSSHContext, context.CancelFunc) {
|
func newMockContext(user *User) (*mockSSHContext, context.CancelFunc) {
|
||||||
innerCtx, cancel := context.WithCancel(context.Background())
|
innerCtx, cancel := context.WithCancel(context.Background())
|
||||||
mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}}
|
mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}}
|
||||||
mockCtx.SetValue("user", user)
|
mockCtx.SetValue("sshUser", user)
|
||||||
|
|
||||||
|
// This naming is confusing but we cant change it because this mocks the SSHContext struct in gliderlabs/ssh
|
||||||
|
mockCtx.SetValue("user", user.Username)
|
||||||
return mockCtx, cancel
|
return mockCtx, cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -27,11 +28,10 @@ type SSHServer struct {
|
||||||
logger *logrus.Logger
|
logger *logrus.Logger
|
||||||
shutdownC chan struct{}
|
shutdownC chan struct{}
|
||||||
caCert ssh.PublicKey
|
caCert ssh.PublicKey
|
||||||
getUserFunc func(string) (*User, error)
|
|
||||||
logManager sshlog.Manager
|
logManager sshlog.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(logManager sshlog.Manager, logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -41,10 +41,14 @@ func New(logManager sshlog.Manager, logger *logrus.Logger, address string, shutd
|
||||||
}
|
}
|
||||||
|
|
||||||
sshServer := SSHServer{
|
sshServer := SSHServer{
|
||||||
Server: ssh.Server{Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout},
|
Server: ssh.Server{
|
||||||
|
Addr: address,
|
||||||
|
MaxTimeout: maxTimeout,
|
||||||
|
IdleTimeout: idleTimeout,
|
||||||
|
Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS),
|
||||||
|
},
|
||||||
logger: logger,
|
logger: logger,
|
||||||
shutdownC: shutdownC,
|
shutdownC: shutdownC,
|
||||||
getUserFunc: lookupUser,
|
|
||||||
logManager: logManager,
|
logManager: logManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
|
|
||||||
type SSHServer struct{}
|
type SSHServer struct{}
|
||||||
|
|
||||||
func New(_ sshlog.Manager, _ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
|
func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
|
||||||
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue