Instantly share code, notes, and snippets.
Created
October 9, 2023 10:05
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save frederik-encord/b63d7b60c8e1c1f73c06514ceec6d758 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
If you have tagged your data (with data tags), this script can help you | |
converting the tags into checklist classifications. | |
Prerequisites: | |
1. You need to be an admin of the Encord Annotate project. | |
2. You need to have `encord-active==0.1.78` installed and working in your shell. | |
Installation: | |
The script uses only dependencies on from `encord-active`. To install | |
`encord-active`, run | |
```shell | |
python -m venv venv | |
source venv/bin/activate | |
python -m pip install encord-active | |
``` | |
Running the script: | |
Easiest way is download this gist and putting it next to your | |
`encord-active.sqlite` file. Next you run | |
```shell | |
python transform_tags_into_checklists.py | |
``` | |
and follow the instructions. | |
Otherwise, this command will help you with running the script: | |
```shell | |
python transform_tags_into_checklists.py --help | |
``` | |
Note that there is a --print-tags option, which will print tags based on the | |
project ontology, which will be used to classify data. | |
""" | |
import sys | |
from collections import Counter | |
from datetime import datetime | |
from enum import Enum | |
from pathlib import Path | |
from typing import Any, Callable, NamedTuple, Optional, TypeVar | |
from uuid import UUID | |
import orjson | |
import rich | |
import typer | |
from encord import EncordUserClient | |
from encord.objects import ( | |
ChecklistAttribute, | |
Classification, | |
FlatOption, | |
OntologyStructure, | |
Option, | |
RadioAttribute, | |
TextAttribute, | |
) | |
from encord.objects.attributes import Attribute | |
from encord.project import LabelRowV2 | |
from encord.project import Project as EncordProject | |
from encord_active.db.models import ( | |
Project, | |
ProjectDataMetadata, | |
ProjectDataUnitMetadata, | |
ProjectTag, | |
ProjectTaggedDataUnit, | |
get_engine, | |
) | |
from InquirerPy import inquirer as inq | |
from InquirerPy.base.control import Choice | |
from loguru import logger | |
from rich.panel import Panel | |
from rich.prompt import Prompt | |
from sqlmodel import Session, select | |
from tqdm import tqdm | |
logger.remove() | |
FORMAT = "<green>{time:MMMM D, YYYY > HH:mm:ss!UTC} | {level} |</green> <blue>{message}</blue> | <green><dim>{extra}</dim></green>" | |
OntologyTags = dict[str, tuple[Classification, Attribute, Option]] | |
T = TypeVar("T") | |
class UpdateStatus(str, Enum): | |
SUCCESS = "Added classification successfully" | |
SUCCESS_CLASSIFICATION_EXISTS = "Correct classification exists already" | |
MISSING_LABEL_ROW = "Couldn't find label row in Encord Project" | |
LABEL_INITIALIZATION = "Couldn't initialize label row - probably due to network errors. Try running script again." | |
SKIPPING_TAGS_NOT_IN_ONTOLOGY = ( | |
"Skipping tags that were not selected for (or in existing) classification" | |
) | |
LABEL_SAVE = "Failed to save label row - probably due to network errors. Try running script again." | |
UNKNOWN_ERROR = "Unknown error - see logs for more details" | |
INVALID_TAG_FORMAT = "The tag format should be `question.answer[.question.answer]`" | |
class TagTuple(NamedTuple): | |
label_hash: UUID | |
du_hash: UUID | |
frame: Optional[int] | |
tag_name: str | |
### Utilities ### | |
def fuzzy_select( | |
options: list[T], | |
transform: Callable[[T], str] = str, | |
description: str = "Select one", | |
) -> T: | |
""" | |
Interactive CLI selection from list. | |
Args: | |
options: what to select from. | |
transform: function to transform every object in the list to a string. | |
description: description to display in the CLI. | |
Returns: the selected object | |
""" | |
choices = list( | |
map(lambda o: Choice(o[0], name=transform(o[1])), enumerate(options)) | |
) | |
res = inq.fuzzy( | |
description, | |
choices=choices, | |
transformer=lambda dbproj: dbproj[0], | |
multiselect=False, | |
vi_mode=True, | |
).execute() | |
if res is None: | |
raise typer.Abort() | |
return options[res] | |
def serialize_json(obj: Any): | |
if isinstance(obj, UUID): | |
return str(obj) | |
if isinstance(obj, TagTuple): | |
return { | |
"label_hash": str(obj.label_hash), | |
"du_hash": str(obj.du_hash), | |
"frame": obj.frame, | |
"tag_name": obj.tag_name, | |
} | |
raise TypeError | |
def initialize_label_rows(project: EncordProject, label_rows: list[LabelRowV2], bs=50): | |
""" | |
A batched way to initialize all label rows. This is much faster than doing | |
it individually on each label row. | |
""" | |
pbar = tqdm(total=len(label_rows), desc="Reading labels from Encord") | |
for batch_start in range(0, len(label_rows), bs): | |
init_bundle = project.create_bundle() | |
batch = label_rows[batch_start : batch_start + bs] | |
for label_row in batch: | |
label_row.initialise_labels() | |
init_bundle.execute() | |
pbar.update(len(batch)) | |
pbar.close() | |
def save_label_rows(project: EncordProject, label_rows: list[LabelRowV2], bs=50): | |
""" | |
A batched way to save all label rows. This is much faster than doing | |
it individually on each label row. | |
""" | |
pbar = tqdm(total=len(label_rows), desc="Sending labels to Encord") | |
for batch_start in range(0, len(label_rows), bs): | |
init_bundle = project.create_bundle() | |
batch = label_rows[batch_start : batch_start + bs] | |
for label_row in batch: | |
label_row.save() | |
init_bundle.execute() | |
pbar.update(len(batch)) | |
pbar.close() | |
def print_stats( | |
tag_status: list[tuple[TagTuple, UpdateStatus]], | |
log_to_file: bool, | |
log_file_name: Optional[str], | |
ts: datetime, | |
): | |
results_json = Path(f"updated_{ts.now().isoformat()}.json") | |
results_json.write_bytes(orjson.dumps(tag_status, default=serialize_json)) | |
stats: list[str] = [] | |
for key, count in Counter(map(lambda res: res[1], tag_status)).items(): | |
cat_stats = f"[green]{key}:[/green] {count:d}\n" | |
cat_stats += f"\t[blue]Description:[/blue] {key.value}" | |
stats.append(cat_stats) | |
stats_str = "\n".join(stats) | |
panel_text = f""" | |
[blue]Update stats:[/blue] | |
{stats_str} | |
Find details here: | |
JSON results stored in [magenta]`{results_json}`[/magenta] | |
""" | |
if log_to_file: | |
panel_text += f"Log stored in [magenta]{log_file_name}[/magenta]" | |
rich.print( | |
Panel( | |
panel_text, | |
title="Classification Stats", | |
) | |
) | |
def get_tag_to_ontology_lookup( | |
ontology: OntologyStructure, | |
) -> OntologyTags: | |
""" | |
Create a dictionary of (tag, (classification, attribute, and option)) entries. | |
""" | |
name_to_opt: OntologyTags = {} | |
queue: list[tuple[str, Attribute, Classification]] = [ | |
(a.title, a, c) for c in ontology.classifications for a in c.attributes | |
] | |
while len(queue): | |
tag_str, item, clf = queue.pop(0) | |
if isinstance(item, TextAttribute): | |
continue | |
elif isinstance(item, RadioAttribute): | |
name_to_opt.update( | |
{f"{tag_str}.{opt.title}": (clf, item, opt) for opt in item.options} | |
) | |
queue += [ | |
(f"{tag_str}.{opt.title}.{attr.title}", attr, clf) | |
for opt in item.options | |
for attr in opt.attributes | |
] | |
elif isinstance(item, ChecklistAttribute): | |
name_to_opt.update( | |
{f"{tag_str}.{opt.title}": (clf, item, opt) for opt in item.options} | |
) | |
return name_to_opt | |
### Active related DB operations ### | |
def select_project_from_db(sess: Session) -> Project: | |
""" | |
Select a project from the database to work on. | |
""" | |
db_projects = sess.exec(select(Project)).all() | |
if not db_projects: | |
raise typer.Abort("Couldn't find any projects in the DB") | |
if len(db_projects) > 1: | |
db_project = fuzzy_select( | |
db_projects, lambda dbp: dbp.project_name, "Select project" | |
) | |
else: | |
db_project = db_projects[0] | |
return db_project | |
def get_tags_from_db(sess: Session, db_project: Project) -> list[TagTuple]: | |
""" | |
Reads all the (data) tags stored in the db. | |
""" | |
tags = list( | |
map( | |
lambda t: TagTuple(*t), | |
sess.exec( | |
select( | |
ProjectDataMetadata.label_hash, | |
ProjectTaggedDataUnit.du_hash, | |
ProjectTaggedDataUnit.frame, | |
ProjectTag.name, | |
).where( | |
ProjectDataMetadata.project_hash == db_project.project_hash, | |
ProjectDataUnitMetadata.project_hash == db_project.project_hash, | |
ProjectTag.project_hash == db_project.project_hash, | |
ProjectTaggedDataUnit.project_hash == db_project.project_hash, | |
ProjectDataMetadata.data_hash == ProjectDataUnitMetadata.data_hash, | |
ProjectTaggedDataUnit.du_hash == ProjectDataUnitMetadata.du_hash, | |
ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash, | |
) | |
).all(), | |
) | |
) | |
logger.debug(f"Found {len(tags)} tags to convert") | |
return tags | |
### Main code | |
def classify_label_row( | |
tag_tuple: TagTuple, | |
lr_lookup: dict[UUID, LabelRowV2], | |
ontology_lookup: OntologyTags, | |
tag_status: list[tuple[TagTuple, UpdateStatus]], | |
) -> Optional[LabelRowV2]: | |
""" | |
Args: | |
tag_tuple: tuple to tag | |
lr_lookup: lookup for encord project label rows | |
ontology_lookup: lookup for encord ontology | |
tag_status: a list that will be mutated to collect update statuses | |
Returns: | |
label row if it needs to be saved. | |
""" | |
with logger.contextualize(t=tag_tuple): | |
try: | |
label_row = lr_lookup.get(tag_tuple.label_hash) | |
if label_row is None: | |
logger.warning(UpdateStatus.MISSING_LABEL_ROW) | |
tag_status.append((tag_tuple, UpdateStatus.MISSING_LABEL_ROW)) | |
return None | |
if not label_row.is_labelling_initialised: | |
try: | |
label_row.initialise_labels() | |
except: | |
logger.warning(UpdateStatus.LABEL_INITIALIZATION) | |
tag_status.append((tag_tuple, UpdateStatus.LABEL_INITIALIZATION)) | |
return None | |
tag = tag_tuple.tag_name | |
tag_info = ontology_lookup.get(tag) | |
if tag_info is None: | |
logger.debug(UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY) | |
tag_status.append( | |
(tag_tuple, UpdateStatus.SKIPPING_TAGS_NOT_IN_ONTOLOGY) | |
) | |
return None | |
clf, _, option = tag_info | |
instances = label_row.get_classification_instances( | |
filter_ontology_classification=clf, filter_frames=tag_tuple.frame | |
) | |
set_for_frame = False | |
if len(instances) == 1: | |
instance = instances[0] | |
answers: list[FlatOption] = instance.get_answer() # type: ignore | |
if next((o for o in answers if o.value == option.value), None): | |
logger.debug(UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS) | |
tag_status.append( | |
(tag_tuple, UpdateStatus.SUCCESS_CLASSIFICATION_EXISTS) | |
) | |
return None | |
elif len(instances) > 1: | |
raise ValueError( | |
"Multiple classification instances for the same frame. This shouldn't happen." | |
) | |
else: | |
instance = instance = clf.create_instance() | |
answers = [] | |
set_for_frame = True | |
instance.set_answer(answers + [option], overwrite=True) | |
if set_for_frame: | |
instance.set_for_frames(0) | |
if not instance.is_assigned_to_label_row(): | |
label_row.add_classification_instance(instance) | |
logger.debug(UpdateStatus.SUCCESS) | |
tag_status.append((tag_tuple, UpdateStatus.SUCCESS)) | |
return label_row | |
except Exception as e: | |
print(e) | |
logger.warning(UpdateStatus.UNKNOWN_ERROR, err=e) | |
tag_status.append((tag_tuple, UpdateStatus.UNKNOWN_ERROR)) | |
return None | |
def transform_tags_into_checklists( | |
tags: list[TagTuple], enc_project: EncordProject, ontology_tags: OntologyTags | |
) -> tuple[list[LabelRowV2], list[tuple[TagTuple, UpdateStatus]]]: | |
""" | |
Transforms tags into checklist classifications in the Encord Annotate project. | |
Args: | |
tags: the list of all the tags in the project. | |
enc_project: the Encord project. | |
ontology_tags: dictionary to look up correct entry in the ontology based | |
on the tags. | |
Returns: | |
The label rows that need to be saved and list of update status for each | |
tag to indicate whether it was stored or not - and why. | |
""" | |
label_hashes = list(set([str(t.label_hash) for t in tags])) | |
label_rows = enc_project.list_label_rows_v2(label_hashes=label_hashes) | |
initialize_label_rows(enc_project, label_rows) | |
lr_lookup: dict[UUID, LabelRowV2] = { | |
UUID(lr.label_hash): lr for lr in label_rows if lr.label_hash is not None | |
} | |
tag_status: list[tuple[TagTuple, UpdateStatus]] = [] | |
labels_to_save = list( | |
set( | |
filter( | |
None, | |
[ | |
classify_label_row( | |
tag, | |
lr_lookup=lr_lookup, | |
ontology_lookup=ontology_tags, | |
tag_status=tag_status, | |
) | |
for tag in tqdm(tags, desc="Anaylysing tags") | |
], | |
) | |
) | |
) | |
return labels_to_save, tag_status | |
def main( | |
db_path: Path = Path("encord-active.sqlite"), | |
ssh_key_path: Path = Path("~/.ssh/encord_frederik"), | |
print_tags: bool = False, | |
verbose: bool = False, | |
very_verbose: bool = False, | |
log_to_file: bool = False, | |
): | |
""" | |
Converts tags from an Encord Active project into checklist classifications | |
on Encord Annotate. The script will take the user through a couple of | |
options: | |
1. A path to an `encord-active.sqlite` database is needed. Provide it by | |
the `db_path` option of answer the associated prompt. | |
2. A path to a private ssh key is needed. Provide it by the `ssh_key_path` | |
option or answer the associated prompt. | |
3. Select which project to convert tags for. Follow the selection prompt. | |
Once the script is done, it will print stats on the interactions with Encord | |
Annotate. | |
All transactions are stored in an `updated_{timestamp}.json` file for easy | |
access afterwards. | |
Args: | |
db_path: Path to the DB containing the project and it's tags. | |
ssh_key_path: Path to the private ssh key associated with Encord. | |
print_tags: if set, no interactions with Annotate will happen, but tags | |
to use in the UI are printed for easy access/copy purposes. | |
verbose: Set this flag if you want the terminal to print errors to the | |
console. | |
very_verbose: Set this flag if you want the terminal to print the status | |
of all tag-to-classification actions. | |
log_to_file: Set this flag to store the tag-to-classification status of | |
every update in a log file in your current working directory. | |
""" | |
ts = datetime.now() | |
if print_tags or verbose or very_verbose: | |
logger.add( | |
sys.stderr, | |
format=FORMAT, | |
level="DEBUG" if very_verbose else "INFO", | |
colorize=True, | |
) | |
log_file_name = None | |
if log_to_file: | |
log_file_name = f"updates_{ts.isoformat()}.log" | |
logger.add(log_file_name, format=FORMAT, level="DEBUG") | |
logger.info("Storing logs at {log_file_name}") | |
if not db_path.is_file(): | |
db_path = Path( | |
Prompt.ask( | |
"What's the path to the Encord Active DB containing the tags", | |
default="./encord-active.sqlite", | |
) | |
) | |
if not db_path.is_dir(): | |
typer.Abort("DB path missing") | |
if not ssh_key_path.is_file(): | |
ssh_key_path = Path( | |
Prompt.ask( | |
"What private ssh-key should be used?", | |
default="~/.ssh/id_ed25519", | |
) | |
).expanduser() | |
if not db_path.is_dir(): | |
typer.Abort("SSH key path missing.") | |
engine = get_engine(db_path, use_alembic=False) | |
with Session(engine) as sess: | |
db_project = select_project_from_db(sess) | |
ontology = OntologyStructure.from_dict(db_project.project_ontology) | |
ontology_tags = get_tag_to_ontology_lookup(ontology) | |
if print_tags: | |
rich.print("[blue]Tag names:[/blue])") | |
for k in ontology_tags.keys(): | |
rich.print(f"`[magenta]{k}[/magenta]`") | |
raise typer.Exit() | |
tags = get_tags_from_db(sess, db_project) | |
client = EncordUserClient.create_with_ssh_private_key( | |
ssh_key_path.expanduser().read_text() | |
) | |
enc_project = client.get_project(project_hash=str(db_project.project_hash)) | |
labels_to_save, tag_status = transform_tags_into_checklists( | |
tags, enc_project, ontology_tags | |
) | |
save_label_rows(enc_project, labels_to_save) | |
print_stats(tag_status, log_to_file, log_file_name, ts) | |
if __name__ == "__main__": | |
typer.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment