Instantly share code, notes, and snippets.
Created
September 18, 2023 12:03
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save frederik-encord/b8eaa3dacf0160d42a1490a7b83274ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
If you have tagged your data (with data tags), this script can help you | |
converting the tags into Radio Button Classifications. | |
Prerequisites: | |
1. You need to be an admin of both the project and the ontology associated | |
with the project. | |
2. You need to have `encord-active` installed and working in your shell. | |
Installation: | |
The script uses only dependencies on from `encord-active`. To install | |
`encord-active`, run | |
```shell | |
python -m venv venv | |
source venv/bin/activate | |
python -m pip install encord-active | |
``` | |
Running the script: | |
Easiest way is download this gist and putting it next to your | |
`encord-active.sqlite` file. Next you run | |
```shell | |
python transform_tags_into_classifications.py | |
``` | |
and follow the instructions. | |
Otherwise, this command will help you with running the script: | |
```shell | |
python transform_tags_into_classifications.py --help | |
``` | |
""" | |
import sys | |
from collections import Counter | |
from datetime import datetime | |
from enum import Enum | |
from functools import partial | |
from pathlib import Path | |
from typing import Any, Callable, NamedTuple, Optional, TypeVar | |
from uuid import UUID | |
import orjson | |
import rich | |
import typer | |
from encord import EncordUserClient | |
from encord.objects import ( | |
Classification, | |
ClassificationInstance, | |
NestableOption, | |
RadioAttribute, | |
) | |
from encord.project import LabelRowV2 | |
from encord.project import Project as EncordProject | |
from encord_active.db.models import ( | |
Project, | |
ProjectDataMetadata, | |
ProjectDataUnitMetadata, | |
ProjectTag, | |
ProjectTaggedDataUnit, | |
get_engine, | |
) | |
from encord_active.lib.common.data_utils import collect_async | |
from InquirerPy import inquirer as inq | |
from InquirerPy.base.control import Choice | |
from loguru import logger | |
from rich.panel import Panel | |
from rich.prompt import Prompt | |
from sqlmodel import Session, select | |
logger.remove() | |
FORMAT = "{time:MMMM D, YYYY > HH:mm:ss!UTC} | {level} | {message} | {extra}" | |
class UpdateStatus(str, Enum): | |
SUCCESS = "Added classification successfully" | |
SUCCESS_CLASSIFICATION_EXISTS = "Correct classification exists already" | |
MISSING_LABEL_ROW = "Couldn't find label row in Encord Project" | |
LABEL_INITIALIZATION = "Couldn't initialize label row - probably due to network errors. Try running script again." | |
SKIPPING_TAGS_NOT_IN_ONTOLOGY = ( | |
"Skipping tags that were not selected for (or in existing) classification" | |
) | |
LABEL_SAVE = "Failed to save label row - probably due to network errors. Try running script again." | |
UNKNOWN_ERROR = "Unknown error - see logs for more details" | |
class TagTuple(NamedTuple): | |
label_hash: UUID | |
du_hash: UUID | |
frame: Optional[int] | |
tag_name: str | |
def classify_label_row( | |
tag_tuple: TagTuple, | |
lr_lookup: dict[UUID, LabelRowV2], | |
tag_clf_obj: Classification, | |
) -> tuple[TagTuple, UpdateStatus]: | |
with logger.contextualize(t=tag_tuple): | |
try: | |
label_row = lr_lookup.get(tag_tuple.label_hash) | |
if label_row is None: | |
logger.warning(UpdateStatus.MISSING_LABEL_ROW) | |
return tag_tuple, UpdateStatus.MISSING_LABEL_ROW | |
# Try next label row with that title | |
if not label_row.is_labelling_initialised: | |
try: | |
label_row.initialise_labels() | |
except: | |
logger.warning(UpdateStatus.LABEL_INITIALIZATION) | |
return tag_tuple, UpdateStatus.LABEL_INITIALIZATION | |
instances = label_row.get_classification_instances( | |
filter_ontology_classification=tag_clf_obj | |
) | |
if instances: # type: ignore | |
logger.debug(UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS) | |
return tag_tuple, UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS | |
instance = tag_clf_obj.create_instance() | |
try: | |
answer = tag_clf_obj.get_child_by_title( | |
tag_tuple.tag_name, NestableOption | |
) | |
except: | |
logger.debug(UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY) | |
return tag_tuple, UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY | |
instance.set_answer(answer) | |
instance.set_for_frames(tag_tuple.frame or 0) | |
label_row.add_classification_instance(instance) | |
try: | |
label_row.save() | |
except: | |
logger.warning(UpdateStatus.LABEL_SAVE) | |
return tag_tuple, UpdateStatus.LABEL_SAVE | |
logger.debug(UpdateStatus.SUCCESS) | |
return tag_tuple, UpdateStatus.SUCCESS | |
except Exception as e: | |
logger.warning(UpdateStatus.UNKNOWN_ERROR, err=e) | |
return tag_tuple, UpdateStatus.UNKNOWN_ERROR | |
T = TypeVar("T") | |
def fuzzy(options: list[T], transform: Callable[[T], str], description: str) -> T: | |
choices = list( | |
map(lambda o: Choice(o[0], name=transform(o[1])), enumerate(options)) | |
) | |
res = inq.fuzzy( | |
description, | |
choices=choices, | |
transformer=lambda dbproj: dbproj[0], | |
multiselect=False, | |
vi_mode=True, | |
).execute() | |
if res is None: | |
raise typer.Abort() | |
return options[res] | |
def serialize_json(obj: Any): | |
if isinstance(obj, UUID): | |
return str(obj) | |
if isinstance(obj, TagTuple): | |
return { | |
"label_hash": str(obj.label_hash), | |
"du_hash": str(obj.du_hash), | |
"frame": obj.frame, | |
"tag_name": obj.tag_name, | |
} | |
raise TypeError | |
def get_tags_from_db(db_path: Path) -> tuple[UUID, list[TagTuple]]: | |
engine = get_engine(db_path, use_alembic=False) | |
with Session(engine) as sess: | |
db_projects = sess.exec( | |
select(Project.project_name, Project.project_hash) | |
).all() | |
if not db_projects: | |
raise typer.Abort("Couldn't find any projects in the DB") | |
if len(db_projects) > 1: | |
project_hash_uuid = fuzzy( | |
db_projects, lambda dbp: dbp[0], "Select project" | |
)[1] | |
else: | |
project_hash_uuid = db_projects[0][1] | |
logger.info(f"Converting tags for project hash {project_hash_uuid}") | |
tags = list( | |
map( | |
lambda t: TagTuple(*t), | |
sess.exec( | |
select( | |
ProjectDataMetadata.label_hash, | |
ProjectTaggedDataUnit.du_hash, | |
ProjectTaggedDataUnit.frame, | |
ProjectTag.name, | |
).where( | |
ProjectDataMetadata.project_hash == project_hash_uuid, | |
ProjectDataUnitMetadata.project_hash == project_hash_uuid, | |
ProjectTag.project_hash == project_hash_uuid, | |
ProjectTaggedDataUnit.project_hash == project_hash_uuid, | |
ProjectDataMetadata.data_hash | |
== ProjectDataUnitMetadata.data_hash, | |
ProjectTaggedDataUnit.du_hash | |
== ProjectDataUnitMetadata.du_hash, | |
ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash, | |
) | |
).all(), | |
) | |
) | |
logger.debug(f"Found {len(tags)} tags to convert") | |
return project_hash_uuid, tags | |
def add_new_radio_classification( | |
client: EncordUserClient, project: EncordProject, tags: list[TagTuple] | |
) -> Classification: | |
ontology = client.get_ontology(project.ontology_hash) | |
ontology_question = Prompt.ask("What should the ontology question be?") | |
classification = ontology.structure.add_classification() | |
radio_attribute = classification.add_attribute( | |
cls=RadioAttribute, name=ontology_question | |
) | |
answers = set(map(lambda t: t.tag_name, tags)) | |
selected_tags = inq.fuzzy( | |
"What tags should be included?", | |
choices=answers, | |
multiselect=True, | |
vi_mode=True, | |
instruction="Use `tab` to select", | |
).execute() | |
for answer in selected_tags: | |
radio_attribute.add_option(answer) | |
ontology.save() | |
project.refetch_ontology() | |
return project.ontology_structure.get_child_by_hash( | |
classification.feature_node_hash, type_=Classification | |
) | |
def main( | |
db_path: Path = Path("encord-active.sqlite"), | |
ssh_key_path: Path = Path("~/.ssh/encord_frederik"), | |
verbose: bool = False, | |
very_verbose: bool = False, | |
log_to_file: bool = False, | |
): | |
""" | |
Converts tags from an Encord Active project into radio button classifications | |
on Encord Annotate. The script will take the user through a couple of | |
options: | |
1. A path to an `encord-active.sqlite` database is needed. Provide it by | |
the `db_path` option of answer the associated prompt. | |
2. A path to a private ssh key is needed. Provide it by the `ssh_key_path` | |
option or answer the associated prompt. | |
3. Select which project to convert tags for. Follow the selection prompt. | |
4. Either tags can be associated with an existing question in the | |
ontology or the script can automatically create a new question in the | |
ontology. If a new one is selected, you will be prompted to select | |
which tags should go into the ontology as answers to the question. | |
Once the script is done, it will print stats on the interactions with Encord | |
Annotate. | |
All transactions are stored in an `updated_{timestamp}.json` file for easy | |
access afterwards. | |
Args: | |
db_path: Path to the DB containing the project and it's tags. | |
ssh_key_path: Path to the private ssh key associated with Encord. | |
verbose: Set this flag if you want the terminal to print errors to the | |
console. | |
very_verbose: Set this flag if you want the terminal to print the status | |
of all tag-to-classification actions. | |
log_to_file: Set this flag to store the tag-to-classification status of | |
every update in a log file in your current working directory. | |
""" | |
ts = datetime.now() | |
if verbose or very_verbose: | |
logger.add(sys.stderr, format=FORMAT, level="DEBUG" if very_verbose else "INFO") | |
log_file_name = None | |
if log_to_file: | |
log_file_name = f"updates_{ts.isoformat()}.log" | |
logger.add(log_file_name, format=FORMAT, level="DEBUG") | |
logger.info("Storing logs at {log_file_name}") | |
if not db_path.is_file(): | |
db_path = Path( | |
Prompt.ask( | |
"What's the path to the Encord Active DB containing the tags", | |
default="./encord-active.sqlite", | |
) | |
) | |
if not db_path.is_dir(): | |
typer.Abort("DB path missing") | |
if not ssh_key_path.is_file(): | |
ssh_key_path = Path( | |
Prompt.ask( | |
"What private ssh-key should be used?", | |
default="~/.ssh/id_ed25519", | |
) | |
).expanduser() | |
if not db_path.is_dir(): | |
typer.Abort("SSH key path missing.") | |
project_hash_uuid, tags = get_tags_from_db(db_path) | |
client = EncordUserClient.create_with_ssh_private_key( | |
ssh_key_path.expanduser().read_text() | |
) | |
project = client.get_project(project_hash=str(project_hash_uuid)) | |
use_existing_ontology = typer.confirm("Would you like to use an existing ontology?") | |
if use_existing_ontology: | |
ontology = project.ontology_structure | |
options = [ | |
(a, c) | |
for c in ontology.classifications | |
for a in c.attributes | |
if isinstance(a, RadioAttribute) | |
] | |
_, selected_classification = fuzzy( | |
options, lambda x: x[0].name, "Select existing classification" | |
) | |
tag_clf_obj = ontology.get_child_by_hash( | |
selected_classification.feature_node_hash, type_=Classification | |
) | |
else: | |
tag_clf_obj = add_new_radio_classification(client, project, tags) | |
lr_lookup: dict[UUID, LabelRowV2] = { | |
UUID(lr.label_hash): lr | |
for lr in project.list_label_rows_v2() | |
if lr.label_hash is not None | |
} | |
results: list[tuple[TagTuple, UpdateStatus]] = collect_async( | |
partial(classify_label_row, lr_lookup=lr_lookup, tag_clf_obj=tag_clf_obj), | |
map(lambda t: (t,), tags), | |
desc="Tagging data with tag hashes", | |
) | |
results_json = Path(f"updated_{ts.now().isoformat()}.json") | |
results_json.write_bytes(orjson.dumps(results, default=serialize_json)) | |
stats: list[str] = [] | |
for key, count in Counter(map(lambda res: res[1], results)).items(): | |
cat_stats = f"[green]{key}:[/green] {count:d}\n" | |
cat_stats += f"\t[blue]Description:[/blue] {key.value}" | |
stats.append(cat_stats) | |
stats_str = "\n".join(stats) | |
panel_text = f""" | |
[blue]Update stats:[/blue] | |
{stats_str} | |
Find details here: | |
JSON results stored in [magenta]`{results_json}`[/magenta] | |
""" | |
if log_to_file: | |
panel_text += f"Log stored in [magenta]{log_file_name}[/magenta]" | |
rich.print( | |
Panel( | |
panel_text, | |
title="Classification Stats", | |
) | |
) | |
if __name__ == "__main__": | |
typer.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment