Created
January 7, 2024 21:02
-
-
Save colonelpanic8/c5df24f293c3497c42e8f48d1870a54b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import logging | |
from dependency_injector.wiring import Provide, inject | |
from sqlalchemy.orm import joinedload | |
from railbird import config, console | |
from railbird.containers import RBDeps | |
from railbird.datatypes import models | |
logger = logging.getLogger(__name__) | |
def build_migration_parser( | |
parent_parser: argparse.ArgumentParser, overrides: config.Overrides | |
): | |
parents = [parent_parser] | |
parser = argparse.ArgumentParser(parents=parents) | |
parser.set_defaults(func=migrate) | |
return parser | |
def main(): | |
console.main(wire=[__name__], build_main_parser=build_migration_parser) | |
@inject | |
def migrate(args, sessionmaker=Provide[RBDeps.sync_sessionmaker]): | |
with sessionmaker() as session: | |
with session.begin(): | |
logger.info("Starting transaction") | |
for shot in ( | |
session.query(models.ShotModel) | |
.options( | |
joinedload(models.ShotModel.cue_object_distance), | |
joinedload(models.ShotModel.target_pocket_distance), | |
joinedload(models.ShotModel.cue_object_angle), | |
joinedload(models.ShotModel.cue_ball_speed), | |
joinedload(models.ShotModel.intended_pocket), | |
joinedload(models.ShotModel.shot_direction), | |
) | |
.all() | |
): | |
logger.info(f"Processing {shot}") | |
try: | |
shot.cue_object_features = models.CueObjectFeatures( | |
shot_id=shot.id, | |
cue_object_distance=shot.cue_object_distance.distance, | |
cue_object_angle=shot.cue_object_angle.angle, | |
cue_ball_speed=shot.cue_ball_speed.speed, | |
shot_direction=shot.shot_direction.direction, | |
) | |
except Exception as e: | |
raise e | |
try: | |
shot.pocketing_intention_features = ( | |
models.PocketingIntentionFeatures( | |
shot_id=shot.id, | |
target_pocket_distance=shot.target_pocket_distance.distance, | |
intended_pocket_type=shot.intended_pocket.type, | |
) | |
) | |
except Exception as e: | |
raise e | |
session.add(shot) | |
logger.info("Committing transaction") | |
session.commit() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment