Skip to content

Instantly share code, notes, and snippets.

@kevinkosterr
Last active January 26, 2024 09:19
Show Gist options
  • Save kevinkosterr/43be0ac226fd355d5a1992118c9e810d to your computer and use it in GitHub Desktop.
Save kevinkosterr/43be0ac226fd355d5a1992118c9e810d to your computer and use it in GitHub Desktop.
Peewee SQLite migration script for creating and dropping columns
import peewee
from playhouse.migrate import SqliteMigrator, migrate
from typing import List, Tuple, Dict
TABLE_INFO_COLUMNS: List[str] = [
"cid", "name", "type", "notnull", "dflt_value", "pk"
]
"""
TODOs:
- TODO: Implement dropping entire tables
- TODO: Implement field type switching
- TODO: Implement nullable and not nullable switches
"""
def build_table_info_query(model: str) -> str:
"""
Build a PRAGMA table_info query for a given model.
:param model: model for the respective table name.
:type model: str
:return: SQL query string
:rtype: str
"""
return f"PRAGMA table_info('{model.lower()}');"
def get_fields(model: peewee.Model) -> Dict[str, peewee.FieldAccessor]:
"""
Get all fields for a given model.
:param model: model to get database fields for
:type model: peewee.Model
:return: dictionary where each key is a field_name with its respective field as a value
:rtype: Dict[str, peewee.FieldAccessor]
"""
return {field_name: _type for (field_name, _type) in model.__dict__.items() if
isinstance(_type, peewee.FieldAccessor)}
def prepare_migrations(database: peewee.SqliteDatabase,
models: List[Tuple[str, peewee.Model]]) -> Dict[str, List]:
"""
Prepare migrations for all models.
:param database: database to prepare migrations for
:type database: peewee.SqliteDatabase
:param models:
:type models: List[Tuple[str, peewee.Model]]
:return: list of fields to make migrations for.
:rtype: List[peewee.FieldAccessor]
"""
sql_fields_per_model = {}
fields_to_create = []
fields_to_drop = []
for model in models:
model_name = model[0]
sql_fields_per_model[model_name] = []
model_instance_fields = get_fields(model[1])
query = build_table_info_query(model[0])
cursor = database.execute_sql(query)
for row in cursor.fetchall():
sql_fields_per_model[model_name].append(dict(zip(TABLE_INFO_COLUMNS, row)))
_sql_field_names = [_["name"] for _ in sql_fields_per_model[model_name]]
# Add the field instance to fields that should be created during the migration process
fields_to_create.extend([
value for (field_name, value) in model_instance_fields.items() if
(field_name not in _sql_field_names) and
(field_name + "_id" not in _sql_field_names)
])
fields_to_drop.extend([
(model[0], field_name) for field_name in
_sql_field_names if field_name.replace("_id", "") # don't unnecessarily remove ForeignKeys
not in model_instance_fields.keys()
])
return {"created": fields_to_create, "dropped": fields_to_drop}
def make_migrations(database: peewee.SqliteDatabase, migrations: Dict[str, List]) -> None:
"""
Execute the migrations that are prepared.
:param database: database to perform migrations on.
:type database: peewee.SqliteDatabase
:param migrations: migrations to perform.
:type migrations: Dict[str, List]
"""
migrator = SqliteMigrator(database)
added_columns = [migrator.add_column(_.model.__name__, _.field.name, _.field) for _ in migrations.get("created")]
dropped_columns = [migrator.drop_column(_[0], _[1]) for _ in migrations.get("dropped")]
migrate(*added_columns, *dropped_columns)
print(f"Created {len(added_columns)} new column(s).")
print(f"Dropped {len(dropped_columns)} column(s).")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment