Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Created February 29, 2024 10:29
Show Gist options
  • Save frederik-encord/7677dd9326cf827fc650744a667a2c98 to your computer and use it in GitHub Desktop.
Save frederik-encord/7677dd9326cf827fc650744a667a2c98 to your computer and use it in GitHub Desktop.
import os
import random
from concurrent.futures import ThreadPoolExecutor
from itertools import chain
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import cast
from uuid import UUID
import av
import requests
from encord.constants.enums import DataType
from encord.objects.common import Shape
from encord.objects.coordinates import BoundingBoxCoordinates, PolygonCoordinates
from encord.objects.ontology_labels_impl import LabelRowV2
from encord.objects.ontology_object import Object
from encord.objects.ontology_object_instance import ObjectInstance
from encord.project import Project as EncordProject
from encord.user_client import EncordUserClient
from PIL import Image
from tqdm.auto import tqdm, trange
from yaml import safe_dump
_VALID_SHAPES = {Shape.POLYGON, Shape.BOUNDING_BOX}
def init_project(ssh_key_path: str | Path, project_hash: UUID) -> EncordProject:
ssh_key_path = (
ssh_key_path if isinstance(ssh_key_path, Path) else Path(ssh_key_path)
)
client = EncordUserClient.create_with_ssh_private_key(
ssh_key_path.expanduser().resolve().read_text()
)
project = client.get_project(project_hash=str(project_hash))
return project
def extract_relevant_objects(
project: EncordProject,
) -> tuple[list[Object], dict[str, int]]:
valid_object_types = {
o.shape for o in project.ontology_structure.objects
}.intersection(_VALID_SHAPES)
n_valid = len(valid_object_types)
shape: Shape | None = None
if n_valid == 0:
raise ValueError(
"Encord project does not have any valid object types in the ontology. Make sure that there's either polygons or bounding boxes in the dataset"
)
elif n_valid == 1:
shape = list(valid_object_types)[0]
else:
options = list(valid_object_types)
opt_str = ", ".join([f"({o[0]}): {o}" for o in options])
prompt = f"Please select one of: {opt_str}"
selection = input(prompt)
shape = next(o for o in options if o[0] == selection.lower())
objects = sorted(
[o for o in project.ontology_structure.objects if o.shape == shape],
key=lambda o: o.uid,
)
fh_to_idx = {o.feature_node_hash: i for i, o in enumerate(objects)}
return objects, fh_to_idx
def init_labels(
project: EncordProject,
batch_size: int = 1000,
) -> list[LabelRowV2]:
label_rows = project.list_label_rows_v2()
for start in range(0, len(label_rows), batch_size):
with project.create_bundle() as bundle:
for lr in label_rows[start : start + batch_size]:
lr.initialise_labels(bundle=bundle)
return label_rows
def download_image(url: str, destination: Path) -> Path:
if destination.is_file():
return destination
if not destination.parent.is_dir():
destination.parent.mkdir()
destination.write_bytes(requests.get(url).content)
return destination
def image_path_to_label_path(image_path: Path):
parents = image_path.parents
images_idx = next(i for i, p in enumerate(parents) if p.name == "images")
before = parents[images_idx + 1 :]
images = parents[images_idx]
label_path = Path("labels") / image_path.relative_to(images)
if before:
label_path = before[0] / label_path
return label_path.with_suffix(".txt")
SplitDict = dict[str, float]
"""
Dict which defines dataset splits. keys will become {key.txt} for file definintions.
Values are percentages and should sum to one.
"""
class ProjectConverter:
def __init__(self, project: EncordProject, dest: Path) -> None:
self.project = project
self.ontology_objects, self.fh_to_cls_idx = extract_relevant_objects(project)
self.title = (
project.title.replace(" ", "_").replace("(", "").replace(")", "").lower()
)
self.dest = dest / self.title
self.dest.mkdir(exist_ok=True, parents=True)
print(f"Storing data at {self.dest}")
self.image_dir = self.dest / "images"
self.label_dir = self.dest / "labels"
self.image_dir.mkdir(exist_ok=True)
self.label_dir.mkdir(exist_ok=True)
self.label_rows: list[LabelRowV2] = []
def do_it(self, splits: SplitDict, batch_size: int = 20, seed: int = 42) -> Path:
"""
Returns the roboflow dataset yaml file.
"""
random.seed(seed)
self.select_label_rows()
file_defs: dict[str, str] = {}
subsets = self.split_label_rows(splits)
for split_name, label_rows in subsets.items():
split_file_path = self.dest / f"{split_name}.txt"
file_paths = self.write_label_rows(
label_rows, split=split_name, batch_size=batch_size
)
with (split_file_path).open("w") as f:
f.writelines(file_paths)
file_defs[split_name] = split_file_path.relative_to(self.dest).as_posix()
print(split_file_path)
return self.write_text_files(file_defs)
def split_label_rows(self, splits: SplitDict) -> dict[str, list[LabelRowV2]]:
names = list(splits.keys())
values = [splits[n] for n in names]
choices = random.choices(names, weights=values, k=len(self.label_rows))
ret: dict[str, list[LabelRowV2]] = {}
for choice, lr in zip(choices, self.label_rows):
ret.setdefault(choice, []).append(lr)
return ret
def write_text_files(self, file_defs: dict[str, str]) -> Path:
lookup = {o.feature_node_hash: o for o in self.ontology_objects}
yaml_dict = {
"path": self.dest.expanduser().resolve().as_posix(),
"names": {v: lookup[k].name for k, v in self.fh_to_cls_idx.items()},
**file_defs,
}
dataset_def = Path.cwd() / f"{self.title}.yaml"
dataset_def.write_text(safe_dump(yaml_dict))
return dataset_def
def select_label_rows(self):
label_rows = init_labels(self.project)
for i in range(
len(label_rows) - 1, -1, -1
): # Reverse order to be able to remove directly from the list
lr = label_rows[i]
relevant_instances = list(
chain(
*[
lr.get_object_instances(filter_ontology_object=o)
for o in self.ontology_objects
]
)
)
if len(relevant_instances) == 0 and len(lr.get_object_instances()) != 0:
# Assume that label rows with other labels are not explicitly chosen
# as negative examples. Ignore them from the dataset build.
label_rows.pop(i)
continue
self.label_rows = label_rows
def write_label_rows(
self, label_rows, split: str, batch_size: int = 20
) -> list[str]:
def write(lr):
if lr.data_type == DataType.VIDEO:
return self.cache_video(lr, split=split)
else:
return self.cache_image(lr, split=split)
(self.image_dir / split).mkdir(exist_ok=True)
(self.label_dir / split).mkdir(exist_ok=True)
paths: list[list[Path]] = []
for step in trange(
0, len(label_rows), batch_size, desc=f"Downloading Data [{split}]"
):
with ThreadPoolExecutor(8) as exe:
paths += list(exe.map(write, label_rows[step : step + batch_size]))
return [
"./" + r.relative_to(self.dest).as_posix() + os.linesep
for path in paths
for r in path
]
def obj_to_txt_row(self, obj: Object, annotation: ObjectInstance.Annotation):
cls = self.fh_to_cls_idx[obj.feature_node_hash]
if obj.shape == Shape.POLYGON:
c = cast(PolygonCoordinates, annotation.coordinates)
coords = " ".join(f"{p.x:.4f} {p.y:.4f}" for p in c.values)
elif obj.shape == Shape.BOUNDING_BOX:
c = cast(BoundingBoxCoordinates, annotation.coordinates)
coords = f"{c.top_left_x} {c.top_left_y} {c.width} {c.height}"
else:
raise ValueError(f"Shape {obj.shape} is not supported")
return f"{cls} {coords}{os.linesep}"
def cache_label(self, view: LabelRowV2.FrameView, img_dest: Path) -> Path:
label_dest = image_path_to_label_path(img_dest)
lines: list[str] = [
self.obj_to_txt_row(obj, ins.get_annotation(frame=view.frame))
for obj in self.ontology_objects
for ins in view.get_object_instances(filter_ontology_object=obj)
]
with label_dest.open("w") as f:
f.writelines(lines)
return label_dest
def cache_video(self, lr: LabelRowV2, split: str, byte_size=1024) -> list[Path]:
video, _ = self.project.get_data(lr.data_hash, get_signed_url=True)
if video is None:
return []
valid_frames = {
fr.frame for fr in lr.get_frame_views() if len(fr.get_object_instances())
}
file_name_format = "{data_hash}_{frame:06d}.jpg"
ret: list[Path] = []
with NamedTemporaryFile("wb", suffix=".mp4", delete=False) as fp:
# Download into temporary file
r = requests.get(video["file_link"], stream=True)
if r.status_code != 200:
raise ConnectionError(
f"Something happened, couldn't download file from: {video['file_link']}"
)
for chunk in tqdm(
r.iter_content(chunk_size=byte_size),
desc="Downloading video",
leave=False,
):
if chunk: # filter out keep-alive new chunks
fp.write(chunk)
fp.flush()
fp.close()
container = av.open(fp.name)
data_hash = lr.data_hash
for frame_num, frame in enumerate(container.decode(video=0)):
if frame_num not in valid_frames:
continue
image: Image.Image = frame.to_image().convert("RGB")
frame_destination = (
self.image_dir
/ split
/ file_name_format.format(data_hash=data_hash, frame=frame_num)
)
image.save(frame_destination)
self.cache_label(lr.get_frame_view(frame_num), frame_destination)
ret.append(frame_destination)
return ret
def cache_image(self, lr: LabelRowV2, split: str) -> list[Path]:
img, frames = self.project.get_data(lr.data_hash, get_signed_url=True)
split_img_dir = self.image_dir / split
if lr.data_type == DataType.IMAGE:
if img is None:
return []
img_dest = (
split_img_dir / f"{lr.data_hash}{Path(lr.data_title).suffix or '.jpg'}"
)
img_dest = download_image(img["file_link"], img_dest)
self.cache_label(lr.get_frame_view(0), img_dest)
return [img_dest]
else:
if frames is None:
return []
ret: list[Path] = []
for frame in frames:
image_hash = frame["image_hash"]
view = lr.get_frame_view(image_hash)
relevant_instances = list(
chain(
*[
view.get_object_instances(filter_ontology_object=o)
for o in self.ontology_objects
]
)
)
if len(relevant_instances) == 0:
continue
img_dest = (
split_img_dir
/ f"{image_hash}{Path(frame['title']).suffix or '.jpg'}"
)
download_image(frame["file_link"], img_dest)
self.cache_label(view, img_dest)
ret.append(img_dest)
return ret
if __name__ == "__main__":
client = EncordUserClient.create_with_ssh_private_key(
Path("~/.ssh/encord_frederik").expanduser().resolve().read_text()
)
# project = client.get_project("8a47fd61-b98f-4e24-ac61-8250248b9adf") # dental polygons
project = client.get_project(
"d5b2e9a8-0dc6-42c3-af81-46cced183eed"
) # Micro all sorts
path = Path("./datasets")
path.mkdir(exist_ok=True)
converter = ProjectConverter(
project,
path,
)
converter.do_it(splits={"train": 0.8, "val": 0.1, "test": 0.1})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment