Skip to content

Instantly share code, notes, and snippets.

@aldy505
Created November 25, 2021 16:05
Show Gist options
  • Save aldy505/367d9220451f98b76af5eb2f3ef5d2ae to your computer and use it in GitHub Desktop.
Save aldy505/367d9220451f98b76af5eb2f3ef5d2ae to your computer and use it in GitHub Desktop.
Quick, simple, and straightforward database migrator in Go
package main
import (
"context"
"database/sql"
"log"
"net/url"
"os"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
func main() {
args := os.Args
if len(args) == 0 {
log.Fatal("You need to specify URL connection string and the sql file path!")
}
filePath := args[len(args)-1]
log.Println("using the file:", filePath)
connStrIndex := findStr(args, "--conn")+1
log.Println(args[connStrIndex])
driver, connStr, err := parseConnectionString(args[connStrIndex])
if err != nil {
log.Fatal(err)
}
log.Println("using the driver:", driver)
log.Println("and the connection string of:", connStr)
file, err := os.ReadFile(filePath)
if err != nil {
log.Fatal(err)
}
statements := strings.Split(string(file), ";")
db, err := sql.Open(driver, connStr)
if err != nil {
log.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
defer cancel()
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
log.Fatal(err)
}
for _, stmt := range statements {
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
tx.Rollback()
log.Fatal(err)
}
}
err = tx.Commit()
if err != nil {
tx.Rollback()
log.Fatal(err)
}
log.Println("Yey berhasil.")
}
func findStr(arr []string, str string) int {
for i, v := range arr {
if v == str {
return i
}
}
log.Fatal("Connection string should be provided")
return -1
}
// I wrote it once here
// https://github.com/teknologi-umum/polarite/blob/master/resources/sql.go
func parseConnectionString(connstr string) (driver string, out string, err error) {
// Validate it first
if strings.HasPrefix(connstr, "mysql://") {
driver = "mysql"
// MySQL DSN format is: username:password@tcp(127.0.0.1:3306)/test
parsedConn, err := url.Parse(connstr)
if err != nil {
return driver, out, err
}
var out strings.Builder
out.WriteString(parsedConn.User.String())
out.WriteString("@")
out.WriteString("tcp(")
out.WriteString(parsedConn.Hostname())
out.WriteString(":")
out.WriteString(parsedConn.Port())
out.WriteString(")")
out.WriteString(parsedConn.EscapedPath())
out.WriteString("?")
out.WriteString(parsedConn.Query().Encode())
return driver, out.String(), err
} else if strings.HasPrefix(connstr, "postgres://") || strings.HasPrefix(connstr, "postgresql://") {
driver = "postgres"
// PostgreSQL DSN format is: user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full
parsedConn, err := url.Parse(connstr)
if err != nil {
return driver, out, err
}
pwd, _ := parsedConn.User.Password()
var out strings.Builder
out.WriteString("user=")
out.WriteString(parsedConn.User.Username())
out.WriteString(" password=")
out.WriteString(pwd)
out.WriteString(" host=")
out.WriteString(parsedConn.Hostname())
out.WriteString(" port=")
out.WriteString(parsedConn.Port())
out.WriteString(" dbname=")
out.WriteString(strings.Replace(parsedConn.EscapedPath(), "/", "", 1))
out.WriteString(" ")
out.WriteString(strings.Join(strings.Split(parsedConn.Query().Encode(), "&"), " "))
return driver, out.String(), err
}
return driver, connstr, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment