Created
August 5, 2016 23:52
-
-
Save samuraisam/3eb7f0120e9bb24e97f16abfd92e9ba1 to your computer and use it in GitHub Desktop.
database migrator in go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dbutils | |
import ( | |
"database/sql" | |
"log" | |
"regexp" | |
"sort" | |
) | |
// Direction is the migration direction (up or down usually) | |
type Direction int | |
func (d Direction) String() string { | |
if int(d) == 1 { | |
return "UP" | |
} | |
return "DOWN" | |
} | |
const ( | |
// Up indicates to the RunMigrations function that it should be migrating up | |
Up Direction = 1 | |
// Down indicates to the RunMigrations function that it should migrate down | |
Down = 2 | |
) | |
// Migration is a describes a migration | |
type Migration struct { | |
Version int | |
Description string | |
Up string | |
Down string | |
} | |
type byVersion []Migration | |
func (m byVersion) Len() int { return len(m) } | |
func (m byVersion) Swap(i, j int) { m[i], m[j] = m[j], m[i] } | |
func (m byVersion) Less(i, j int) bool { return m[i].Version < m[i].Version } | |
// Migrateable describes an object which can run database migrations | |
type Migrateable interface { | |
Name() string | |
Run(db *sql.DB, direction Direction) error | |
Put(m Migration) | |
GetAll() []Migration | |
} | |
// Migrator is a base implementation of Migrateable | |
type Migrator struct { | |
name string | |
migrations []Migration | |
} | |
// NewMigrator will generate a new Migrator with the provided name and any up-front provided migrations | |
// the name must be able to be used a SQL table name (no special chars, spaces, etc) | |
func NewMigrator(name string, migrations ...Migration) *Migrator { | |
return &Migrator{name: name, migrations: migrations} | |
} | |
// Name is the name of this migrator. typically named after a sub-app or whatever | |
func (m *Migrator) Name() string { | |
return m.name | |
} | |
// Put will add a new migration to the migration set | |
func (m *Migrator) Put(n Migration) { | |
m.migrations = append(m.migrations, n) | |
} | |
// GetAll will return all migrations that were either provided at New() time or via Put() | |
func (m *Migrator) GetAll() []Migration { | |
return m.migrations | |
} | |
// Run will run all migrations that were either provided at New() time or via Put() | |
func (m *Migrator) Run(db *sql.DB, direction Direction) error { | |
return WrapTxn(db, func(tx *sql.Tx) error { | |
return m.run(tx, direction) | |
}) | |
} | |
func (m *Migrator) run(tx *sql.Tx, direction Direction) error { | |
tblName := m.tableName() | |
// create a version tracker | |
if _, err := tx.Exec("CREATE TABLE IF NOT EXISTS " + tblName + " (version int not null primary key);"); err != nil { | |
log.Printf("Error creating versions table `m_versions`: %s", err) | |
return err | |
} | |
version, vErr := m.Version(tx) | |
if vErr != nil { | |
log.Printf("Error retrieving database version: %s", vErr) | |
return vErr | |
} | |
// gather desired migrations | |
migrations := byVersion([]Migration{}) | |
all := byVersion(m.GetAll()) | |
sort.Sort(all) | |
for _, mi := range all { | |
if direction == Up && uint64(mi.Version) > version { | |
migrations = append(migrations, mi) | |
} else if direction == Down && uint64(mi.Version) <= version { | |
migrations = append(migrations, mi) | |
} | |
} | |
if len(migrations) < 1 { | |
log.Printf("No migrations to run for %s", m.Name()) | |
return nil | |
} | |
if direction == Down { | |
// run down migrations in reverse | |
sort.Sort(sort.Reverse(migrations)) | |
} | |
for _, mi := range migrations { | |
log.Printf("Running migration %s %s %d %s", direction.String(), m.Name(), mi.Version, mi.Description) | |
var setVeresionQuery string | |
var runQuery string | |
if direction == Up { | |
setVeresionQuery = "INSERT INTO " + tblName + " (version) VALUES ($1)" | |
runQuery = mi.Up | |
} else if direction == Down { | |
setVeresionQuery = "DELETE FROM " + tblName + " WHERE version=$1" | |
runQuery = mi.Down | |
} | |
if _, err := tx.Exec(setVeresionQuery, mi.Version); err != nil { | |
log.Printf("Error updating version metadata: %s", err) | |
return err | |
} | |
if runQuery == "" { | |
continue | |
} | |
if _, err := tx.Exec(runQuery); err != nil { | |
log.Printf("Error running migration: %s\nQuery Was: \n%s", err, runQuery) | |
return err | |
} | |
} | |
return nil | |
} | |
// Version returns the current version of the database | |
func (m *Migrator) Version(tx *sql.Tx) (uint64, error) { | |
var version uint64 | |
err := tx.QueryRow("SELECT version FROM " + m.tableName() + " ORDER BY version DESC LIMIT 1").Scan(&version) | |
switch { | |
case err == sql.ErrNoRows: | |
return 0, nil | |
case err != nil: | |
return 0, err | |
default: | |
return version, nil | |
} | |
} | |
// RollbackToVersion manually changes the version back to a specific number | |
// without making ANY other changes in the database (doesn't actually run migrations) | |
// this function should only be used for testing migrations | |
func (m *Migrator) RollbackToVersion(db *sql.DB, to int) error { | |
log.Printf("Rolling back to version %d", to) | |
_, err := db.Exec("DELETE FROM "+m.tableName()+" WHERE version > $1", to) | |
return err | |
} | |
func (m *Migrator) tableName() string { | |
re := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) | |
if !re.MatchString(m.Name()) { | |
log.Fatalf("Name() must return a valid SQL table name. `%s` is not a valid name.", m.Name()) | |
} | |
return "m_versions_" + m.Name() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment