Skip to content

Instantly share code, notes, and snippets.

@frederik-encord
Last active May 30, 2023 07:52
Show Gist options
  • Save frederik-encord/92b0951d5c596fff5a0a8a1093c11276 to your computer and use it in GitHub Desktop.
Save frederik-encord/92b0951d5c596fff5a0a8a1093c11276 to your computer and use it in GitHub Desktop.
"""
This file demonstrates how to
1. Download a project with DICOM series from Encord Annotate
2. Extract masks based on all the polygons and bounding boxes in the project
The only dependency is `encord-active`.
_Installation:_
```shell
$ python -m venv venv
$ source venv/bin/activate
(venv)$ python -m pip install encord-active=0.1.57
```
The data will be stored at the `out_directory` path and will be organised as
follows:
```
{out_directory}
├── data
│   ├── {label_hash}
│   │   ├── images
│   │   │   ├── {data_hash}_0.dicom
│   │   │   ├── ...
│   │   │   └── {data_hash}_9.dicom
│   │   ├── label_row.json
│   │   └── masks
│   │   ├── {data_hash}_0.png
│   │   ├── ...
│   │   └── {data_hash}_75.png
│   └── mask_lookup.json
├── label_row_meta.json
├── ontology.json
├── prisma.db
└── project_meta.yaml
```
That is, every directory in the `data` directory will contain one DICOM series
where the dicom files are stored in the `images` sub-directory and masks are
stored in the `masks` directory.
The masks will be PNGs with one channel. Each pixel value will correspond to the
class defined in the `data/mask_lookup.json` file. 0 means background.
Note that you can use the `--refresh` command to refresh the project masks with
new labels from the Encord Annotate platform.
Here is an example command:
```shell
(venv)$ python dicom_masks.py "/path/to/project/directory" --project-hash "<project_hash>" --ssh-key-path ~/.ssh/id_ed25519
```
on successive runs of the code, you won't have to provide the project hash and
the ssh key path. They are stored in the `project_meta.yaml` file. So you can
just do
```shell
(venv)$ python dicom_masks.py "/path/to/project/directory"
```
## Remarks:
- Encord Active does not support DICOM files at the time of writing, so we
need to apply a "monkey patch" to make downloading DICOM series possible.
"""
import json
from pathlib import Path
from typing import Optional
import encord_active.lib.project.project as ea_project
import numpy as np
import rich
import typer
from encord import EncordUserClient
from encord.objects.ontology_object import Object
from encord.objects.ontology_structure import OntologyStructure
from encord.orm.label_row import LabelRow
from encord_active.lib.common.utils import download_file, try_execute
from encord_active.lib.db.connection import PrismaConnection
from encord_active.lib.project import Project
from encord_active.lib.project.metadata import update_project_meta
from encord_active.lib.project.project_file_structure import ProjectFileStructure
from PIL import Image, ImageDraw
from rich.panel import Panel
from tqdm import tqdm
### BEGIN PATCH for dicom volumes in Encord Active Project Download
def patch_data_download(download_dicom: bool):
def download_data(
label_row: LabelRow, project_file_structure: ProjectFileStructure
):
if label_row.label_hash is None:
return
lr_structure = project_file_structure.label_row_structure(label_row.label_hash)
lr_structure.images_dir.mkdir(parents=True, exist_ok=True)
data_units = sorted(
label_row.data_units.values(), key=lambda du: int(du["data_sequence"]) # type: ignore
)
for du in data_units:
suffix = f".{du['data_type'].split('/')[1]}"
# Add non-video type of data to the db
with PrismaConnection(project_file_structure) as conn:
with conn.batch_() as batcher:
if "data_links" not in du:
return
for frame, du_slice in tqdm(
du["labels"].items(),
desc=f"Downloading dicom frames for {du['data_title']}",
leave=False,
):
destination = (
lr_structure.images_dir / f"{du['data_hash']}_{frame}"
).with_suffix(suffix)
url = du_slice.get("metadata", {}).get("file_uri")
if not url:
continue
if download_dicom:
try_execute(
download_file,
5,
{"url": url, "destination": destination},
)
batcher.dataunit.upsert(
where={
"data_hash_frame": { # state the values of the compound key
"data_hash": du["data_hash"],
"frame": int(frame),
}
},
data={
"create": {
"data_hash": du["data_hash"],
"data_title": du["data_title"],
"frame": int(frame),
"location": destination.resolve().as_posix(),
"lr_data_hash": label_row.data_hash, # type: ignore
},
"update": {
"data_title": du["data_title"],
"location": destination.resolve().as_posix(),
},
},
)
ea_project.download_data = download_data
### END PATCH
def object_to_polygon(obj: dict, img_h: int, img_w: int) -> Optional[np.ndarray]:
"""
Convert Encord object dictionary to polygon coordinates used to draw geometries
with opencv.
:param obj: the encord object dict
:param w: the image width
:param h: the image height
:return: The polygon coordinates
"""
if obj["shape"] == "polygon":
p = obj.get("polygon", {})
if not p:
return None
polygon = np.array(
[[p[str(i)]["x"] * img_w, p[str(i)]["y"] * img_h] for i in range(len(p))]
)
elif obj["shape"] == "bounding_box":
bbox_dict = obj.get("boundingBox", {})
if not bbox_dict:
return None
b = obj["boundingBox"]
polygon = np.array(
[
[b["x"] * img_w, b["y"] * img_h],
[(b["x"] + b["w"]) * img_w, b["y"] * img_h],
[
(b["x"] + b["w"]) * img_w,
(b["y"] + b["h"]) * img_h,
],
[b["x"] * img_w, (b["y"] + b["h"]) * img_h],
]
)
else:
# A similar thing could be done for, e.g., rotatable bounding boxes here.
return None
polygon = polygon.reshape((-1, 2))
return polygon
def hex_to_rgb(hex_value: str):
value = hex_value.lstrip("#")
lv = len(value)
return list(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))
def get_palette(
ontology: OntologyStructure,
) -> tuple[list[int], dict]:
"""
Constructs a color palette, which will applied to the PNG images.
This is basically a trick to be able to store masks like:
0 0 0 0 0
0 1 1 1 0
0 2 2 2 0
0 0 0 0 0
as png images and have them show in image viewers with proper colors (the
colors from the Encord Annotate editor).
The palette defines that, e.g., pixel value 2 has rgb (100, 25, 255) or whatever.
"""
ontology_lookup: dict[int, Object] = {int(o.uid): o for o in ontology.objects}
if max(ontology_lookup.keys()) > 255:
raise ValueError(
"Because of PNG format, mask can only have objects from 255 distinct classes."
)
palette: list[int] = []
for i in range(256):
if i not in ontology_lookup:
palette += [0, 0, 0]
else:
o = ontology_lookup[i]
next_color = o.color
palette += hex_to_rgb(next_color)
lookup_json = {
0: "background",
**{k: v.to_dict() for k, v in ontology_lookup.items()},
}
return palette, lookup_json
def compute_masks(project: Project):
"""
Iterate over all labels and store masks associated with each image/frame.
Every mask will be a composition of all the objects in the frame.
The `project.ontology.objects[0].uid` will be what determines the pixel
values. That is, mask pixels with value 0 are background, pixel value 1
corresponds to the first ontology object class, etc.
Masks will be stored in
`path/to/out_directory/data/{label_hash}/masks/{data_hash}_{frame}.png`.
Note that with the current implementation, there is no notion of object
instances after exporting. For example, every pixel contained within some
"lung" polygon, will get the long value in the mask independent of which
object it came from.
Args:
project: The project that contains all the (cached) data from the project.
"""
fill_values = {o.feature_node_hash: int(o.uid) for o in project.ontology.objects}
palette, class_lookup = get_palette(project.ontology)
(project.file_structure.data / "mask_lookup.json").write_text(
json.dumps(class_lookup)
)
for label_hash, lr_dict in project.label_rows.items():
label_row_structure = project.file_structure.label_row_structure(label_hash)
(label_row_structure.path / "masks").mkdir(exist_ok=True)
for data_hash, data_unit in lr_dict["data_units"].items():
data_unit = lr_dict["data_units"][data_hash]
width, height = data_unit.get("width"), data_unit.get("height")
for frame, labels in data_unit["labels"].items():
frame = int(frame)
label_objects = labels.get("objects", [])
# Draw mask
mask = Image.new("P", (width, height), 0)
draw = ImageDraw.Draw(mask)
for object_ in label_objects:
polygon = object_to_polygon(object_, img_h=height, img_w=width)
if polygon is None:
continue
pixel_value = fill_values[object_["featureHash"]]
draw.polygon(
list(map(tuple, polygon.tolist())),
outline=pixel_value,
fill=pixel_value,
)
mask.putpalette(palette)
mask.save(
label_row_structure.path / "masks" / f"{data_hash}_{frame}.png"
)
def main(
out_directory: Path = typer.Argument(
help="Where to download all the data to.", default=Path.cwd() / "output"
),
project_hash: str = typer.Option(
help="The project hash of the project you want to download", default=None
),
ssh_key_path: Path = typer.Option(
help="Path to your ssh key associated with Encord",
default=Path.home() / ".ssh" / "id_ed25519",
),
refresh: bool = typer.Option(
help="Whether to refresh labels for an existing project", default=False
),
download_dicom: bool = typer.Option(
help="If true, dicom files will also be downloaded", default=False
),
):
patch_data_download(download_dicom)
client = EncordUserClient.create_with_ssh_private_key(
Path(ssh_key_path).expanduser().read_text(),
)
project = Project(out_directory)
project_meta = project.project_meta
if refresh:
project.refresh().load()
else:
if project_hash is None:
if not project_meta["project_hash"]:
rich.print(
Panel(
"""
The `--project-hash` argument was not defined. Cannot import project without it.
""",
title="⛔ Missing project hash ⛔ ",
style="yellow",
)
)
raise typer.Abort()
else:
project_hash = project_meta["project_hash"]
if ssh_key_path is None or not ssh_key_path.exists():
if not project_meta["ssh_key_path"]:
rich.print(
Panel(
"""
The `--ssh-key-path` argument was not defined or does not exist. Cannot import project without it.
""",
title="⛔ Missing ssh key file ⛔ ",
style="yellow",
)
)
raise typer.Abort()
else:
ssh_key_path = Path(project_meta["ssh_key_path"])
ssh_key_path = ssh_key_path.expanduser().absolute()
update_project_meta(project.file_structure.project_dir, project_meta)
encord_project = client.get_project(project_hash=project_hash)
project = project.from_encord_project(encord_project).load()
project_meta["ssh_key_path"] = ssh_key_path.as_posix()
compute_masks(project)
if __name__ == "__main__":
typer.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment