Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save frederik-encord/8623c5a819414ec03f47c07192cf1524 to your computer and use it in GitHub Desktop.
Save frederik-encord/8623c5a819414ec03f47c07192cf1524 to your computer and use it in GitHub Desktop.
"""
This script will add tags in Encord Active based on nested radio attributes.
If an object has no nested attributes, it will be tagged with "no classification".
To run the script, make sure to have your encord active environment sourced and
then do:
```shell
(ea-venv)$ python tag_with_radio_classifications.py "/path/to/your/encord-active/project"
```
Note that if you at some point run `encord-active refresh` on the project, you
can run this script again to update the tags. However, existing tags (except from
the "no classification" tag will not be removed with new classifications coming
in.
Requirements:
`encord-active`
To install requirements in a fresh environment:
```shell
python -m venv venv
source venv/bin/activate
python -m pip install encord-active
```
"""
import json
from pathlib import Path
import typer
from encord.objects.ontology_labels_impl import RadioAttribute
from encord.objects.ontology_structure import OntologyStructure
from encord_active.lib.common.iterator import DatasetIterator
from encord_active.lib.db.connection import DBConnection
from encord_active.lib.db.merged_metrics import MergedMetrics
from encord_active.lib.db.tags import Tag, TagScope
from encord_active.lib.project.project_file_structure import ProjectFileStructure
def tag_data(pfs: ProjectFileStructure):
with DBConnection(pfs) as conn:
mm = MergedMetrics(conn).all()
ontology = OntologyStructure.from_dict(json.loads(pfs.ontology.read_text()))
radio_answer_hashes: set[str] = set()
for obj in ontology.objects:
if not obj.attributes:
continue
for attr in obj.attributes:
if not isinstance(attr, RadioAttribute):
continue
radio_answer_hashes.add(attr.feature_node_hash)
no_label_tag = Tag("no classification", TagScope.LABEL)
iterator = DatasetIterator(pfs.project_dir)
for du, _ in iterator.iterate(desc="Tagging Objects with radio classifications"):
base_key = f"{iterator.label_hash}_{iterator.du_hash}_{iterator.frame:05d}"
object_answers = iterator.label_rows[iterator.label_hash]["object_answers"]
for obj in du.get("labels", {}).get("objects", []):
object_hash = obj["objectHash"]
key = f"{base_key}_{object_hash}"
if key not in mm.index:
continue
tags = mm.loc[key].tags
new_tags: list[Tag] = []
for answer in object_answers[object_hash].get("classifications", [{}]):
if answer["featureHash"] not in radio_answer_hashes:
continue
name = answer["name"]
value = answer["answers"][0]["name"]
new_tags.append(Tag(f"{name}: {value}", TagScope.LABEL))
if not new_tags:
new_tags = [no_label_tag]
elif no_label_tag in tags:
tags.pop(tags.index(no_label_tag))
for tag in new_tags:
if tag in tags:
continue
tags.append(tag)
with DBConnection(pfs) as conn:
MergedMetrics(conn).update_tags(key, tags)
def main(
out_directory: Path = typer.Argument(
help="The project director to update with the tags",
default=Path.cwd(),
exists=True,
),
):
pfs = ProjectFileStructure(out_directory)
tag_data(pfs)
if __name__ == "__main__":
typer.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment