Skip to content

Instantly share code, notes, and snippets.

@chadselph
Created April 18, 2020 18:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chadselph/35bd7691db2b8ccf66d57391cad43ed8 to your computer and use it in GitHub Desktop.
Save chadselph/35bd7691db2b8ccf66d57391cad43ed8 to your computer and use it in GitHub Desktop.
Compare two db schemas
from typing import Generator, Tuple, Set, List
from sqlalchemy import inspect, create_engine
import sys
GeneratorOfDiscrepancies = Generator[Tuple[str, str], None, None]
class InspectorWrapper:
def __init__(self, url: str):
self.inspector = inspect(create_engine(url))
self.tables = {name: TableWrapper(name, self.inspector) for name in self.inspector.get_table_names()}
class TableWrapper:
def __init__(self, name: str, ins):
self.name = name
self.columns = { c['name']: c for c in ins.get_columns(name) }
self.primary_key = ins.get_pk_constraint(name)['constrained_columns']
self.indexes = self._fmt("{name}{unique_str} on {cols}", ins.get_indexes(name))
self.uniques = self._fmt("{name} on {cols}", ins.get_unique_constraints(name))
self.checked = self._fmt("{name} {sqltext}", ins.get_check_constraints(name))
self.fkeys = self._fmt("{name} {cols} -> {referred_table}", ins.get_foreign_keys(name))
def _fmt(self, fmt: str, values: List[dict]) -> Set[str]:
# TODO: set of str makes it easy to compare these fields but
# if we want to get more specific errors, like index is on
# the same keys but the order is different, we might want
# a class for each of these that knows how to compare itself
def make_params(d):
cols = d.get("constrained_columns") or d.get("column_names") or []
if len(cols) == 1:
table_cols = "{}.{}".format(self.name, cols[0])
else :
table_cols = "{}.[{}]".format(self.name, ",".join(cols))
defaults = {
"table": self.name,
"name": "",
"sqltext": "",
"unique_str":" unique" if d.get("unique") else "",
"cols": table_cols
}
return dict(defaults, **d)
return { fmt.format(**make_params(v)) for v in values }
# Compare metas, using meta1 as ""expected""
def compare(meta1: InspectorWrapper, meta2: InspectorWrapper) -> GeneratorOfDiscrepancies:
# Extra tables
for extra_table in meta2.tables.keys() - meta1.tables.keys():
yield ("extra table", extra_table)
for (table_name, table) in meta1.tables.items():
if table_name not in meta2.tables:
yield ("missing table", table_name)
else:
yield from compare_table(table, meta2.tables[table_name])
def compare_table(table1: TableWrapper, table2: TableWrapper) -> GeneratorOfDiscrepancies:
name = table1.name
for column in table2.columns.keys() - table1.columns.keys():
yield ("extra column", column)
for column_name, column in table1.columns.items():
if column_name not in table2.columns:
yield ("missing column", column_name)
else:
yield from compare_column(column, table2.columns[column_name], name)
pk1, pk2 = table1.primary_key, table2.primary_key
if pk1 != pk2:
if set(pk1) == set(pk2):
yield ("different primary key order", "on {}: {} instead of {}".format(name, ",".join(pk1), ",".join(pk2)))
else:
yield ("different primary key", "on {}: {} instead of {}".format(name, ",".join(pk1), ",".join(pk2)))
def table_attr(key):
return getattr(table1, key), getattr(table2, key)
yield from compare_simple("index", *table_attr("indexes"))
yield from compare_simple("unique", *table_attr("uniques"))
yield from compare_simple("checked constraint", *table_attr("checked"))
yield from compare_simple("foreignkey", *table_attr("fkeys"))
def compare_simple(what: str, expected: Set[str], actual: Set[str]) -> GeneratorOfDiscrepancies:
for missing in expected - actual:
yield ("missing " + what, missing)
for extra in actual - expected:
yield ("extra " + what, extra)
def compare_column(col1: dict, col2: dict, table_name: str) -> GeneratorOfDiscrepancies:
checks = [ "type", "autoincrement", "default", "nullable" ]
for ch in checks:
col1_value = str(col1[ch])
col2_value = str(col2[ch])
if col1_value != col2_value:
yield (
"column different",
"{}.{} is different {}: {} instead of {}".format(
table_name, col2['name'], ch, col2_value, col1_value))
if __name__ == "__main__":
# usage: python checkdb.py postgres://user:passd@host/somedb postgres://u2:p2@host/otherdb
m1, m2 = InspectorWrapper(sys.argv[1]), InspectorWrapper(sys.argv[2])
print("Loaded {} tables.".format(len(m1.tables)))
print("Loaded {} tables.".format(len(m2.tables)))
for diff in compare(m1, m2):
print(diff)
@chadselph
Copy link
Author

Basically a simpler but less featured version of https://github.com/djrobstep/migra

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment