From 858ef2986813852216da8ac7f332d1b024d6d1db Mon Sep 17 00:00:00 2001 From: Michael Borkenstein Date: Wed, 28 Aug 2019 10:48:30 -0500 Subject: [PATCH] AUTH-2022: Adds ssh timeout configuration --- cmd/cloudflared/tunnel/cmd.go | 36 +++++++++++++++++++++++++++----- sshserver/authentication.go | 11 +++++----- sshserver/authentication_test.go | 2 +- sshserver/sshserver_unix.go | 6 +++--- sshserver/sshserver_windows.go | 3 ++- 5 files changed, 43 insertions(+), 15 deletions(-) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 4191ccbf..2dcc443d 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -40,7 +40,21 @@ import ( "gopkg.in/urfave/cli.v2/altsrc" ) -const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" +const ( + sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" + + // sshPortFlag is the port on localhost the cloudflared ssh server will run on + sshPortFlag = "local-ssh-port" + + // shortLivedCertFlag enables short lived cert authentication + shortLivedCertFlag = "short-lived-certs" + + // sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed + sshIdleTimeoutFlag = "ssh-idle-timeout" + + // sshMaxTimeoutFlag defines the max duration a SSH session can remain open for + sshMaxTimeoutFlag = "ssh-max-timeout" +) var ( shutdownC chan struct{} @@ -337,8 +351,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan logger.Infof("ssh-server set") - sshServerAddress := "127.0.0.1:" + c.String("local-ssh-port") - server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool("short-lived-certs")) + sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag) + server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag)) if err != nil { logger.WithError(err).Error("Cannot create new SSH Server") return errors.Wrap(err, "Cannot create new SSH Server") @@ -909,17 +923,29 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Hidden: true, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "local-ssh-port", + Name: sshPortFlag, Usage: "Localhost port that cloudflared SSH server will run on", Value: "22", EnvVars: []string{"LOCAL_SSH_PORT"}, Hidden: true, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: "short-lived-certs", + Name: shortLivedCertFlag, Usage: "Enable short lived cert authentication for SSH server", EnvVars: []string{"SHORT_LIVED_CERTS"}, Hidden: true, }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: sshIdleTimeoutFlag, + Usage: "Connection timeout after no activity", + EnvVars: []string{"SSH_IDLE_TIMEOUT"}, + Hidden: true, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: sshMaxTimeoutFlag, + Usage: "Absolute connection timeout", + EnvVars: []string{"SSH_MAX_TIMEOUT"}, + Hidden: true, + }), } } diff --git a/sshserver/authentication.go b/sshserver/authentication.go index 217c700d..c97264bf 100644 --- a/sshserver/authentication.go +++ b/sshserver/authentication.go @@ -15,7 +15,7 @@ import ( var ( systemConfigPath = "/etc/cloudflared/" - authorizeKeysPath = ".cloudflared/authorized_keys" + authorizedKeysDir = ".cloudflared/authorized_keys" ) func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { @@ -25,9 +25,9 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo return false } - authorizedKeysPath := path.Join(sshUser.HomeDir, authorizeKeysPath) + authorizedKeysPath := path.Join(sshUser.HomeDir, authorizedKeysDir) if _, err := os.Stat(authorizedKeysPath); os.IsNotExist(err) { - s.logger.Debugf("authorized_keys file %s not found", authorizeKeysPath) + s.logger.Debugf("authorized_keys file %s not found", authorizedKeysPath) return false } @@ -38,11 +38,12 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo } for len(authorizedKeysBytes) > 0 { + // Skips invalid keys. Returns error if no valid keys remain. pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) authorizedKeysBytes = rest if err != nil { - s.logger.WithError(err).Errorf("No valid keys found in %s", authorizeKeysPath) + s.logger.Errorf("Invalid key(s) found in %s", authorizedKeysPath) return false } @@ -51,7 +52,7 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo return true } } - s.logger.Debugf("Matching public key not found in %s", authorizeKeysPath) + s.logger.Debugf("Matching public key not found in %s", authorizedKeysPath) return false } diff --git a/sshserver/authentication_test.go b/sshserver/authentication_test.go index 2773709c..40a1696f 100644 --- a/sshserver/authentication_test.go +++ b/sshserver/authentication_test.go @@ -32,7 +32,7 @@ const ( var logger, hook = test.NewNullLogger() func TestMain(m *testing.M) { - authorizeKeysPath = testUserKeyFilename + authorizedKeysDir = testUserKeyFilename logger.SetLevel(logrus.DebugLevel) code := m.Run() os.Exit(code) diff --git a/sshserver/sshserver_unix.go b/sshserver/sshserver_unix.go index ae948922..34fc417e 100644 --- a/sshserver/sshserver_unix.go +++ b/sshserver/sshserver_unix.go @@ -12,6 +12,7 @@ import ( "os/user" "strconv" "syscall" + "time" "unsafe" "github.com/creack/pty" @@ -27,7 +28,7 @@ type SSHServer struct { getUserFunc func(string) (*User, error) } -func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool) (*SSHServer, error) { +func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) { currentUser, err := user.Current() if err != nil { return nil, err @@ -37,7 +38,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi } sshServer := SSHServer{ - Server: ssh.Server{Addr: address}, + Server: ssh.Server{Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout}, logger: logger, shutdownC: shutdownC, getUserFunc: lookupUser, @@ -76,7 +77,6 @@ func (s *SSHServer) Start() error { } func (s *SSHServer) connectionHandler(session ssh.Session) { - // Get uid and gid of user attempting to login sshUser, ok := session.Context().Value("sshUser").(*User) if !ok || sshUser == nil { diff --git a/sshserver/sshserver_windows.go b/sshserver/sshserver_windows.go index a2002e01..89faba04 100644 --- a/sshserver/sshserver_windows.go +++ b/sshserver/sshserver_windows.go @@ -6,11 +6,12 @@ import ( "errors" "github.com/sirupsen/logrus" + "time" ) type SSHServer struct{} -func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool) (*SSHServer, error) { +func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) { return nil, errors.New("cloudflared ssh server is not supported on windows") }