Last active
January 26, 2024 09:19
-
-
Save kevinkosterr/43be0ac226fd355d5a1992118c9e810d to your computer and use it in GitHub Desktop.
Peewee SQLite migration script for creating and dropping columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import peewee | |
from playhouse.migrate import SqliteMigrator, migrate | |
from typing import List, Tuple, Dict | |
TABLE_INFO_COLUMNS: List[str] = [ | |
"cid", "name", "type", "notnull", "dflt_value", "pk" | |
] | |
""" | |
TODOs: | |
- TODO: Implement dropping entire tables | |
- TODO: Implement field type switching | |
- TODO: Implement nullable and not nullable switches | |
""" | |
def build_table_info_query(model: str) -> str: | |
""" | |
Build a PRAGMA table_info query for a given model. | |
:param model: model for the respective table name. | |
:type model: str | |
:return: SQL query string | |
:rtype: str | |
""" | |
return f"PRAGMA table_info('{model.lower()}');" | |
def get_fields(model: peewee.Model) -> Dict[str, peewee.FieldAccessor]: | |
""" | |
Get all fields for a given model. | |
:param model: model to get database fields for | |
:type model: peewee.Model | |
:return: dictionary where each key is a field_name with its respective field as a value | |
:rtype: Dict[str, peewee.FieldAccessor] | |
""" | |
return {field_name: _type for (field_name, _type) in model.__dict__.items() if | |
isinstance(_type, peewee.FieldAccessor)} | |
def prepare_migrations(database: peewee.SqliteDatabase, | |
models: List[Tuple[str, peewee.Model]]) -> Dict[str, List]: | |
""" | |
Prepare migrations for all models. | |
:param database: database to prepare migrations for | |
:type database: peewee.SqliteDatabase | |
:param models: | |
:type models: List[Tuple[str, peewee.Model]] | |
:return: list of fields to make migrations for. | |
:rtype: List[peewee.FieldAccessor] | |
""" | |
sql_fields_per_model = {} | |
fields_to_create = [] | |
fields_to_drop = [] | |
for model in models: | |
model_name = model[0] | |
sql_fields_per_model[model_name] = [] | |
model_instance_fields = get_fields(model[1]) | |
query = build_table_info_query(model[0]) | |
cursor = database.execute_sql(query) | |
for row in cursor.fetchall(): | |
sql_fields_per_model[model_name].append(dict(zip(TABLE_INFO_COLUMNS, row))) | |
_sql_field_names = [_["name"] for _ in sql_fields_per_model[model_name]] | |
# Add the field instance to fields that should be created during the migration process | |
fields_to_create.extend([ | |
value for (field_name, value) in model_instance_fields.items() if | |
(field_name not in _sql_field_names) and | |
(field_name + "_id" not in _sql_field_names) | |
]) | |
fields_to_drop.extend([ | |
(model[0], field_name) for field_name in | |
_sql_field_names if field_name.replace("_id", "") # don't unnecessarily remove ForeignKeys | |
not in model_instance_fields.keys() | |
]) | |
return {"created": fields_to_create, "dropped": fields_to_drop} | |
def make_migrations(database: peewee.SqliteDatabase, migrations: Dict[str, List]) -> None: | |
""" | |
Execute the migrations that are prepared. | |
:param database: database to perform migrations on. | |
:type database: peewee.SqliteDatabase | |
:param migrations: migrations to perform. | |
:type migrations: Dict[str, List] | |
""" | |
migrator = SqliteMigrator(database) | |
added_columns = [migrator.add_column(_.model.__name__, _.field.name, _.field) for _ in migrations.get("created")] | |
dropped_columns = [migrator.drop_column(_[0], _[1]) for _ in migrations.get("dropped")] | |
migrate(*added_columns, *dropped_columns) | |
print(f"Created {len(added_columns)} new column(s).") | |
print(f"Dropped {len(dropped_columns)} column(s).") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment