Implement migration system, add first migration
Thank you for the help Chris! https://github.com/whereswaldon
This commit is contained in:
parent
fc27ee8438
commit
0675278fe2
|
@ -63,18 +63,18 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Verifying database schema")
|
fmt.Println("Checking whether database needs initialising")
|
||||||
err = db.VerifySchema(dbConn)
|
err = db.InitialiseDatabase(dbConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error verifying database schema:", err)
|
fmt.Println("Error initialising database:", err)
|
||||||
fmt.Println("Attempting to load schema")
|
os.Exit(1)
|
||||||
err = db.LoadSchema(dbConn)
|
}
|
||||||
if err != nil {
|
fmt.Println("Checking whether there are pending migrations")
|
||||||
fmt.Println("Error loading schema:", err)
|
err = db.Migrate(dbConn)
|
||||||
os.Exit(1)
|
if err != nil {
|
||||||
}
|
fmt.Println("Error migrating database schema:", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
fmt.Println("Database schema verified")
|
|
||||||
|
|
||||||
if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
|
if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
|
||||||
createUser(dbConn, *flagAddUser)
|
createUser(dbConn, *flagAddUser)
|
||||||
|
|
40
db/db.go
40
db/db.go
|
@ -6,46 +6,54 @@ package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"embed"
|
_ "embed"
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Embed the schema into the binary
|
//go:embed sql/schema.sql
|
||||||
//
|
var schema string
|
||||||
//go:embed sql
|
|
||||||
var embeddedSQL embed.FS
|
|
||||||
|
|
||||||
// Open opens a connection to the SQLite database
|
// Open opens a connection to the SQLite database
|
||||||
func Open(dbPath string) (*sql.DB, error) {
|
func Open(dbPath string) (*sql.DB, error) {
|
||||||
return sql.Open("sqlite", dbPath)
|
return sql.Open("sqlite", dbPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifySchema(dbConn *sql.DB) error {
|
// VerifySchema checks whether the schema has been initalised and initialises it
|
||||||
|
// if not
|
||||||
|
func InitialiseDatabase(dbConn *sql.DB) error {
|
||||||
|
var name string
|
||||||
|
err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'").Scan(&name)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
tables := []string{
|
tables := []string{
|
||||||
"users",
|
"users",
|
||||||
"sessions",
|
"sessions",
|
||||||
"projects",
|
"projects",
|
||||||
|
"releases",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
name := ""
|
name := ""
|
||||||
err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name)
|
err := dbConn.QueryRow(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=@table",
|
||||||
|
sql.Named("table", table),
|
||||||
|
).Scan(&name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if err = loadSchema(dbConn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSchema loads the schema into the database
|
// loadSchema loads the initial schema into the database
|
||||||
func LoadSchema(dbConn *sql.DB) error {
|
func loadSchema(dbConn *sql.DB) error {
|
||||||
schema, err := embeddedSQL.ReadFile("sql/schema.sql")
|
if _, err := dbConn.Exec(schema); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
_, err = dbConn.Exec(string(schema))
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
// SPDX-FileCopyrightText: Chris Waldon <christopher.waldon.dev@gmail.com>
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type migration struct {
|
||||||
|
upQuery string
|
||||||
|
downQuery string
|
||||||
|
postHook func(*sql.Tx) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed sql/1_add_project_ids.up.sql
|
||||||
|
migration1Up string
|
||||||
|
//go:embed sql/1_add_project_ids.down.sql
|
||||||
|
migration1Down string
|
||||||
|
)
|
||||||
|
|
||||||
|
var migrations = [...]migration{
|
||||||
|
0: {
|
||||||
|
upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool);
|
||||||
|
INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`,
|
||||||
|
downQuery: `DROP TABLE schema_migrations;`,
|
||||||
|
},
|
||||||
|
1: {
|
||||||
|
upQuery: migration1Up,
|
||||||
|
downQuery: migration1Down,
|
||||||
|
postHook: generateAndInsertProjectIDs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migrate runs all pending migrations
|
||||||
|
func Migrate(db *sql.DB) error {
|
||||||
|
version := getSchemaVersion(db)
|
||||||
|
for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
|
||||||
|
if err := runMigration(db, nextMigration); err != nil {
|
||||||
|
return fmt.Errorf("migrations failed: %w", err)
|
||||||
|
}
|
||||||
|
if version := getSchemaVersion(db); version != nextMigration {
|
||||||
|
return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runMigration runs a single migration inside a transaction, updates the schema
|
||||||
|
// version and commits the transaction if successful, and rolls back the
|
||||||
|
// transaction if unsuccessful.
|
||||||
|
func runMigration(db *sql.DB, migrationIdx int) (err error) {
|
||||||
|
current := migrations[migrationIdx]
|
||||||
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
err = tx.Commit()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if rbErr := tx.Rollback(); rbErr != nil {
|
||||||
|
err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if len(current.upQuery) > 0 {
|
||||||
|
if _, err := tx.Exec(current.upQuery); err != nil {
|
||||||
|
return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if current.postHook != nil {
|
||||||
|
if err := current.postHook(tx); err != nil {
|
||||||
|
return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return updateSchemaVersion(tx, migrationIdx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// undoMigration rolls the single most recent migration back inside a
|
||||||
|
// transaction, updates the schema version and commits the transaction if
|
||||||
|
// successful, and rolls back the transaction if unsuccessful.
|
||||||
|
//
|
||||||
|
//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
|
||||||
|
func undoMigration(db *sql.DB, migrationIdx int) (err error) {
|
||||||
|
current := migrations[migrationIdx]
|
||||||
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
err = tx.Commit()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if rbErr := tx.Rollback(); rbErr != nil {
|
||||||
|
err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if len(current.downQuery) > 0 {
|
||||||
|
if _, err := tx.Exec(current.downQuery); err != nil {
|
||||||
|
return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return updateSchemaVersion(tx, migrationIdx-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSchemaVersion returns the schema version from the database
|
||||||
|
func getSchemaVersion(db *sql.DB) int {
|
||||||
|
row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
|
||||||
|
var version int
|
||||||
|
if err := row.Scan(&version); err != nil {
|
||||||
|
version = -1
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSchemaVersion sets the version to the provided int
|
||||||
|
func updateSchemaVersion(tx *sql.Tx, version int) error {
|
||||||
|
if version < 0 {
|
||||||
|
// Do not try to use the schema_migrations table in a schema version where it doesn't exist
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateAndInsertProjectIDs runs during migration 1, fetches all rows from
|
||||||
|
// projects_tmp, loops through the rows generating a repeatable ID for each
|
||||||
|
// project, and inserting it into the new table along with the data from the old
|
||||||
|
// table.
|
||||||
|
func generateAndInsertProjectIDs(tx *sql.Tx) error {
|
||||||
|
// Loop through projects_tmp, generate a project_id for each, and insert
|
||||||
|
// into projects
|
||||||
|
rows, err := tx.Query("SELECT url, name, forge, version, created_at FROM projects_tmp")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list projects in projects_tmp: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
url string
|
||||||
|
name string
|
||||||
|
forge string
|
||||||
|
version string
|
||||||
|
created_at string
|
||||||
|
)
|
||||||
|
if err := rows.Scan(&url, &name, &forge, &version, &created_at); err != nil {
|
||||||
|
return fmt.Errorf("failed to scan row from projects_tmp: %w", err)
|
||||||
|
}
|
||||||
|
id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge+created_at)))
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"INSERT INTO projects (id, url, name, forge, version, created_at) VALUES (@id, @url, @name, @forge, @version, @created_at)",
|
||||||
|
sql.Named("id", id),
|
||||||
|
sql.Named("url", url),
|
||||||
|
sql.Named("name", name),
|
||||||
|
sql.Named("forge", forge),
|
||||||
|
sql.Named("version", version),
|
||||||
|
sql.Named("created_at", created_at),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert project into projects: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec("DROP TABLE projects_tmp"); err != nil {
|
||||||
|
return fmt.Errorf("failed to drop projects_tmp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
-- SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
|
||||||
|
--
|
||||||
|
-- SPDX-License-Identifier: CC0-1.0
|
||||||
|
|
||||||
|
--ALTER TABLE projects RENAME TO projects_tmp; -- noqa
|
||||||
|
|
||||||
|
ALTER TABLE projects RENAME TO projects_tmp;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
url TEXT NOT NULL PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
forge TEXT NOT NULL,
|
||||||
|
version TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO projects (url, name, forge, version, created_at)
|
||||||
|
SELECT
|
||||||
|
url,
|
||||||
|
name,
|
||||||
|
forge,
|
||||||
|
version,
|
||||||
|
created_at
|
||||||
|
FROM projects_tmp;
|
||||||
|
|
||||||
|
DROP TABLE projects_tmp;
|
|
@ -0,0 +1,14 @@
|
||||||
|
-- SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
|
||||||
|
--
|
||||||
|
-- SPDX-License-Identifier: CC0-1.0
|
||||||
|
|
||||||
|
ALTER TABLE projects RENAME TO projects_tmp;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
url TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
forge TEXT NOT NULL,
|
||||||
|
version TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
|
@ -29,6 +29,7 @@ type Project struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Release struct {
|
type Release struct {
|
||||||
|
ID string
|
||||||
URL string
|
URL string
|
||||||
Tag string
|
Tag string
|
||||||
Content string
|
Content string
|
||||||
|
@ -70,6 +71,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) {
|
||||||
}
|
}
|
||||||
for _, release := range rssReleases {
|
for _, release := range rssReleases {
|
||||||
p.Releases = append(p.Releases, Release{
|
p.Releases = append(p.Releases, Release{
|
||||||
|
ID: genReleaseID(p.URL, release.URL, release.Tag),
|
||||||
Tag: release.Tag,
|
Tag: release.Tag,
|
||||||
Content: release.Content,
|
Content: release.Content,
|
||||||
URL: release.URL,
|
URL: release.URL,
|
||||||
|
@ -88,6 +90,7 @@ func fetchReleases(dbConn *sql.DB, p Project) (Project, error) {
|
||||||
}
|
}
|
||||||
for _, release := range gitReleases {
|
for _, release := range gitReleases {
|
||||||
p.Releases = append(p.Releases, Release{
|
p.Releases = append(p.Releases, Release{
|
||||||
|
ID: genReleaseID(p.URL, release.URL, release.Tag),
|
||||||
Tag: release.Tag,
|
Tag: release.Tag,
|
||||||
Content: release.Content,
|
Content: release.Content,
|
||||||
URL: release.URL,
|
URL: release.URL,
|
||||||
|
|
Loading…
Reference in New Issue