Skip to content

Instantly share code, notes, and snippets.

@kaczmarj
Created October 22, 2022 18:07
Show Gist options
  • Save kaczmarj/6328b65629c22d382012571fe32bc19a to your computer and use it in GitHub Desktop.
Save kaczmarj/6328b65629c22d382012571fe32bc19a to your computer and use it in GitHub Desktop.
Example torch dataset for loading whole slide image patches on the fly.
"""Example torch Dataset for loading whole slide image patches on the fly."""
from pathlib import Path
import large_image
import pandas as pd
import torch
class PatchDataset(torch.utils.data.Dataset):
"""Dataset to load patches from a whole slide image on the fly.
Parameters
----------
slide : path
Path to whole slide image.
csv_of_coordinates : path
Path to a CSV with the patch coordinates. The CSV must have columns minx, miny,
width, and height. All values should be in base-resolution pixels.
um_px : float
Spacing of the patches (micrometers per pixel).
patch_px : int
Width (and height) of the patches in pixels.
transform : callable
Transform function for the patch.
"""
def __init__(
self,
slide,
csv_of_coordinates,
um_px=0.5, # approximately 20x
patch_px=100,
transform=None,
):
self.slide = Path(slide)
self.csv_of_coordinates = Path(csv_of_coordinates)
self.um_px = um_px
self.patch_px = patch_px
self.transform = transform
self.coords = pd.read_csv(self.csv_of_coordinates)
assert set(self.coords.columns) == {"minx", "miny", "width", "height"}
self.ts = large_image.getTileSource(self.slide)
# DANGER: this is a ~*~hack~*~. large-image automatically caches tiles and I
# cannot find a supported API to disable it. If the cache is not disabled, then
# memory usage increases over time.
self.ts.cache._Cache__maxsize = 0
def __len__(self):
return self.coords.shape[0]
def __getitem__(self, idx):
coord_row = self.coords.iloc[idx, :]
source_region = dict(
left=coord_row["minx"],
top=coord_row["miny"],
width=coord_row["width"],
height=coord_row["height"],
units="base_pixels",
)
target_scale = dict(mm_x=self.um_px / 1000) # convert to mm/px.
tissue_arr, _ = self.ts.getRegionAtAnotherScale(
format=large_image.tilesource.TILE_FORMAT_NUMPY,
sourceRegion=source_region,
targetScale=target_scale,
)
tissue_arr = tissue_arr.copy() # Make sure it's writable.
tissue_arr = tissue_arr[:, :, :3] # remove alpha channel
if (
abs(tissue_arr.shape[0] - self.patch_px) > 1
or abs(tissue_arr.shape[1] - self.patch_px) > 1
):
# If the coordinates are at the edge of the whole slide image, it is
# possible that the patch will not have the expected size.
# We also allow a patch to be within one pixel of the expected size.
raise ValueError(f"expected {self.patch_px} but got {tissue_arr.shape}")
if self.transform is not None:
tissue_arr = self.transform(tissue_arr)
return tissue_arr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment