Skip to content

Instantly share code, notes, and snippets.

@stbuehler
Created May 22, 2024 16:48
Show Gist options
  • Save stbuehler/545bbcdbbd5a3be83b31d8432807ad7a to your computer and use it in GitHub Desktop.
Save stbuehler/545bbcdbbd5a3be83b31d8432807ad7a to your computer and use it in GitHub Desktop.
postgres: track renamed constraints and indices using apgdiff
#!/usr/bin/env python3
# track CONSTRAINT and INDEX renames for ALTER statement.
# uses [`apgdiff``](https://www.apgdiff.com/)
import dataclasses
import re
import subprocess
# create old and new schema exports with: `pg_dump -sxO dbname > schema.sql`
OLD = "old-schema.sql"
NEW = "new-schema.sql"
@dataclasses.dataclass(slots=True, kw_only=True)
class CreateConstraint:
table: str
name: str
definition: str
def __str__(self) -> str:
return f'ALTER TABLE {self.table} ADD CONSTRAINT {self.name} {self.definition};'
@dataclasses.dataclass(slots=True, kw_only=True)
class DropConstraint:
table: str
name: str
def __str__(self) -> str:
return f'ALTER TABLE {self.table} DROP CONSTRAINT {self.name};'
@dataclasses.dataclass(slots=True, kw_only=True)
class CreateIndex:
table: str
name: str
definition: str
def __str__(self) -> str:
return f'CREATE INDEX {self.name} ON {self.table} {self.definition};'
@dataclasses.dataclass(slots=True, kw_only=True)
class DropIndex:
name: str
def __str__(self) -> str:
return f'DROP INDEX {self.name};'
@dataclasses.dataclass(slots=True, kw_only=True)
class CreateTable:
table: str
columns: list[str]
constraints: list[CreateConstraint]
def __str__(self) -> str:
cols = self.columns + [f"CONSTRAINT {constr.name} {constr.definition}" for constr in self.constraints]
col_str = ',\n'.join(f"\t{col}" for col in cols)
return f'CREATE TABLE {self.table} (\n{col_str}\n);'
@dataclasses.dataclass(slots=True, kw_only=True)
class RawStatement:
statement: str
def __str__(self) -> str:
return self.statement + ";"
Statement = CreateConstraint | DropConstraint | CreateIndex | DropIndex | RawStatement
def parse_statement(stmt: str) -> Statement:
if stmt.startswith("CREATE TABLE "):
rem = stmt.removeprefix("CREATE TABLE ")
tbl, rem = rem.split(maxsplit=1)
tbl = tbl.removeprefix("public.")
if rem.startswith("(") and rem.endswith(")"):
crt = CreateTable(table=tbl, columns=[], constraints=[])
rem = rem.removeprefix("(").removesuffix(")").strip()
for col in rem.split(","):
col = col.strip()
if col.startswith("CONSTRAINT "):
name, rem = col.removeprefix("CONSTRAINT ").split(maxsplit=1)
crt.constraints.append(CreateConstraint(table=tbl, name=name, definition=rem))
else:
crt.columns.append(col)
return crt
elif stmt.startswith("ALTER TABLE "):
rem = stmt.removeprefix("ALTER TABLE ")
rem = rem.removeprefix("ONLY ") # ignore ONLY
tbl, rem = rem.split(maxsplit=1)
tbl = tbl.removeprefix("public.")
if rem.startswith("DROP CONSTRAINT "):
# must be followed by only the name
name, = rem.removeprefix("DROP CONSTRAINT ").split()
return DropConstraint(table=tbl, name=name)
elif rem.startswith("ADD CONSTRAINT "):
name, rem = rem.removeprefix("ADD CONSTRAINT ").split(maxsplit=1)
return CreateConstraint(table=tbl, name=name, definition=rem)
elif stmt.startswith("DROP INDEX "):
# must be followed by only the name
name, = stmt.removeprefix("DROP INDEX ").split()
return DropIndex(name=name)
elif stmt.startswith("CREATE INDEX "):
name, rem = stmt.removeprefix("CREATE INDEX ").split(maxsplit=1)
if rem.startswith("ON "):
tbl, rem = rem.removeprefix("ON ").split(maxsplit=1)
tbl = tbl.removeprefix("public.")
return CreateIndex(table=tbl, name=name, definition=rem)
return RawStatement(statement=stmt)
def parse_sql(text: str) -> list[Statement]:
# remove comments
text = re.sub(r'--.*', '', text)
# fold whitespace
text = ' '.join(text.split()).strip()
# by statement
return [parse_statement(stmt.strip()) for stmt in text.split(';') if stmt ]
def main():
with open(OLD) as f:
oldstruct = parse_sql(f.read())
# store by ...[table][name] = definition
old_constraints: dict[str, dict[str, str]] = {}
# store by ...[name] = (table, definition)
old_indices: dict[str, tuple[str, str]] = {}
for stmt in oldstruct:
if isinstance(stmt, CreateConstraint):
old_constraints.setdefault(stmt.table, {})[stmt.name] = stmt.definition
elif isinstance(stmt, CreateTable):
for constr in stmt.constraints:
old_constraints.setdefault(constr.table, {})[constr.name] = constr.definition
elif isinstance(stmt, CreateIndex):
old_indices[stmt.name] = (stmt.table, stmt.definition)
diff = parse_sql(subprocess.check_output(["apgdiff", OLD, NEW], encoding="utf-8"))
dropped_constraints: set[tuple[str, str]] = set()
dropped_indices: set[str] = set()
replacable_constraints: dict[str, dict[str, str]] = {}
replacable_indicies: dict[str, dict[str, str]] = {}
for stmt in diff:
if isinstance(stmt, DropConstraint):
dropped_constraints.add((stmt.table, stmt.name))
# store by definition
old_def = old_constraints[stmt.table][stmt.name]
replacable_constraints.setdefault(stmt.table, {})[old_def] = stmt.name
elif isinstance(stmt, DropIndex):
dropped_indices.add(stmt.name)
# store by definition
old_table, old_def = old_indices[stmt.name]
replacable_indicies.setdefault(old_table, {})[old_def] = stmt.name
new_stmts: list[CreateConstraint | CreateIndex] = []
for stmt in diff:
if isinstance(stmt, CreateConstraint):
old = replacable_constraints.get(stmt.table, {}).pop(stmt.definition, None)
if old:
dropped_constraints.remove((stmt.table, old))
print(f"ALTER TABLE {stmt.table} RENAME CONSTRAINT {old} TO {stmt.name};")
else:
new_stmts.append(stmt)
# print(f"New Constraint: {stmt.table}.{stmt.name}: {stmt.definition}")
# for constr in replacable_constraints.get(stmt.table, {}):
# print(f"Not matching: {constr}")
elif isinstance(stmt, CreateIndex):
old = replacable_indicies.get(stmt.table, {}).pop(stmt.definition, None)
if old:
dropped_indices.remove(old)
print(f"ALTER INDEX {old} RENAME TO {stmt.name}; -- on table {stmt.table}")
else:
new_stmts.append(stmt)
# print(f"Index: {stmt}")
# for constr in replacable_indicies.get(stmt.table, {}):
# print(f"Not matching: {constr}")
print("-- non-rename statements")
for stmt in new_stmts:
print(stmt)
# print not-renamed old constraint/indices
for tbl, old in dropped_constraints:
print(f"ALTER TABLE {tbl} DROP CONSTRAINT {old};")
for old in dropped_indices:
print(f"DROP INDEX {old};")
# print unhandled statements
for stmt in diff:
if isinstance(stmt, (CreateConstraint, CreateIndex, DropConstraint, DropIndex)):
pass
else:
print(stmt)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment