Instantly share code, notes, and snippets.
Created
September 19, 2023 11:55
-
Star
(0)
0
You must be signed in to star a gist -
Fork
(0)
0
You must be signed in to fork a gist
-
Save frederik-encord/bb1ed1c42ac1d612b85bd6414f6ed67f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import re | |
import sys | |
from pathlib import Path | |
from uuid import UUID, uuid4 | |
import typer | |
from encord_active.db.models import ( | |
Project, | |
ProjectDataUnitMetadata, | |
ProjectTag, | |
ProjectTaggedDataUnit, | |
get_engine, | |
) | |
from InquirerPy import inquirer as inq | |
from InquirerPy.base.control import Choice | |
from loguru import logger | |
from sqlmodel import Session, select | |
logger.remove() | |
FORMAT = "<e>{time:MMMM D, YYYY > HH:mm:ss!UTC}</e> | <y>{level}</y> | <g>{message}</g> | <le>{extra}</le>" | |
def main( | |
csv_file: Path, | |
db_file: Path = Path("encord-active.sqlite"), | |
tag_pattern: str = r"\'([\w\d_]+)\'", | |
verbose: bool = False, | |
log_to_file: bool = False, | |
): | |
""" | |
Reads csv file and adds tags to the selected project - unless they exist already. | |
CSV file should have three "comma separated" columns: | |
1. file name | |
2. tags - a string that can be matched with a regex (`tag_pattern`) by `findall`. | |
3. data_hash from Encord | |
Args: | |
db_file: Path to where the `encord-active.sqlite` file is. | |
tag_pattern: By default, finds all strings in single quotes. | |
verbose: If verbose, logs additional information during the execution. | |
log_to_file: Logs can be stored in a file with this flag set. | |
Raises: | |
FileNotFoundError: If the database or csv file was not found. | |
""" | |
level = "DEBUG" if verbose else "INFO" | |
logger.add(sys.stderr, format=FORMAT, level=level) | |
if log_to_file: | |
from datetime import datetime | |
logger.add(f"{datetime.now().isoformat()}.log", format=FORMAT, level="DEBUG") | |
regex = re.compile(tag_pattern) | |
if not db_file.is_file(): | |
raise FileNotFoundError("Couldn't locate db file") | |
engine = get_engine(db_file, use_alembic=False) | |
with Session(engine) as sess: | |
projects = sess.exec(select(Project)).all() | |
choices = [Choice(project, project.project_name) for project in projects] | |
project: Project = inq.fuzzy( | |
"Choose project to import tags to", | |
choices=choices, | |
multiselect=False, | |
vi_mode=True, | |
).execute() | |
with Session(engine) as sess: | |
data_units = { | |
d.du_hash: d | |
for d in sess.exec( | |
select(ProjectDataUnitMetadata).where( | |
ProjectDataUnitMetadata.project_hash == project.project_hash | |
) | |
).all() | |
} | |
data_to_du = {d.data_hash: d.du_hash for d in data_units.values()} | |
existing_tags = { | |
tag.name: tag.tag_hash | |
for tag in sess.exec( | |
select(ProjectTag).where( | |
ProjectTag.project_hash == project.project_hash | |
) | |
).all() | |
} | |
def get_or_add_tag(tag: str) -> UUID: | |
if tag in existing_tags: | |
return existing_tags[tag] | |
tag_hash = uuid4() | |
new_tag = ProjectTag( | |
tag_hash=tag_hash, | |
name=tag, | |
project_hash=project.project_hash, | |
description="", | |
) | |
sess.add(new_tag) | |
existing_tags[tag] = new_tag.tag_hash | |
return tag_hash | |
with csv_file.open(newline="") as csvfile: | |
reader = csv.reader(csvfile, delimiter=",") | |
_ = next(reader) # skip header | |
for _, tags_list_str, du_hash_str in reader: | |
du_hash = UUID(du_hash_str) | |
if du_hash in data_to_du: | |
du_hash = data_to_du[du_hash] | |
with logger.contextualize(du_hash=du_hash): | |
du = data_units.get(du_hash) | |
if not du: | |
logger.warning(f"du_hash not found in the project. Skipping.") | |
continue | |
du_existing_tags = set( | |
sess.exec( | |
select(ProjectTaggedDataUnit.tag_hash).where( | |
ProjectTaggedDataUnit.project_hash | |
== project.project_hash, | |
ProjectTaggedDataUnit.du_hash == du_hash, | |
ProjectTaggedDataUnit.frame == 0, | |
) | |
).all() | |
) | |
logger.debug( | |
f"Data unit has {len(du_existing_tags)} tag(s) already" | |
) | |
for tag_str in regex.findall(tags_list_str): | |
tag_hash = get_or_add_tag(tag_str) | |
if tag_hash in du_existing_tags: | |
logger.info(f"Skipping {tag_str}, tag exists already.") | |
continue | |
logger.debug(f"Adding {tag_str}") | |
sess.add( | |
ProjectTaggedDataUnit( | |
project_hash=project.project_hash, | |
du_hash=du_hash, | |
frame=0, | |
tag_hash=tag_hash, | |
) | |
) | |
sess.commit() | |
if __name__ == "__main__": | |
typer.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment