Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Created September 18, 2023 12:00
Show Gist options
  • Save frederik-encord/a922143648389f75e978e0d71be409a4 to your computer and use it in GitHub Desktop.
Save frederik-encord/a922143648389f75e978e0d71be409a4 to your computer and use it in GitHub Desktop.
"""
If you have tagged your data (with data tags), this script can help you
converting the tags into Radio Button Classifications.
Prerequisites:
1. You need to be an admin of both the project and the ontology associated
with the project.
2. You need to have `encord-active` installed and working in your shell.
Installation:
The script uses only dependencies on from `encord-active`. To install
`encord-active`, run
```shell
python -m venv venv
source venv/bin/activate
python -m pip install encord-active
```
Running the script:
This command will help you with running the script:
```shell
python transform_tags_into_classifications.py --help
```
"""
import sys
from collections import Counter
from datetime import datetime
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional, TypeVar
from uuid import UUID
import orjson
import rich
import typer
from encord import EncordUserClient
from encord.objects import (
Classification,
ClassificationInstance,
NestableOption,
RadioAttribute,
)
from encord.project import LabelRowV2
from encord.project import Project as EncordProject
from encord_active.db.models import (
Project,
ProjectDataMetadata,
ProjectDataUnitMetadata,
ProjectTag,
ProjectTaggedDataUnit,
get_engine,
)
from encord_active.lib.common.data_utils import collect_async
from InquirerPy import inquirer as inq
from InquirerPy.base.control import Choice
from loguru import logger
from rich.panel import Panel
from rich.prompt import Prompt
from sqlmodel import Session, select
logger.remove()
FORMAT = "{time:MMMM D, YYYY > HH:mm:ss!UTC} | {level} | {message} | {extra}"
class UpdateStatus(str, Enum):
SUCCESS = "Added classification successfully"
SUCCESS_CLASSIFICATION_EXISTS = "Correct classification exists already"
MISSING_LABEL_ROW = "Couldn't find label row in Encord Project"
LABEL_INITIALIZATION = "Couldn't initialize label row - probably due to network errors. Try running script again."
SKIPPING_TAGS_NOT_IN_ONTOLOGY = (
"Skipping tags that were not selected for (or in existing) classification"
)
LABEL_SAVE = "Failed to save label row - probably due to network errors. Try running script again."
UNKNOWN_ERROR = "Unknown error - see logs for more details"
class TagTuple(NamedTuple):
label_hash: UUID
du_hash: UUID
frame: Optional[int]
tag_name: str
def classify_label_row(
tag_tuple: TagTuple,
lr_lookup: dict[UUID, LabelRowV2],
tag_clf_obj: Classification,
) -> tuple[TagTuple, UpdateStatus]:
with logger.contextualize(t=tag_tuple):
try:
label_row = lr_lookup.get(tag_tuple.label_hash)
if label_row is None:
logger.warning(UpdateStatus.MISSING_LABEL_ROW)
return tag_tuple, UpdateStatus.MISSING_LABEL_ROW
# Try next label row with that title
if not label_row.is_labelling_initialised:
try:
label_row.initialise_labels()
except:
logger.warning(UpdateStatus.LABEL_INITIALIZATION)
return tag_tuple, UpdateStatus.LABEL_INITIALIZATION
instances = label_row.get_classification_instances(
filter_ontology_classification=tag_clf_obj
)
if instances: # type: ignore
logger.debug(UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS)
return tag_tuple, UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS
instance = tag_clf_obj.create_instance()
try:
answer = tag_clf_obj.get_child_by_title(
tag_tuple.tag_name, NestableOption
)
except:
logger.debug(UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY)
return tag_tuple, UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY
instance.set_answer(answer)
instance.set_for_frames(tag_tuple.frame or 0)
label_row.add_classification_instance(instance)
try:
label_row.save()
except:
logger.warning(UpdateStatus.LABEL_SAVE)
return tag_tuple, UpdateStatus.LABEL_SAVE
logger.debug(UpdateStatus.SUCCESS)
return tag_tuple, UpdateStatus.SUCCESS
except Exception as e:
logger.warning(UpdateStatus.UNKNOWN_ERROR, err=e)
return tag_tuple, UpdateStatus.UNKNOWN_ERROR
T = TypeVar("T")
def fuzzy(options: list[T], transform: Callable[[T], str], description: str) -> T:
choices = list(
map(lambda o: Choice(o[0], name=transform(o[1])), enumerate(options))
)
res = inq.fuzzy(
description,
choices=choices,
transformer=lambda dbproj: dbproj[0],
multiselect=False,
vi_mode=True,
).execute()
if res is None:
raise typer.Abort()
return options[res]
def serialize_json(obj: Any):
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, TagTuple):
return {
"label_hash": str(obj.label_hash),
"du_hash": str(obj.du_hash),
"frame": obj.frame,
"tag_name": obj.tag_name,
}
raise TypeError
def get_tags_from_db(db_path: Path) -> tuple[UUID, list[TagTuple]]:
engine = get_engine(db_path, use_alembic=False)
with Session(engine) as sess:
db_projects = sess.exec(
select(Project.project_name, Project.project_hash)
).all()
if not db_projects:
raise typer.Abort("Couldn't find any projects in the DB")
if len(db_projects) > 1:
project_hash_uuid = fuzzy(
db_projects, lambda dbp: dbp[0], "Select project"
)[1]
else:
project_hash_uuid = db_projects[0][1]
logger.info(f"Converting tags for project hash {project_hash_uuid}")
tags = list(
map(
lambda t: TagTuple(*t),
sess.exec(
select(
ProjectDataMetadata.label_hash,
ProjectTaggedDataUnit.du_hash,
ProjectTaggedDataUnit.frame,
ProjectTag.name,
).where(
ProjectDataMetadata.project_hash == project_hash_uuid,
ProjectDataUnitMetadata.project_hash == project_hash_uuid,
ProjectTag.project_hash == project_hash_uuid,
ProjectTaggedDataUnit.project_hash == project_hash_uuid,
ProjectDataMetadata.data_hash
== ProjectDataUnitMetadata.data_hash,
ProjectTaggedDataUnit.du_hash
== ProjectDataUnitMetadata.du_hash,
ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash,
)
).all(),
)
)
logger.debug(f"Found {len(tags)} tags to convert")
return project_hash_uuid, tags
def add_new_radio_classification(
client: EncordUserClient, project: EncordProject, tags: list[TagTuple]
) -> Classification:
ontology = client.get_ontology(project.ontology_hash)
ontology_question = Prompt.ask("What should the ontology question be?")
classification = ontology.structure.add_classification()
radio_attribute = classification.add_attribute(
cls=RadioAttribute, name=ontology_question
)
answers = set(map(lambda t: t.tag_name, tags))
selected_tags = inq.fuzzy(
"What tags should be included?",
choices=answers,
multiselect=True,
vi_mode=True,
instruction="Use `tab` to select",
).execute()
for answer in selected_tags:
radio_attribute.add_option(answer)
ontology.save()
project.refetch_ontology()
return project.ontology_structure.get_child_by_hash(
classification.feature_node_hash, type_=Classification
)
def main(
db_path: Path = Path("encord-active.sqlite"),
ssh_key_path: Path = Path("~/.ssh/encord_frederik"),
verbose: bool = False,
very_verbose: bool = False,
log_to_file: bool = False,
):
"""
Converts tags from an Encord Active project into radio button classifications
on Encord Annotate. The script will take the user through a couple of
options:
1. A path to an `encord-active.sqlite` database is needed. Provide it by
the `db_path` option of answer the associated prompt.
2. A path to a private ssh key is needed. Provide it by the `ssh_key_path`
option or answer the associated prompt.
3. Select which project to convert tags for. Follow the selection prompt.
4. Either tags can be associated with an existing question in the
ontology or the script can automatically create a new question in the
ontology. If a new one is selected, you will be prompted to select
which tags should go into the ontology as answers to the question.
Once the script is done, it will print stats on the interactions with Encord
Annotate.
All transactions are stored in an `updated_{timestamp}.json` file for easy
access afterwards.
Args:
db_path: Path to the DB containing the project and it's tags.
ssh_key_path: Path to the private ssh key associated with Encord.
verbose: Set this flag if you want the terminal to print errors to the
console.
very_verbose: Set this flag if you want the terminal to print the status
of all tag-to-classification actions.
log_to_file: Set this flag to store the tag-to-classification status of
every update in a log file in your current working directory.
"""
ts = datetime.now()
if verbose or very_verbose:
logger.add(sys.stderr, format=FORMAT, level="DEBUG" if very_verbose else "INFO")
log_file_name = None
if log_to_file:
log_file_name = f"updates_{ts.isoformat()}.log"
logger.add(log_file_name, format=FORMAT, level="DEBUG")
logger.info("Storing logs at {log_file_name}")
if not db_path.is_file():
db_path = Path(
Prompt.ask(
"What's the path to the Encord Active DB containing the tags",
default="./encord-active.sqlite",
)
)
if not db_path.is_dir():
typer.Abort("DB path missing")
if not ssh_key_path.is_file():
ssh_key_path = Path(
Prompt.ask(
"What private ssh-key should be used?",
default="~/.ssh/id_ed25519",
)
).expanduser()
if not db_path.is_dir():
typer.Abort("SSH key path missing.")
project_hash_uuid, tags = get_tags_from_db(db_path)
client = EncordUserClient.create_with_ssh_private_key(
ssh_key_path.expanduser().read_text()
)
project = client.get_project(project_hash=str(project_hash_uuid))
use_existing_ontology = typer.confirm("Would you like to use an existing ontology?")
if use_existing_ontology:
ontology = project.ontology_structure
options = [
(a, c)
for c in ontology.classifications
for a in c.attributes
if isinstance(a, RadioAttribute)
]
_, selected_classification = fuzzy(
options, lambda x: x[0].name, "Select existing classification"
)
tag_clf_obj = ontology.get_child_by_hash(
selected_classification.feature_node_hash, type_=Classification
)
else:
tag_clf_obj = add_new_radio_classification(client, project, tags)
lr_lookup: dict[UUID, LabelRowV2] = {
UUID(lr.label_hash): lr
for lr in project.list_label_rows_v2()
if lr.label_hash is not None
}
results: list[tuple[TagTuple, UpdateStatus]] = collect_async(
partial(classify_label_row, lr_lookup=lr_lookup, tag_clf_obj=tag_clf_obj),
map(lambda t: (t,), tags),
desc="Tagging data with tag hashes",
)
results_json = Path(f"updated_{ts.now().isoformat()}.json")
results_json.write_bytes(orjson.dumps(results, default=serialize_json))
stats: list[str] = []
for key, count in Counter(map(lambda res: res[1], results)).items():
cat_stats = f"[green]{key}:[/green] {count:d}\n"
cat_stats += f"\t[blue]Description:[/blue] {key.value}"
stats.append(cat_stats)
stats_str = "\n".join(stats)
panel_text = f"""
[blue]Update stats:[/blue]
{stats_str}
Find details here:
JSON results stored in [magenta]`{results_json}`[/magenta]
"""
if log_to_file:
panel_text += f"Log stored in [magenta]{log_file_name}[/magenta]"
rich.print(
Panel(
panel_text,
title="Classification Stats",
)
)
if __name__ == "__main__":
typer.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment