Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Last active March 22, 2024 13:44
Show Gist options
  • Save frederik-encord/8fd1cdca6780db7f5a7aaa3568001be8 to your computer and use it in GitHub Desktop.
Save frederik-encord/8fd1cdca6780db7f5a7aaa3568001be8 to your computer and use it in GitHub Desktop.
import json
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from tempfile import NamedTemporaryFile
from uuid import UUID
import av
import requests
from encord.constants.enums import DataType
from encord.objects.common import Shape
from encord.project import LabelRow, LabelRowMetadata
from encord.project import Project as EncordProject
from encord.user_client import EncordUserClient
from PIL import Image
from tqdm.auto import tqdm, trange
_KEEP_CHARACTERS = {".", "_", " ", "-"}
_REPLACE_CHARACTERS = {" ": "_"}
def safe_str(unsafe: str) -> str:
if not isinstance(unsafe, str):
raise ValueError(f"{unsafe} ({type(unsafe)}) not a string")
return "".join(
_REPLACE_CHARACTERS.get(c, c)
for c in unsafe
if c.isalnum() or c in _KEEP_CHARACTERS
).rstrip()
def init_project(ssh_key_path: str | Path, project_hash: UUID) -> EncordProject:
ssh_key_path = (
ssh_key_path if isinstance(ssh_key_path, Path) else Path(ssh_key_path)
)
client = EncordUserClient.create_with_ssh_private_key(
ssh_key_path.expanduser().resolve().read_text()
)
project = client.get_project(project_hash=str(project_hash))
return project
def download_image(url: str, destination: Path) -> Path:
if destination.is_file():
return destination
if not destination.parent.is_dir():
destination.parent.mkdir()
destination.write_bytes(requests.get(url).content)
return destination
def image_path_to_label_path(image_path: Path):
parents = image_path.parents
images_idx = next(i for i, p in enumerate(parents) if p.name == "images")
before = parents[images_idx + 1 :]
images = parents[images_idx]
label_path = Path("labels") / image_path.relative_to(images)
if before:
label_path = before[0] / label_path
return label_path.with_suffix(".json")
def make_project(
project: EncordProject,
dest: Path,
):
"""
Returns the roboflow dataset yaml file.
"""
project = project
title = safe_str(project.title)
dest = dest / title
dest.mkdir(exist_ok=True, parents=True)
(dest / "ontology.json").write_text(
json.dumps(project.ontology_structure.to_dict(), indent=2)
)
print(f"Storing data at {dest}")
image_dir = dest / "images"
label_dir = dest / "labels"
image_dir.mkdir(exist_ok=True)
label_dir.mkdir(exist_ok=True)
label_rows = project.list_label_rows(include_uninitialised_labels=True)
def write(lr_meta: LabelRowMetadata):
lr = (lr_meta.label_hash and project.get_label_row(lr_meta.label_hash)) or None
if lr_meta.data_type == DataType.VIDEO:
return cache_video(project, image_dir, lr_meta, lr)
else:
return cache_image(project, image_dir, lr_meta, lr)
batch_size = 100
for step in trange(0, len(label_rows), batch_size, desc=f"Downloading Data"):
with ThreadPoolExecutor(8) as exe:
exe.map(write, label_rows[step : step + batch_size])
def cache_label(labels: dict, img_dest: Path):
label_dest = image_path_to_label_path(img_dest)
label_dest.parent.mkdir(exist_ok=True)
label_dest.write_text(json.dumps(labels))
def cache_video(
project, image_root, lr_meta: LabelRowMetadata, lr: LabelRow | None, byte_size=1024
):
video, _ = project.get_data(lr_meta.data_hash, get_signed_url=True)
if video is None:
return
data_unit = next(lr["data_units"].values()) if lr else None
valid_frames = (data_unit and set(map(int, data_unit["labels"]))) or {}
file_name_format = "{data_hash}_{frame:06d}.jpg"
ret: list[Path] = []
with NamedTemporaryFile("wb", suffix=".mp4", delete=False) as fp:
# Download into temporary file
r = requests.get(video["file_link"], stream=True)
if r.status_code != 200:
raise ConnectionError(
f"Something happened, couldn't download file from: {video['file_link']}"
)
for chunk in tqdm(
r.iter_content(chunk_size=byte_size),
desc="Downloading video",
leave=False,
):
if chunk: # filter out keep-alive new chunks
fp.write(chunk)
fp.flush()
fp.close()
container = av.open(fp.name)
data_hash = lr_meta.data_hash
for frame_num, frame in enumerate(container.decode(video=0)):
if valid_frames and frame_num not in valid_frames:
continue
image: Image.Image = frame.to_image().convert("RGB")
frame_destination = image_root / file_name_format.format(
data_hash=data_hash, frame=frame_num
)
image.save(frame_destination)
if data_unit is not None:
cache_label(data_unit["labels"][str(frame_num)], frame_destination)
ret.append(frame_destination)
return ret
def cache_image(project, image_root, lr_meta: LabelRowMetadata, lr: LabelRow | None):
img, frames = project.get_data(lr_meta.data_hash, get_signed_url=True)
if lr_meta.data_type == DataType.IMAGE:
if img is None:
return
data_unit = next(lr["data_units"].values()) if lr is not None else {}
img_dest = (
image_root
/ f"{lr_meta.data_hash}{Path(lr_meta.data_title).suffix or '.jpg'}"
)
img_dest = download_image(img["file_link"], img_dest)
labels = data_unit.get("labels", {})
if labels:
cache_label(labels, img_dest)
else: # Image group / sequence
if frames is None:
return
seq_dir = image_root / safe_str(lr_meta.data_title)
seq_dir.mkdir(exist_ok=True)
data_units = lr["data_units"] if lr is not None else {}
for frame in frames:
image_hash = frame["image_hash"]
title = Path(frame["title"])
img_dest = seq_dir / f"{title.stem}{title.suffix or '.jpg'}"
download_image(frame["file_link"], img_dest)
labels = data_units.get(image_hash, {}).get("labels", {})
if labels:
cache_label(labels, img_dest)
# if __name__ == "__main__":
# client = EncordUserClient.create_with_ssh_private_key(
# Path("~/.ssh/encord_frederik").expanduser().resolve().read_text()
# )
# # project = client.get_project("8a47fd61-b98f-4e24-ac61-8250248b9adf") # dental polygons
# project = client.get_project(
# "98dd1573-54ad-4ea7-9358-37af037e7f0a"
# ) # Micro all sorts
# path = Path("./datasets")
# path.mkdir(exist_ok=True)
# make_project(
# project,
# path,
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment