Skip to content

Instantly share code, notes, and snippets.

@etes
Forked from rabbitt/schema_clone.py
Created April 4, 2018 08:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save etes/b3838c6af9e977db051a05534ad70f22 to your computer and use it in GitHub Desktop.
Save etes/b3838c6af9e977db051a05534ad70f22 to your computer and use it in GitHub Desktop.
PostgreSQL schema cloner (including data).
import psycopg2 as pg
from io import BytesIO
from collections import defaultdict
from contextlib import contextmanager
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED
READ_COMMIT = ISOLATION_LEVEL_READ_COMMITTED
AUTO_COMMIT = ISOLATION_LEVEL_AUTOCOMMIT
class SchemaCloner(object):
def __init__(self, dsn = None, *args, **kwargs):
self.__connection = pg.connect(dsn) if dsn else pg.connect(*args, **kwargs)
self.__cursor = None
self.__schemas = None
self.__schema = 'public'
self.__tables = {}
self.__columns = {}
self.__constraints = {}
self.__sequences = {}
self.__indexes = {}
self.__primary_keys = {}
self.read_commit # ensure we're using transactions
@property
def _cursor(self):
if not self.__cursor:
self.__cursor = self.__connection.cursor()
return self.__cursor
@property
def _connection(self):
return self.__connection
@property
def isolation(self):
return self._connection.isolation_level
@property
def auto_commit(self):
self.isolation = AUTO_COMMIT
return self.isolation
@property
def read_commit(self):
self.isolation = READ_COMMIT
return self.isolation
@contextmanager
def isolation_context(self, level):
original_level = self.isolation
try:
self.isolation = level
yield
finally:
self.isolation = original_level
@isolation.setter
def isolation(self, value):
return self._connection.set_isolation_level(value)
@property
def schema(self):
return self.__schema
@schema.setter
def schema(self, value):
old_schema = self.__schema
self.__schema = value
return old_schema
@property
def schemas(self):
if not self.__schemas:
results = self.query("""
SELECT n.oid AS schema_id, n.nspname AS schema_name, r.rolname AS owner
FROM pg_namespace AS n
JOIN pg_roles AS r ON n.nspowner = r.oid
""")
self.__schemas = dict( ( _name, ( _id, _owner )) for _id, _name, _owner in results )
return self.__schemas
@property
def schema_oid(self):
return self.schemas[self.schema][0]
@property
def schema_owner(self):
return self.schemas[self.schema][1]
@property
def sequences(self):
if not self.schema in self.__sequences:
sequences = self.query("""
SELECT quote_ident(S.relname) AS sequence_name,
quote_ident(T.relname) AS table_name,
quote_ident(C.attname) AS column_name
FROM pg_class AS S,
pg_depend AS D,
pg_class AS T,
pg_attribute AS C,
pg_tables AS PGT
WHERE S.relkind = 'S'
AND S.oid = D.objid
AND D.refobjid = T.oid
AND D.refobjid = C.attrelid
AND D.refobjsubid = C.attnum
AND T.relname = PGT.tablename
AND PGT.schemaname = %s
ORDER BY sequence_name;
""", (self.schema,))
tables = defaultdict(
lambda: {},
dict((seq, {tbl: col}) for seq, tbl, col in set(sequences))
)
self.__sequences[self.schema] = tables
return self.__sequences[self.schema]
@property
def tables(self):
if not self.schema in self.__tables:
results = self.query("""
SELECT relfilenode, relname
FROM pg_class
WHERE relnamespace = %s AND relkind = %s
""", (self.schema_oid,'r',))
self.__tables[self.schema] = dict( ( _name, _id ) for _id, _name in results )
return self.__tables[self.schema]
@property
def primary_keys(self):
if not self.schema in self.__primary_keys:
# if primaries haven't yet been loaded, get them all
primaries = self.query("""
SELECT pgct.relname AS table_name,
con.conname AS constraint_name,
pg_catalog.pg_get_constraintdef(con.oid) AS constraint_definition
FROM pg_catalog.pg_constraint AS con
JOIN pg_class AS pgct ON pgct.relnamespace = con.connamespace AND pgct.oid = con.conrelid
WHERE pgct.relnamespace = %s AND con.contype = %s;
""", (self.schema_oid, 'p', ))
tables = {}
for table in set( [ p[0] for p in primaries ] ):
tables[table] = map(lambda p: (p[1], p[2]), filter(lambda p: p[0] == table, primaries))
self.__primary_keys[self.schema] = defaultdict(lambda: [], tables)
return self.__primary_keys[self.schema]
@property
def indexes(self):
if not self.schema in self.__indexes:
self.__indexes[self.schema] = {}
indexes = self.query("""
SELECT pgct.relname AS table_name,
pg_catalog.pg_get_indexdef(pgi.indexrelid) AS index_definition
FROM pg_index pgi
JOIN pg_class AS pgci ON pgci.oid = pgi.indexrelid
JOIN pg_class AS pgct ON pgct.oid = pgi.indrelid
WHERE pgci.relnamespace = %s AND pgi.indisprimary = false
""", (self.schema_oid,) )
tables = {}
for table in set( [ i[0] for i in indexes ] ):
tables[table] = map(lambda i: i[1], filter(lambda i: i[0] == table, indexes))
self.__indexes[self.schema] = defaultdict(lambda: [], tables)
return self.__indexes[self.schema]
@property
def columns(self):
if not self.schema in self.__columns:
self.__columns[self.schema] = {}
columns = self.query("""
SELECT table_name, column_name, column_default
FROM information_schema.columns
WHERE table_schema = %s
""", (self.schema,))
tables = {}
for table in set( [ c[0] for c in columns ] ):
tables[table] = map(lambda c: (c[1], c[2]), filter(lambda c: c[0] == table, columns))
self.__columns[self.schema] = defaultdict(lambda: [], tables)
return self.__columns[self.schema]
@property
def constraints(self):
if not self.schema in self.__constraints:
# if constraints haven't yet been loaded, get them all
constraints = self.query("""
SELECT pgct.relname AS table_name,
con.conname AS constraint_name,
pg_catalog.pg_get_constraintdef(con.oid) AS constraint_definition
FROM pg_catalog.pg_constraint AS con
JOIN pg_class AS pgct ON pgct.relnamespace = con.connamespace AND pgct.oid = con.conrelid
WHERE pgct.relnamespace = %s AND con.contype = %s;
""", (self.schema_oid, 'f', ))
tables = {}
for table in set( [ con[0] for con in constraints ] ):
tables[table] = map(lambda c: (c[1], c[2]), filter(lambda c: c[0] == table, constraints))
self.__constraints[self.schema] = defaultdict(lambda: [], tables)
return self.__constraints[self.schema]
def query_one(self, sql, *args, **kwargs):
self._cursor.execute(sql, *args, **kwargs)
return self._cursor.fetchone()
def query(self, sql, *args, **kwargs):
try:
self.execute(sql, *args, **kwargs)
return self._cursor.fetchall()
except Exception, e:
print "Exception during query: ", e
print " sql : ", sql
print " args : ", args
print " kwargs: ", kwargs
raise e
def execute(self, sql, *args, **kwargs):
print self._cursor.mogrify(sql, *args, **kwargs)
self._cursor.execute(sql, *args, **kwargs)
def commit(self):
self._connection.commit()
def rollback(self):
self._connection.rollback()
def clone(self, source, destination):
with self.isolation_context(READ_COMMIT):
self.schema = source
self.isolation = ISOLATION_LEVEL_READ_COMMITTED
# create schema
self.execute('CREATE SCHEMA %s' % destination)
self.execute('ALTER SCHEMA %s OWNER TO "%s"' % (destination, self.schema_owner))
self.execute('SET search_path = %s, pg_catalog' % destination)
# create sequences
for sequence in self.sequences.keys():
self.execute("CREATE SEQUENCE %s.%s" % (destination, sequence, ))
# first table pass - create tables, sequences, defaults and ownerships
for table in self.tables.keys():
self.execute('CREATE TABLE %s.%s (LIKE %s.%s INCLUDING DEFAULTS)' % (destination, table, source, table,))
self.execute('ALTER TABLE %s.%s OWNER TO "%s"' % (destination, table, self.schema_owner,))
# update sequences to use destination schema sequence instead of source
columns = filter(lambda col: col[1] and col[1].startswith('nextval'), self.columns[table])
for column, default_value in columns:
default_value = default_value.replace('%s.' % source, '%s.' % destination)
sequence_table = default_value.split("'")[1]
self.execute('ALTER SEQUENCE %s OWNED BY %s.%s' % (sequence_table, table, column,))
self.execute('ALTER TABLE ONLY %s ALTER COLUMN %s SET DEFAULT %s' % (table, column, default_value,))
# second table pass - copy data
for table in self.tables.keys():
data = BytesIO()
self._cursor.copy_to(data, "%s.%s" % (source, table), sep="|")
data.seek(0)
self._cursor.copy_from(data, "%s.%s" % (destination, table), sep="|")
print "Copied %d bytes from %s.%s -> %s.%s" % (data.seek(0, 2), source, table, destination, table)
# third pass - create primary keys and indexes
for table in self.tables.keys():
for key_name, key_definition in self.primary_keys[table]:
key_definition = key_definition.replace('%s.' % source, '%s.' % destination)
self.execute('ALTER TABLE ONLY %s ADD CONSTRAINT %s %s' % (table, key_name, key_definition))
for index_definition in self.indexes[table]:
index_definition = index_definition.replace('%s.' % source, '%s.' % destination)
self.execute(index_definition)
# fourth pass - create constraints
for table in self.tables.keys():
for constraint_name, constraint_definition in self.constraints[table]:
constraint_definition = constraint_definition.replace('%s.' % source, '%s.' % destination)
self.execute('ALTER TABLE ONLY %s ADD CONSTRAINT %s %s' % (table, constraint_name, constraint_definition))
# fifth pass - fix sequences. Inserting as part of copy_from doesn't update the sequences, so we do that here.
for sequence in self.sequences.keys():
for table, column in self.sequences[sequence].items():
self.execute("""
SELECT setval('%s', (SELECT COALESCE(MAX(%s), 1) FROM %s), true)
""".strip() % (sequence, column, table))
# and we're done...
self.commit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment