Skip to content

Instantly share code, notes, and snippets.

@maratori
Last active April 3, 2024 12:22
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maratori/9105427ceab5a8ab6bd798c34235a66c to your computer and use it in GitHub Desktop.
Save maratori/9105427ceab5a8ab6bd798c34235a66c to your computer and use it in GitHub Desktop.
Clone postgres schema for each test
package my_package_test
import (
"fmt"
"testing"
"github.com/gtforge/global_payment_manager_service/testdb"
"github.com/stretchr/testify/require"
)
func TestUsingTestdbPrepare(t *testing.T) {
t.Parallel()
for i := 0; i < 10; i++ {
t.Run(fmt.Sprintf("%d", i+1), func(t *testing.T) {
t.Parallel()
db, _ := testdb.Prepare(t)
_, err := db.Exec("CREATE TABLE abc (id INT PRIMARY KEY)")
require.NoError(t, err)
_, err = db.Exec("INSERT INTO abc (id) VALUES ($1)", 123)
require.NoError(t, err)
})
}
}
package testdb
import (
"crypto/md5"
"database/sql"
"fmt"
"regexp"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var once sync.Once
func DSN() string {
return "host=localhost port=5432 user=postgres password=postgres dbname=my_database sslmode=disable timezone=UTC"
}
// Prepare clones public schema (w/o data) and returns connection to new schema.
// The second returned parameter is schema name.
func Prepare(t *testing.T) (*sql.DB, string) {
dsn := DSN()
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
once.Do(func() {
exec(t, db, createFunctionCloneSchema)
})
schema := strings.ToLower(regexp.MustCompile(`\W`).ReplaceAllString(t.Name(), "_"))
const pqIdentifierMaxLength = 63
if len(schema) > pqIdentifierMaxLength {
schema = fmt.Sprintf("%s_%x", schema[:pqIdentifierMaxLength-md5.Size*2-1], md5.Sum([]byte(schema)))
require.Len(t, schema, pqIdentifierMaxLength)
}
exec(t, db, "SELECT clone_schema('public', $1);", schema)
err = db.Close()
require.NoError(t, err)
db, err = sql.Open("postgres", dsn+" search_path="+schema)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, db.Close())
})
return db, schema
}
func exec(t *testing.T, db *sql.DB, query string, args ...interface{}) {
_, err := db.Exec(query, args...)
require.NoError(t, err)
}
const createFunctionCloneSchema string = `--
BEGIN;
SELECT pg_advisory_xact_lock(2142616474639426746);
CREATE OR REPLACE FUNCTION public.clone_schema(source_schema text, dest_schema text)
RETURNS void AS
$BODY$
DECLARE
object text;
buffer text;
default_ text;
column_ text;
constraint_name_ text;
constraint_def_ text;
trigger_name_ text;
trigger_timing_ text;
trigger_events_ text;
trigger_orientation_ text;
trigger_action_ text;
BEGIN
-- replace existing schema
EXECUTE 'DROP SCHEMA IF EXISTS ' || dest_schema || ' CASCADE';
-- create schema
EXECUTE 'CREATE SCHEMA ' || dest_schema;
-- create sequences
FOR object IN
SELECT sequence_name::text
FROM information_schema.SEQUENCES
WHERE sequence_schema = source_schema
LOOP
EXECUTE 'CREATE SEQUENCE ' || dest_schema || '.' || object;
END LOOP;
-- create tables
FOR object IN
SELECT table_name::text FROM information_schema.TABLES WHERE table_schema = source_schema
LOOP
buffer := dest_schema || '.' || object;
-- create table
EXECUTE 'CREATE TABLE ' || buffer || ' (LIKE ' || source_schema || '.' || object ||
' INCLUDING CONSTRAINTS INCLUDING INDEXES INCLUDING DEFAULTS)';
-- fix sequence defaults
FOR column_, default_ IN
SELECT column_name::text,
REPLACE(column_default::text, source_schema || '.', dest_schema || '.')
FROM information_schema.COLUMNS
WHERE table_schema = dest_schema
AND table_name = object
AND column_default LIKE 'nextval(%' || source_schema || '.%::regclass)'
LOOP
EXECUTE 'ALTER TABLE ' || buffer || ' ALTER COLUMN ' || column_ ||
' SET DEFAULT ' || default_;
END LOOP;
-- create triggers
FOR
trigger_name_, trigger_timing_, trigger_events_, trigger_orientation_, trigger_action_ IN
SELECT trigger_name::text,
action_timing::text,
string_agg(event_manipulation::text, ' OR '),
action_orientation::text,
action_statement::text
FROM information_schema.TRIGGERS
WHERE event_object_schema = source_schema
AND event_object_table = object
GROUP BY trigger_name, action_timing, action_orientation, action_statement
LOOP
EXECUTE 'CREATE TRIGGER ' || trigger_name_ || ' ' || trigger_timing_ || ' ' ||
trigger_events_ || ' ON ' || buffer || ' FOR EACH ' ||
trigger_orientation_ || ' ' || trigger_action_;
END LOOP;
END LOOP;
-- reiterate tables and create foreign keys
FOR object IN
SELECT table_name::text FROM information_schema.TABLES WHERE table_schema = source_schema
LOOP
buffer := dest_schema || '.' || object;
-- create foreign keys
FOR constraint_name_, constraint_def_ IN
SELECT conname::text,
REPLACE(pg_get_constraintdef(pg_constraint.oid), source_schema || '.',
dest_schema || '.')
FROM pg_constraint
INNER JOIN pg_class ON conrelid = pg_class.oid
INNER JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace
WHERE contype = 'f'
AND relname = object
AND nspname = source_schema
LOOP
EXECUTE 'ALTER TABLE ' || buffer || ' ADD CONSTRAINT ' || constraint_name_ ||
' ' || constraint_def_;
END LOOP;
END LOOP;
END;
$BODY$
LANGUAGE plpgsql VOLATILE
COST 100;
COMMIT;
`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment