Skip to content

Instantly share code, notes, and snippets.

@cliffxuan
Created November 10, 2012 21:43
Show Gist options
  • Save cliffxuan/4052632 to your computer and use it in GitHub Desktop.
Save cliffxuan/4052632 to your computer and use it in GitHub Desktop.
Introspect constraint on table's primary key in the entire database
import unittest
import itertools
import sqlalchemy as sa
from sqlalchemy import Table, Column, String, ForeignKey
from newman.schema.branch import tables
class Introspector(object):
def __init__(self, metadata):
self.metadata = metadata
@property
def all_tables(self):
return self.metadata.tables.values()
@staticmethod
def get_primary_key(table):
"""takes a sqlalchemy table object and
returns the primary key string names.
params:
table -- sqlalchemy table object
return:
a list in this format [col_1, col_2]
e.g. for CustomerMembership table ['id']
"""
return table.primary_key.columns.keys()
@staticmethod
def construct_sql(action, table, col):
if action == 'update':
return 'update %(table)s ' \
'set %(col)s = :new_id '\
'where %(col)s = :old_id' %dict(table=table,
col=col)
elif action == 'select':
return 'select %(col)s from %(table)s ' \
'where %(col)s = :id' %dict(table=table,
col=col)
def find_foreign_keys_to_table(self, table):
"""takes a sqlalchemy table object and returns a list of
foreign keys referencing this table
params:
table -- sqlalchemy table object
returns:
a list of sqlalchemy ForeignKey objects
"""
foreign_key_list = []
for t in self.all_tables:
for fk in t.foreign_keys:
if fk.column.table == table:
foreign_key_list.append(fk)
return foreign_key_list
def find_related_primary_key(self, table):
"""takes a sqlalchemy table object and returns a list of
primary keys in the db where constituent column(s) also
refers to the table
params:
table -- sqlalchemy table object
returns:
a list of sqlalchemy PrimaryKeyConstraint objects
"""
related_primary_key_list = []
foreign_key_list = self.find_foreign_keys_to_table(table)
for foreign_key in foreign_key_list:
pk = foreign_key.parent.table.primary_key
if pk.columns.contains_column(foreign_key.parent):
related_primary_key_list.append(pk)
return list(set(related_primary_key_list))
def find_related_unique_constraints(self, table):
"""takes a sqlalchemy table object and returns a list of
unique constraints in the db where constituent column(s) also
refers to the table
params:
table -- sqlalchemy table object
returns:
a list of sqlalchemy UniqueConstraint list
"""
related_unique_constraints = []
foreign_key_list = self.find_foreign_keys_to_table(table)
for foreign_key in foreign_key_list:
for con in foreign_key.parent.table.constraints:
if type(con) == sa.schema.UniqueConstraint:
if con.columns.contains_column(foreign_key.parent):
related_unique_constraints.append(con)
return related_unique_constraints
def find_foreign_keys_with_other_constraint(self, table):
"""takes a sqlalchemy table object and returns a list of
foreign keys in the db where constituent column(s) is
a constituency of either a primary key constraint or a
unique constraint
params:
table -- sqlalchemy table object
returns:
a list of sqlalchemy ForeignKey objects
"""
uc_list = self.find_foreign_keys_to_table(table)
fk_no_constraint_list = self.find_foreign_keys_without_other_constraint(table)
return [fk for fk in uc_list if fk not in fk_no_constraint_list]
def find_foreign_keys_without_other_constraint(self, table):
"""takes a sqlalchemy table object and returns a list of
foreign keys in the db where constituent column(s) is not
a constituency of either a primary key constraint or a
unique constraint
params:
table -- sqlalchemy table object
returns:
a list of sqlalchemy ForeignKey objects
"""
fk_list = self.find_foreign_keys_to_table(table)
pk_list = self.find_related_primary_key(table)
pk_col_list = [col for col in itertools.chain(*[[ c for c in pk.columns] for pk in pk_list])]
uc_list = self.find_related_unique_constraints(table)
uc_col_list = [col for col in itertools.chain(*[[ c for c in uc.columns] for uc in uc_list])]
filtered = [fk for fk in fk_list if fk.parent not in uc_col_list]
filtered = [fk for fk in filtered if fk.parent not in pk_col_list]
return filtered
def generate_select_sql(self, table, with_comment=False):
"""takes a sqlalchemy table object and returns a sql statement
for selecting related data from related table to the object
params:
table -- sqlalchemy table object
with_comment -- boolean
returns:
sql string
"""
output = []
def format_comment(comment):
return '\n/*\n%s\n*/' %comment
free_fk_list = self.find_foreign_keys_without_other_constraint(table)
if with_comment:
output.append(format_comment('the following queries are safe to run without violating any constraint'))
for free_fk in free_fk_list:
t_name = free_fk.parent.table.name
c_name = free_fk.parent.name
output.append(self.construct_sql('select', t_name, c_name))
related_pk_list = self.find_related_primary_key(table)
if with_comment:
output.append(format_comment('the following queries may violat primary key constraint'))
col_list = []
for related_pk in related_pk_list:
for key in related_pk.columns.keys():
col = related_pk.columns[key]
for fk in col.foreign_keys:
if fk.column.table == table:
col_list.append(col)
for col in col_list:
t_name = col.table.name
c_name = col.name
output.append(self.construct_sql('select', t_name, c_name))
if with_comment:
output.append(format_comment('the following queries may violat unique constraint'))
related_uc_list = self.find_related_unique_constraints(table)
col_list = []
for related_uc in related_uc_list:
for key in related_uc.columns.keys():
col = related_uc.columns[key]
for fk in col.foreign_keys:
if fk.column.table == table:
col_list.append(col)
for col in col_list:
t_name = col.table.name
c_name = col.name
output.append(self.construct_sql('select', t_name, c_name))
return output
def generate_update_sql(self, table, with_comment=False):
"""takes a sqlalchemy table object and returns a sql statement
for updating the object
params:
table -- sqlalchemy table object
with_comment -- boolean
returns:
sql string
"""
output = []
def format_comment(comment):
return '\n/*\n%s\n*/' %comment
free_fk_list = self.find_foreign_keys_without_other_constraint(table)
if with_comment:
output.append(format_comment('the following queries are safe to run without violating any constraint'))
for free_fk in free_fk_list:
t_name = free_fk.parent.table.name
c_name = free_fk.parent.name
output.append(self.construct_sql('update', t_name, c_name))
related_pk_list = self.find_related_primary_key(table)
if with_comment:
output.append(format_comment('the following queries may violat primary key constraint'))
col_list = []
for related_pk in related_pk_list:
for key in related_pk.columns.keys():
col = related_pk.columns[key]
for fk in col.foreign_keys:
if fk.column.table == table:
col_list.append(col)
for col in col_list:
t_name = col.table.name
c_name = col.name
output.append(self.construct_sql('update', t_name, c_name))
if with_comment:
output.append(format_comment('the following queries may violat unique constraint'))
related_uc_list = self.find_related_unique_constraints(table)
col_list = []
for related_uc in related_uc_list:
for key in related_uc.columns.keys():
col = related_uc.columns[key]
for fk in col.foreign_keys:
if fk.column.table == table:
col_list.append(col)
for col in col_list:
t_name = col.table.name
c_name = col.name
output.append(self.construct_sql('update', t_name, c_name))
return output
class MergeError(Exception):
"""Error when merging two data objects"""
class Merger(object):
"""Merging two objects"""
def __init__(self, conn, introspector, table):
"""params:
db_introspector -- instance of Introspector
table -- sqlalchemy table
"""
self.conn = conn
self.table = table
self.introspector = introspector
@property
def pk_column(self):
"""sole primary key column. if multiple primary key cols are found
a MergeError is raised, as merge only happens on top level objects."""
key_list = self.table.primary_key.columns.keys()
if len(key_list) > 1:
raise MergeError('Table "%s" has multi primary columns.' %self.table.name)
else:
key = key_list[0]
return self.table.primary_key.columns[key]
def find_conflict(self, id_1, id_2):
""""""
def _conflicts_in_lists(one_list, other_list, compare_cols, id_cols):
cc = {}
for one, other in itertools.product(one_list, other_list):
if all(one[col] == other[col] for col in compare_cols):
k = tuple([one[col] for col in compare_cols])
v = (tuple([one[col] for col in id_cols]), tuple([other[col] for col in id_cols]))
cc[k] = v
return cc
conflict = {}
related_uc_list = self.introspector.find_related_unique_constraints(self.table)
related_pk_list = self.introspector.find_related_primary_key(self.table)
for uc in related_uc_list:
related_table = uc.columns[uc.columns.keys()[0]].table
where_col = next(fk.parent.name for fk in related_table.foreign_keys if fk.column is self.pk_column)
compare_cols = [ key for key in uc.columns.keys() if key != where_col]
pk_cols = related_table.primary_key.columns.keys()
sql = 'select %s from %s where %s = ?' %(', '.join(pk_cols + compare_cols), related_table.name, where_col)
records_1 = self.conn.execute(sql, id_1).fetchall()
records_2 = self.conn.execute(sql, id_2).fetchall()
cc = _conflicts_in_lists(records_1, records_2, compare_cols, pk_cols)
if cc:
conflict[uc] = cc
for pk in related_pk_list:
#table_name = pk.table.name
table_name = pk.columns[pk.columns.keys()[0]].table.name
where_col = next(col.name for col in pk.columns if any(fk.column.table == self.table for fk in col.foreign_keys))
pk_cols = pk.columns.keys()
sql = 'select %s from %s where %s = ?' %(', '.join(pk_cols), table_name, where_col)
compare_cols = [col for col in pk_cols if col != where_col]
records_1 = self.conn.execute(sql, id_1).fetchall()
records_2 = self.conn.execute(sql, id_2).fetchall()
cc = _conflicts_in_lists(records_1, records_2, compare_cols, pk_cols)
if cc:
conflict[pk] = cc
sql = sa.text('select * from %s where %s = :id' %(self.table.name, self.pk_column.name))
rec_1 = self.conn.execute(sql, id=id_1).fetchone()
rec_2 = self.conn.execute(sql, id=id_2).fetchone()
for col in self.table.columns:
if col != self.pk_column:
v_1 = rec_1[col.name]
v_2 = rec_2[col.name]
if v_1 != v_2:
conflict[col] = (v_1, v_2)
return conflict
class DBIndependent(unittest.TestCase):
def setUp(self):
"""this method gets all the branch tables.
return:
a list of sqlalchemy Table objects
"""
self.introspector = Introspector(tables.metadata)
def test_find_foreign_keys_to_table(self):
t = tables.customer_table
foreign_key_list = self.introspector.find_foreign_keys_to_table(t)
dep_tables = [t.parent.table for t in foreign_key_list]
print [fk.parent.name for fk in foreign_key_list]
self.assertTrue(tables.customer_membership_table in dep_tables)
self.assertTrue(tables.direct_debit_table in dep_tables)
self.assertTrue(tables.customer_reservation_table in dep_tables)
def test_multi_primary_keys(self):
for t in self.introspector.all_tables:
if len(t.primary_key.columns) > 1:
deps = self.introspector.find_foreign_keys_to_table(t)
if len(deps):
print deps
def test_find_related_primary_key(self):
t = tables.customer_table
pk_list = self.introspector.find_related_primary_key(t)
try:
table_list = [pk.table for pk in pk_list]
except:
#for sqlalchemy 0.6.5
#this error occurs:
#InvalidRequestError: This constraint is not bound to a table.
#Did you mean to call table.add_constraint(constraint)?
table_list = [pk.columns[pk.columns.keys()[0]].table for pk in pk_list]
print [tb.name for tb in table_list]
self.assertTrue(tables.customer_reservation_table in table_list)
def test_find_related_unique_constraints(self):
t = tables.customer_table
related_uc_list = self.introspector.find_related_unique_constraints(t)
related_tables = [t.table for t in related_uc_list]
print [t.name for t in related_tables]
self.assertTrue(tables.customer_membership_table in related_tables)
self.assertTrue(tables.direct_debit_table in related_tables)
self.assertTrue(tables.customer_product_credit_table not in related_tables)
def test_find_foreign_keys_with_other_constraint(self):
t = tables.customer_table
non_free_fk_list = self.introspector.find_foreign_keys_with_other_constraint(t)
fk = [k for k in tables.customer_membership_table.columns['customer_id'].foreign_keys if type(k) == ForeignKey][0]
self.assertTrue(fk in non_free_fk_list)
def test_find_foreign_keys_without_other_constraint(self):
t = tables.customer_table
free_fk_list = self.introspector.find_foreign_keys_without_other_constraint(t)
dep_tables = [fk.parent.table for fk in free_fk_list]
print ['.'.join([fk.parent.table.name, fk.parent.name]) for fk in free_fk_list]
self.assertTrue(tables.customer_product_credit_table in dep_tables)
def test_generate_update_sql(self):
t = tables.customer_table
print '\n'.join(self.introspector.generate_update_sql(t, with_comment=True))
def test_generate_select_sql(self):
t = tables.customer_table
print '\n'.join(self.introspector.generate_select_sql(t, with_comment=True))
class SQLite(unittest.TestCase):
def setUp(self):
metadata = sa.MetaData()
self.c_table = Table('Customer',
metadata,
Column("id", String(3), primary_key=True),
Column("name", String(4), nullable=False),
)
self.m_table = Table('Membership',
metadata,
Column("id", String(2), primary_key=True),
Column("name", String(4), nullable=False),
)
self.cm_table = Table('CustomerMembership',
metadata,
Column("id", String(4), primary_key=True),
Column("customer_id", String(3), ForeignKey('Customer.id', use_alter=True, name='FK_CustomerMembership_Customer'), nullable=False),
Column("membership_id", String(2), ForeignKey('Membership.id', use_alter=True, name='FK_CustomerMembership_Membership'), nullable=False),
sa.schema.UniqueConstraint('customer_id', 'membership_id', name='uCustomerMembership_customer_id-CustomerMembership_membership_id')
)
self.s_table = Table('Subscription',
metadata,
Column("id", String(2), primary_key=True),
Column("name", String(4), nullable=False),
)
self.cs_table = Table('CustomerSubscription',
metadata,
Column("subscription_id", String(3), ForeignKey('Subscription.id', use_alter=True, name='FK_CustomerSubscription_Subscription'), nullable=False, primary_key=True),
Column("customer_id", String(3), ForeignKey('Customer.id', use_alter=True, name='FK_CustomerSubscription_Customer'), nullable=False, primary_key=True),
)
engine = sa.create_engine('sqlite:///:memory:')
for tb in metadata.tables.values():
tb.create(engine)
self.conn = engine.connect()
self.conn.execute("insert into Customer(id, name) values('c1', 'foo')")
self.conn.execute("insert into Customer(id, name) values('c2', 'bar')")
self.conn.execute("insert into Membership(id, name) values('m1', 'mem1')")
self.conn.execute("insert into Membership(id, name) values('m2', 'mem2')")
self.conn.execute("insert into Membership(id, name) values('m3', 'mem3')")
self.conn.execute("insert into Subscription(id, name) values('s1', 'sub1')")
self.conn.execute("insert into Subscription(id, name) values('s2', 'sub2')")
self.conn.execute("insert into Subscription(id, name) values('s3', 'sub3')")
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm11', 'c1', 'm1')")
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm12', 'c1', 'm2')")
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm21', 'c2', 'm1')")
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm23', 'c2', 'm3')")
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c1', 's1')")
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c1', 's2')")
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c2', 's1')")
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c2', 's3')")
self.introspector = Introspector(metadata)
def test_find_conflict(self):
merger = Merger(self.conn, self.introspector, self.c_table)
uc = next(c for c in self.cm_table.constraints if type(c) == sa.schema.UniqueConstraint)
pk = self.cs_table.primary_key
name_col = self.c_table.columns['name']
conflict = merger.find_conflict('c1', 'c2')
self.assertTrue( uc in conflict )
self.assertTrue( pk in conflict )
self.assertTrue( name_col in conflict )
self.assertTrue(('m1',) in conflict[uc])
self.assertTrue(('s1',) in conflict[pk])
self.assertEqual(('foo', 'bar'), conflict[name_col] )
class QuickRun(unittest.TestCase):
def setUp(self):
self.conn = sa.create_engine('mssql://sa:Password1@localhost/branch_dedup_v3dot4').connect()
self.introspector = Introspector(tables.metadata)
self.merger = Merger(self.conn, self.introspector, tables.customer_table)
def test_run_merge(self):
t = tables.customer_table
trans = self.conn.begin()
try:
top2 = self.conn.execute('select top 2 id from customer')
old_id = top2.fetchone()[0]
new_id = top2.fetchone()[0]
conflict = self.merger.find_conflict(old_id, new_id)
if conflict:
col = []
pk = []
uc = []
for k in conflict:
if type(k) == sa.Column:
col.append(k.name)
elif type(k) == sa.schema.UniqueConstraint:
uc.append(k.table.name + '=>' + ', '.join(k.columns.keys()) + str(conflict[k]))
elif type(k) == sa.schema.PrimaryKeyConstraint:
where_col = next(c.name for c in k.columns if any(fk.column.table == t for fk in c.foreign_keys))
pk_cols = k.columns.keys()
compare_cols = [c for c in pk_cols if c != where_col]
pk.append(next(c for c in k.columns).table.name + '=>' + ', '.join(compare_cols) + str(conflict[k]))
print 'needs human input for:\n'
if col:
print 'Columns of Customer table:', ', '.join(col)
if uc:
print 'Unique Constraint:', ', '.join(uc)
if pk:
print 'Primary Key Constraint:', ', '.join(pk)
else:
print 'give all that belongs to "%s" to "%s"' %(old_id, new_id)
update_sql = self.introspector.generate_update_sql(t)
for u_sql in update_sql:
u_sql = sa.text(u_sql)
self.conn.execute(u_sql, old_id=old_id, new_id=new_id)
select_sql = self.introspector.generate_select_sql(t)
nothing_left = True
for sql in select_sql:
result = self.conn.execute(sa.text(sql), id=old_id).fetchall()
if result:
print sql.replace(':id', "'" + old_id + "'")
print result
nothing_left = False
if nothing_left:
self.conn.execute('delete from customer where id = ?', old_id)
print 'removed old customer'
else:
print 'could not remove old customer'
trans.commit()
except:
print 'merge failed'
#print u_sql
trans.rollback()
raise
def test_raise_unique_constraint(self):
m_id, count = self.conn.execute("select membership_id, count(membership_id) from customermembership group by membership_id order by count(membership_id) desc").fetchone()
top2 = self.conn.execute('select top 2 id, customer_id from customermembership where membership_id = ?', m_id)
first = top2.fetchone()
second = top2.fetchone()
sql = 'update customermembership set customer_id = :c_id where id = :id'
self.assertRaises(Exception, self.conn.execute, sa.text(sql), c_id=first[1], id=second[0])
def test_raise_unique_constraint_no_raised_in_multi_update_statement(self):
m_id, count = self.conn.execute("select membership_id, count(membership_id) from customermembership group by membership_id order by count(membership_id) desc").fetchone()
top2 = self.conn.execute('select top 2 id, customer_id from customermembership where membership_id = ?', m_id)
first = top2.fetchone()
second = top2.fetchone()
sql = []
sql.append("update customer set firstname = 'foo3' where id = :first_c_id")
sql.append('update customermembership set customer_id = :first_c_id where id = :first_id')
#this line violate unique constraint,
#but is ignored by sqlalchemy
sql.append('update customermembership set customer_id = :second_c_id where id = :first_id')
sql.append("update customer set middlename = 'mid2' where id = :first_c_id")
sql.append("update customer set lastname = 'bar2' where id = :first_c_id")
self.conn.execute(sa.text('\n'.join(sql)), first_c_id=first[1], first_id=first[0], second_c_id=second[1])
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment