Skip to content

Instantly share code, notes, and snippets.

@adamchainz
Created October 18, 2022 14:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adamchainz/3d0f17c4516cd7ba89e17e659ae61edb to your computer and use it in GitHub Desktop.
Save adamchainz/3d0f17c4516cd7ba89e17e659ae61edb to your computer and use it in GitHub Desktop.
Django 4.1 serial to identity migration script, alternative version
"""
Alternative version that uses internal-poking technique from:
https://www.enterprisedb.com/blog/postgresql-10-identity-columns-explained
"""
from __future__ import annotations
import argparse
from typing import Any
from django.core.management.base import BaseCommand
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.backends.utils import CursorWrapper
from django.db.transaction import atomic
class Command(BaseCommand):
help = "Migrate all tables using 'serial' columns to use 'identity' instead."
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--database",
default=DEFAULT_DB_ALIAS,
help='Which database to update. Defaults to the "default" database.',
)
parser.add_argument(
"--write",
action="store_true",
default=False,
help="Actually edit the database",
)
def handle(self, *args: Any, database: str, write: bool, **kwargs: Any) -> None:
if not write:
self.stdout.write("In dry run mode (--write not passed)")
with connections[database].cursor() as cursor:
cursor.execute(find_serial_columns)
column_specs = cursor.fetchall()
self.stdout.write(f"Found {len(column_specs)} columns to update")
for table_name, column_name in column_specs:
print(table_name, column_name)
migrate_serial_to_identity(
cursor, database, table_name, column_name, write
)
# Adapted from: https://dba.stackexchange.com/a/90567
find_serial_columns = """\
SELECT
a.attrelid::regclass::text AS table_name,
a.attname AS column_name
FROM pg_attribute a
WHERE
a.attnum > 0
AND NOT a.attisdropped
AND a.atttypid = ANY ('{int,int8,int2}'::regtype[])
AND EXISTS (
SELECT FROM pg_attrdef ad
WHERE
ad.adrelid = a.attrelid
AND ad.adnum = a.attnum
AND (
pg_get_expr(ad.adbin, ad.adrelid)
=
'nextval('''
|| (
pg_get_serial_sequence(a.attrelid::regclass::text, a.attname)
)::regclass
|| '''::regclass)'
)
)
ORDER BY a.attnum
"""
def migrate_serial_to_identity(
cursor: CursorWrapper,
database: str,
table_name: str,
column_name: str,
write: bool,
) -> None:
with atomic(using=database):
# Adapted from upgrade_serial_to_identity() in:
# https://www.enterprisedb.com/blog/postgresql-10-identity-columns-explained
cursor.execute(
"""\
SELECT attnum
FROM pg_attribute
WHERE attrelid = %s::regclass
AND attname = %s
""",
(table_name, column_name),
)
column_number = cursor.fetchone()[0]
cursor.execute(
"""\
SELECT objid
FROM pg_depend
WHERE (refclassid, refobjid, refobjsubid) = ('pg_class'::regclass, %s::regclass, %s)
AND classid = 'pg_class'::regclass
AND objsubid = 0
AND deptype = 'a'
""",
(table_name, column_number),
)
results = cursor.fetchall()
if len(results) < 1:
print("Failed to find linked sequence")
raise SystemExit(1)
elif len(results) > 1:
print("Found more than one linked sequence!")
raise SystemExit(1)
sequence_id = results[0][0]
if write:
# Drop the default
qn = cursor.db.ops.quote_name
cursor.execute(
f"""\
ALTER TABLE {qn(table_name)}
ALTER COLUMN {qn(column_name)} DROP DEFAULT;
"""
)
# Modify sequence to be an internal dependency
cursor.execute(
"""\
UPDATE pg_depend
SET deptype = 'i'
WHERE (classid, objid, objsubid) = ('pg_class'::regclass, %s, 0)
AND deptype = 'a'
""",
(sequence_id,),
)
# Change to identity column, generated by default
cursor.execute(
"""\
UPDATE pg_attribute
SET attidentity = 'd'
WHERE attrelid = %s::regclass
AND attname = %s
""",
(table_name, column_name),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment