Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Created March 14, 2024 16:19
Show Gist options
  • Save frederik-encord/95d57f4fa2fec8fd347e5d57cca99079 to your computer and use it in GitHub Desktop.
Save frederik-encord/95d57f4fa2fec8fd347e5d57cca99079 to your computer and use it in GitHub Desktop.
import json
import os
import random
from concurrent.futures import ThreadPoolExecutor
from itertools import chain
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import cast
from uuid import UUID
import av
import numpy as np
import requests
from encord.constants.enums import DataType
from encord.objects.common import Shape
from encord.objects.coordinates import (
BoundingBoxCoordinates,
PointCoordinate,
PolygonCoordinates,
)
from encord.objects.ontology_labels_impl import LabelRowV2
from encord.objects.ontology_object import Object
from encord.objects.ontology_object_instance import ObjectInstance
from encord.project import Project as EncordProject
from encord.user_client import EncordUserClient
from PIL import Image
from tqdm.auto import tqdm, trange
from yaml import safe_dump
_VALID_SHAPES = {Shape.POLYGON, Shape.BOUNDING_BOX}
_KEEP_CHARACTERS = {".", "_", " ", "-"}
_REPLACE_CHARACTERS = {" ": "_"}
def safe_str(unsafe: str) -> str:
if not isinstance(unsafe, str):
raise ValueError(f"{unsafe} ({type(unsafe)}) not a string")
return "".join(
_REPLACE_CHARACTERS.get(c, c)
for c in unsafe
if c.isalnum() or c in _KEEP_CHARACTERS
).rstrip()
def init_project(ssh_key_path: str | Path, project_hash: UUID) -> EncordProject:
ssh_key_path = (
ssh_key_path if isinstance(ssh_key_path, Path) else Path(ssh_key_path)
)
client = EncordUserClient.create_with_ssh_private_key(
ssh_key_path.expanduser().resolve().read_text()
)
project = client.get_project(project_hash=str(project_hash))
return project
def extract_relevant_objects(
project: EncordProject,
) -> tuple[list[Object], dict[str, int]]:
valid_object_types = {
o.shape for o in project.ontology_structure.objects
}.intersection(_VALID_SHAPES)
n_valid = len(valid_object_types)
shape: Shape | None = None
if n_valid == 0:
raise ValueError(
"Encord project does not have any valid object types in the ontology. Make sure that there's either polygons or bounding boxes in the dataset"
)
elif n_valid == 1:
shape = list(valid_object_types)[0]
else:
options = list(valid_object_types)
opt_str = ", ".join([f"({o[0]}): {o}" for o in options])
prompt = f"Please select one of: {opt_str}"
selection = input(prompt)
shape = next(o for o in options if o[0] == selection.lower())
objects = sorted(
[o for o in project.ontology_structure.objects if o.shape == shape],
key=lambda o: o.uid,
)
fh_to_idx = {o.feature_node_hash: i for i, o in enumerate(objects)}
return objects, fh_to_idx
def init_labels(
project: EncordProject,
batch_size: int = 1000,
) -> list[LabelRowV2]:
label_rows = project.list_label_rows_v2()
for start in range(0, len(label_rows), batch_size):
with project.create_bundle() as bundle:
for lr in label_rows[start : start + batch_size]:
lr.initialise_labels(bundle=bundle)
return label_rows
def download_image(url: str, destination: Path) -> Path:
if destination.is_file():
return destination
if not destination.parent.is_dir():
destination.parent.mkdir()
destination.write_bytes(requests.get(url).content)
return destination
def image_path_to_label_path(image_path: Path):
parents = image_path.parents
images_idx = next(i for i, p in enumerate(parents) if p.name == "images")
before = parents[images_idx + 1 :]
images = parents[images_idx]
label_path = Path("labels") / image_path.relative_to(images)
if before:
label_path = before[0] / label_path
return label_path.with_suffix(".txt")
SplitDict = dict[str, float]
"""
Dict which defines dataset splits. keys will become {key.txt} for file definintions.
Values are percentages and should sum to one.
"""
class ProjectConverter:
def __init__(self, project: EncordProject, dest: Path) -> None:
self.project = project
self.ontology_objects, self.fh_to_cls_idx = extract_relevant_objects(project)
self.title = safe_str(project.title)
self.dest = dest / self.title
self.dest.mkdir(exist_ok=True, parents=True)
print(f"Storing data at {self.dest}")
self.image_dir = self.dest / "images"
self.label_dir = self.dest / "labels"
self.image_dir.mkdir(exist_ok=True)
self.label_dir.mkdir(exist_ok=True)
self.label_rows: list[LabelRowV2] = []
def do_it(
self,
splits: SplitDict,
batch_size: int = 20,
seed: int = 42,
ignore_empty: bool = True,
) -> Path:
"""
Returns the roboflow dataset yaml file.
"""
random.seed(seed)
self.select_label_rows(ignore_empty)
file_defs: dict[str, str] = {}
subsets = self.split_label_rows(splits)
for split_name, label_rows in subsets.items():
split_file_path = self.dest / f"{split_name}.txt"
file_paths = self.write_label_rows(
label_rows, split=split_name, batch_size=batch_size
)
with (split_file_path).open("w") as f:
f.writelines(file_paths)
file_defs[split_name] = split_file_path.relative_to(self.dest).as_posix()
print(split_file_path)
return self.write_text_files(file_defs)
def split_label_rows(self, splits: SplitDict) -> dict[str, list[LabelRowV2]]:
names = list(splits.keys())
values = [splits[n] for n in names]
choices = random.choices(names, weights=values, k=len(self.label_rows))
ret: dict[str, list[LabelRowV2]] = {}
for choice, lr in zip(choices, self.label_rows):
ret.setdefault(choice, []).append(lr)
return ret
def write_text_files(self, file_defs: dict[str, str]) -> Path:
lookup = {o.feature_node_hash: o for o in self.ontology_objects}
yaml_dict = {
"path": self.dest.expanduser().resolve().as_posix(),
"names": {v: lookup[k].name for k, v in self.fh_to_cls_idx.items()},
**file_defs,
}
dataset_def = Path.cwd() / f"{self.title}.yaml"
dataset_def.write_text(safe_dump(yaml_dict))
return dataset_def
def select_label_rows(self, ignore_empty: bool):
label_rows = init_labels(self.project)
for i in range(
len(label_rows) - 1, -1, -1
): # Reverse order to be able to remove directly from the list
lr = label_rows[i]
relevant_instances = list(
chain(
*[
lr.get_object_instances(filter_ontology_object=o)
for o in self.ontology_objects
]
)
)
if len(relevant_instances) == 0 and ignore_empty:
# Assume that label rows with other labels are not explicitly chosen
# as negative examples. Ignore them from the dataset build.
label_rows.pop(i)
continue
self.label_rows = label_rows
def write_label_rows(
self, label_rows, split: str, batch_size: int = 20
) -> list[str]:
def write(lr):
if lr.data_type == DataType.VIDEO:
return self.cache_video(lr, split=split)
else:
return self.cache_image(lr, split=split)
(self.image_dir / split).mkdir(exist_ok=True)
(self.label_dir / split).mkdir(exist_ok=True)
paths: list[list[Path]] = []
for step in trange(
0, len(label_rows), batch_size, desc=f"Downloading Data [{split}]"
):
with ThreadPoolExecutor(8) as exe:
paths += list(exe.map(write, label_rows[step : step + batch_size]))
return [
"./" + r.relative_to(self.dest).as_posix() + os.linesep
for path in paths
for r in path
]
def obj_to_txt_row(self, obj: Object, annotation: ObjectInstance.Annotation):
cls = self.fh_to_cls_idx[obj.feature_node_hash]
if obj.shape == Shape.POLYGON:
c = cast(PolygonCoordinates, annotation.coordinates)
coords = " ".join(f"{p.x:.4f} {p.y:.4f}" for p in c.values)
elif obj.shape == Shape.BOUNDING_BOX:
c = cast(BoundingBoxCoordinates, annotation.coordinates)
coords = f"{c.top_left_x} {c.top_left_y} {c.width} {c.height}"
else:
raise ValueError(f"Shape {obj.shape} is not supported")
return f"{cls} {coords}{os.linesep}"
def cache_label(self, view: LabelRowV2.FrameView, img_dest: Path) -> Path:
label_dest = image_path_to_label_path(img_dest)
lines: list[str] = [
self.obj_to_txt_row(obj, ins.get_annotation(frame=view.frame))
for obj in self.ontology_objects
for ins in view.get_object_instances(filter_ontology_object=obj)
]
with label_dest.open("w") as f:
f.writelines(lines)
return label_dest
def cache_video(self, lr: LabelRowV2, split: str, byte_size=1024) -> list[Path]:
video, _ = self.project.get_data(lr.data_hash, get_signed_url=True)
if video is None:
return []
valid_frames = {
fr.frame for fr in lr.get_frame_views() if len(fr.get_object_instances())
}
file_name_format = "{data_hash}_{frame:06d}.jpg"
ret: list[Path] = []
with NamedTemporaryFile("wb", suffix=".mp4", delete=False) as fp:
# Download into temporary file
r = requests.get(video["file_link"], stream=True)
if r.status_code != 200:
raise ConnectionError(
f"Something happened, couldn't download file from: {video['file_link']}"
)
for chunk in tqdm(
r.iter_content(chunk_size=byte_size),
desc="Downloading video",
leave=False,
):
if chunk: # filter out keep-alive new chunks
fp.write(chunk)
fp.flush()
fp.close()
container = av.open(fp.name)
data_hash = lr.data_hash
for frame_num, frame in enumerate(container.decode(video=0)):
if frame_num not in valid_frames:
continue
image: Image.Image = frame.to_image().convert("RGB")
frame_destination = (
self.image_dir
/ split
/ file_name_format.format(data_hash=data_hash, frame=frame_num)
)
image.save(frame_destination)
self.cache_label(lr.get_frame_view(frame_num), frame_destination)
ret.append(frame_destination)
return ret
def cache_image(self, lr: LabelRowV2, split: str) -> list[Path]:
img, frames = self.project.get_data(lr.data_hash, get_signed_url=True)
split_img_dir = self.image_dir / split
if lr.data_type == DataType.IMAGE:
if img is None:
return []
img_dest = (
split_img_dir / f"{lr.data_hash}{Path(lr.data_title).suffix or '.jpg'}"
)
img_dest = download_image(img["file_link"], img_dest)
self.cache_label(lr.get_frame_view(0), img_dest)
return [img_dest]
else:
if frames is None:
return []
ret: list[Path] = []
for frame in frames:
image_hash = frame["image_hash"]
view = lr.get_frame_view(image_hash)
relevant_instances = list(
chain(
*[
view.get_object_instances(filter_ontology_object=o)
for o in self.ontology_objects
]
)
)
if len(relevant_instances) == 0:
continue
img_dest = (
split_img_dir
/ f"{image_hash}{Path(frame['title']).suffix or '.jpg'}"
)
download_image(frame["file_link"], img_dest)
self.cache_label(view, img_dest)
ret.append(img_dest)
return ret
def set_boxes_from_predictions(self, lr: LabelRowV2, frame: int, pred_file: Path):
with pred_file.open("r") as f:
for pred_line in f:
if not pred_line:
continue
cls, *points, cnf = pred_line.split()
cls = int(cls)
cnf = float(cnf)
instance = self.ontology_objects[cls].create_instance()
if len(points) == 4:
x, y, w, h = points
x = float(x)
y = float(y)
w = float(w)
h = float(h)
coords = BoundingBoxCoordinates(
top_left_x=x, top_left_y=y, width=w, height=h
)
else:
np_points = np.array(
list(map(float, points)), dtype=np.float32
).reshape(-1, 2)
coords = PolygonCoordinates(
[PointCoordinate(x=p[0], y=p[1]) for p in np_points]
)
instance.set_for_frames(coords, frames=frame, confidence=cnf)
lr.add_object_instance(instance)
def create_encord_json_predictions(
self, pred_dir: Path, out_file: Path | None = None
) -> Path:
"""
Creates a json file with the predictions from the given directory.
Note, that this won't work if you don't have the `--save-conf` flag set
when you run `python detect.py ...`.
Args:
pred_dir: The directory with the prediction files. For the YoloV9 examples, this is ./runs/detect/exp*/labels
out_file: The json file to store the result in. If not provided, `pred_dir/{project_title}_{exp*}.json` will
be used.
Returns: the file where the json is stored and the python object that the json file contains.
"""
data_lookup = {lr.data_hash: lr.label_hash for lr in self.label_rows}
for lr in self.label_rows:
if not lr.data_type == DataType.IMG_GROUP:
continue
data_lookup.update(
{fv.image_hash: lr.label_hash for fv in lr.get_frame_views()}
)
label_hashes = [lr.label_hash for lr in self.label_rows if lr.label_hash]
fresh_label_rows = self.project.list_label_rows_v2(label_hashes=label_hashes)
for start in range(0, len(fresh_label_rows), 1000):
with self.project.create_bundle() as bundle:
for lr in fresh_label_rows[start : start + 1000]:
lr.initialise_labels(
bundle=bundle,
include_object_feature_hashes=set(),
include_classification_feature_hashes=set(),
)
fresh_lookup = {lr.label_hash: lr for lr in fresh_label_rows}
for pred_file in pred_dir.iterdir():
if pred_file.suffix != ".txt":
continue
data_hash, *rest = pred_file.stem.split("_")
label_hash = data_lookup.get(data_hash)
if label_hash is None:
continue
new_lr = fresh_lookup[label_hash]
frame_from_file = 0
if len(rest):
rest = rest[0]
try:
frame_from_file = int(rest) # frame number
except ValueError:
if rest != new_lr.data_hash: # image_hash
frame_from_file = data_hash
else:
raise
self.set_boxes_from_predictions(
new_lr, new_lr.get_frame_view(frame_from_file).frame, pred_file
)
if out_file is None:
out_file = pred_dir / "encord.json"
elif out_file.is_dir():
out_file = out_file / f"{self.title}_{pred_dir.parent.name}.json"
out_file.write_text(
json.dumps([lr.to_encord_dict() for lr in fresh_label_rows])
)
return out_file
if __name__ == "__main__":
client = EncordUserClient.create_with_ssh_private_key(
Path("~/.ssh/encord_frederik").expanduser().resolve().read_text()
)
# project = client.get_project("8a47fd61-b98f-4e24-ac61-8250248b9adf") # dental polygons
project = client.get_project(
"d5b2e9a8-0dc6-42c3-af81-46cced183eed"
) # Micro all sorts
path = Path("./datasets")
path.mkdir(exist_ok=True)
converter = ProjectConverter(
project,
path,
)
converter.do_it(splits={"train": 0.8, "val": 0.1, "test": 0.1})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment