Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Created October 9, 2023 10:05
Show Gist options
  • Save frederik-encord/b63d7b60c8e1c1f73c06514ceec6d758 to your computer and use it in GitHub Desktop.
Save frederik-encord/b63d7b60c8e1c1f73c06514ceec6d758 to your computer and use it in GitHub Desktop.
"""
If you have tagged your data (with data tags), this script can help you
converting the tags into checklist classifications.
Prerequisites:
1. You need to be an admin of the Encord Annotate project.
2. You need to have `encord-active==0.1.78` installed and working in your shell.
Installation:
The script uses only dependencies on from `encord-active`. To install
`encord-active`, run
```shell
python -m venv venv
source venv/bin/activate
python -m pip install encord-active
```
Running the script:
Easiest way is download this gist and putting it next to your
`encord-active.sqlite` file. Next you run
```shell
python transform_tags_into_checklists.py
```
and follow the instructions.
Otherwise, this command will help you with running the script:
```shell
python transform_tags_into_checklists.py --help
```
Note that there is a --print-tags option, which will print tags based on the
project ontology, which will be used to classify data.
"""
import sys
from collections import Counter
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Callable, NamedTuple, Optional, TypeVar
from uuid import UUID
import orjson
import rich
import typer
from encord import EncordUserClient
from encord.objects import (
ChecklistAttribute,
Classification,
FlatOption,
OntologyStructure,
Option,
RadioAttribute,
TextAttribute,
)
from encord.objects.attributes import Attribute
from encord.project import LabelRowV2
from encord.project import Project as EncordProject
from encord_active.db.models import (
Project,
ProjectDataMetadata,
ProjectDataUnitMetadata,
ProjectTag,
ProjectTaggedDataUnit,
get_engine,
)
from InquirerPy import inquirer as inq
from InquirerPy.base.control import Choice
from loguru import logger
from rich.panel import Panel
from rich.prompt import Prompt
from sqlmodel import Session, select
from tqdm import tqdm
logger.remove()
FORMAT = "<green>{time:MMMM D, YYYY > HH:mm:ss!UTC} | {level} |</green> <blue>{message}</blue> | <green><dim>{extra}</dim></green>"
OntologyTags = dict[str, tuple[Classification, Attribute, Option]]
T = TypeVar("T")
class UpdateStatus(str, Enum):
SUCCESS = "Added classification successfully"
SUCCESS_CLASSIFICATION_EXISTS = "Correct classification exists already"
MISSING_LABEL_ROW = "Couldn't find label row in Encord Project"
LABEL_INITIALIZATION = "Couldn't initialize label row - probably due to network errors. Try running script again."
SKIPPING_TAGS_NOT_IN_ONTOLOGY = (
"Skipping tags that were not selected for (or in existing) classification"
)
LABEL_SAVE = "Failed to save label row - probably due to network errors. Try running script again."
UNKNOWN_ERROR = "Unknown error - see logs for more details"
INVALID_TAG_FORMAT = "The tag format should be `question.answer[.question.answer]`"
class TagTuple(NamedTuple):
label_hash: UUID
du_hash: UUID
frame: Optional[int]
tag_name: str
### Utilities ###
def fuzzy_select(
options: list[T],
transform: Callable[[T], str] = str,
description: str = "Select one",
) -> T:
"""
Interactive CLI selection from list.
Args:
options: what to select from.
transform: function to transform every object in the list to a string.
description: description to display in the CLI.
Returns: the selected object
"""
choices = list(
map(lambda o: Choice(o[0], name=transform(o[1])), enumerate(options))
)
res = inq.fuzzy(
description,
choices=choices,
transformer=lambda dbproj: dbproj[0],
multiselect=False,
vi_mode=True,
).execute()
if res is None:
raise typer.Abort()
return options[res]
def serialize_json(obj: Any):
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, TagTuple):
return {
"label_hash": str(obj.label_hash),
"du_hash": str(obj.du_hash),
"frame": obj.frame,
"tag_name": obj.tag_name,
}
raise TypeError
def initialize_label_rows(project: EncordProject, label_rows: list[LabelRowV2], bs=50):
"""
A batched way to initialize all label rows. This is much faster than doing
it individually on each label row.
"""
pbar = tqdm(total=len(label_rows), desc="Reading labels from Encord")
for batch_start in range(0, len(label_rows), bs):
init_bundle = project.create_bundle()
batch = label_rows[batch_start : batch_start + bs]
for label_row in batch:
label_row.initialise_labels()
init_bundle.execute()
pbar.update(len(batch))
pbar.close()
def save_label_rows(project: EncordProject, label_rows: list[LabelRowV2], bs=50):
"""
A batched way to save all label rows. This is much faster than doing
it individually on each label row.
"""
pbar = tqdm(total=len(label_rows), desc="Sending labels to Encord")
for batch_start in range(0, len(label_rows), bs):
init_bundle = project.create_bundle()
batch = label_rows[batch_start : batch_start + bs]
for label_row in batch:
label_row.save()
init_bundle.execute()
pbar.update(len(batch))
pbar.close()
def print_stats(
tag_status: list[tuple[TagTuple, UpdateStatus]],
log_to_file: bool,
log_file_name: Optional[str],
ts: datetime,
):
results_json = Path(f"updated_{ts.now().isoformat()}.json")
results_json.write_bytes(orjson.dumps(tag_status, default=serialize_json))
stats: list[str] = []
for key, count in Counter(map(lambda res: res[1], tag_status)).items():
cat_stats = f"[green]{key}:[/green] {count:d}\n"
cat_stats += f"\t[blue]Description:[/blue] {key.value}"
stats.append(cat_stats)
stats_str = "\n".join(stats)
panel_text = f"""
[blue]Update stats:[/blue]
{stats_str}
Find details here:
JSON results stored in [magenta]`{results_json}`[/magenta]
"""
if log_to_file:
panel_text += f"Log stored in [magenta]{log_file_name}[/magenta]"
rich.print(
Panel(
panel_text,
title="Classification Stats",
)
)
def get_tag_to_ontology_lookup(
ontology: OntologyStructure,
) -> OntologyTags:
"""
Create a dictionary of (tag, (classification, attribute, and option)) entries.
"""
name_to_opt: OntologyTags = {}
queue: list[tuple[str, Attribute, Classification]] = [
(a.title, a, c) for c in ontology.classifications for a in c.attributes
]
while len(queue):
tag_str, item, clf = queue.pop(0)
if isinstance(item, TextAttribute):
continue
elif isinstance(item, RadioAttribute):
name_to_opt.update(
{f"{tag_str}.{opt.title}": (clf, item, opt) for opt in item.options}
)
queue += [
(f"{tag_str}.{opt.title}.{attr.title}", attr, clf)
for opt in item.options
for attr in opt.attributes
]
elif isinstance(item, ChecklistAttribute):
name_to_opt.update(
{f"{tag_str}.{opt.title}": (clf, item, opt) for opt in item.options}
)
return name_to_opt
### Active related DB operations ###
def select_project_from_db(sess: Session) -> Project:
"""
Select a project from the database to work on.
"""
db_projects = sess.exec(select(Project)).all()
if not db_projects:
raise typer.Abort("Couldn't find any projects in the DB")
if len(db_projects) > 1:
db_project = fuzzy_select(
db_projects, lambda dbp: dbp.project_name, "Select project"
)
else:
db_project = db_projects[0]
return db_project
def get_tags_from_db(sess: Session, db_project: Project) -> list[TagTuple]:
"""
Reads all the (data) tags stored in the db.
"""
tags = list(
map(
lambda t: TagTuple(*t),
sess.exec(
select(
ProjectDataMetadata.label_hash,
ProjectTaggedDataUnit.du_hash,
ProjectTaggedDataUnit.frame,
ProjectTag.name,
).where(
ProjectDataMetadata.project_hash == db_project.project_hash,
ProjectDataUnitMetadata.project_hash == db_project.project_hash,
ProjectTag.project_hash == db_project.project_hash,
ProjectTaggedDataUnit.project_hash == db_project.project_hash,
ProjectDataMetadata.data_hash == ProjectDataUnitMetadata.data_hash,
ProjectTaggedDataUnit.du_hash == ProjectDataUnitMetadata.du_hash,
ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash,
)
).all(),
)
)
logger.debug(f"Found {len(tags)} tags to convert")
return tags
### Main code
def classify_label_row(
tag_tuple: TagTuple,
lr_lookup: dict[UUID, LabelRowV2],
ontology_lookup: OntologyTags,
tag_status: list[tuple[TagTuple, UpdateStatus]],
) -> Optional[LabelRowV2]:
"""
Args:
tag_tuple: tuple to tag
lr_lookup: lookup for encord project label rows
ontology_lookup: lookup for encord ontology
tag_status: a list that will be mutated to collect update statuses
Returns:
label row if it needs to be saved.
"""
with logger.contextualize(t=tag_tuple):
try:
label_row = lr_lookup.get(tag_tuple.label_hash)
if label_row is None:
logger.warning(UpdateStatus.MISSING_LABEL_ROW)
tag_status.append((tag_tuple, UpdateStatus.MISSING_LABEL_ROW))
return None
if not label_row.is_labelling_initialised:
try:
label_row.initialise_labels()
except:
logger.warning(UpdateStatus.LABEL_INITIALIZATION)
tag_status.append((tag_tuple, UpdateStatus.LABEL_INITIALIZATION))
return None
tag = tag_tuple.tag_name
tag_info = ontology_lookup.get(tag)
if tag_info is None:
logger.debug(UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY)
tag_status.append(
(tag_tuple, UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY)
)
return None
clf, _, option = tag_info
instances = label_row.get_classification_instances(
filter_ontology_classification=clf, filter_frames=tag_tuple.frame
)
set_for_frame = False
if len(instances) == 1:
instance = instances[0]
answers: list[FlatOption] = instance.get_answer() # type: ignore
if next((o for o in answers if o.value == option.value), None):
logger.debug(UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS)
tag_status.append(
(tag_tuple, UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS)
)
return None
elif len(instances) > 1:
raise ValueError(
"Multiple classification instances for the same frame. This shouldn't happen."
)
else:
instance = instance = clf.create_instance()
answers = []
set_for_frame = True
instance.set_answer(answers + [option], overwrite=True)
if set_for_frame:
instance.set_for_frames(0)
if not instance.is_assigned_to_label_row():
label_row.add_classification_instance(instance)
logger.debug(UpdateStatus.SUCCESS)
tag_status.append((tag_tuple, UpdateStatus.SUCCESS))
return label_row
except Exception as e:
print(e)
logger.warning(UpdateStatus.UNKNOWN_ERROR, err=e)
tag_status.append((tag_tuple, UpdateStatus.UNKNOWN_ERROR))
return None
def transform_tags_into_checklists(
tags: list[TagTuple], enc_project: EncordProject, ontology_tags: OntologyTags
) -> tuple[list[LabelRowV2], list[tuple[TagTuple, UpdateStatus]]]:
"""
Transforms tags into checklist classifications in the Encord Annotate project.
Args:
tags: the list of all the tags in the project.
enc_project: the Encord project.
ontology_tags: dictionary to look up correct entry in the ontology based
on the tags.
Returns:
The label rows that need to be saved and list of update status for each
tag to indicate whether it was stored or not - and why.
"""
label_hashes = list(set([str(t.label_hash) for t in tags]))
label_rows = enc_project.list_label_rows_v2(label_hashes=label_hashes)
initialize_label_rows(enc_project, label_rows)
lr_lookup: dict[UUID, LabelRowV2] = {
UUID(lr.label_hash): lr for lr in label_rows if lr.label_hash is not None
}
tag_status: list[tuple[TagTuple, UpdateStatus]] = []
labels_to_save = list(
set(
filter(
None,
[
classify_label_row(
tag,
lr_lookup=lr_lookup,
ontology_lookup=ontology_tags,
tag_status=tag_status,
)
for tag in tqdm(tags, desc="Anaylysing tags")
],
)
)
)
return labels_to_save, tag_status
def main(
db_path: Path = Path("encord-active.sqlite"),
ssh_key_path: Path = Path("~/.ssh/encord_frederik"),
print_tags: bool = False,
verbose: bool = False,
very_verbose: bool = False,
log_to_file: bool = False,
):
"""
Converts tags from an Encord Active project into checklist classifications
on Encord Annotate. The script will take the user through a couple of
options:
1. A path to an `encord-active.sqlite` database is needed. Provide it by
the `db_path` option of answer the associated prompt.
2. A path to a private ssh key is needed. Provide it by the `ssh_key_path`
option or answer the associated prompt.
3. Select which project to convert tags for. Follow the selection prompt.
Once the script is done, it will print stats on the interactions with Encord
Annotate.
All transactions are stored in an `updated_{timestamp}.json` file for easy
access afterwards.
Args:
db_path: Path to the DB containing the project and it's tags.
ssh_key_path: Path to the private ssh key associated with Encord.
print_tags: if set, no interactions with Annotate will happen, but tags
to use in the UI are printed for easy access/copy purposes.
verbose: Set this flag if you want the terminal to print errors to the
console.
very_verbose: Set this flag if you want the terminal to print the status
of all tag-to-classification actions.
log_to_file: Set this flag to store the tag-to-classification status of
every update in a log file in your current working directory.
"""
ts = datetime.now()
if print_tags or verbose or very_verbose:
logger.add(
sys.stderr,
format=FORMAT,
level="DEBUG" if very_verbose else "INFO",
colorize=True,
)
log_file_name = None
if log_to_file:
log_file_name = f"updates_{ts.isoformat()}.log"
logger.add(log_file_name, format=FORMAT, level="DEBUG")
logger.info("Storing logs at {log_file_name}")
if not db_path.is_file():
db_path = Path(
Prompt.ask(
"What's the path to the Encord Active DB containing the tags",
default="./encord-active.sqlite",
)
)
if not db_path.is_dir():
typer.Abort("DB path missing")
if not ssh_key_path.is_file():
ssh_key_path = Path(
Prompt.ask(
"What private ssh-key should be used?",
default="~/.ssh/id_ed25519",
)
).expanduser()
if not db_path.is_dir():
typer.Abort("SSH key path missing.")
engine = get_engine(db_path, use_alembic=False)
with Session(engine) as sess:
db_project = select_project_from_db(sess)
ontology = OntologyStructure.from_dict(db_project.project_ontology)
ontology_tags = get_tag_to_ontology_lookup(ontology)
if print_tags:
rich.print("[blue]Tag names:[/blue])")
for k in ontology_tags.keys():
rich.print(f"`[magenta]{k}[/magenta]`")
raise typer.Exit()
tags = get_tags_from_db(sess, db_project)
client = EncordUserClient.create_with_ssh_private_key(
ssh_key_path.expanduser().read_text()
)
enc_project = client.get_project(project_hash=str(db_project.project_hash))
labels_to_save, tag_status = transform_tags_into_checklists(
tags, enc_project, ontology_tags
)
save_label_rows(enc_project, labels_to_save)
print_stats(tag_status, log_to_file, log_file_name, ts)
if __name__ == "__main__":
typer.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment