From 0675278fe25fc72234f82d5309f2fd718af8a551 Mon Sep 17 00:00:00 2001 From: Amolith Date: Fri, 22 Dec 2023 17:59:19 -0500 Subject: [PATCH] Implement migration system, add first migration Thank you for the help Chris! https://github.com/whereswaldon --- cmd/willow.go | 20 ++--- db/db.go | 40 +++++---- db/migrations.go | 133 ++++++++++++++++++++++++++++++ db/posthooks.go | 57 +++++++++++++ db/sql/1_add_project_ids.down.sql | 26 ++++++ db/sql/1_add_project_ids.up.sql | 14 ++++ project/project.go | 3 + 7 files changed, 267 insertions(+), 26 deletions(-) create mode 100644 db/migrations.go create mode 100644 db/posthooks.go create mode 100644 db/sql/1_add_project_ids.down.sql create mode 100644 db/sql/1_add_project_ids.up.sql diff --git a/cmd/willow.go b/cmd/willow.go index 46c83d0..0c3c7ea 100644 --- a/cmd/willow.go +++ b/cmd/willow.go @@ -63,18 +63,18 @@ func main() { os.Exit(1) } - fmt.Println("Verifying database schema") - err = db.VerifySchema(dbConn) + fmt.Println("Checking whether database needs initialising") + err = db.InitialiseDatabase(dbConn) if err != nil { - fmt.Println("Error verifying database schema:", err) - fmt.Println("Attempting to load schema") - err = db.LoadSchema(dbConn) - if err != nil { - fmt.Println("Error loading schema:", err) - os.Exit(1) - } + fmt.Println("Error initialising database:", err) + os.Exit(1) + } + fmt.Println("Checking whether there are pending migrations") + err = db.Migrate(dbConn) + if err != nil { + fmt.Println("Error migrating database schema:", err) + os.Exit(1) } - fmt.Println("Database schema verified") if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 { createUser(dbConn, *flagAddUser) diff --git a/db/db.go b/db/db.go index 8fc694a..e5c9d60 100644 --- a/db/db.go +++ b/db/db.go @@ -6,46 +6,54 @@ package db import ( "database/sql" - "embed" + _ "embed" _ "modernc.org/sqlite" ) -// Embed the schema into the binary -// -//go:embed sql -var embeddedSQL embed.FS +//go:embed sql/schema.sql +var schema string // Open opens a connection to the SQLite database func Open(dbPath string) (*sql.DB, error) { return sql.Open("sqlite", dbPath) } -func VerifySchema(dbConn *sql.DB) error { +// VerifySchema checks whether the schema has been initalised and initialises it +// if not +func InitialiseDatabase(dbConn *sql.DB) error { + var name string + err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'").Scan(&name) + if err == nil { + return nil + } + tables := []string{ "users", "sessions", "projects", + "releases", } for _, table := range tables { name := "" - err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name) + err := dbConn.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name=@table", + sql.Named("table", table), + ).Scan(&name) if err != nil { - return err + if err = loadSchema(dbConn); err != nil { + return err + } } } return nil } -// LoadSchema loads the schema into the database -func LoadSchema(dbConn *sql.DB) error { - schema, err := embeddedSQL.ReadFile("sql/schema.sql") - if err != nil { +// loadSchema loads the initial schema into the database +func loadSchema(dbConn *sql.DB) error { + if _, err := dbConn.Exec(schema); err != nil { return err } - - _, err = dbConn.Exec(string(schema)) - - return err + return nil } diff --git a/db/migrations.go b/db/migrations.go new file mode 100644 index 0000000..9b3c369 --- /dev/null +++ b/db/migrations.go @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: Chris Waldon +// +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "database/sql" + _ "embed" + "fmt" +) + +type migration struct { + upQuery string + downQuery string + postHook func(*sql.Tx) error +} + +var ( + //go:embed sql/1_add_project_ids.up.sql + migration1Up string + //go:embed sql/1_add_project_ids.down.sql + migration1Down string +) + +var migrations = [...]migration{ + 0: { + upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool); + INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`, + downQuery: `DROP TABLE schema_migrations;`, + }, + 1: { + upQuery: migration1Up, + downQuery: migration1Down, + postHook: generateAndInsertProjectIDs, + }, +} + +// Migrate runs all pending migrations +func Migrate(db *sql.DB) error { + version := getSchemaVersion(db) + for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ { + if err := runMigration(db, nextMigration); err != nil { + return fmt.Errorf("migrations failed: %w", err) + } + if version := getSchemaVersion(db); version != nextMigration { + return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version) + } + } + return nil +} + +// runMigration runs a single migration inside a transaction, updates the schema +// version and commits the transaction if successful, and rolls back the +// transaction if unsuccessful. +func runMigration(db *sql.DB, migrationIdx int) (err error) { + current := migrations[migrationIdx] + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err) + } + defer func() { + if err == nil { + err = tx.Commit() + } + if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) + } + } + }() + if len(current.upQuery) > 0 { + if _, err := tx.Exec(current.upQuery); err != nil { + return fmt.Errorf("failed running migration %d: %w", migrationIdx, err) + } + } + if current.postHook != nil { + if err := current.postHook(tx); err != nil { + return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err) + } + } + return updateSchemaVersion(tx, migrationIdx) +} + +// undoMigration rolls the single most recent migration back inside a +// transaction, updates the schema version and commits the transaction if +// successful, and rolls back the transaction if unsuccessful. +// +//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34) +func undoMigration(db *sql.DB, migrationIdx int) (err error) { + current := migrations[migrationIdx] + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err) + } + defer func() { + if err == nil { + err = tx.Commit() + } + if err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) + } + } + }() + if len(current.downQuery) > 0 { + if _, err := tx.Exec(current.downQuery); err != nil { + return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err) + } + } + return updateSchemaVersion(tx, migrationIdx-1) +} + +// getSchemaVersion returns the schema version from the database +func getSchemaVersion(db *sql.DB) int { + row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`) + var version int + if err := row.Scan(&version); err != nil { + version = -1 + } + return version +} + +// updateSchemaVersion sets the version to the provided int +func updateSchemaVersion(tx *sql.Tx, version int) error { + if version < 0 { + // Do not try to use the schema_migrations table in a schema version where it doesn't exist + return nil + } + _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version)) + return err +} diff --git a/db/posthooks.go b/db/posthooks.go new file mode 100644 index 0000000..c5be805 --- /dev/null +++ b/db/posthooks.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: Amolith +// +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "crypto/sha256" + "database/sql" + "fmt" +) + +// generateAndInsertProjectIDs runs during migration 1, fetches all rows from +// projects_tmp, loops through the rows generating a repeatable ID for each +// project, and inserting it into the new table along with the data from the old +// table. +func generateAndInsertProjectIDs(tx *sql.Tx) error { + // Loop through projects_tmp, generate a project_id for each, and insert + // into projects + rows, err := tx.Query("SELECT url, name, forge, version, created_at FROM projects_tmp") + if err != nil { + return fmt.Errorf("failed to list projects in projects_tmp: %w", err) + } + defer rows.Close() + + for rows.Next() { + var ( + url string + name string + forge string + version string + created_at string + ) + if err := rows.Scan(&url, &name, &forge, &version, &created_at); err != nil { + return fmt.Errorf("failed to scan row from projects_tmp: %w", err) + } + id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge+created_at))) + _, err = tx.Exec( + "INSERT INTO projects (id, url, name, forge, version, created_at) VALUES (@id, @url, @name, @forge, @version, @created_at)", + sql.Named("id", id), + sql.Named("url", url), + sql.Named("name", name), + sql.Named("forge", forge), + sql.Named("version", version), + sql.Named("created_at", created_at), + ) + if err != nil { + return fmt.Errorf("failed to insert project into projects: %w", err) + } + } + + if _, err := tx.Exec("DROP TABLE projects_tmp"); err != nil { + return fmt.Errorf("failed to drop projects_tmp: %w", err) + } + + return nil +} diff --git a/db/sql/1_add_project_ids.down.sql b/db/sql/1_add_project_ids.down.sql new file mode 100644 index 0000000..18d04cc --- /dev/null +++ b/db/sql/1_add_project_ids.down.sql @@ -0,0 +1,26 @@ +-- SPDX-FileCopyrightText: Amolith +-- +-- SPDX-License-Identifier: CC0-1.0 + +--ALTER TABLE projects RENAME TO projects_tmp; -- noqa + +ALTER TABLE projects RENAME TO projects_tmp; + +CREATE TABLE IF NOT EXISTS projects ( + url TEXT NOT NULL PRIMARY KEY, + name TEXT NOT NULL, + forge TEXT NOT NULL, + version TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +INSERT INTO projects (url, name, forge, version, created_at) +SELECT + url, + name, + forge, + version, + created_at +FROM projects_tmp; + +DROP TABLE projects_tmp; diff --git a/db/sql/1_add_project_ids.up.sql b/db/sql/1_add_project_ids.up.sql new file mode 100644 index 0000000..6821563 --- /dev/null +++ b/db/sql/1_add_project_ids.up.sql @@ -0,0 +1,14 @@ +-- SPDX-FileCopyrightText: Amolith +-- +-- SPDX-License-Identifier: CC0-1.0 + +ALTER TABLE projects RENAME TO projects_tmp; + +CREATE TABLE IF NOT EXISTS projects ( + id TEXT NOT NULL PRIMARY KEY, + url TEXT NOT NULL, + name TEXT NOT NULL, + forge TEXT NOT NULL, + version TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/project/project.go b/project/project.go index 9b0bc56..a5078de 100644 --- a/project/project.go +++ b/project/project.go @@ -29,6 +29,7 @@ type Project struct { } type Release struct { + ID string URL string Tag string Content string @@ -70,6 +71,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) { } for _, release := range rssReleases { p.Releases = append(p.Releases, Release{ + ID: genReleaseID(p.URL, release.URL, release.Tag), Tag: release.Tag, Content: release.Content, URL: release.URL, @@ -88,6 +90,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) { } for _, release := range gitReleases { p.Releases = append(p.Releases, Release{ + ID: genReleaseID(p.URL, release.URL, release.Tag), Tag: release.Tag, Content: release.Content, URL: release.URL,