Skip to content

Instantly share code, notes, and snippets.

@dorosch
Created June 1, 2023 07:51
Show Gist options
  • Save dorosch/3724f81f71089c5a4a3312e67884d7f8 to your computer and use it in GitHub Desktop.
Save dorosch/3724f81f71089c5a4a3312e67884d7f8 to your computer and use it in GitHub Desktop.
Tool for manage migrations for impala
import os
import json
from glob import glob
from pathlib import Path
from datetime import datetime
from types import ModuleType
from typing import List, Type, Optional, Dict
from typing import Protocol
import importlib.util
import typer
from pydantic import BaseModel
from conf import cfg
MIGRATION_TEMPLATE = """from typing import List
from core.migration import BaseMigration, Executor
class Migration(BaseMigration):
{meta}
author = '{author}'
datetime = '{datetime}'
def upgrade(self, executor: Executor):
pass
def downgrade(self, executor: Executor):
pass
def before_migration(self, executor: Executor):
pass
def after_migration(self, executor: Executor):
pass
"""
OK = typer.style('Ok', fg=typer.colors.GREEN, bold=True)
ERROR = typer.style('Error', fg=typer.colors.RED, bold=True)
APPLIED = typer.style('+', fg=typer.colors.GREEN, bold=True)
UNAPPLIED = typer.style('-', fg=typer.colors.WHITE, bold=True)
def import_module(module: str, path: str) -> Optional[ModuleType]:
"""Import module by parent module name and full path to file."""
spec = importlib.util.spec_from_file_location(module, path)
if spec:
imported_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(imported_module)
return imported_module
def next_order_number(values: List[str]) -> str:
"""Get next ordinal for migration number:
>>> next_order_number([])
'0001'
>>> next_order_number(['0001_a', '0003_b', '0005_c'])
'0006'
"""
numbers = []
for value in values:
number, *_ = value.split('_')
if number.isdigit():
numbers.append(int(number))
return str(max(numbers or [0]) + 1).zfill(4)
class Executor(Protocol):
"""Query execution protocol for migrations."""
async def connect(self) -> None: ...
async def disconnect(self) -> None: ...
async def execute(self, query: str, *params) -> None: ...
async def executemany(self, query: str, *params) -> None: ...
class BaseMigration:
"""Database migration base class.
Migrations are a set of actions on the database that can be
performed to change the database schema to a certain state.
"""
initial: bool = False
previous: List[str] = []
author: str = None
datetime: str = None
_name: str = None
_applied: bool = False
@classmethod
def name(cls) -> str:
return cls._name
@classmethod
def applied(cls) -> bool:
return cls._applied
def run_upgrade(self, executor: Executor):
self.before_migration(executor)
self.upgrade(executor)
self.after_migration(executor)
def run_downgrade(self, executor: Executor):
self.before_migration(executor)
self.downgrade(executor)
self.after_migration(executor)
# ### User defined methods ###
def upgrade(self, executor: Executor):
pass
def downgrade(self, executor: Executor):
pass
def before_migration(self, executor: Executor):
pass
def after_migration(self, executor: Executor):
pass
class MigrationState:
"""Storing the state of applied migrations."""
state_path: Path = Path(cfg.src_dir) / 'migrations' / 'state.json'
def __init__(self):
self._state: Dict = {'applied': []}
if self.state_path.exists():
with open(self.state_path) as file:
try:
self._state: Dict = json.loads(file.read())
except json.decoder.JSONDecodeError:
self._sync_state()
def apply(self, migration_name: str):
if migration_name not in self._state['applied']:
self._state['applied'].append(migration_name)
self._sync_state()
def unapply(self, migration_name: str):
if migration_name in self._state['applied']:
self._state['applied'].remove(migration_name)
self._sync_state()
def is_applied(self, migration_name: str) -> bool:
return migration_name in self._state['applied']
def _sync_state(self):
with open(self.state_path, 'w') as file:
file.write(json.dumps(self._state, sort_keys=True, indent=2))
class MigrationHistory:
"""Working with classes that make up the history of migrations."""
migrations_format: str = '{order}_{name}.py'
migrations_path: Path = Path(cfg.src_dir) / 'migrations'
def __init__(self):
self.history: List[Type[BaseMigration]] = []
self.migration_state = MigrationState()
if self.migrations_path.exists():
for file in glob(f'{self.migrations_path}/*.py'):
if migration_class := self.import_migration_class(file):
name = Path(file).stem
is_applied = self.migration_state.is_applied(name)
setattr(migration_class, '_name', name)
setattr(migration_class, '_applied', is_applied)
self.history.append(migration_class)
self.history.sort(key=lambda x: datetime.fromisoformat(x.datetime))
else:
# Create the necessary migration's directory structure
self.migrations_path.mkdir(exist_ok=True)
(self.migrations_path / '__init__.py').touch(exist_ok=True)
def import_migration_class(self, file) -> Optional[Type[BaseMigration]]:
"""Import migration class from full path of migration file."""
return getattr(
import_module(self.migrations_path.name, file), 'Migration', None
)
def add(self, template: str) -> None:
if self.history:
order = next_order_number([
migration.name() for migration in self.history
])
meta = f'previous = [\'{self.last().name()}\']'
name = self.migrations_format.format(order=order, name='migration')
else:
meta = f'initial = True'
name = self.migrations_format.format(order='0001', name='initial')
with open(self.migrations_path / f'{name}', 'w') as file:
file.write(
template.format(meta=meta)
)
def get_migration(self, target_migration: str):
# TODO: Add typing definition
for migration in self.history:
if target_migration == migration.name():
return migration
def last_applied(self):
# TODO: Add typing definition
for migration in reversed(self.history):
if migration.applied():
return migration
def last(self) -> Optional[Type[BaseMigration]]:
return self.history[-1] if self.history else None
class Emigrant:
"""Migration of models based on `pydantic` classes."""
def __init__(
self,
executor: Executor,
models: List[Type[BaseModel]] = None
):
self.models = models or []
self.executor = executor
async def show_migrations(self):
for migration in MigrationHistory().history:
migration_name = typer.style(
migration.name(), fg=typer.colors.WHITE, bold=True
)
status = APPLIED if migration.applied() else UNAPPLIED
typer.echo(f' [{status}] {migration_name}')
async def make_migrations(self, empty: bool):
# TODO: Move here class for build migration's payload
MigrationHistory().add(
MIGRATION_TEMPLATE.format(
author=os.getlogin(),
datetime=datetime.utcnow().isoformat(timespec='seconds'),
meta='{meta}'
)
)
async def migrate(self, migration_target_name: Optional[str]):
typer.secho('Running migrations:', fg=typer.colors.WHITE, bold=True)
migrations_state = MigrationState()
migrations_history = MigrationHistory()
if migration_target_name:
target_migration = migrations_history.get_migration(
migration_target_name
)
if not target_migration:
typer.echo(f' No migration {migration_target_name}')
raise typer.Exit(code=1)
else:
target_migration = migrations_history.last()
if not target_migration:
typer.echo(f' Nothing to apply')
raise typer.Exit()
forward = True
last_applied = migrations_history.last_applied()
if last_applied:
last_applied_order = datetime.fromisoformat(last_applied.datetime)
target_migration_order = datetime.fromisoformat(target_migration.datetime)
forward = last_applied_order < target_migration_order
if forward:
migrations = migrations_history.history
else:
migrations = reversed(migrations_history.history)
for migration in migrations:
migration_name = migration.name()
if forward:
if migration.applied():
continue
typer.echo(f' [{APPLIED}] {migration_name}')
migration().run_upgrade(self.executor)
migrations_state.apply(migration_name)
else:
if not migration.applied():
continue
typer.echo(f' [{UNAPPLIED}] {migration_name}')
migration().run_downgrade(self.executor)
migrations_state.unapply(migration_name)
if migration == target_migration:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment