Created
May 25, 2023 10:45
-
-
Save frederik-encord/8623c5a819414ec03f47c07192cf1524 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script will add tags in Encord Active based on nested radio attributes. | |
If an object has no nested attributes, it will be tagged with "no classification". | |
To run the script, make sure to have your encord active environment sourced and | |
then do: | |
```shell | |
(ea-venv)$ python tag_with_radio_classifications.py "/path/to/your/encord-active/project" | |
``` | |
Note that if you at some point run `encord-active refresh` on the project, you | |
can run this script again to update the tags. However, existing tags (except from | |
the "no classification" tag will not be removed with new classifications coming | |
in. | |
Requirements: | |
`encord-active` | |
To install requirements in a fresh environment: | |
```shell | |
python -m venv venv | |
source venv/bin/activate | |
python -m pip install encord-active | |
``` | |
""" | |
import json | |
from pathlib import Path | |
import typer | |
from encord.objects.ontology_labels_impl import RadioAttribute | |
from encord.objects.ontology_structure import OntologyStructure | |
from encord_active.lib.common.iterator import DatasetIterator | |
from encord_active.lib.db.connection import DBConnection | |
from encord_active.lib.db.merged_metrics import MergedMetrics | |
from encord_active.lib.db.tags import Tag, TagScope | |
from encord_active.lib.project.project_file_structure import ProjectFileStructure | |
def tag_data(pfs: ProjectFileStructure): | |
with DBConnection(pfs) as conn: | |
mm = MergedMetrics(conn).all() | |
ontology = OntologyStructure.from_dict(json.loads(pfs.ontology.read_text())) | |
radio_answer_hashes: set[str] = set() | |
for obj in ontology.objects: | |
if not obj.attributes: | |
continue | |
for attr in obj.attributes: | |
if not isinstance(attr, RadioAttribute): | |
continue | |
radio_answer_hashes.add(attr.feature_node_hash) | |
no_label_tag = Tag("no classification", TagScope.LABEL) | |
iterator = DatasetIterator(pfs.project_dir) | |
for du, _ in iterator.iterate(desc="Tagging Objects with radio classifications"): | |
base_key = f"{iterator.label_hash}_{iterator.du_hash}_{iterator.frame:05d}" | |
object_answers = iterator.label_rows[iterator.label_hash]["object_answers"] | |
for obj in du.get("labels", {}).get("objects", []): | |
object_hash = obj["objectHash"] | |
key = f"{base_key}_{object_hash}" | |
if key not in mm.index: | |
continue | |
tags = mm.loc[key].tags | |
new_tags: list[Tag] = [] | |
for answer in object_answers[object_hash].get("classifications", [{}]): | |
if answer["featureHash"] not in radio_answer_hashes: | |
continue | |
name = answer["name"] | |
value = answer["answers"][0]["name"] | |
new_tags.append(Tag(f"{name}: {value}", TagScope.LABEL)) | |
if not new_tags: | |
new_tags = [no_label_tag] | |
elif no_label_tag in tags: | |
tags.pop(tags.index(no_label_tag)) | |
for tag in new_tags: | |
if tag in tags: | |
continue | |
tags.append(tag) | |
with DBConnection(pfs) as conn: | |
MergedMetrics(conn).update_tags(key, tags) | |
def main( | |
out_directory: Path = typer.Argument( | |
help="The project director to update with the tags", | |
default=Path.cwd(), | |
exists=True, | |
), | |
): | |
pfs = ProjectFileStructure(out_directory) | |
tag_data(pfs) | |
if __name__ == "__main__": | |
typer.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment