Skip to content

Instantly share code, notes, and snippets.

@mortenson
Last active June 19, 2023 13:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mortenson/c3c1e7f2a1b10c5c3674f8c91123c3e0 to your computer and use it in GitHub Desktop.
Save mortenson/c3c1e7f2a1b10c5c3674f8c91123c3e0 to your computer and use it in GitHub Desktop.
Create sqlc migrations automatically using pg-schema-diff
package main
import (
"context"
"database/sql"
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
_ "github.com/lib/pq"
"github.com/stripe/pg-schema-diff/pkg/diff"
"github.com/stripe/pg-schema-diff/pkg/tempdb"
)
func main() {
namePtr := flag.String("name", "", "The migration name")
dsnPtr := flag.String("dsn", "", "The connection string")
schemaDirPtr := flag.String("schemaDir", "", "The schema directory")
migrationDirPtr := flag.String("migrationDir", "", "The migration directory")
flag.Parse()
name := *namePtr
dsn := *dsnPtr
schemaDir := *schemaDirPtr
migrationDir := *migrationDirPtr
if name == "" || dsn == "" || schemaDir == "" || migrationDir == "" {
fmt.Println("Example usage: go run create_migration.go -dsn \"postgres://...\" -schemaDir schema -migrationDir migrations -name create_users")
return
}
ctx := context.Background()
connConfig, err := pgx.ParseConfig(dsn)
if err != nil {
panic(err)
}
tempDbFactory, err := tempdb.NewOnInstanceFactory(ctx, func(ctx context.Context, dbName string) (*sql.DB, error) {
copiedConfig := connConfig.Copy()
copiedConfig.Database = dbName
return openDbWithPgxConfig(copiedConfig)
})
if err != nil {
panic(err)
}
defer tempDbFactory.Close()
ddl, err := getDDLFromPath(schemaDir)
if err != nil {
panic(err)
}
connPool, err := openDbWithPgxConfig(connConfig)
if err != nil {
panic(err)
}
defer connPool.Close()
conn, err := connPool.Conn(ctx)
if err != nil {
panic(err)
}
defer conn.Close()
plan, err := diff.GeneratePlan(ctx, conn, tempDbFactory, ddl, diff.WithDoNotValidatePlan())
if err != nil {
panic(err)
}
statements := [][]string{}
curr_index := 0
for _, statement := range plan.Statements {
statementStr := statement.ToSQL()
if strings.Contains(statementStr, "goose") {
continue
}
if len(statements) == curr_index {
statements = append(statements, []string{})
}
if strings.Contains(statementStr, "CONCURRENTLY") {
if len(statements[curr_index]) == 0 {
statements[curr_index] = []string{statementStr}
curr_index += 1
} else {
statements = append(statements, []string{statementStr})
curr_index += 2
}
} else {
statements[curr_index] = append(statements[curr_index], statementStr)
}
fmt.Printf("[STATEMENT] %s\n", statementStr)
for _, hazard := range statement.Hazards {
fmt.Printf("\033[31m[WARNING] %s\033[0m\n", hazard.String())
}
}
now := time.Now()
for i, statementBlock := range statements {
contents := strings.Join(statementBlock, "\n\n")
migrationStr := fmt.Sprintf("-- +goose Up\n-- +goose StatementBegin\n%s\n-- +goose StatementEnd\n", contents)
if strings.Contains(contents, "CONCURRENTLY") {
migrationStr = "-- +goose NO TRANSACTION\n" + migrationStr
}
var filename string
if len(statements) > 1 {
now = now.Add(time.Second)
filename = fmt.Sprintf("%s_%s_%02d", now.Format("20060102150405"), name, i+1)
} else {
filename = fmt.Sprintf("%s_%s", now.Format("20060102150405"), name)
}
filename += ".sql"
filePath := filepath.Join(migrationDir, filename)
err = os.WriteFile(filePath, []byte(migrationStr), 0644)
if err != nil {
panic(err)
}
fmt.Printf("\033[32mCreated %s\033[0m\n", filePath)
}
}
func getDDLFromPath(path string) ([]string, error) {
fileEntries, err := os.ReadDir(path)
if err != nil {
return nil, err
}
var ddl []string
for _, entry := range fileEntries {
if filepath.Ext(entry.Name()) == ".sql" {
if stmts, err := os.ReadFile(filepath.Join(path, entry.Name())); err != nil {
return nil, err
} else {
ddl = append(ddl, string(stmts))
}
}
}
return ddl, nil
}
func openDbWithPgxConfig(config *pgx.ConnConfig) (*sql.DB, error) {
connPool := stdlib.OpenDB(*config)
if err := connPool.Ping(); err != nil {
connPool.Close()
return nil, err
}
return connPool, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment