Created
February 29, 2024 10:29
-
-
Save frederik-encord/7677dd9326cf827fc650744a667a2c98 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
from concurrent.futures import ThreadPoolExecutor | |
from itertools import chain | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
from typing import cast | |
from uuid import UUID | |
import av | |
import requests | |
from encord.constants.enums import DataType | |
from encord.objects.common import Shape | |
from encord.objects.coordinates import BoundingBoxCoordinates, PolygonCoordinates | |
from encord.objects.ontology_labels_impl import LabelRowV2 | |
from encord.objects.ontology_object import Object | |
from encord.objects.ontology_object_instance import ObjectInstance | |
from encord.project import Project as EncordProject | |
from encord.user_client import EncordUserClient | |
from PIL import Image | |
from tqdm.auto import tqdm, trange | |
from yaml import safe_dump | |
_VALID_SHAPES = {Shape.POLYGON, Shape.BOUNDING_BOX} | |
def init_project(ssh_key_path: str | Path, project_hash: UUID) -> EncordProject: | |
ssh_key_path = ( | |
ssh_key_path if isinstance(ssh_key_path, Path) else Path(ssh_key_path) | |
) | |
client = EncordUserClient.create_with_ssh_private_key( | |
ssh_key_path.expanduser().resolve().read_text() | |
) | |
project = client.get_project(project_hash=str(project_hash)) | |
return project | |
def extract_relevant_objects( | |
project: EncordProject, | |
) -> tuple[list[Object], dict[str, int]]: | |
valid_object_types = { | |
o.shape for o in project.ontology_structure.objects | |
}.intersection(_VALID_SHAPES) | |
n_valid = len(valid_object_types) | |
shape: Shape | None = None | |
if n_valid == 0: | |
raise ValueError( | |
"Encord project does not have any valid object types in the ontology. Make sure that there's either polygons or bounding boxes in the dataset" | |
) | |
elif n_valid == 1: | |
shape = list(valid_object_types)[0] | |
else: | |
options = list(valid_object_types) | |
opt_str = ", ".join([f"({o[0]}): {o}" for o in options]) | |
prompt = f"Please select one of: {opt_str}" | |
selection = input(prompt) | |
shape = next(o for o in options if o[0] == selection.lower()) | |
objects = sorted( | |
[o for o in project.ontology_structure.objects if o.shape == shape], | |
key=lambda o: o.uid, | |
) | |
fh_to_idx = {o.feature_node_hash: i for i, o in enumerate(objects)} | |
return objects, fh_to_idx | |
def init_labels( | |
project: EncordProject, | |
batch_size: int = 1000, | |
) -> list[LabelRowV2]: | |
label_rows = project.list_label_rows_v2() | |
for start in range(0, len(label_rows), batch_size): | |
with project.create_bundle() as bundle: | |
for lr in label_rows[start : start + batch_size]: | |
lr.initialise_labels(bundle=bundle) | |
return label_rows | |
def download_image(url: str, destination: Path) -> Path: | |
if destination.is_file(): | |
return destination | |
if not destination.parent.is_dir(): | |
destination.parent.mkdir() | |
destination.write_bytes(requests.get(url).content) | |
return destination | |
def image_path_to_label_path(image_path: Path): | |
parents = image_path.parents | |
images_idx = next(i for i, p in enumerate(parents) if p.name == "images") | |
before = parents[images_idx + 1 :] | |
images = parents[images_idx] | |
label_path = Path("labels") / image_path.relative_to(images) | |
if before: | |
label_path = before[0] / label_path | |
return label_path.with_suffix(".txt") | |
SplitDict = dict[str, float] | |
""" | |
Dict which defines dataset splits. keys will become {key.txt} for file definintions. | |
Values are percentages and should sum to one. | |
""" | |
class ProjectConverter: | |
def __init__(self, project: EncordProject, dest: Path) -> None: | |
self.project = project | |
self.ontology_objects, self.fh_to_cls_idx = extract_relevant_objects(project) | |
self.title = ( | |
project.title.replace(" ", "_").replace("(", "").replace(")", "").lower() | |
) | |
self.dest = dest / self.title | |
self.dest.mkdir(exist_ok=True, parents=True) | |
print(f"Storing data at {self.dest}") | |
self.image_dir = self.dest / "images" | |
self.label_dir = self.dest / "labels" | |
self.image_dir.mkdir(exist_ok=True) | |
self.label_dir.mkdir(exist_ok=True) | |
self.label_rows: list[LabelRowV2] = [] | |
def do_it(self, splits: SplitDict, batch_size: int = 20, seed: int = 42) -> Path: | |
""" | |
Returns the roboflow dataset yaml file. | |
""" | |
random.seed(seed) | |
self.select_label_rows() | |
file_defs: dict[str, str] = {} | |
subsets = self.split_label_rows(splits) | |
for split_name, label_rows in subsets.items(): | |
split_file_path = self.dest / f"{split_name}.txt" | |
file_paths = self.write_label_rows( | |
label_rows, split=split_name, batch_size=batch_size | |
) | |
with (split_file_path).open("w") as f: | |
f.writelines(file_paths) | |
file_defs[split_name] = split_file_path.relative_to(self.dest).as_posix() | |
print(split_file_path) | |
return self.write_text_files(file_defs) | |
def split_label_rows(self, splits: SplitDict) -> dict[str, list[LabelRowV2]]: | |
names = list(splits.keys()) | |
values = [splits[n] for n in names] | |
choices = random.choices(names, weights=values, k=len(self.label_rows)) | |
ret: dict[str, list[LabelRowV2]] = {} | |
for choice, lr in zip(choices, self.label_rows): | |
ret.setdefault(choice, []).append(lr) | |
return ret | |
def write_text_files(self, file_defs: dict[str, str]) -> Path: | |
lookup = {o.feature_node_hash: o for o in self.ontology_objects} | |
yaml_dict = { | |
"path": self.dest.expanduser().resolve().as_posix(), | |
"names": {v: lookup[k].name for k, v in self.fh_to_cls_idx.items()}, | |
**file_defs, | |
} | |
dataset_def = Path.cwd() / f"{self.title}.yaml" | |
dataset_def.write_text(safe_dump(yaml_dict)) | |
return dataset_def | |
def select_label_rows(self): | |
label_rows = init_labels(self.project) | |
for i in range( | |
len(label_rows) - 1, -1, -1 | |
): # Reverse order to be able to remove directly from the list | |
lr = label_rows[i] | |
relevant_instances = list( | |
chain( | |
*[ | |
lr.get_object_instances(filter_ontology_object=o) | |
for o in self.ontology_objects | |
] | |
) | |
) | |
if len(relevant_instances) == 0 and len(lr.get_object_instances()) != 0: | |
# Assume that label rows with other labels are not explicitly chosen | |
# as negative examples. Ignore them from the dataset build. | |
label_rows.pop(i) | |
continue | |
self.label_rows = label_rows | |
def write_label_rows( | |
self, label_rows, split: str, batch_size: int = 20 | |
) -> list[str]: | |
def write(lr): | |
if lr.data_type == DataType.VIDEO: | |
return self.cache_video(lr, split=split) | |
else: | |
return self.cache_image(lr, split=split) | |
(self.image_dir / split).mkdir(exist_ok=True) | |
(self.label_dir / split).mkdir(exist_ok=True) | |
paths: list[list[Path]] = [] | |
for step in trange( | |
0, len(label_rows), batch_size, desc=f"Downloading Data [{split}]" | |
): | |
with ThreadPoolExecutor(8) as exe: | |
paths += list(exe.map(write, label_rows[step : step + batch_size])) | |
return [ | |
"./" + r.relative_to(self.dest).as_posix() + os.linesep | |
for path in paths | |
for r in path | |
] | |
def obj_to_txt_row(self, obj: Object, annotation: ObjectInstance.Annotation): | |
cls = self.fh_to_cls_idx[obj.feature_node_hash] | |
if obj.shape == Shape.POLYGON: | |
c = cast(PolygonCoordinates, annotation.coordinates) | |
coords = " ".join(f"{p.x:.4f} {p.y:.4f}" for p in c.values) | |
elif obj.shape == Shape.BOUNDING_BOX: | |
c = cast(BoundingBoxCoordinates, annotation.coordinates) | |
coords = f"{c.top_left_x} {c.top_left_y} {c.width} {c.height}" | |
else: | |
raise ValueError(f"Shape {obj.shape} is not supported") | |
return f"{cls} {coords}{os.linesep}" | |
def cache_label(self, view: LabelRowV2.FrameView, img_dest: Path) -> Path: | |
label_dest = image_path_to_label_path(img_dest) | |
lines: list[str] = [ | |
self.obj_to_txt_row(obj, ins.get_annotation(frame=view.frame)) | |
for obj in self.ontology_objects | |
for ins in view.get_object_instances(filter_ontology_object=obj) | |
] | |
with label_dest.open("w") as f: | |
f.writelines(lines) | |
return label_dest | |
def cache_video(self, lr: LabelRowV2, split: str, byte_size=1024) -> list[Path]: | |
video, _ = self.project.get_data(lr.data_hash, get_signed_url=True) | |
if video is None: | |
return [] | |
valid_frames = { | |
fr.frame for fr in lr.get_frame_views() if len(fr.get_object_instances()) | |
} | |
file_name_format = "{data_hash}_{frame:06d}.jpg" | |
ret: list[Path] = [] | |
with NamedTemporaryFile("wb", suffix=".mp4", delete=False) as fp: | |
# Download into temporary file | |
r = requests.get(video["file_link"], stream=True) | |
if r.status_code != 200: | |
raise ConnectionError( | |
f"Something happened, couldn't download file from: {video['file_link']}" | |
) | |
for chunk in tqdm( | |
r.iter_content(chunk_size=byte_size), | |
desc="Downloading video", | |
leave=False, | |
): | |
if chunk: # filter out keep-alive new chunks | |
fp.write(chunk) | |
fp.flush() | |
fp.close() | |
container = av.open(fp.name) | |
data_hash = lr.data_hash | |
for frame_num, frame in enumerate(container.decode(video=0)): | |
if frame_num not in valid_frames: | |
continue | |
image: Image.Image = frame.to_image().convert("RGB") | |
frame_destination = ( | |
self.image_dir | |
/ split | |
/ file_name_format.format(data_hash=data_hash, frame=frame_num) | |
) | |
image.save(frame_destination) | |
self.cache_label(lr.get_frame_view(frame_num), frame_destination) | |
ret.append(frame_destination) | |
return ret | |
def cache_image(self, lr: LabelRowV2, split: str) -> list[Path]: | |
img, frames = self.project.get_data(lr.data_hash, get_signed_url=True) | |
split_img_dir = self.image_dir / split | |
if lr.data_type == DataType.IMAGE: | |
if img is None: | |
return [] | |
img_dest = ( | |
split_img_dir / f"{lr.data_hash}{Path(lr.data_title).suffix or '.jpg'}" | |
) | |
img_dest = download_image(img["file_link"], img_dest) | |
self.cache_label(lr.get_frame_view(0), img_dest) | |
return [img_dest] | |
else: | |
if frames is None: | |
return [] | |
ret: list[Path] = [] | |
for frame in frames: | |
image_hash = frame["image_hash"] | |
view = lr.get_frame_view(image_hash) | |
relevant_instances = list( | |
chain( | |
*[ | |
view.get_object_instances(filter_ontology_object=o) | |
for o in self.ontology_objects | |
] | |
) | |
) | |
if len(relevant_instances) == 0: | |
continue | |
img_dest = ( | |
split_img_dir | |
/ f"{image_hash}{Path(frame['title']).suffix or '.jpg'}" | |
) | |
download_image(frame["file_link"], img_dest) | |
self.cache_label(view, img_dest) | |
ret.append(img_dest) | |
return ret | |
if __name__ == "__main__": | |
client = EncordUserClient.create_with_ssh_private_key( | |
Path("~/.ssh/encord_frederik").expanduser().resolve().read_text() | |
) | |
# project = client.get_project("8a47fd61-b98f-4e24-ac61-8250248b9adf") # dental polygons | |
project = client.get_project( | |
"d5b2e9a8-0dc6-42c3-af81-46cced183eed" | |
) # Micro all sorts | |
path = Path("./datasets") | |
path.mkdir(exist_ok=True) | |
converter = ProjectConverter( | |
project, | |
path, | |
) | |
converter.do_it(splits={"train": 0.8, "val": 0.1, "test": 0.1}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment