Skip to content

Instantly share code, notes, and snippets.

@orf
Created May 25, 2023 11:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save orf/a52550cc18d49306f503acc461391f80 to your computer and use it in GitHub Desktop.
Save orf/a52550cc18d49306f503acc461391f80 to your computer and use it in GitHub Desktop.
from collections import defaultdict
from django.core.management import base
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.executor import MigrationExecutor
from psycopg2.extensions import connection
from psycopg2.extras import DictCursor
from psycopg2.errors import ActiveSqlTransaction
# This is a proof-of-concept tool for finding the exact Postgres locks taken by a given migration.
# It works by extracting the SQL used to execute the migration in the same manner as `sqlmigrate`, then
# executing it inside a transaction. Before committing that transaction, another connection inspects
# the pg_locks view to find out specifically what locks the migration transaction is holding.
# The transaction is then committed, and once all other migrations are applied they are rolled back.
# While there is documentation about this online, and people tend to build up internal knowledge about
# what is dangerous and what is not, if Postgres is able to tell us exactly what locks are held by which
# statements then we should leverage that as well?
def execute_statement_and_get_locks(stmt_cursor, monitor_cursor, sql):
locks = []
stmt_cursor.execute("select pg_backend_pid()")
backend_pid = stmt_cursor.fetchone()[0]
for statement in sql:
if not statement or statement.startswith("--"):
continue
stmt_cursor.execute("BEGIN")
print(f"Executing {statement}")
try:
stmt_cursor.execute(statement)
except ActiveSqlTransaction as e:
print(f'Error: {e} - Aborting and retrying')
stmt_cursor.execute("ABORT")
stmt_cursor.execute(statement)
continue
monitor_cursor.execute(
"select relation::regclass, classid::regclass, locktype, mode from pg_locks where pid = %s",
(backend_pid,),
)
locks.append((statement, [dict(v) for v in monitor_cursor.fetchall()]))
stmt_cursor.execute("COMMIT")
return locks
class Command(base.BaseCommand):
help = "Show locks acquired by currently unapplied migrations"
def handle(self, *args, **options) -> None:
db_connection = connections[DEFAULT_DB_ALIAS]
executor = MigrationExecutor(db_connection, None)
executor.loader.check_consistent_history(db_connection)
targets = executor.loader.graph.leaf_nodes()
plan = executor.migration_plan(targets)
statements = []
rollback_statements = []
for (item, _) in plan:
forward_sql_statements = executor.loader.collect_sql([(item, False)])
statements.append((item, forward_sql_statements))
backward_sql_statements = executor.loader.collect_sql([(item, True)])
rollback_statements.append((item, backward_sql_statements))
base_connection: connection = db_connection.get_new_connection(
db_connection.get_connection_params()
)
monitor_connection: connection = db_connection.get_new_connection(
db_connection.get_connection_params()
)
# Lock types:
# https://www.postgresql.org/docs/current/monitoring-stats.html#WAIT-EVENT-LOCK-TABLE
# Lock types:
# https://www.postgresql.org/docs/current/explicit-locking.html#LOCKING-TABLES
migration_to_locks = defaultdict(list)
with base_connection.cursor() as base_cursor, monitor_connection.cursor(
cursor_factory=DictCursor
) as monitor_cursor:
for (migration, sql) in statements:
migration_to_locks[migration].extend(
execute_statement_and_get_locks(base_cursor, monitor_cursor, sql)
)
for (migration, rollback_sql) in reversed(rollback_statements):
execute_statement_and_get_locks(base_cursor, monitor_cursor, rollback_sql)
for (migration, statement_locks) in migration_to_locks.items():
for (statement, locks) in statement_locks:
for lock in locks:
if lock["locktype"] == "relation" and lock["mode"] == "AccessExclusiveLock":
self.stderr.write(
f"⛔️ The following SQL statement in {migration} takes an AccessExclusiveLock on {lock['relation']} ⛔️"
)
self.stderr.write(statement)
self.stderr.write()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment