parent
b3bcce97da
commit
dbde3870da
@ -1,108 +0,0 @@
|
||||
//+build !windows
|
||||
|
||||
package sshserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
systemConfigPath = "/etc/cloudflared/"
|
||||
authorizedKeysDir = ".cloudflared/authorized_keys"
|
||||
)
|
||||
|
||||
func (s *SSHServer) configureAuthentication() {
|
||||
caCert, err := getCACert()
|
||||
if err != nil {
|
||||
s.logger.Info(err)
|
||||
}
|
||||
s.caCert = caCert
|
||||
s.PublicKeyHandler = s.authenticationHandler
|
||||
}
|
||||
|
||||
// authenticationHandler is a callback that returns true if the user attempting to connect is authenticated.
|
||||
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
sshUser, err := lookupUser(ctx.User())
|
||||
if err != nil {
|
||||
s.logger.Debugf("Invalid user: %s", ctx.User())
|
||||
return false
|
||||
}
|
||||
ctx.SetValue("sshUser", sshUser)
|
||||
|
||||
cert, ok := key.(*gossh.Certificate)
|
||||
if !ok {
|
||||
return s.authorizedKeyHandler(ctx, key)
|
||||
}
|
||||
return s.shortLivedCertHandler(ctx, cert)
|
||||
}
|
||||
|
||||
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
sshUser, ok := ctx.Value("sshUser").(*User)
|
||||
if !ok {
|
||||
s.logger.Error("Failed to retrieve user from context")
|
||||
return false
|
||||
}
|
||||
|
||||
authorizedKeysPath := path.Join(sshUser.HomeDir, authorizedKeysDir)
|
||||
if _, err := os.Stat(authorizedKeysPath); os.IsNotExist(err) {
|
||||
s.logger.Debugf("authorized_keys file %s not found", authorizedKeysPath)
|
||||
return false
|
||||
}
|
||||
|
||||
authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysPath)
|
||||
if err != nil {
|
||||
s.logger.WithError(err).Errorf("Failed to load authorized_keys %s", authorizedKeysPath)
|
||||
return false
|
||||
}
|
||||
|
||||
for len(authorizedKeysBytes) > 0 {
|
||||
// Skips invalid keys. Returns error if no valid keys remain.
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||
authorizedKeysBytes = rest
|
||||
if err != nil {
|
||||
s.logger.Errorf("Invalid key(s) found in %s", authorizedKeysPath)
|
||||
return false
|
||||
}
|
||||
|
||||
if ssh.KeysEqual(pubKey, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
s.logger.Debugf("Matching public key not found in %s", authorizedKeysPath)
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certificate) bool {
|
||||
if !ssh.KeysEqual(s.caCert, cert.SignatureKey) {
|
||||
s.logger.Debug("CA certificate does not match user certificate signer")
|
||||
return false
|
||||
}
|
||||
|
||||
checker := gossh.CertChecker{}
|
||||
if err := checker.CheckCert(ctx.User(), cert); err != nil {
|
||||
s.logger.Debug(err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func getCACert() (ssh.PublicKey, error) {
|
||||
caCertPath := path.Join(systemConfigPath, "ca.pub")
|
||||
caCertBytes, err := ioutil.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("Failed to load CA certificate %s", caCertPath))
|
||||
}
|
||||
caCert, _, _, _, err := ssh.ParseAuthorizedKey(caCertBytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to parse CA Certificate")
|
||||
}
|
||||
|
||||
return caCert, nil
|
||||
}
|
@ -1,185 +0,0 @@
|
||||
//+build !windows
|
||||
|
||||
package sshserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/log"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
testDir = "testdata"
|
||||
testUserKeyFilename = "id_rsa.pub"
|
||||
testCAFilename = "ca.pub"
|
||||
testOtherCAFilename = "other_ca.pub"
|
||||
testUserCertFilename = "id_rsa-cert.pub"
|
||||
)
|
||||
|
||||
var (
|
||||
logger, hook = test.NewNullLogger()
|
||||
mockUser = &User{Username: "testUser", HomeDir: testDir}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
authorizedKeysDir = testUserKeyFilename
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_Success(t *testing.T) {
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{logger: logger}
|
||||
|
||||
pubKey := getKey(t, testUserKeyFilename)
|
||||
assert.True(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_MissingKey(t *testing.T) {
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{logger: logger}
|
||||
|
||||
pubKey := getKey(t, testOtherCAFilename)
|
||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
assert.Contains(t, hook.LastEntry().Message, "Matching public key not found in")
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_InvalidUser(t *testing.T) {
|
||||
context, cancel := newMockContext(&User{Username: "notAUser"})
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{logger: logger}
|
||||
|
||||
pubKey := getKey(t, testUserKeyFilename)
|
||||
assert.False(t, sshServer.authenticationHandler(context, pubKey))
|
||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_MissingFile(t *testing.T) {
|
||||
tempUser, err := user.Current()
|
||||
require.Nil(t, err)
|
||||
currentUser, err := lookupUser(tempUser.Username)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, err)
|
||||
context, cancel := newMockContext(currentUser)
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{Server: ssh.Server{}, logger: logger}
|
||||
|
||||
pubKey := getKey(t, testUserKeyFilename)
|
||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
assert.Contains(t, hook.LastEntry().Message, "not found")
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_Success(t *testing.T) {
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.True(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testOtherCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||
context, cancel := newMockContext(&User{Username: "NotAUser"})
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
||||
}
|
||||
|
||||
func getKey(t *testing.T, filename string) ssh.PublicKey {
|
||||
path := path.Join(testDir, filename)
|
||||
bytes, err := ioutil.ReadFile(path)
|
||||
require.Nil(t, err)
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(bytes)
|
||||
require.Nil(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
type mockSSHContext struct {
|
||||
context.Context
|
||||
*sync.Mutex
|
||||
}
|
||||
|
||||
func newMockContext(user *User) (*mockSSHContext, context.CancelFunc) {
|
||||
innerCtx, cancel := context.WithCancel(context.Background())
|
||||
mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}}
|
||||
mockCtx.SetValue("sshUser", user)
|
||||
|
||||
// This naming is confusing but we cant change it because this mocks the SSHContext struct in gliderlabs/ssh
|
||||
mockCtx.SetValue("user", user.Username)
|
||||
return mockCtx, cancel
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) SetValue(key, value interface{}) {
|
||||
ctx.Context = context.WithValue(ctx.Context, key, value)
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) User() string {
|
||||
return ctx.Value("user").(string)
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) SessionID() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) ClientVersion() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) ServerVersion() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *mockSSHContext) Permissions() *ssh.Permissions {
|
||||
return nil
|
||||
}
|
@ -1,226 +0,0 @@
|
||||
// Taken from https://github.com/golang/go/blob/ad644d2e86bab85787879d41c2d2aebbd7c57db8/src/os/user/user.go
|
||||
// and modified to return login shell in User struct. cloudflared requires cgo for compilation because of this addition.
|
||||
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris
|
||||
// +build cgo,!osusergo
|
||||
|
||||
package sshserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
/*
|
||||
#cgo solaris CFLAGS: -D_POSIX_PTHREAD_SEMANTICS
|
||||
#include <unistd.h>
|
||||
#include <sys/types.h>
|
||||
#include <pwd.h>
|
||||
#include <grp.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static int mygetpwuid_r(int uid, struct passwd *pwd,
|
||||
char *buf, size_t buflen, struct passwd **result) {
|
||||
return getpwuid_r(uid, pwd, buf, buflen, result);
|
||||
}
|
||||
|
||||
static int mygetpwnam_r(const char *name, struct passwd *pwd,
|
||||
char *buf, size_t buflen, struct passwd **result) {
|
||||
return getpwnam_r(name, pwd, buf, buflen, result);
|
||||
}
|
||||
|
||||
static int mygetgrgid_r(int gid, struct group *grp,
|
||||
char *buf, size_t buflen, struct group **result) {
|
||||
return getgrgid_r(gid, grp, buf, buflen, result);
|
||||
}
|
||||
|
||||
static int mygetgrnam_r(const char *name, struct group *grp,
|
||||
char *buf, size_t buflen, struct group **result) {
|
||||
return getgrnam_r(name, grp, buf, buflen, result);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
type UnknownUserIdError int
|
||||
|
||||
func (e UnknownUserIdError) Error() string {
|
||||
return "user: unknown userid " + strconv.Itoa(int(e))
|
||||
}
|
||||
|
||||
// UnknownUserError is returned by Lookup when
|
||||
// a user cannot be found.
|
||||
type UnknownUserError string
|
||||
|
||||
func (e UnknownUserError) Error() string {
|
||||
return "user: unknown user " + string(e)
|
||||
}
|
||||
|
||||
// UnknownGroupIdError is returned by LookupGroupId when
|
||||
// a group cannot be found.
|
||||
type UnknownGroupIdError string
|
||||
|
||||
func (e UnknownGroupIdError) Error() string {
|
||||
return "group: unknown groupid " + string(e)
|
||||
}
|
||||
|
||||
// UnknownGroupError is returned by LookupGroup when
|
||||
// a group cannot be found.
|
||||
type UnknownGroupError string
|
||||
|
||||
func (e UnknownGroupError) Error() string {
|
||||
return "group: unknown group " + string(e)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
// Uid is the user ID.
|
||||
// On POSIX systems, this is a decimal number representing the uid.
|
||||
// On Windows, this is a security identifier (SID) in a string format.
|
||||
// On Plan 9, this is the contents of /dev/user.
|
||||
Uid string
|
||||
// Gid is the primary group ID.
|
||||
// On POSIX systems, this is a decimal number representing the gid.
|
||||
// On Windows, this is a SID in a string format.
|
||||
// On Plan 9, this is the contents of /dev/user.
|
||||
Gid string
|
||||
// Username is the login name.
|
||||
Username string
|
||||
// Name is the user's real or display name.
|
||||
// It might be blank.
|
||||
// On POSIX systems, this is the first (or only) entry in the GECOS field
|
||||
// list.
|
||||
// On Windows, this is the user's display name.
|
||||
// On Plan 9, this is the contents of /dev/user.
|
||||
Name string
|
||||
// HomeDir is the path to the user's home directory (if they have one).
|
||||
HomeDir string
|
||||
|
||||
/****************** Begin added code ******************/
|
||||
// Login shell
|
||||
Shell string
|
||||
/****************** End added code ******************/
|
||||
}
|
||||
|
||||
func lookupUser(username string) (*User, error) {
|
||||
var pwd C.struct_passwd
|
||||
var result *C.struct_passwd
|
||||
nameC := make([]byte, len(username)+1)
|
||||
copy(nameC, username)
|
||||
|
||||
buf := alloc(userBuffer)
|
||||
defer buf.free()
|
||||
|
||||
err := retryWithBuffer(buf, func() syscall.Errno {
|
||||
// mygetpwnam_r is a wrapper around getpwnam_r to avoid
|
||||
// passing a size_t to getpwnam_r, because for unknown
|
||||
// reasons passing a size_t to getpwnam_r doesn't work on
|
||||
// Solaris.
|
||||
return syscall.Errno(C.mygetpwnam_r((*C.char)(unsafe.Pointer(&nameC[0])),
|
||||
&pwd,
|
||||
(*C.char)(buf.ptr),
|
||||
C.size_t(buf.size),
|
||||
&result))
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user: lookup username %s: %v", username, err)
|
||||
}
|
||||
if result == nil {
|
||||
return nil, UnknownUserError(username)
|
||||
}
|
||||
return buildUser(&pwd), err
|
||||
}
|
||||
|
||||
func buildUser(pwd *C.struct_passwd) *User {
|
||||
u := &User{
|
||||
Uid: strconv.FormatUint(uint64(pwd.pw_uid), 10),
|
||||
Gid: strconv.FormatUint(uint64(pwd.pw_gid), 10),
|
||||
Username: C.GoString(pwd.pw_name),
|
||||
Name: C.GoString(pwd.pw_gecos),
|
||||
HomeDir: C.GoString(pwd.pw_dir),
|
||||
/****************** Begin added code ******************/
|
||||
Shell: C.GoString(pwd.pw_shell),
|
||||
/****************** End added code ******************/
|
||||
}
|
||||
// The pw_gecos field isn't quite standardized. Some docs
|
||||
// say: "It is expected to be a comma separated list of
|
||||
// personal data where the first item is the full name of the
|
||||
// user."
|
||||
if i := strings.Index(u.Name, ","); i >= 0 {
|
||||
u.Name = u.Name[:i]
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
type bufferKind C.int
|
||||
|
||||
const (
|
||||
userBuffer = bufferKind(C._SC_GETPW_R_SIZE_MAX)
|
||||
)
|
||||
|
||||
func (k bufferKind) initialSize() C.size_t {
|
||||
sz := C.sysconf(C.int(k))
|
||||
if sz == -1 {
|
||||
// DragonFly and FreeBSD do not have _SC_GETPW_R_SIZE_MAX.
|
||||
// Additionally, not all Linux systems have it, either. For
|
||||
// example, the musl libc returns -1.
|
||||
return 1024
|
||||
}
|
||||
if !isSizeReasonable(int64(sz)) {
|
||||
// Truncate. If this truly isn't enough, retryWithBuffer will error on the first run.
|
||||
return maxBufferSize
|
||||
}
|
||||
return C.size_t(sz)
|
||||
}
|
||||
|
||||
type memBuffer struct {
|
||||
ptr unsafe.Pointer
|
||||
size C.size_t
|
||||
}
|
||||
|
||||
func alloc(kind bufferKind) *memBuffer {
|
||||
sz := kind.initialSize()
|
||||
return &memBuffer{
|
||||
ptr: C.malloc(sz),
|
||||
size: sz,
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *memBuffer) resize(newSize C.size_t) {
|
||||
mb.ptr = C.realloc(mb.ptr, newSize)
|
||||
mb.size = newSize
|
||||
}
|
||||
|
||||
func (mb *memBuffer) free() {
|
||||
C.free(mb.ptr)
|
||||
}
|
||||
|
||||
// retryWithBuffer repeatedly calls f(), increasing the size of the
|
||||
// buffer each time, until f succeeds, fails with a non-ERANGE error,
|
||||
// or the buffer exceeds a reasonable limit.
|
||||
func retryWithBuffer(buf *memBuffer, f func() syscall.Errno) error {
|
||||
for {
|
||||
errno := f()
|
||||
if errno == 0 {
|
||||
return nil
|
||||
} else if errno != syscall.ERANGE {
|
||||
return errno
|
||||
}
|
||||
newSize := buf.size * 2
|
||||
if !isSizeReasonable(int64(newSize)) {
|
||||
return fmt.Errorf("internal buffer exceeds %d bytes", maxBufferSize)
|
||||
}
|
||||
buf.resize(newSize)
|
||||
}
|
||||
}
|
||||
|
||||
const maxBufferSize = 1 << 20
|
||||
|
||||
func isSizeReasonable(sz int64) bool {
|
||||
return sz > 0 && sz <= maxBufferSize
|
||||
}
|
Loading…
Reference in new issue