Skip to content

Instantly share code, notes, and snippets.

@TobeTek
Created June 23, 2024 20:13
Show Gist options
  • Save TobeTek/74b8eb75900c261466ed30eaeb7b5070 to your computer and use it in GitHub Desktop.
Save TobeTek/74b8eb75900c261466ed30eaeb7b5070 to your computer and use it in GitHub Desktop.
A django management command to create migrations automatically for all models with Postgres' SearchVectorField
import os
import string
from collections import defaultdict
from django.core.management.base import BaseCommand, CommandError
from django.db import migrations
from django.db.migrations.writer import MigrationWriter
from django.db.models import Model
MIGRATION_FILE_NAME = "searchvectortrigger"
class Command(BaseCommand):
"""
Create migrations to create triggers for search vector fields.
Should be invoked after `./manage.py makemigrations`
"""
help = "Creates new migration(s) to create triggers for search vectors in models."
include_header = True
@property
def log_output(self):
return self.stdout
def log(self, msg):
self.log_output.write(msg)
def handle(self, *app_labels, **options):
from articles.models.articles import Article
from articles.models.categories import Category
from community.models import Post
from tools_and_settings.models import AppTool
SEARCH_VECTOR_FIELDS = {
Article: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
],
Category: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
],
Post: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
],
AppTool: [
{
"vector_column": "english_fts_vector",
"trigger_columns": ["escaped_content_html", "topic", "title"],
},
{
"vector_column": "description_fts_vector",
"trigger_columns": ["description"],
},
],
}
search_trigger_migrations = defaultdict(list)
for model, search_vectors in SEARCH_VECTOR_FIELDS.items():
operations = []
for search_vector in search_vectors:
operations.append(
generate_search_vector_sql(
model=model,
vector_column=search_vector["vector_column"],
trigger_columns=search_vector["trigger_columns"],
)
)
search_trigger_migrations[model._meta.app_label].append(
[
model._meta.model_name,
operations,
]
)
self.write_migration_files(search_trigger_migrations)
def write_migration_files(self, changes):
"""
Take a changes dict and write them out as migration files.
"""
for app_label, model_migrations in changes.items():
for [model_name, operations] in model_migrations:
subclass = type(
"Migration",
(migrations.Migration,),
{
"dependencies": [],
"operations": operations,
},
)
migration = subclass(
name=f"0001_{MIGRATION_FILE_NAME}_{model_name}",
app_label=app_label,
)
writer = MigrationWriter(migration, self.include_header)
# Add dependency migrations if they exist
if dependency := self.get_most_recent_migration(writer.basedir):
dependency_migration_no, _ = dependency.split("_", 1)
new_migration_no = int(dependency_migration_no) + 1
subclass = type(
"Migration",
(migrations.Migration,),
{
"dependencies": [(app_label, dependency)],
"operations": operations,
},
)
migration = subclass(
name=f"{new_migration_no:0>4}_{MIGRATION_FILE_NAME}_{model_name}",
app_label=app_label,
)
writer = MigrationWriter(migration, self.include_header)
migrations_directory = os.path.dirname(writer.path)
if self.has_search_vector_migration(writer.basedir, model_name):
continue
if not os.path.exists(migrations_directory):
os.makedirs(migrations_directory, exist_ok=True)
init_path = os.path.join(migrations_directory, "__init__.py")
if not os.path.isfile(init_path):
open(init_path, "w").close()
migration_string = writer.as_string()
with open(writer.path, "w", encoding="utf-8") as fh:
fh.write(migration_string)
def has_search_vector_migration(self, app_migrations_folder: str, model_name: str):
for filename in sorted(os.listdir(app_migrations_folder)):
if MIGRATION_FILE_NAME in filename and model_name in filename:
return True
return False
def get_most_recent_migration(self, app_migrations_folder: str):
migration_files = [
filename
for filename in os.listdir(app_migrations_folder)
if "__init__" not in filename
]
migration_files = sorted(migration_files, key=lambda a: str(a))
if migration_files:
latest_migration, _ = migration_files[-1].rsplit(".", 1)
return latest_migration
def generate_search_vector_sql(
model: type[Model], vector_column: str, trigger_columns: list[str]
):
CREATE_TRIGGER_SQL = """ALTER TABLE {model_table} DROP COLUMN IF EXISTS {vector_column};
ALTER TABLE {model_table} ADD COLUMN {vector_column} tsvector GENERATED ALWAYS AS ({setweight_stmts}) STORED;"""
REVERSE_CREATE_TRIGGER_SQL = (
"""ALTER TABLE {model_table} DROP COLUMN {vector_column};"""
)
if len(trigger_columns) > len(string.ascii_uppercase):
CommandError("Maximum number of trigger columns exceeded for search vector")
db_table = model._meta.db_table
setweight_stmts = " || ".join(
[
f"setweight(to_tsvector('english', coalesce('{db_table}.{column}', '')), '{string.ascii_uppercase[indx]}')"
for indx, column in enumerate(trigger_columns)
]
)
return migrations.RunSQL(
sql=CREATE_TRIGGER_SQL.format(
model_table=db_table,
vector_column=vector_column,
setweight_stmts=setweight_stmts,
).replace("\n", " "),
reverse_sql=REVERSE_CREATE_TRIGGER_SQL.format(
model_table=db_table, vector_column=vector_column
).replace("\n", " "),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment