Last active
March 22, 2024 13:44
-
-
Save frederik-encord/8fd1cdca6780db7f5a7aaa3568001be8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from concurrent.futures import ThreadPoolExecutor | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
from uuid import UUID | |
import av | |
import requests | |
from encord.constants.enums import DataType | |
from encord.objects.common import Shape | |
from encord.project import LabelRow, LabelRowMetadata | |
from encord.project import Project as EncordProject | |
from encord.user_client import EncordUserClient | |
from PIL import Image | |
from tqdm.auto import tqdm, trange | |
_KEEP_CHARACTERS = {".", "_", " ", "-"} | |
_REPLACE_CHARACTERS = {" ": "_"} | |
def safe_str(unsafe: str) -> str: | |
if not isinstance(unsafe, str): | |
raise ValueError(f"{unsafe} ({type(unsafe)}) not a string") | |
return "".join( | |
_REPLACE_CHARACTERS.get(c, c) | |
for c in unsafe | |
if c.isalnum() or c in _KEEP_CHARACTERS | |
).rstrip() | |
def init_project(ssh_key_path: str | Path, project_hash: UUID) -> EncordProject: | |
ssh_key_path = ( | |
ssh_key_path if isinstance(ssh_key_path, Path) else Path(ssh_key_path) | |
) | |
client = EncordUserClient.create_with_ssh_private_key( | |
ssh_key_path.expanduser().resolve().read_text() | |
) | |
project = client.get_project(project_hash=str(project_hash)) | |
return project | |
def download_image(url: str, destination: Path) -> Path: | |
if destination.is_file(): | |
return destination | |
if not destination.parent.is_dir(): | |
destination.parent.mkdir() | |
destination.write_bytes(requests.get(url).content) | |
return destination | |
def image_path_to_label_path(image_path: Path): | |
parents = image_path.parents | |
images_idx = next(i for i, p in enumerate(parents) if p.name == "images") | |
before = parents[images_idx + 1 :] | |
images = parents[images_idx] | |
label_path = Path("labels") / image_path.relative_to(images) | |
if before: | |
label_path = before[0] / label_path | |
return label_path.with_suffix(".json") | |
def make_project( | |
project: EncordProject, | |
dest: Path, | |
): | |
""" | |
Returns the roboflow dataset yaml file. | |
""" | |
project = project | |
title = safe_str(project.title) | |
dest = dest / title | |
dest.mkdir(exist_ok=True, parents=True) | |
(dest / "ontology.json").write_text( | |
json.dumps(project.ontology_structure.to_dict(), indent=2) | |
) | |
print(f"Storing data at {dest}") | |
image_dir = dest / "images" | |
label_dir = dest / "labels" | |
image_dir.mkdir(exist_ok=True) | |
label_dir.mkdir(exist_ok=True) | |
label_rows = project.list_label_rows(include_uninitialised_labels=True) | |
def write(lr_meta: LabelRowMetadata): | |
lr = (lr_meta.label_hash and project.get_label_row(lr_meta.label_hash)) or None | |
if lr_meta.data_type == DataType.VIDEO: | |
return cache_video(project, image_dir, lr_meta, lr) | |
else: | |
return cache_image(project, image_dir, lr_meta, lr) | |
batch_size = 100 | |
for step in trange(0, len(label_rows), batch_size, desc=f"Downloading Data"): | |
with ThreadPoolExecutor(8) as exe: | |
exe.map(write, label_rows[step : step + batch_size]) | |
def cache_label(labels: dict, img_dest: Path): | |
label_dest = image_path_to_label_path(img_dest) | |
label_dest.parent.mkdir(exist_ok=True) | |
label_dest.write_text(json.dumps(labels)) | |
def cache_video( | |
project, image_root, lr_meta: LabelRowMetadata, lr: LabelRow | None, byte_size=1024 | |
): | |
video, _ = project.get_data(lr_meta.data_hash, get_signed_url=True) | |
if video is None: | |
return | |
data_unit = next(lr["data_units"].values()) if lr else None | |
valid_frames = (data_unit and set(map(int, data_unit["labels"]))) or {} | |
file_name_format = "{data_hash}_{frame:06d}.jpg" | |
ret: list[Path] = [] | |
with NamedTemporaryFile("wb", suffix=".mp4", delete=False) as fp: | |
# Download into temporary file | |
r = requests.get(video["file_link"], stream=True) | |
if r.status_code != 200: | |
raise ConnectionError( | |
f"Something happened, couldn't download file from: {video['file_link']}" | |
) | |
for chunk in tqdm( | |
r.iter_content(chunk_size=byte_size), | |
desc="Downloading video", | |
leave=False, | |
): | |
if chunk: # filter out keep-alive new chunks | |
fp.write(chunk) | |
fp.flush() | |
fp.close() | |
container = av.open(fp.name) | |
data_hash = lr_meta.data_hash | |
for frame_num, frame in enumerate(container.decode(video=0)): | |
if valid_frames and frame_num not in valid_frames: | |
continue | |
image: Image.Image = frame.to_image().convert("RGB") | |
frame_destination = image_root / file_name_format.format( | |
data_hash=data_hash, frame=frame_num | |
) | |
image.save(frame_destination) | |
if data_unit is not None: | |
cache_label(data_unit["labels"][str(frame_num)], frame_destination) | |
ret.append(frame_destination) | |
return ret | |
def cache_image(project, image_root, lr_meta: LabelRowMetadata, lr: LabelRow | None): | |
img, frames = project.get_data(lr_meta.data_hash, get_signed_url=True) | |
if lr_meta.data_type == DataType.IMAGE: | |
if img is None: | |
return | |
data_unit = next(lr["data_units"].values()) if lr is not None else {} | |
img_dest = ( | |
image_root | |
/ f"{lr_meta.data_hash}{Path(lr_meta.data_title).suffix or '.jpg'}" | |
) | |
img_dest = download_image(img["file_link"], img_dest) | |
labels = data_unit.get("labels", {}) | |
if labels: | |
cache_label(labels, img_dest) | |
else: # Image group / sequence | |
if frames is None: | |
return | |
seq_dir = image_root / safe_str(lr_meta.data_title) | |
seq_dir.mkdir(exist_ok=True) | |
data_units = lr["data_units"] if lr is not None else {} | |
for frame in frames: | |
image_hash = frame["image_hash"] | |
title = Path(frame["title"]) | |
img_dest = seq_dir / f"{title.stem}{title.suffix or '.jpg'}" | |
download_image(frame["file_link"], img_dest) | |
labels = data_units.get(image_hash, {}).get("labels", {}) | |
if labels: | |
cache_label(labels, img_dest) | |
# if __name__ == "__main__": | |
# client = EncordUserClient.create_with_ssh_private_key( | |
# Path("~/.ssh/encord_frederik").expanduser().resolve().read_text() | |
# ) | |
# # project = client.get_project("8a47fd61-b98f-4e24-ac61-8250248b9adf") # dental polygons | |
# project = client.get_project( | |
# "98dd1573-54ad-4ea7-9358-37af037e7f0a" | |
# ) # Micro all sorts | |
# path = Path("./datasets") | |
# path.mkdir(exist_ok=True) | |
# make_project( | |
# project, | |
# path, | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment