Skip to content

Instantly share code, notes, and snippets.

@samuraisam
Created August 5, 2016 23:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samuraisam/3eb7f0120e9bb24e97f16abfd92e9ba1 to your computer and use it in GitHub Desktop.
Save samuraisam/3eb7f0120e9bb24e97f16abfd92e9ba1 to your computer and use it in GitHub Desktop.
database migrator in go
package dbutils
import (
"database/sql"
"log"
"regexp"
"sort"
)
// Direction is the migration direction (up or down usually)
type Direction int
func (d Direction) String() string {
if int(d) == 1 {
return "UP"
}
return "DOWN"
}
const (
// Up indicates to the RunMigrations function that it should be migrating up
Up Direction = 1
// Down indicates to the RunMigrations function that it should migrate down
Down = 2
)
// Migration is a describes a migration
type Migration struct {
Version int
Description string
Up string
Down string
}
type byVersion []Migration
func (m byVersion) Len() int { return len(m) }
func (m byVersion) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
func (m byVersion) Less(i, j int) bool { return m[i].Version < m[i].Version }
// Migrateable describes an object which can run database migrations
type Migrateable interface {
Name() string
Run(db *sql.DB, direction Direction) error
Put(m Migration)
GetAll() []Migration
}
// Migrator is a base implementation of Migrateable
type Migrator struct {
name string
migrations []Migration
}
// NewMigrator will generate a new Migrator with the provided name and any up-front provided migrations
// the name must be able to be used a SQL table name (no special chars, spaces, etc)
func NewMigrator(name string, migrations ...Migration) *Migrator {
return &Migrator{name: name, migrations: migrations}
}
// Name is the name of this migrator. typically named after a sub-app or whatever
func (m *Migrator) Name() string {
return m.name
}
// Put will add a new migration to the migration set
func (m *Migrator) Put(n Migration) {
m.migrations = append(m.migrations, n)
}
// GetAll will return all migrations that were either provided at New() time or via Put()
func (m *Migrator) GetAll() []Migration {
return m.migrations
}
// Run will run all migrations that were either provided at New() time or via Put()
func (m *Migrator) Run(db *sql.DB, direction Direction) error {
return WrapTxn(db, func(tx *sql.Tx) error {
return m.run(tx, direction)
})
}
func (m *Migrator) run(tx *sql.Tx, direction Direction) error {
tblName := m.tableName()
// create a version tracker
if _, err := tx.Exec("CREATE TABLE IF NOT EXISTS " + tblName + " (version int not null primary key);"); err != nil {
log.Printf("Error creating versions table `m_versions`: %s", err)
return err
}
version, vErr := m.Version(tx)
if vErr != nil {
log.Printf("Error retrieving database version: %s", vErr)
return vErr
}
// gather desired migrations
migrations := byVersion([]Migration{})
all := byVersion(m.GetAll())
sort.Sort(all)
for _, mi := range all {
if direction == Up && uint64(mi.Version) > version {
migrations = append(migrations, mi)
} else if direction == Down && uint64(mi.Version) <= version {
migrations = append(migrations, mi)
}
}
if len(migrations) < 1 {
log.Printf("No migrations to run for %s", m.Name())
return nil
}
if direction == Down {
// run down migrations in reverse
sort.Sort(sort.Reverse(migrations))
}
for _, mi := range migrations {
log.Printf("Running migration %s %s %d %s", direction.String(), m.Name(), mi.Version, mi.Description)
var setVeresionQuery string
var runQuery string
if direction == Up {
setVeresionQuery = "INSERT INTO " + tblName + " (version) VALUES ($1)"
runQuery = mi.Up
} else if direction == Down {
setVeresionQuery = "DELETE FROM " + tblName + " WHERE version=$1"
runQuery = mi.Down
}
if _, err := tx.Exec(setVeresionQuery, mi.Version); err != nil {
log.Printf("Error updating version metadata: %s", err)
return err
}
if runQuery == "" {
continue
}
if _, err := tx.Exec(runQuery); err != nil {
log.Printf("Error running migration: %s\nQuery Was: \n%s", err, runQuery)
return err
}
}
return nil
}
// Version returns the current version of the database
func (m *Migrator) Version(tx *sql.Tx) (uint64, error) {
var version uint64
err := tx.QueryRow("SELECT version FROM " + m.tableName() + " ORDER BY version DESC LIMIT 1").Scan(&version)
switch {
case err == sql.ErrNoRows:
return 0, nil
case err != nil:
return 0, err
default:
return version, nil
}
}
// RollbackToVersion manually changes the version back to a specific number
// without making ANY other changes in the database (doesn't actually run migrations)
// this function should only be used for testing migrations
func (m *Migrator) RollbackToVersion(db *sql.DB, to int) error {
log.Printf("Rolling back to version %d", to)
_, err := db.Exec("DELETE FROM "+m.tableName()+" WHERE version > $1", to)
return err
}
func (m *Migrator) tableName() string {
re := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
if !re.MatchString(m.Name()) {
log.Fatalf("Name() must return a valid SQL table name. `%s` is not a valid name.", m.Name())
}
return "m_versions_" + m.Name()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment