diff --git a/auth.go b/auth.go index f62b328..8fe59e3 100644 --- a/auth.go +++ b/auth.go @@ -26,15 +26,25 @@ import ( passlib "gopkg.in/hlandau/passlib.v1" ) +type AuthCtx struct { + reader func(string) ([]byte, error) // eg. ioutil.ReadFile() + userlookup func(string) (*user.User, error) // eg. os/user.Lookup() +} + +func NewAuthCtx(/*reader func(string) ([]byte, error), userlookup func(string) (*user.User, error)*/) (ret *AuthCtx) { + ret = &AuthCtx{ioutil.ReadFile, user.Lookup} + return +} + // --------- System passwd/shadow auth routine(s) -------------- // Verify a password against system standard shadow file // Note auxilliary fields for expiry policy are *not* inspected. -func VerifyPass(reader func(string) ([]byte, error), user, password string) (bool, error) { - if reader == nil { - reader = ioutil.ReadFile // dependency injection hides that this is required +func VerifyPass(ctx *AuthCtx, user, password string) (bool, error) { + if ctx.reader == nil { + ctx.reader = ioutil.ReadFile // dependency injection hides that this is required } passlib.UseDefaults(passlib.Defaults20180601) - pwFileData, e := reader("/etc/shadow") + pwFileData, e := ctx.reader("/etc/shadow") if e != nil { return false, e } @@ -73,11 +83,14 @@ func VerifyPass(reader func(string) ([]byte, error), user, password string) (boo // This checks /etc/xs.passwd for auth info, and system /etc/passwd // to cross-check the user actually exists. // nolint: gocyclo -func AuthUserByPasswd(reader func(string) ([]byte, error), userlookup func(string) (*user.User, error), username string, auth string, fname string) (valid bool, allowedCmds string) { - if reader == nil { - reader = ioutil.ReadFile // dependency injection hides that this is required +func AuthUserByPasswd(ctx *AuthCtx, username string, auth string, fname string) (valid bool, allowedCmds string) { + if ctx.reader == nil { + ctx.reader = ioutil.ReadFile // dependency injection hides that this is required } - b, e := reader(fname) // nolint: gosec + if ctx.userlookup == nil { + ctx.userlookup = user.Lookup // again for dependency injection as dep is now hidden + } + b, e := ctx.reader(fname) // nolint: gosec if e != nil { valid = false log.Printf("ERROR: Cannot read %s!\n", fname) @@ -121,7 +134,7 @@ func AuthUserByPasswd(reader func(string) ([]byte, error), userlookup func(strin r = nil runtime.GC() - _, userErr := userlookup(username) + _, userErr := ctx.userlookup(username) if userErr != nil { valid = false } @@ -135,21 +148,21 @@ func AuthUserByPasswd(reader func(string) ([]byte, error), userlookup func(strin // via the -g option. // The function also check system /etc/passwd to cross-check the user // actually exists. -func AuthUserByToken(reader func(string) ([]byte, error), userlookup func(string) (*user.User, error), username string, connhostname string, auth string) (valid bool) { - if reader == nil { - reader = ioutil.ReadFile // dependency injection hides that this is required +func AuthUserByToken(ctx *AuthCtx, username string, connhostname string, auth string) (valid bool) { + if ctx.reader == nil { + ctx.reader = ioutil.ReadFile // dependency injection hides that this is required } - if userlookup == nil { - userlookup = user.Lookup // again for dependency injection as dep is now hidden + if ctx.userlookup == nil { + ctx.userlookup = user.Lookup // again for dependency injection as dep is now hidden } auth = strings.TrimSpace(auth) - u, ue := userlookup(username) + u, ue := ctx.userlookup(username) if ue != nil { return false } - b, e := reader(fmt.Sprintf("%s/.xs_id", u.HomeDir)) + b, e := ctx.reader(fmt.Sprintf("%s/.xs_id", u.HomeDir)) if e != nil { log.Printf("INFO: Cannot read %s/.xs_id\n", u.HomeDir) return false @@ -176,7 +189,7 @@ func AuthUserByToken(reader func(string) ([]byte, error), userlookup func(string break } } - _, userErr := userlookup(username) + _, userErr := ctx.userlookup(username) if userErr != nil { valid = false } diff --git a/auth_test.go b/auth_test.go index 2d90e33..3672607 100644 --- a/auth_test.go +++ b/auth_test.go @@ -31,6 +31,11 @@ disableduser:!:18310::::::` readfile_arg_f string ) +func newMockAuthCtx(reader func(string) ([]byte, error), userlookup func(string) (*user.User, error)) (ret *AuthCtx) { + ret = &AuthCtx{reader, userlookup} + return +} + func _mock_user_Lookup(username string) (*user.User, error) { username = userlookup_arg_u if username == "baduser" { @@ -64,8 +69,9 @@ func _mock_ioutil_ReadFileHasError(f string) ([]byte, error) { func TestVerifyPass(t *testing.T) { readfile_arg_f = "/etc/shadow" + ctx := newMockAuthCtx(_mock_ioutil_ReadFile, nil) for idx, rec := range testGoodUsers { - stat, e := VerifyPass(_mock_ioutil_ReadFile, rec.user, rec.passwd) + stat, e := VerifyPass(ctx, rec.user, rec.passwd) if rec.good && (!stat || e != nil) { t.Fatalf("failed %d\n", idx) } @@ -73,21 +79,24 @@ func TestVerifyPass(t *testing.T) { } func TestVerifyPassFailsOnEmptyFile(t *testing.T) { - stat, e := VerifyPass(_mock_ioutil_ReadFileEmpty, "johndoe", "sompass") + ctx := newMockAuthCtx(_mock_ioutil_ReadFileEmpty, nil) + stat, e := VerifyPass(ctx, "johndoe", "somepass") if stat || (e == nil) { t.Fatal("failed to fail w/empty file") } } func TestVerifyPassFailsOnFileError(t *testing.T) { - stat, e := VerifyPass(_mock_ioutil_ReadFileEmpty, "johndoe", "somepass") + ctx := newMockAuthCtx(_mock_ioutil_ReadFileEmpty, nil) + stat, e := VerifyPass(ctx, "johndoe", "somepass") if stat || (e == nil) { t.Fatal("failed to fail on ioutil.ReadFile error") } } func TestVerifyPassFailsOnDisabledEntry(t *testing.T) { - stat, e := VerifyPass(_mock_ioutil_ReadFileEmpty, "disableduser", "!") + ctx := newMockAuthCtx(_mock_ioutil_ReadFileEmpty, nil) + stat, e := VerifyPass(ctx, "disableduser", "!") if stat || (e == nil) { t.Fatal("failed to fail on disabled user entry") } @@ -96,38 +105,43 @@ func TestVerifyPassFailsOnDisabledEntry(t *testing.T) { //// func TestAuthUserByTokenFailsOnMissingEntryForHost(t *testing.T) { - stat := AuthUserByToken(_mock_ioutil_ReadFile, _mock_user_Lookup, "johndoe", "hostZ", "abcdefg") + ctx := newMockAuthCtx(_mock_ioutil_ReadFile, _mock_user_Lookup) + stat := AuthUserByToken(ctx, "johndoe", "hostZ", "abcdefg") if stat { t.Fatal("failed to fail on missing/mismatched host entry") } } func TestAuthUserByTokenFailsOnMissingEntryForUser(t *testing.T) { - stat := AuthUserByToken(_mock_ioutil_ReadFile, _mock_user_Lookup, "unkuser", "hostA", "abcdefg") + ctx := newMockAuthCtx(_mock_ioutil_ReadFile, _mock_user_Lookup) + stat := AuthUserByToken(ctx, "unkuser", "hostA", "abcdefg") if stat { t.Fatal("failed to fail on wrong user") } } func TestAuthUserByTokenFailsOnUserLookupFailure(t *testing.T) { + ctx := newMockAuthCtx(_mock_ioutil_ReadFile, _mock_user_Lookup) userlookup_arg_u = "baduser" - stat := AuthUserByToken(_mock_ioutil_ReadFile, _mock_user_Lookup, "johndoe", "hostA", "abcdefg") + stat := AuthUserByToken(ctx, "johndoe", "hostA", "abcdefg") if stat { t.Fatal("failed to fail with bad return from user.Lookup()") } } func TestAuthUserByTokenFailsOnMismatchedTokenForUser(t *testing.T) { - stat := AuthUserByToken(_mock_ioutil_ReadFile, _mock_user_Lookup, "johndoe", "hostA", "badtoken") + ctx := newMockAuthCtx(_mock_ioutil_ReadFile, _mock_user_Lookup) + stat := AuthUserByToken(ctx, "johndoe", "hostA", "badtoken") if stat { t.Fatal("failed to fail with valid user, bad token") } } func TestAuthUserByTokenSucceedsWithMatchedUserAndToken(t *testing.T) { + ctx := newMockAuthCtx(_mock_ioutil_ReadFile, _mock_user_Lookup) userlookup_arg_u = "johndoe" readfile_arg_f = "/.xs_id" - stat := AuthUserByToken(_mock_ioutil_ReadFile, _mock_user_Lookup, userlookup_arg_u, "hostA", "hostA:abcdefg") + stat := AuthUserByToken(ctx, userlookup_arg_u, "hostA", "hostA:abcdefg") if !stat { t.Fatal("failed with valid user and token") } diff --git a/xsd/xsd.go b/xsd/xsd.go index 528d094..e46c2a7 100755 --- a/xsd/xsd.go +++ b/xsd/xsd.go @@ -708,14 +708,14 @@ func main() { var valid bool var allowedCmds string // Currently unused - if xs.AuthUserByToken(ioutil.ReadFile, user.Lookup, string(rec.Who()), string(rec.ConnHost()), string(rec.AuthCookie(true))) { + if xs.AuthUserByToken(xs.NewAuthCtx(), string(rec.Who()), string(rec.ConnHost()), string(rec.AuthCookie(true))) { valid = true } else { if useSystemPasswd { //var passErr error - valid, _ /*passErr*/ = xs.VerifyPass(ioutil.ReadFile, string(rec.Who()), string(rec.AuthCookie(true))) + valid, _ /*passErr*/ = xs.VerifyPass(xs.NewAuthCtx(), string(rec.Who()), string(rec.AuthCookie(true))) } else { - valid, allowedCmds = xs.AuthUserByPasswd(ioutil.ReadFile, user.Lookup, string(rec.Who()), string(rec.AuthCookie(true)), "/etc/xs.passwd") + valid, allowedCmds = xs.AuthUserByPasswd(xs.NewAuthCtx(), string(rec.Who()), string(rec.AuthCookie(true)), "/etc/xs.passwd") } }