AUTH-2022: Adds ssh timeout configuration
This commit is contained in:
parent
baec3e289e
commit
858ef29868
|
@ -40,7 +40,21 @@ import (
|
||||||
"gopkg.in/urfave/cli.v2/altsrc"
|
"gopkg.in/urfave/cli.v2/altsrc"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
|
const (
|
||||||
|
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
|
||||||
|
|
||||||
|
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
|
||||||
|
sshPortFlag = "local-ssh-port"
|
||||||
|
|
||||||
|
// shortLivedCertFlag enables short lived cert authentication
|
||||||
|
shortLivedCertFlag = "short-lived-certs"
|
||||||
|
|
||||||
|
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
|
||||||
|
sshIdleTimeoutFlag = "ssh-idle-timeout"
|
||||||
|
|
||||||
|
// sshMaxTimeoutFlag defines the max duration a SSH session can remain open for
|
||||||
|
sshMaxTimeoutFlag = "ssh-max-timeout"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
shutdownC chan struct{}
|
shutdownC chan struct{}
|
||||||
|
@ -337,8 +351,8 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
||||||
|
|
||||||
logger.Infof("ssh-server set")
|
logger.Infof("ssh-server set")
|
||||||
|
|
||||||
sshServerAddress := "127.0.0.1:" + c.String("local-ssh-port")
|
sshServerAddress := "127.0.0.1:" + c.String(sshPortFlag)
|
||||||
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool("short-lived-certs"))
|
server, err := sshserver.New(logger, sshServerAddress, shutdownC, c.Bool(shortLivedCertFlag), c.Duration(sshIdleTimeoutFlag), c.Duration(sshMaxTimeoutFlag))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Cannot create new SSH Server")
|
logger.WithError(err).Error("Cannot create new SSH Server")
|
||||||
return errors.Wrap(err, "Cannot create new SSH Server")
|
return errors.Wrap(err, "Cannot create new SSH Server")
|
||||||
|
@ -909,17 +923,29 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "local-ssh-port",
|
Name: sshPortFlag,
|
||||||
Usage: "Localhost port that cloudflared SSH server will run on",
|
Usage: "Localhost port that cloudflared SSH server will run on",
|
||||||
Value: "22",
|
Value: "22",
|
||||||
EnvVars: []string{"LOCAL_SSH_PORT"},
|
EnvVars: []string{"LOCAL_SSH_PORT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
Name: "short-lived-certs",
|
Name: shortLivedCertFlag,
|
||||||
Usage: "Enable short lived cert authentication for SSH server",
|
Usage: "Enable short lived cert authentication for SSH server",
|
||||||
EnvVars: []string{"SHORT_LIVED_CERTS"},
|
EnvVars: []string{"SHORT_LIVED_CERTS"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: sshIdleTimeoutFlag,
|
||||||
|
Usage: "Connection timeout after no activity",
|
||||||
|
EnvVars: []string{"SSH_IDLE_TIMEOUT"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: sshMaxTimeoutFlag,
|
||||||
|
Usage: "Absolute connection timeout",
|
||||||
|
EnvVars: []string{"SSH_MAX_TIMEOUT"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
var (
|
var (
|
||||||
systemConfigPath = "/etc/cloudflared/"
|
systemConfigPath = "/etc/cloudflared/"
|
||||||
authorizeKeysPath = ".cloudflared/authorized_keys"
|
authorizedKeysDir = ".cloudflared/authorized_keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||||
|
@ -25,9 +25,9 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizedKeysPath := path.Join(sshUser.HomeDir, authorizeKeysPath)
|
authorizedKeysPath := path.Join(sshUser.HomeDir, authorizedKeysDir)
|
||||||
if _, err := os.Stat(authorizedKeysPath); os.IsNotExist(err) {
|
if _, err := os.Stat(authorizedKeysPath); os.IsNotExist(err) {
|
||||||
s.logger.Debugf("authorized_keys file %s not found", authorizeKeysPath)
|
s.logger.Debugf("authorized_keys file %s not found", authorizedKeysPath)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,11 +38,12 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||||
}
|
}
|
||||||
|
|
||||||
for len(authorizedKeysBytes) > 0 {
|
for len(authorizedKeysBytes) > 0 {
|
||||||
|
|
||||||
// Skips invalid keys. Returns error if no valid keys remain.
|
// Skips invalid keys. Returns error if no valid keys remain.
|
||||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||||
authorizedKeysBytes = rest
|
authorizedKeysBytes = rest
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.WithError(err).Errorf("No valid keys found in %s", authorizeKeysPath)
|
s.logger.Errorf("Invalid key(s) found in %s", authorizedKeysPath)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,7 +52,7 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.logger.Debugf("Matching public key not found in %s", authorizeKeysPath)
|
s.logger.Debugf("Matching public key not found in %s", authorizedKeysPath)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ const (
|
||||||
var logger, hook = test.NewNullLogger()
|
var logger, hook = test.NewNullLogger()
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
authorizeKeysPath = testUserKeyFilename
|
authorizedKeysDir = testUserKeyFilename
|
||||||
logger.SetLevel(logrus.DebugLevel)
|
logger.SetLevel(logrus.DebugLevel)
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"os/user"
|
"os/user"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/creack/pty"
|
"github.com/creack/pty"
|
||||||
|
@ -27,7 +28,7 @@ type SSHServer struct {
|
||||||
getUserFunc func(string) (*User, error)
|
getUserFunc func(string) (*User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool) (*SSHServer, error) {
|
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -37,7 +38,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi
|
||||||
}
|
}
|
||||||
|
|
||||||
sshServer := SSHServer{
|
sshServer := SSHServer{
|
||||||
Server: ssh.Server{Addr: address},
|
Server: ssh.Server{Addr: address, MaxTimeout: maxTimeout, IdleTimeout: idleTimeout},
|
||||||
logger: logger,
|
logger: logger,
|
||||||
shutdownC: shutdownC,
|
shutdownC: shutdownC,
|
||||||
getUserFunc: lookupUser,
|
getUserFunc: lookupUser,
|
||||||
|
@ -76,7 +77,6 @@ func (s *SSHServer) Start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHServer) connectionHandler(session ssh.Session) {
|
func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||||
|
|
||||||
// Get uid and gid of user attempting to login
|
// Get uid and gid of user attempting to login
|
||||||
sshUser, ok := session.Context().Value("sshUser").(*User)
|
sshUser, ok := session.Context().Value("sshUser").(*User)
|
||||||
if !ok || sshUser == nil {
|
if !ok || sshUser == nil {
|
||||||
|
|
|
@ -6,11 +6,12 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSHServer struct{}
|
type SSHServer struct{}
|
||||||
|
|
||||||
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool) (*SSHServer, error) {
|
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) {
|
||||||
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue