// SPDX-FileCopyrightText: Chris Waldon // // SPDX-License-Identifier: Apache-2.0 package db import ( "context" "database/sql" _ "embed" "fmt" ) type migration struct { upQuery string downQuery string postHook func(*sql.Tx) error } var ( //go:embed sql/1_add_project_ids.up.sql migration1Up string //go:embed sql/1_add_project_ids.down.sql migration1Down string //go:embed sql/2_swap_project_url_for_id.up.sql migration2Up string //go:embed sql/2_swap_project_url_for_id.down.sql migration2Down string ) var migrations = [...]migration{ 0: { upQuery: `CREATE TABLE schema_migrations (version uint64, dirty bool); INSERT INTO schema_migrations (version, dirty) VALUES (0, 0);`, downQuery: `DROP TABLE schema_migrations;`, }, 1: { upQuery: migration1Up, downQuery: migration1Down, postHook: generateAndInsertProjectIDs, }, 2: { upQuery: migration2Up, downQuery: migration2Down, }, 3: { postHook: correctProjectIDs, }, } // Migrate runs all pending migrations func Migrate(db *sql.DB) error { version := getSchemaVersion(db) for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ { if err := runMigration(db, nextMigration); err != nil { return fmt.Errorf("migrations failed: %w", err) } if version := getSchemaVersion(db); version != nextMigration { return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version) } } return nil } // runMigration runs a single migration inside a transaction, updates the schema // version and commits the transaction if successful, and rolls back the // transaction if unsuccessful. func runMigration(db *sql.DB, migrationIdx int) (err error) { current := migrations[migrationIdx] tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err) } defer func() { if err == nil { err = tx.Commit() } if err != nil { if rbErr := tx.Rollback(); rbErr != nil { err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) } } }() if len(current.upQuery) > 0 { if _, err := tx.Exec(current.upQuery); err != nil { return fmt.Errorf("failed running migration %d: %w", migrationIdx, err) } } if current.postHook != nil { if err := current.postHook(tx); err != nil { return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err) } } return updateSchemaVersion(tx, migrationIdx) } // undoMigration rolls the single most recent migration back inside a // transaction, updates the schema version and commits the transaction if // successful, and rolls back the transaction if unsuccessful. // //lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34) func undoMigration(db *sql.DB, migrationIdx int) (err error) { current := migrations[migrationIdx] tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err) } defer func() { if err == nil { err = tx.Commit() } if err != nil { if rbErr := tx.Rollback(); rbErr != nil { err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) } } }() if len(current.downQuery) > 0 { if _, err := tx.Exec(current.downQuery); err != nil { return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err) } } return updateSchemaVersion(tx, migrationIdx-1) } // getSchemaVersion returns the schema version from the database func getSchemaVersion(db *sql.DB) int { row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`) var version int if err := row.Scan(&version); err != nil { version = -1 } return version } // updateSchemaVersion sets the version to the provided int func updateSchemaVersion(tx *sql.Tx, version int) error { if version < 0 { // Do not try to use the schema_migrations table in a schema version where it doesn't exist return nil } _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version)) return err }