diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 8658670c..dd784fc8 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -387,7 +387,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan } sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) - server, err := sshserver.New(logManager, logger, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) + server, err := sshserver.New(logManager, logger, version, sshServerAddress, shutdownC, c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) if err != nil { logger.WithError(err).Error("Cannot create new SSH Server") return errors.Wrap(err, "Cannot create new SSH Server") @@ -398,6 +398,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan if err = server.Start(); err != nil && err != ssh.ErrServerClosed { logger.WithError(err).Error("SSH server error") } + // TODO: remove when declarative tunnels are implemented. + close(shutdownC) }() c.Set("url", "ssh://"+sshServerAddress) } @@ -966,7 +968,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag { altsrc.NewStringFlag(&cli.StringFlag{ Name: sshPortFlag, Usage: "Localhost port that cloudflared SSH server will run on", - Value: "22", + Value: "2222", EnvVars: []string{"LOCAL_SSH_PORT"}, Hidden: true, }), diff --git a/sshserver/authentication.go b/sshserver/authentication.go index 7dab6d92..64ecc2b2 100644 --- a/sshserver/authentication.go +++ b/sshserver/authentication.go @@ -27,7 +27,15 @@ func (s *SSHServer) configureAuthentication() { s.PublicKeyHandler = s.authenticationHandler } +// authenticationHandler is a callback that returns true if the user attempting to connect is authenticated. func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool { + sshUser, err := lookupUser(ctx.User()) + if err != nil { + s.logger.Debugf("Invalid user: %s", ctx.User()) + return false + } + ctx.SetValue("sshUser", sshUser) + cert, ok := key.(*gossh.Certificate) if !ok { return s.authorizedKeyHandler(ctx, key) @@ -36,9 +44,9 @@ func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bo } func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { - sshUser, err := s.getUserFunc(ctx.User()) - if err != nil { - s.logger.Debugf("Invalid user: %s", ctx.User()) + sshUser, ok := ctx.Value("sshUser").(*User) + if !ok { + s.logger.Error("Failed to retrieve user from context") return false } @@ -55,7 +63,6 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo } for len(authorizedKeysBytes) > 0 { - // Skips invalid keys. Returns error if no valid keys remain. pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) authorizedKeysBytes = rest @@ -65,7 +72,6 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo } if ssh.KeysEqual(pubKey, key) { - ctx.SetValue("sshUser", sshUser) return true } } @@ -83,13 +89,6 @@ func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certifica if err := checker.CheckCert(ctx.User(), cert); err != nil { s.logger.Debug(err) return false - } else { - sshUser, err := s.getUserFunc(ctx.User()) - if err != nil { - s.logger.Debugf("Invalid user: %s", ctx.User()) - return false - } - ctx.SetValue("sshUser", sshUser) } return true } diff --git a/sshserver/authentication_test.go b/sshserver/authentication_test.go index 50661231..5a7f3369 100644 --- a/sshserver/authentication_test.go +++ b/sshserver/authentication_test.go @@ -22,7 +22,6 @@ import ( ) const ( - validPrincipal = "testUser" testDir = "testdata" testUserKeyFilename = "id_rsa.pub" testCAFilename = "ca.pub" @@ -30,7 +29,10 @@ const ( testUserCertFilename = "id_rsa-cert.pub" ) -var logger, hook = test.NewNullLogger() +var ( + logger, hook = test.NewNullLogger() + mockUser = &User{Username: "testUser", HomeDir: testDir} +) func TestMain(m *testing.M) { authorizedKeysDir = testUserKeyFilename @@ -40,20 +42,20 @@ func TestMain(m *testing.M) { } func TestPublicKeyAuth_Success(t *testing.T) { - context, cancel := newMockContext(validPrincipal) + context, cancel := newMockContext(mockUser) defer cancel() - sshServer := SSHServer{getUserFunc: getMockUser} + sshServer := SSHServer{logger: logger} pubKey := getKey(t, testUserKeyFilename) assert.True(t, sshServer.authorizedKeyHandler(context, pubKey)) } func TestPublicKeyAuth_MissingKey(t *testing.T) { - context, cancel := newMockContext(validPrincipal) + context, cancel := newMockContext(mockUser) defer cancel() - sshServer := SSHServer{logger: logger, getUserFunc: getMockUser} + sshServer := SSHServer{logger: logger} pubKey := getKey(t, testOtherCAFilename) assert.False(t, sshServer.authorizedKeyHandler(context, pubKey)) @@ -61,23 +63,27 @@ func TestPublicKeyAuth_MissingKey(t *testing.T) { } func TestPublicKeyAuth_InvalidUser(t *testing.T) { - context, cancel := newMockContext("notAUser") + context, cancel := newMockContext(&User{Username: "notAUser"}) defer cancel() - sshServer := SSHServer{logger: logger, getUserFunc: lookupUser} + sshServer := SSHServer{logger: logger} pubKey := getKey(t, testUserKeyFilename) - assert.False(t, sshServer.authorizedKeyHandler(context, pubKey)) + assert.False(t, sshServer.authenticationHandler(context, pubKey)) assert.Contains(t, hook.LastEntry().Message, "Invalid user") } func TestPublicKeyAuth_MissingFile(t *testing.T) { - currentUser, err := user.Current() + tempUser, err := user.Current() require.Nil(t, err) - context, cancel := newMockContext(currentUser.Username) + currentUser, err := lookupUser(tempUser.Username) + require.Nil(t, err) + + require.Nil(t, err) + context, cancel := newMockContext(currentUser) defer cancel() - sshServer := SSHServer{Server: ssh.Server{}, logger: logger, getUserFunc: lookupUser} + sshServer := SSHServer{Server: ssh.Server{}, logger: logger} pubKey := getKey(t, testUserKeyFilename) assert.False(t, sshServer.authorizedKeyHandler(context, pubKey)) @@ -85,11 +91,11 @@ func TestPublicKeyAuth_MissingFile(t *testing.T) { } func TestShortLivedCerts_Success(t *testing.T) { - context, cancel := newMockContext(validPrincipal) + context, cancel := newMockContext(mockUser) defer cancel() caCert := getKey(t, testCAFilename) - sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser} + sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert} userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) require.True(t, ok) @@ -97,11 +103,11 @@ func TestShortLivedCerts_Success(t *testing.T) { } func TestShortLivedCerts_CAsDontMatch(t *testing.T) { - context, cancel := newMockContext(validPrincipal) + context, cancel := newMockContext(mockUser) defer cancel() caCert := getKey(t, testOtherCAFilename) - sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser} + sshServer := SSHServer{logger: logger, caCert: caCert} userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) require.True(t, ok) @@ -109,25 +115,12 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) { assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message) } -func TestShortLivedCerts_UserDoesNotExist(t *testing.T) { - context, cancel := newMockContext(validPrincipal) - defer cancel() - - caCert := getKey(t, testCAFilename) - sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser} - - userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) - require.True(t, ok) - assert.False(t, sshServer.shortLivedCertHandler(context, userCert)) - assert.Contains(t, hook.LastEntry().Message, "Invalid user") -} - func TestShortLivedCerts_InvalidPrincipal(t *testing.T) { - context, cancel := newMockContext("notAUser") + context, cancel := newMockContext(&User{Username: "NotAUser"}) defer cancel() caCert := getKey(t, testCAFilename) - sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser} + sshServer := SSHServer{logger: logger, caCert: caCert} userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate) require.True(t, ok) @@ -135,14 +128,6 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) { assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate") } -func getMockUser(_ string) (*User, error) { - return &User{ - Username: validPrincipal, - HomeDir: testDir, - }, nil - -} - func getKey(t *testing.T, filename string) ssh.PublicKey { path := path.Join(testDir, filename) bytes, err := ioutil.ReadFile(path) @@ -157,10 +142,13 @@ type mockSSHContext struct { *sync.Mutex } -func newMockContext(user string) (*mockSSHContext, context.CancelFunc) { +func newMockContext(user *User) (*mockSSHContext, context.CancelFunc) { innerCtx, cancel := context.WithCancel(context.Background()) mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}} - mockCtx.SetValue("user", user) + mockCtx.SetValue("sshUser", user) + + // This naming is confusing but we cant change it because this mocks the SSHContext struct in gliderlabs/ssh + mockCtx.SetValue("user", user.Username) return mockCtx, cancel } diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index dcc0e26c..cb968ff7 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -9,6 +9,7 @@ import ( "os" "os/exec" "os/user" + "runtime" "strconv" "syscall" "time" @@ -24,14 +25,13 @@ import ( type SSHServer struct { ssh.Server - logger *logrus.Logger - shutdownC chan struct{} - caCert ssh.PublicKey - getUserFunc func(string) (*User, error) - logManager sshlog.Manager + logger *logrus.Logger + shutdownC chan struct{} + caCert ssh.PublicKey + logManager sshlog.Manager } -func New(logManager sshlog.Manager, logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { +func New(logManager sshlog.Manager, logger *logrus.Logger, version, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { currentUser, err := user.Current() if err != nil { return nil, err @@ -41,11 +41,15 @@ func New(logManager sshlog.Manager, logger *logrus.Logger, address string, shutd } sshServer := SSHServer{ - Server: ssh.Server{Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout}, - logger: logger, - shutdownC: shutdownC, - getUserFunc: lookupUser, - logManager: logManager, + Server: ssh.Server{ + Addr: address, + MaxTimeout: maxTimeout, + IdleTimeout: idleTimeout, + Version: fmt.Sprintf("SSH-2.0-Cloudflare-Access_%s_%s", version, runtime.GOOS), + }, + logger: logger, + shutdownC: shutdownC, + logManager: logManager, } if err := sshServer.configureHostKeys(); err != nil { diff --git a/sshserver/sshserver_windows.go b/sshserver/sshserver_windows.go index 5cf6c741..d5f89744 100644 --- a/sshserver/sshserver_windows.go +++ b/sshserver/sshserver_windows.go @@ -13,7 +13,7 @@ import ( type SSHServer struct{} -func New(_ sshlog.Manager, _ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { +func New(_ sshlog.Manager, _ *logrus.Logger, _, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) { return nil, errors.New("cloudflared ssh server is not supported on windows") }