Created
February 29, 2024 14:09
-
-
Save frederik-encord/e3e469d4062a24589fcab4b816b0d6ec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import random | |
from concurrent.futures import ThreadPoolExecutor | |
from itertools import chain | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
from typing import cast | |
from uuid import UUID | |
import av | |
import requests | |
from encord.constants.enums import DataType | |
from encord.objects.common import Shape | |
from encord.objects.coordinates import BoundingBoxCoordinates, PolygonCoordinates | |
from encord.objects.ontology_labels_impl import LabelRowV2 | |
from encord.objects.ontology_object import Object | |
from encord.objects.ontology_object_instance import ObjectInstance | |
from encord.project import Project as EncordProject | |
from encord.user_client import EncordUserClient | |
from PIL import Image | |
from tqdm.auto import tqdm, trange | |
from yaml import safe_dump | |
_VALID_SHAPES = {Shape.POLYGON, Shape.BOUNDING_BOX} | |
_KEEP_CHARACTERS = {".", "_", " ", "-"} | |
_REPLACE_CHARACTERS = {" ": "_"} | |
def safe_str(unsafe: str) -> str: | |
if not isinstance(unsafe, str): | |
raise ValueError(f"{unsafe} ({type(unsafe)}) not a string") | |
return "".join( | |
_REPLACE_CHARACTERS.get(c, c) | |
for c in unsafe | |
if c.isalnum() or c in _KEEP_CHARACTERS | |
).rstrip() | |
def init_project(ssh_key_path: str | Path, project_hash: UUID) -> EncordProject: | |
ssh_key_path = ( | |
ssh_key_path if isinstance(ssh_key_path, Path) else Path(ssh_key_path) | |
) | |
client = EncordUserClient.create_with_ssh_private_key( | |
ssh_key_path.expanduser().resolve().read_text() | |
) | |
project = client.get_project(project_hash=str(project_hash)) | |
return project | |
def extract_relevant_objects( | |
project: EncordProject, | |
) -> tuple[list[Object], dict[str, int]]: | |
valid_object_types = { | |
o.shape for o in project.ontology_structure.objects | |
}.intersection(_VALID_SHAPES) | |
n_valid = len(valid_object_types) | |
shape: Shape | None = None | |
if n_valid == 0: | |
raise ValueError( | |
"Encord project does not have any valid object types in the ontology. Make sure that there's either polygons or bounding boxes in the dataset" | |
) | |
elif n_valid == 1: | |
shape = list(valid_object_types)[0] | |
else: | |
options = list(valid_object_types) | |
opt_str = ", ".join([f"({o[0]}): {o}" for o in options]) | |
prompt = f"Please select one of: {opt_str}" | |
selection = input(prompt) | |
shape = next(o for o in options if o[0] == selection.lower()) | |
objects = sorted( | |
[o for o in project.ontology_structure.objects if o.shape == shape], | |
key=lambda o: o.uid, | |
) | |
fh_to_idx = {o.feature_node_hash: i for i, o in enumerate(objects)} | |
return objects, fh_to_idx | |
def init_labels( | |
project: EncordProject, | |
batch_size: int = 1000, | |
) -> list[LabelRowV2]: | |
label_rows = project.list_label_rows_v2() | |
for start in range(0, len(label_rows), batch_size): | |
with project.create_bundle() as bundle: | |
for lr in label_rows[start : start + batch_size]: | |
lr.initialise_labels(bundle=bundle) | |
return label_rows | |
def download_image(url: str, destination: Path) -> Path: | |
if destination.is_file(): | |
return destination | |
if not destination.parent.is_dir(): | |
destination.parent.mkdir() | |
destination.write_bytes(requests.get(url).content) | |
return destination | |
def image_path_to_label_path(image_path: Path): | |
parents = image_path.parents | |
images_idx = next(i for i, p in enumerate(parents) if p.name == "images") | |
before = parents[images_idx + 1 :] | |
images = parents[images_idx] | |
label_path = Path("labels") / image_path.relative_to(images) | |
if before: | |
label_path = before[0] / label_path | |
return label_path.with_suffix(".txt") | |
SplitDict = dict[str, float] | |
""" | |
Dict which defines dataset splits. keys will become {key.txt} for file definintions. | |
Values are percentages and should sum to one. | |
""" | |
class ProjectConverter: | |
def __init__(self, project: EncordProject, dest: Path) -> None: | |
self.project = project | |
self.ontology_objects, self.fh_to_cls_idx = extract_relevant_objects(project) | |
self.title = safe_str(project.title) | |
self.dest = dest / self.title | |
self.dest.mkdir(exist_ok=True, parents=True) | |
print(f"Storing data at {self.dest}") | |
self.image_dir = self.dest / "images" | |
self.label_dir = self.dest / "labels" | |
self.image_dir.mkdir(exist_ok=True) | |
self.label_dir.mkdir(exist_ok=True) | |
self.label_rows: list[LabelRowV2] = [] | |
def do_it( | |
self, | |
splits: SplitDict, | |
batch_size: int = 20, | |
seed: int = 42, | |
ignore_empty: bool = True, | |
) -> Path: | |
""" | |
Returns the roboflow dataset yaml file. | |
""" | |
random.seed(seed) | |
self.select_label_rows(ignore_empty) | |
file_defs: dict[str, str] = {} | |
subsets = self.split_label_rows(splits) | |
for split_name, label_rows in subsets.items(): | |
split_file_path = self.dest / f"{split_name}.txt" | |
file_paths = self.write_label_rows( | |
label_rows, split=split_name, batch_size=batch_size | |
) | |
with (split_file_path).open("w") as f: | |
f.writelines(file_paths) | |
file_defs[split_name] = split_file_path.relative_to(self.dest).as_posix() | |
print(split_file_path) | |
return self.write_text_files(file_defs) | |
def split_label_rows(self, splits: SplitDict) -> dict[str, list[LabelRowV2]]: | |
names = list(splits.keys()) | |
values = [splits[n] for n in names] | |
choices = random.choices(names, weights=values, k=len(self.label_rows)) | |
ret: dict[str, list[LabelRowV2]] = {} | |
for choice, lr in zip(choices, self.label_rows): | |
ret.setdefault(choice, []).append(lr) | |
return ret | |
def write_text_files(self, file_defs: dict[str, str]) -> Path: | |
lookup = {o.feature_node_hash: o for o in self.ontology_objects} | |
yaml_dict = { | |
"path": self.dest.expanduser().resolve().as_posix(), | |
"names": {v: lookup[k].name for k, v in self.fh_to_cls_idx.items()}, | |
**file_defs, | |
} | |
dataset_def = Path.cwd() / f"{self.title}.yaml" | |
dataset_def.write_text(safe_dump(yaml_dict)) | |
return dataset_def | |
def select_label_rows(self, ignore_empty: bool): | |
label_rows = init_labels(self.project) | |
for i in range( | |
len(label_rows) - 1, -1, -1 | |
): # Reverse order to be able to remove directly from the list | |
lr = label_rows[i] | |
relevant_instances = list( | |
chain( | |
*[ | |
lr.get_object_instances(filter_ontology_object=o) | |
for o in self.ontology_objects | |
] | |
) | |
) | |
if len(relevant_instances) == 0 and ignore_empty: | |
# Assume that label rows with other labels are not explicitly chosen | |
# as negative examples. Ignore them from the dataset build. | |
label_rows.pop(i) | |
continue | |
self.label_rows = label_rows | |
def write_label_rows( | |
self, label_rows, split: str, batch_size: int = 20 | |
) -> list[str]: | |
def write(lr): | |
if lr.data_type == DataType.VIDEO: | |
return self.cache_video(lr, split=split) | |
else: | |
return self.cache_image(lr, split=split) | |
(self.image_dir / split).mkdir(exist_ok=True) | |
(self.label_dir / split).mkdir(exist_ok=True) | |
paths: list[list[Path]] = [] | |
for step in trange( | |
0, len(label_rows), batch_size, desc=f"Downloading Data [{split}]" | |
): | |
with ThreadPoolExecutor(8) as exe: | |
paths += list(exe.map(write, label_rows[step : step + batch_size])) | |
return [ | |
"./" + r.relative_to(self.dest).as_posix() + os.linesep | |
for path in paths | |
for r in path | |
] | |
def obj_to_txt_row(self, obj: Object, annotation: ObjectInstance.Annotation): | |
cls = self.fh_to_cls_idx[obj.feature_node_hash] | |
if obj.shape == Shape.POLYGON: | |
c = cast(PolygonCoordinates, annotation.coordinates) | |
coords = " ".join(f"{p.x:.4f} {p.y:.4f}" for p in c.values) | |
elif obj.shape == Shape.BOUNDING_BOX: | |
c = cast(BoundingBoxCoordinates, annotation.coordinates) | |
coords = f"{c.top_left_x} {c.top_left_y} {c.width} {c.height}" | |
else: | |
raise ValueError(f"Shape {obj.shape} is not supported") | |
return f"{cls} {coords}{os.linesep}" | |
def cache_label(self, view: LabelRowV2.FrameView, img_dest: Path) -> Path: | |
label_dest = image_path_to_label_path(img_dest) | |
lines: list[str] = [ | |
self.obj_to_txt_row(obj, ins.get_annotation(frame=view.frame)) | |
for obj in self.ontology_objects | |
for ins in view.get_object_instances(filter_ontology_object=obj) | |
] | |
with label_dest.open("w") as f: | |
f.writelines(lines) | |
return label_dest | |
def cache_video(self, lr: LabelRowV2, split: str, byte_size=1024) -> list[Path]: | |
video, _ = self.project.get_data(lr.data_hash, get_signed_url=True) | |
if video is None: | |
return [] | |
valid_frames = { | |
fr.frame for fr in lr.get_frame_views() if len(fr.get_object_instances()) | |
} | |
file_name_format = "{data_hash}_{frame:06d}.jpg" | |
ret: list[Path] = [] | |
with NamedTemporaryFile("wb", suffix=".mp4", delete=False) as fp: | |
# Download into temporary file | |
r = requests.get(video["file_link"], stream=True) | |
if r.status_code != 200: | |
raise ConnectionError( | |
f"Something happened, couldn't download file from: {video['file_link']}" | |
) | |
for chunk in tqdm( | |
r.iter_content(chunk_size=byte_size), | |
desc="Downloading video", | |
leave=False, | |
): | |
if chunk: # filter out keep-alive new chunks | |
fp.write(chunk) | |
fp.flush() | |
fp.close() | |
container = av.open(fp.name) | |
data_hash = lr.data_hash | |
for frame_num, frame in enumerate(container.decode(video=0)): | |
if frame_num not in valid_frames: | |
continue | |
image: Image.Image = frame.to_image().convert("RGB") | |
frame_destination = ( | |
self.image_dir | |
/ split | |
/ file_name_format.format(data_hash=data_hash, frame=frame_num) | |
) | |
image.save(frame_destination) | |
self.cache_label(lr.get_frame_view(frame_num), frame_destination) | |
ret.append(frame_destination) | |
return ret | |
def cache_image(self, lr: LabelRowV2, split: str) -> list[Path]: | |
img, frames = self.project.get_data(lr.data_hash, get_signed_url=True) | |
split_img_dir = self.image_dir / split | |
if lr.data_type == DataType.IMAGE: | |
if img is None: | |
return [] | |
img_dest = ( | |
split_img_dir / f"{lr.data_hash}{Path(lr.data_title).suffix or '.jpg'}" | |
) | |
img_dest = download_image(img["file_link"], img_dest) | |
self.cache_label(lr.get_frame_view(0), img_dest) | |
return [img_dest] | |
else: | |
if frames is None: | |
return [] | |
ret: list[Path] = [] | |
for frame in frames: | |
image_hash = frame["image_hash"] | |
view = lr.get_frame_view(image_hash) | |
relevant_instances = list( | |
chain( | |
*[ | |
view.get_object_instances(filter_ontology_object=o) | |
for o in self.ontology_objects | |
] | |
) | |
) | |
if len(relevant_instances) == 0: | |
continue | |
img_dest = ( | |
split_img_dir | |
/ f"{image_hash}{Path(frame['title']).suffix or '.jpg'}" | |
) | |
download_image(frame["file_link"], img_dest) | |
self.cache_label(view, img_dest) | |
ret.append(img_dest) | |
return ret | |
def set_boxes_from_predictions(self, lr: LabelRowV2, frame: int, pred_file: Path): | |
with pred_file.open("r") as f: | |
for pred_line in f: | |
if not pred_line: | |
continue | |
cls, x, y, w, h, cnf = pred_line.split() | |
cls = int(cls) | |
x = float(x) | |
y = float(y) | |
w = float(w) | |
h = float(h) | |
cnf = float(cnf) | |
instance = self.ontology_objects[cls].create_instance() | |
coords = BoundingBoxCoordinates( | |
top_left_x=x, top_left_y=y, width=w, height=h | |
) | |
instance.set_for_frames(coords, frames=frame, confidence=cnf) | |
lr.add_object_instance(instance) | |
def create_encord_json_predictions( | |
self, pred_dir: Path, out_file: Path | None = None | |
) -> Path: | |
""" | |
Creates a json file with the predictions from the given directory. | |
Note, that this won't work if you don't have the `--save-conf` flag set | |
when you run `python detect.py ...`. | |
Args: | |
pred_dir: The directory with the prediction files. For the YoloV9 examples, this is ./runs/detect/exp*/labels | |
out_file: The json file to store the result in. If not provided, `pred_dir/{project_title}_{exp*}.json` will | |
be used. | |
Returns: the file where the json is stored and the python object that the json file contains. | |
""" | |
data_lookup = {lr.data_hash: lr.label_hash for lr in self.label_rows} | |
for lr in self.label_rows: | |
if not lr.data_type == DataType.IMG_GROUP: | |
continue | |
data_lookup.update( | |
{fv.image_hash: lr.label_hash for fv in lr.get_frame_views()} | |
) | |
label_hashes = [lr.label_hash for lr in self.label_rows if lr.label_hash] | |
fresh_label_rows = self.project.list_label_rows_v2(label_hashes=label_hashes) | |
for start in range(0, len(fresh_label_rows), 1000): | |
with self.project.create_bundle() as bundle: | |
for lr in fresh_label_rows[start : start + 1000]: | |
lr.initialise_labels( | |
bundle=bundle, | |
include_object_feature_hashes=set(), | |
include_classification_feature_hashes=set(), | |
) | |
fresh_lookup = {lr.label_hash: lr for lr in fresh_label_rows} | |
for pred_file in pred_dir.iterdir(): | |
if pred_file.suffix != ".txt": | |
continue | |
data_hash, *rest = pred_file.stem.split("_") | |
label_hash = data_lookup.get(data_hash) | |
if label_hash is None: | |
continue | |
new_lr = fresh_lookup[label_hash] | |
frame_from_file = 0 | |
if len(rest): | |
rest = rest[0] | |
try: | |
frame_from_file = int(rest) # frame number | |
except ValueError: | |
if rest != new_lr.data_hash: # image_hash | |
frame_from_file = data_hash | |
else: | |
raise | |
self.set_boxes_from_predictions( | |
new_lr, new_lr.get_frame_view(frame_from_file).frame, pred_file | |
) | |
if out_file is None: | |
out_file = pred_dir / "encord.json" | |
elif out_file.is_dir(): | |
out_file = out_file / f"{self.title}_{pred_dir.parent.name}.json" | |
out_file.write_text( | |
json.dumps([lr.to_encord_dict() for lr in fresh_label_rows]) | |
) | |
return out_file | |
if __name__ == "__main__": | |
client = EncordUserClient.create_with_ssh_private_key( | |
Path("~/.ssh/encord_frederik").expanduser().resolve().read_text() | |
) | |
# project = client.get_project("8a47fd61-b98f-4e24-ac61-8250248b9adf") # dental polygons | |
project = client.get_project( | |
"d5b2e9a8-0dc6-42c3-af81-46cced183eed" | |
) # Micro all sorts | |
path = Path("./datasets") | |
path.mkdir(exist_ok=True) | |
converter = ProjectConverter( | |
project, | |
path, | |
) | |
converter.do_it(splits={"train": 0.8, "val": 0.1, "test": 0.1}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment