PostgreSQL schema cloner (including data).
import psycopg2 as pg
from io import BytesIO
from collections import defaultdict
from contextlib import contextmanager
class SchemaCloner(object):
def __init__(self, dsn = None, *args, **kwargs):
self.__connection = pg.connect(dsn) if dsn else pg.connect(*args, **kwargs)
self.__cursor = None
self.__schemas = None
self.__schema = 'public'
self.__tables = {}
self.__columns = {}
self.__constraints = {}
self.__sequences = {}
self.__indexes = {}
self.__primary_keys = {}
self.read_commit # ensure we're using transactions
def _cursor(self):
if not self.__cursor:
self.__cursor = self.__connection.cursor()
return self.__cursor
def _connection(self):
return self.__connection
def isolation(self):
return self._connection.isolation_level
def auto_commit(self):
self.isolation = AUTO_COMMIT
return self.isolation
def read_commit(self):
self.isolation = READ_COMMIT
return self.isolation
def isolation_context(self, level):
original_level = self.isolation
self.isolation = level
self.isolation = original_level
def isolation(self, value):
return self._connection.set_isolation_level(value)
def schema(self):
return self.__schema
def schema(self, value):
old_schema = self.__schema
self.__schema = value
return old_schema
def schemas(self):
if not self.__schemas:
results = self.query("""
SELECT n.oid AS schema_id, n.nspname AS schema_name, r.rolname AS owner
FROM pg_namespace AS n
JOIN pg_roles AS r ON n.nspowner = r.oid
self.__schemas = dict( ( _name, ( _id, _owner )) for _id, _name, _owner in results )
return self.__schemas
def schema_oid(self):
return self.schemas[self.schema][0]
def schema_owner(self):
return self.schemas[self.schema][1]
def sequences(self):
if not self.schema in self.__sequences:
sequences = self.query("""
SELECT quote_ident(S.relname) AS sequence_name,
quote_ident(T.relname) AS table_name,
quote_ident(C.attname) AS column_name
FROM pg_class AS S,
pg_depend AS D,
pg_class AS T,
pg_attribute AS C,
pg_tables AS PGT
WHERE S.relkind = 'S'
AND S.oid = D.objid
AND D.refobjid = T.oid
AND D.refobjid = C.attrelid
AND D.refobjsubid = C.attnum
AND T.relname = PGT.tablename
AND PGT.schemaname = %s
ORDER BY sequence_name;
""", (self.schema,))
tables = defaultdict(
lambda: {},
dict((seq, {tbl: col}) for seq, tbl, col in set(sequences))
self.__sequences[self.schema] = tables
return self.__sequences[self.schema]
def tables(self):
if not self.schema in self.__tables:
results = self.query("""
SELECT relfilenode, relname
FROM pg_class
WHERE relnamespace = %s AND relkind = %s
""", (self.schema_oid,'r',))
self.__tables[self.schema] = dict( ( _name, _id ) for _id, _name in results )
return self.__tables[self.schema]
def primary_keys(self):
if not self.schema in self.__primary_keys:
# if primaries haven't yet been loaded, get them all
primaries = self.query("""
SELECT pgct.relname AS table_name,
con.conname AS constraint_name,
pg_catalog.pg_get_constraintdef(con.oid) AS constraint_definition
FROM pg_catalog.pg_constraint AS con
JOIN pg_class AS pgct ON pgct.relnamespace = con.connamespace AND pgct.oid = con.conrelid
WHERE pgct.relnamespace = %s AND con.contype = %s;
""", (self.schema_oid, 'p', ))
tables = {}
for table in set( [ p[0] for p in primaries ] ):
tables[table] = map(lambda p: (p[1], p[2]), filter(lambda p: p[0] == table, primaries))
self.__primary_keys[self.schema] = defaultdict(lambda: [], tables)
return self.__primary_keys[self.schema]
def indexes(self):
if not self.schema in self.__indexes:
self.__indexes[self.schema] = {}
indexes = self.query("""
SELECT pgct.relname AS table_name,
pg_catalog.pg_get_indexdef(pgi.indexrelid) AS index_definition
FROM pg_index pgi
JOIN pg_class AS pgci ON pgci.oid = pgi.indexrelid
JOIN pg_class AS pgct ON pgct.oid = pgi.indrelid
WHERE pgci.relnamespace = %s AND pgi.indisprimary = false
""", (self.schema_oid,) )
tables = {}
for table in set( [ i[0] for i in indexes ] ):
tables[table] = map(lambda i: i[1], filter(lambda i: i[0] == table, indexes))
self.__indexes[self.schema] = defaultdict(lambda: [], tables)
return self.__indexes[self.schema]
def columns(self):
if not self.schema in self.__columns:
self.__columns[self.schema] = {}
columns = self.query("""
SELECT table_name, column_name, column_default
FROM information_schema.columns
WHERE table_schema = %s
""", (self.schema,))
tables = {}
for table in set( [ c[0] for c in columns ] ):
tables[table] = map(lambda c: (c[1], c[2]), filter(lambda c: c[0] == table, columns))
self.__columns[self.schema] = defaultdict(lambda: [], tables)
return self.__columns[self.schema]
def constraints(self):
if not self.schema in self.__constraints:
# if constraints haven't yet been loaded, get them all
constraints = self.query("""
SELECT pgct.relname AS table_name,
con.conname AS constraint_name,
pg_catalog.pg_get_constraintdef(con.oid) AS constraint_definition
FROM pg_catalog.pg_constraint AS con
JOIN pg_class AS pgct ON pgct.relnamespace = con.connamespace AND pgct.oid = con.conrelid
WHERE pgct.relnamespace = %s AND con.contype = %s;
""", (self.schema_oid, 'f', ))
tables = {}
for table in set( [ con[0] for con in constraints ] ):
tables[table] = map(lambda c: (c[1], c[2]), filter(lambda c: c[0] == table, constraints))
self.__constraints[self.schema] = defaultdict(lambda: [], tables)
return self.__constraints[self.schema]
def query_one(self, sql, *args, **kwargs):
self._cursor.execute(sql, *args, **kwargs)
return self._cursor.fetchone()
def query(self, sql, *args, **kwargs):
self.execute(sql, *args, **kwargs)
return self._cursor.fetchall()
except Exception, e:
print "Exception during query: ", e
print " sql : ", sql
print " args : ", args
print " kwargs: ", kwargs
raise e
def execute(self, sql, *args, **kwargs):
print self._cursor.mogrify(sql, *args, **kwargs)
self._cursor.execute(sql, *args, **kwargs)
def commit(self):
def rollback(self):
def clone(self, source, destination):
with self.isolation_context(READ_COMMIT):
self.schema = source
# create schema
self.execute('CREATE SCHEMA %s' % destination)
self.execute('ALTER SCHEMA %s OWNER TO "%s"' % (destination, self.schema_owner))
self.execute('SET search_path = %s, pg_catalog' % destination)
# create sequences
for sequence in self.sequences.keys():
self.execute("CREATE SEQUENCE %s.%s" % (destination, sequence, ))
# first table pass - create tables, sequences, defaults and ownerships
for table in self.tables.keys():
self.execute('CREATE TABLE %s.%s (LIKE %s.%s INCLUDING DEFAULTS)' % (destination, table, source, table,))
self.execute('ALTER TABLE %s.%s OWNER TO "%s"' % (destination, table, self.schema_owner,))
# update sequences to use destination schema sequence instead of source
columns = filter(lambda col: col[1] and col[1].startswith('nextval'), self.columns[table])
for column, default_value in columns:
default_value = default_value.replace('%s.' % source, '%s.' % destination)
sequence_table = default_value.split("'")[1]
self.execute('ALTER SEQUENCE %s OWNED BY %s.%s' % (sequence_table, table, column,))
self.execute('ALTER TABLE ONLY %s ALTER COLUMN %s SET DEFAULT %s' % (table, column, default_value,))
# second table pass - copy data
for table in self.tables.keys():
data = BytesIO()
self._cursor.copy_to(data, "%s.%s" % (source, table), sep="|")
self._cursor.copy_from(data, "%s.%s" % (destination, table), sep="|")
print "Copied %d bytes from %s.%s -> %s.%s" % (, 2), source, table, destination, table)
# third pass - create primary keys and indexes
for table in self.tables.keys():
for key_name, key_definition in self.primary_keys[table]:
key_definition = key_definition.replace('%s.' % source, '%s.' % destination)
self.execute('ALTER TABLE ONLY %s ADD CONSTRAINT %s %s' % (table, key_name, key_definition))
for index_definition in self.indexes[table]:
index_definition = index_definition.replace('%s.' % source, '%s.' % destination)
# fourth pass - create constraints
for table in self.tables.keys():
for constraint_name, constraint_definition in self.constraints[table]:
constraint_definition = constraint_definition.replace('%s.' % source, '%s.' % destination)
self.execute('ALTER TABLE ONLY %s ADD CONSTRAINT %s %s' % (table, constraint_name, constraint_definition))
# fifth pass - fix sequences. Inserting as part of copy_from doesn't update the sequences, so we do that here.
for sequence in self.sequences.keys():
for table, column in self.sequences[sequence].items():
SELECT setval('%s', (SELECT COALESCE(MAX(%s), 1) FROM %s), true)
""".strip() % (sequence, column, table))
# and we're done...
