Instantly share code, notes, and snippets.
Created
September 18, 2023 10:24
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save frederik-encord/78e3976f095c0e0acfd5e61d05d1f3a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from datetime import datetime | |
from enum import Enum | |
from functools import partial | |
from pathlib import Path | |
from typing import Any, Callable, NamedTuple, Optional, TypeVar | |
from uuid import UUID | |
import orjson | |
import typer | |
from encord import EncordUserClient | |
from encord.objects import ( | |
Classification, | |
ClassificationInstance, | |
NestableOption, | |
RadioAttribute, | |
) | |
from encord.project import LabelRowV2 | |
from encord.project import Project as EncordProject | |
from encord_active.db.models import ( | |
Project, | |
ProjectDataMetadata, | |
ProjectDataUnitMetadata, | |
ProjectTag, | |
ProjectTaggedDataUnit, | |
get_engine, | |
) | |
from encord_active.lib.common.data_utils import collect_async | |
from InquirerPy import inquirer as inq | |
from InquirerPy.base.control import Choice | |
from loguru import logger | |
from rich.prompt import Prompt | |
from sqlmodel import Session, select | |
logger.remove() | |
FORMAT = "{time:MMMM D, YYYY > HH:mm:ss!UTC} | {level} | {message} | {extra}" | |
logger.add("updates.log", format=FORMAT, level="DEBUG") | |
logger.add(sys.stderr, format=FORMAT, level="INFO") | |
class UpdateStatus(str, Enum): | |
SUCCESS = "Updated classification successfully" | |
SUCCESS_CLASSIFICATION_EXISTS = "Correct classification exists already" | |
MISSING_LABEL_ROW = "Couldn't find label row in Encord Project" | |
NO_EMPTY_LRS_LEFT = "No empty label rows left to assign label rows to." | |
LABEL_INITIALIZATION = "Couldn't initialize label row" | |
SKIPPING_TAGS_NOT_IN_ONTOLOGY = ( | |
"Skipping tags that were not selected for (or in existing) classification" | |
) | |
LABEL_SAVE = "Failed to save label row" | |
UNKNOWN_ERROR = "Unknown error - see logs for more details" | |
class TagTuple(NamedTuple): | |
label_hash: UUID | |
du_hash: UUID | |
frame: Optional[int] | |
tag_name: str | |
def classify_label_row( | |
tag_tuple: TagTuple, | |
lr_lookup: dict[UUID, LabelRowV2], | |
tag_clf_obj: Classification, | |
) -> tuple[TagTuple, UpdateStatus]: | |
with logger.contextualize(t=tag_tuple): | |
try: | |
label_row = lr_lookup.get(tag_tuple.label_hash) | |
if label_row is None: | |
logger.warning(UpdateStatus.MISSING_LABEL_ROW) | |
return tag_tuple, UpdateStatus.MISSING_LABEL_ROW | |
# Try next label row with that title | |
if not label_row.is_labelling_initialised: | |
try: | |
label_row.initialise_labels() | |
except: | |
logger.warning(UpdateStatus.LABEL_INITIALIZATION) | |
return tag_tuple, UpdateStatus.LABEL_INITIALIZATION | |
instances = label_row.get_classification_instances( | |
filter_ontology_classification=tag_clf_obj | |
) | |
if instances: # type: ignore | |
logger.debug(UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS) | |
return tag_tuple, UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS | |
instance = tag_clf_obj.create_instance() | |
try: | |
answer = tag_clf_obj.get_child_by_title( | |
tag_tuple.tag_name, NestableOption | |
) | |
except: | |
logger.debug(UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY) | |
return tag_tuple, UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY | |
instance.set_answer(answer) | |
instance.set_for_frames(tag_tuple.frame or 0) | |
label_row.add_classification_instance(instance) | |
try: | |
label_row.save() | |
except: | |
logger.warning(UpdateStatus.LABEL_SAVE) | |
return tag_tuple, UpdateStatus.LABEL_SAVE | |
logger.debug(UpdateStatus.SUCCESS) | |
return tag_tuple, UpdateStatus.SUCCESS | |
except Exception as e: | |
logger.warning(UpdateStatus.UNKNOWN_ERROR, err=e) | |
return tag_tuple, UpdateStatus.UNKNOWN_ERROR | |
T = TypeVar("T") | |
def fuzzy(options: list[T], transform: Callable[[T], str], description: str) -> T: | |
print(options) | |
choices = list( | |
map(lambda o: Choice(o[0], name=transform(o[1])), enumerate(options)) | |
) | |
res = inq.fuzzy( | |
description, | |
choices=choices, | |
transformer=lambda dbproj: dbproj[0], | |
multiselect=False, | |
vi_mode=True, | |
).execute() | |
if res is None: | |
raise typer.Abort() | |
return options[res] | |
def serialize_json(obj: Any): | |
if isinstance(obj, UUID): | |
return str(obj) | |
if isinstance(obj, TagTuple): | |
return { | |
"label_hash": str(obj.label_hash), | |
"du_hash": str(obj.du_hash), | |
"frame": obj.frame, | |
"tag_name": obj.tag_name, | |
} | |
raise TypeError | |
def get_tags_from_db(db_path: Path) -> tuple[UUID, list[TagTuple]]: | |
engine = get_engine(db_path, use_alembic=False) | |
with Session(engine) as sess: | |
db_projects = sess.exec( | |
select(Project.project_name, Project.project_hash) | |
).all() | |
if not db_projects: | |
raise typer.Abort("Couldn't find any projects in the DB") | |
if len(db_projects) > 1: | |
project_hash_uuid = fuzzy( | |
db_projects, lambda dbp: dbp[0], "Select project" | |
)[1] | |
else: | |
project_hash_uuid = db_projects[0][1] | |
logger.info(f"Converting tags for project hash {project_hash_uuid}") | |
tags = list( | |
map( | |
lambda t: TagTuple(*t), | |
sess.exec( | |
select( | |
ProjectDataMetadata.label_hash, | |
ProjectTaggedDataUnit.du_hash, | |
ProjectTaggedDataUnit.frame, | |
ProjectTag.name, | |
).where( | |
ProjectDataMetadata.project_hash == project_hash_uuid, | |
ProjectDataUnitMetadata.project_hash == project_hash_uuid, | |
ProjectTag.project_hash == project_hash_uuid, | |
ProjectTaggedDataUnit.project_hash == project_hash_uuid, | |
ProjectDataMetadata.data_hash | |
== ProjectDataUnitMetadata.data_hash, | |
ProjectTaggedDataUnit.du_hash | |
== ProjectDataUnitMetadata.du_hash, | |
ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash, | |
) | |
).all(), | |
) | |
) | |
logger.debug(f"Found {len(tags)} tags to convert") | |
return project_hash_uuid, tags | |
def add_new_radio_classification( | |
client: EncordUserClient, project: EncordProject, tags: list[TagTuple] | |
) -> Classification: | |
ontology = client.get_ontology(project.ontology_hash) | |
ontology_question = Prompt.ask("What should the ontology question be?") | |
classification = ontology.structure.add_classification() | |
radio_attribute = classification.add_attribute( | |
cls=RadioAttribute, name=ontology_question | |
) | |
answers = set(map(lambda t: t.tag_name, tags)) | |
selected_tags = inq.fuzzy( | |
"What tags should be included?", choices=answers, multiselect=True, vi_mode=True | |
).execute() | |
for answer in selected_tags: | |
radio_attribute.add_option(answer) | |
ontology.save() | |
project.refetch_ontology() | |
return project.ontology_structure.get_child_by_hash( | |
classification.feature_node_hash, type_=Classification | |
) | |
def main( | |
db_path: Path = Path("encord-active.sqlite"), | |
ssh_key_path: Path = Path("~/.ssh/encord_frederik"), | |
): | |
if not db_path.is_file(): | |
db_path = Path( | |
Prompt.ask( | |
"What's the path to the Encord Active DB containing the tags", | |
default="./encord-active.sqlite", | |
) | |
) | |
if not db_path.is_dir(): | |
typer.Abort("DB path missing") | |
if not ssh_key_path.is_file(): | |
ssh_key_path = Path( | |
Prompt.ask( | |
"What private ssh-key should be used?", | |
default="~/.ssh/id_ed25519", | |
) | |
).expanduser() | |
if not db_path.is_dir(): | |
typer.Abort("SSH key path missing.") | |
project_hash_uuid, tags = get_tags_from_db(db_path) | |
client = EncordUserClient.create_with_ssh_private_key( | |
ssh_key_path.expanduser().read_text() | |
) | |
project = client.get_project(project_hash=str(project_hash_uuid)) | |
use_existing_ontology = ( | |
"y" in Prompt.ask("Would you like to use an existing ontology? [y/n]").lower() | |
) | |
if use_existing_ontology: | |
ontology = project.ontology_structure | |
options = [ | |
(a, c) | |
for c in ontology.classifications | |
for a in c.attributes | |
if isinstance(a, RadioAttribute) | |
] | |
_, selected_classification = fuzzy( | |
options, lambda x: x[0].name, "Select existing classification" | |
) | |
tag_clf_obj = ontology.get_child_by_hash( | |
selected_classification.feature_node_hash, type_=Classification | |
) | |
else: | |
tag_clf_obj = add_new_radio_classification(client, project, tags) | |
lr_lookup: dict[UUID, LabelRowV2] = { | |
UUID(lr.label_hash): lr | |
for lr in project.list_label_rows_v2() | |
if lr.label_hash is not None | |
} | |
results: list[tuple[TagTuple, UpdateStatus]] = collect_async( | |
partial(classify_label_row, lr_lookup=lr_lookup, tag_clf_obj=tag_clf_obj), | |
map(lambda t: (t,), tags), | |
desc="Tagging data with tag hashes", | |
) | |
Path(f"updated_{datetime.now().isoformat()}.json").write_bytes( | |
orjson.dumps(results, default=serialize_json) | |
) | |
if __name__ == "__main__": | |
typer.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment