Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Created September 18, 2023 10:24
Show Gist options
  • Save frederik-encord/78e3976f095c0e0acfd5e61d05d1f3a3 to your computer and use it in GitHub Desktop.
Save frederik-encord/78e3976f095c0e0acfd5e61d05d1f3a3 to your computer and use it in GitHub Desktop.
import sys
from datetime import datetime
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional, TypeVar
from uuid import UUID
import orjson
import typer
from encord import EncordUserClient
from encord.objects import (
Classification,
ClassificationInstance,
NestableOption,
RadioAttribute,
)
from encord.project import LabelRowV2
from encord.project import Project as EncordProject
from encord_active.db.models import (
Project,
ProjectDataMetadata,
ProjectDataUnitMetadata,
ProjectTag,
ProjectTaggedDataUnit,
get_engine,
)
from encord_active.lib.common.data_utils import collect_async
from InquirerPy import inquirer as inq
from InquirerPy.base.control import Choice
from loguru import logger
from rich.prompt import Prompt
from sqlmodel import Session, select
logger.remove()
FORMAT = "{time:MMMM D, YYYY > HH:mm:ss!UTC} | {level} | {message} | {extra}"
logger.add("updates.log", format=FORMAT, level="DEBUG")
logger.add(sys.stderr, format=FORMAT, level="INFO")
class UpdateStatus(str, Enum):
SUCCESS = "Updated classification successfully"
SUCCESS_CLASSIFICATION_EXISTS = "Correct classification exists already"
MISSING_LABEL_ROW = "Couldn't find label row in Encord Project"
NO_EMPTY_LRS_LEFT = "No empty label rows left to assign label rows to."
LABEL_INITIALIZATION = "Couldn't initialize label row"
SKIPPING_TAGS_NOT_IN_ONTOLOGY = (
"Skipping tags that were not selected for (or in existing) classification"
)
LABEL_SAVE = "Failed to save label row"
UNKNOWN_ERROR = "Unknown error - see logs for more details"
class TagTuple(NamedTuple):
label_hash: UUID
du_hash: UUID
frame: Optional[int]
tag_name: str
def classify_label_row(
tag_tuple: TagTuple,
lr_lookup: dict[UUID, LabelRowV2],
tag_clf_obj: Classification,
) -> tuple[TagTuple, UpdateStatus]:
with logger.contextualize(t=tag_tuple):
try:
label_row = lr_lookup.get(tag_tuple.label_hash)
if label_row is None:
logger.warning(UpdateStatus.MISSING_LABEL_ROW)
return tag_tuple, UpdateStatus.MISSING_LABEL_ROW
# Try next label row with that title
if not label_row.is_labelling_initialised:
try:
label_row.initialise_labels()
except:
logger.warning(UpdateStatus.LABEL_INITIALIZATION)
return tag_tuple, UpdateStatus.LABEL_INITIALIZATION
instances = label_row.get_classification_instances(
filter_ontology_classification=tag_clf_obj
)
if instances: # type: ignore
logger.debug(UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS)
return tag_tuple, UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS
instance = tag_clf_obj.create_instance()
try:
answer = tag_clf_obj.get_child_by_title(
tag_tuple.tag_name, NestableOption
)
except:
logger.debug(UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY)
return tag_tuple, UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY
instance.set_answer(answer)
instance.set_for_frames(tag_tuple.frame or 0)
label_row.add_classification_instance(instance)
try:
label_row.save()
except:
logger.warning(UpdateStatus.LABEL_SAVE)
return tag_tuple, UpdateStatus.LABEL_SAVE
logger.debug(UpdateStatus.SUCCESS)
return tag_tuple, UpdateStatus.SUCCESS
except Exception as e:
logger.warning(UpdateStatus.UNKNOWN_ERROR, err=e)
return tag_tuple, UpdateStatus.UNKNOWN_ERROR
T = TypeVar("T")
def fuzzy(options: list[T], transform: Callable[[T], str], description: str) -> T:
print(options)
choices = list(
map(lambda o: Choice(o[0], name=transform(o[1])), enumerate(options))
)
res = inq.fuzzy(
description,
choices=choices,
transformer=lambda dbproj: dbproj[0],
multiselect=False,
vi_mode=True,
).execute()
if res is None:
raise typer.Abort()
return options[res]
def serialize_json(obj: Any):
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, TagTuple):
return {
"label_hash": str(obj.label_hash),
"du_hash": str(obj.du_hash),
"frame": obj.frame,
"tag_name": obj.tag_name,
}
raise TypeError
def get_tags_from_db(db_path: Path) -> tuple[UUID, list[TagTuple]]:
engine = get_engine(db_path, use_alembic=False)
with Session(engine) as sess:
db_projects = sess.exec(
select(Project.project_name, Project.project_hash)
).all()
if not db_projects:
raise typer.Abort("Couldn't find any projects in the DB")
if len(db_projects) > 1:
project_hash_uuid = fuzzy(
db_projects, lambda dbp: dbp[0], "Select project"
)[1]
else:
project_hash_uuid = db_projects[0][1]
logger.info(f"Converting tags for project hash {project_hash_uuid}")
tags = list(
map(
lambda t: TagTuple(*t),
sess.exec(
select(
ProjectDataMetadata.label_hash,
ProjectTaggedDataUnit.du_hash,
ProjectTaggedDataUnit.frame,
ProjectTag.name,
).where(
ProjectDataMetadata.project_hash == project_hash_uuid,
ProjectDataUnitMetadata.project_hash == project_hash_uuid,
ProjectTag.project_hash == project_hash_uuid,
ProjectTaggedDataUnit.project_hash == project_hash_uuid,
ProjectDataMetadata.data_hash
== ProjectDataUnitMetadata.data_hash,
ProjectTaggedDataUnit.du_hash
== ProjectDataUnitMetadata.du_hash,
ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash,
)
).all(),
)
)
logger.debug(f"Found {len(tags)} tags to convert")
return project_hash_uuid, tags
def add_new_radio_classification(
client: EncordUserClient, project: EncordProject, tags: list[TagTuple]
) -> Classification:
ontology = client.get_ontology(project.ontology_hash)
ontology_question = Prompt.ask("What should the ontology question be?")
classification = ontology.structure.add_classification()
radio_attribute = classification.add_attribute(
cls=RadioAttribute, name=ontology_question
)
answers = set(map(lambda t: t.tag_name, tags))
selected_tags = inq.fuzzy(
"What tags should be included?", choices=answers, multiselect=True, vi_mode=True
).execute()
for answer in selected_tags:
radio_attribute.add_option(answer)
ontology.save()
project.refetch_ontology()
return project.ontology_structure.get_child_by_hash(
classification.feature_node_hash, type_=Classification
)
def main(
db_path: Path = Path("encord-active.sqlite"),
ssh_key_path: Path = Path("~/.ssh/encord_frederik"),
):
if not db_path.is_file():
db_path = Path(
Prompt.ask(
"What's the path to the Encord Active DB containing the tags",
default="./encord-active.sqlite",
)
)
if not db_path.is_dir():
typer.Abort("DB path missing")
if not ssh_key_path.is_file():
ssh_key_path = Path(
Prompt.ask(
"What private ssh-key should be used?",
default="~/.ssh/id_ed25519",
)
).expanduser()
if not db_path.is_dir():
typer.Abort("SSH key path missing.")
project_hash_uuid, tags = get_tags_from_db(db_path)
client = EncordUserClient.create_with_ssh_private_key(
ssh_key_path.expanduser().read_text()
)
project = client.get_project(project_hash=str(project_hash_uuid))
use_existing_ontology = (
"y" in Prompt.ask("Would you like to use an existing ontology? [y/n]").lower()
)
if use_existing_ontology:
ontology = project.ontology_structure
options = [
(a, c)
for c in ontology.classifications
for a in c.attributes
if isinstance(a, RadioAttribute)
]
_, selected_classification = fuzzy(
options, lambda x: x[0].name, "Select existing classification"
)
tag_clf_obj = ontology.get_child_by_hash(
selected_classification.feature_node_hash, type_=Classification
)
else:
tag_clf_obj = add_new_radio_classification(client, project, tags)
lr_lookup: dict[UUID, LabelRowV2] = {
UUID(lr.label_hash): lr
for lr in project.list_label_rows_v2()
if lr.label_hash is not None
}
results: list[tuple[TagTuple, UpdateStatus]] = collect_async(
partial(classify_label_row, lr_lookup=lr_lookup, tag_clf_obj=tag_clf_obj),
map(lambda t: (t,), tags),
desc="Tagging data with tag hashes",
)
Path(f"updated_{datetime.now().isoformat()}.json").write_bytes(
orjson.dumps(results, default=serialize_json)
)
if __name__ == "__main__":
typer.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment