AUTH-2022: Adds ssh timeout configuration

This commit is contained in:
Michael Borkenstein 2019-08-28 10:48:30 -05:00
parent baec3e289e
commit 858ef29868
5 changed files with 43 additions and 15 deletions

View File

@ -40,7 +40,21 @@ import (
"gopkg.in/urfave/cli.v2/altsrc" "gopkg.in/urfave/cli.v2/altsrc"
) )
const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
sshPortFlag = "local-ssh-port"
// shortLivedCertFlag enables short lived cert authentication
shortLivedCertFlag = "short-lived-certs"
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
sshIdleTimeoutFlag = "ssh-idle-timeout"
// sshMaxTimeoutFlag defines the max duration a SSH session can remain open for
sshMaxTimeoutFlag = "ssh-max-timeout"
)
var ( var (
shutdownC chan struct{} shutdownC chan struct{}
@ -337,8 +351,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
logger.Infof("ssh-server set") logger.Infof("ssh-server set")
sshServerAddress := "127.0.0.1:" + c.String("local-ssh-port") sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool("short-lived-certs")) server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
if err != nil { if err != nil {
logger.WithError(err).Error("Cannot create new SSH Server") logger.WithError(err).Error("Cannot create new SSH Server")
return errors.Wrap(err, "Cannot create new SSH Server") return errors.Wrap(err, "Cannot create new SSH Server")
@ -909,17 +923,29 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "local-ssh-port", Name: sshPortFlag,
Usage: "Localhost port that cloudflared SSH server will run on", Usage: "Localhost port that cloudflared SSH server will run on",
Value: "22", Value: "22",
EnvVars: []string{"LOCAL_SSH_PORT"}, EnvVars: []string{"LOCAL_SSH_PORT"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "short-lived-certs", Name: shortLivedCertFlag,
Usage: "Enable short lived cert authentication for SSH server", Usage: "Enable short lived cert authentication for SSH server",
EnvVars: []string{"SHORT_LIVED_CERTS"}, EnvVars: []string{"SHORT_LIVED_CERTS"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshIdleTimeoutFlag,
Usage: "Connection timeout after no activity",
EnvVars: []string{"SSH_IDLE_TIMEOUT"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshMaxTimeoutFlag,
Usage: "Absolute connection timeout",
EnvVars: []string{"SSH_MAX_TIMEOUT"},
Hidden: true,
}),
} }
} }

View File

@ -15,7 +15,7 @@ import (
var ( var (
systemConfigPath = "/etc/cloudflared/" systemConfigPath = "/etc/cloudflared/"
authorizeKeysPath = ".cloudflared/authorized_keys" authorizedKeysDir = ".cloudflared/authorized_keys"
) )
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
@ -25,9 +25,9 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
return false return false
} }
authorizedKeysPath := path.Join(sshUser.HomeDir, authorizeKeysPath) authorizedKeysPath := path.Join(sshUser.HomeDir, authorizedKeysDir)
if _, err := os.Stat(authorizedKeysPath); os.IsNotExist(err) { if _, err := os.Stat(authorizedKeysPath); os.IsNotExist(err) {
s.logger.Debugf("authorized_keys file %s not found", authorizeKeysPath) s.logger.Debugf("authorized_keys file %s not found", authorizedKeysPath)
return false return false
} }
@ -38,11 +38,12 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
} }
for len(authorizedKeysBytes) > 0 { for len(authorizedKeysBytes) > 0 {
// Skips invalid keys. Returns error if no valid keys remain. // Skips invalid keys. Returns error if no valid keys remain.
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
authorizedKeysBytes = rest authorizedKeysBytes = rest
if err != nil { if err != nil {
s.logger.WithError(err).Errorf("No valid keys found in %s", authorizeKeysPath) s.logger.Errorf("Invalid key(s) found in %s", authorizedKeysPath)
return false return false
} }
@ -51,7 +52,7 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
return true return true
} }
} }
s.logger.Debugf("Matching public key not found in %s", authorizeKeysPath) s.logger.Debugf("Matching public key not found in %s", authorizedKeysPath)
return false return false
} }

View File

@ -32,7 +32,7 @@ const (
var logger, hook = test.NewNullLogger() var logger, hook = test.NewNullLogger()
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
authorizeKeysPath = testUserKeyFilename authorizedKeysDir = testUserKeyFilename
logger.SetLevel(logrus.DebugLevel) logger.SetLevel(logrus.DebugLevel)
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)

View File

@ -12,6 +12,7 @@ import (
"os/user" "os/user"
"strconv" "strconv"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"github.com/creack/pty" "github.com/creack/pty"
@ -27,7 +28,7 @@ type SSHServer struct {
getUserFunc func(string) (*User, error) getUserFunc func(string) (*User, error)
} }
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool) (*SSHServer, error) { func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
currentUser, err := user.Current() currentUser, err := user.Current()
if err != nil { if err != nil {
return nil, err return nil, err
@ -37,7 +38,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi
} }
sshServer := SSHServer{ sshServer := SSHServer{
Server: ssh.Server{Addr: address}, Server: ssh.Server{Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout},
logger: logger, logger: logger,
shutdownC: shutdownC, shutdownC: shutdownC,
getUserFunc: lookupUser, getUserFunc: lookupUser,
@ -76,7 +77,6 @@ func (s *SSHServer) Start() error {
} }
func (s *SSHServer) connectionHandler(session ssh.Session) { func (s *SSHServer) connectionHandler(session ssh.Session) {
// Get uid and gid of user attempting to login // Get uid and gid of user attempting to login
sshUser, ok := session.Context().Value("sshUser").(*User) sshUser, ok := session.Context().Value("sshUser").(*User)
if !ok || sshUser == nil { if !ok || sshUser == nil {

View File

@ -6,11 +6,12 @@ import (
"errors" "errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"time"
) )
type SSHServer struct{} type SSHServer struct{}
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool) (*SSHServer, error) { func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) {
return nil, errors.New("cloudflared ssh server is not supported on windows") return nil, errors.New("cloudflared ssh server is not supported on windows")
} }