Created
October 22, 2022 18:07
-
-
Save kaczmarj/6328b65629c22d382012571fe32bc19a to your computer and use it in GitHub Desktop.
Example torch dataset for loading whole slide image patches on the fly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Example torch Dataset for loading whole slide image patches on the fly.""" | |
from pathlib import Path | |
import large_image | |
import pandas as pd | |
import torch | |
class PatchDataset(torch.utils.data.Dataset): | |
"""Dataset to load patches from a whole slide image on the fly. | |
Parameters | |
---------- | |
slide : path | |
Path to whole slide image. | |
csv_of_coordinates : path | |
Path to a CSV with the patch coordinates. The CSV must have columns minx, miny, | |
width, and height. All values should be in base-resolution pixels. | |
um_px : float | |
Spacing of the patches (micrometers per pixel). | |
patch_px : int | |
Width (and height) of the patches in pixels. | |
transform : callable | |
Transform function for the patch. | |
""" | |
def __init__( | |
self, | |
slide, | |
csv_of_coordinates, | |
um_px=0.5, # approximately 20x | |
patch_px=100, | |
transform=None, | |
): | |
self.slide = Path(slide) | |
self.csv_of_coordinates = Path(csv_of_coordinates) | |
self.um_px = um_px | |
self.patch_px = patch_px | |
self.transform = transform | |
self.coords = pd.read_csv(self.csv_of_coordinates) | |
assert set(self.coords.columns) == {"minx", "miny", "width", "height"} | |
self.ts = large_image.getTileSource(self.slide) | |
# DANGER: this is a ~*~hack~*~. large-image automatically caches tiles and I | |
# cannot find a supported API to disable it. If the cache is not disabled, then | |
# memory usage increases over time. | |
self.ts.cache._Cache__maxsize = 0 | |
def __len__(self): | |
return self.coords.shape[0] | |
def __getitem__(self, idx): | |
coord_row = self.coords.iloc[idx, :] | |
source_region = dict( | |
left=coord_row["minx"], | |
top=coord_row["miny"], | |
width=coord_row["width"], | |
height=coord_row["height"], | |
units="base_pixels", | |
) | |
target_scale = dict(mm_x=self.um_px / 1000) # convert to mm/px. | |
tissue_arr, _ = self.ts.getRegionAtAnotherScale( | |
format=large_image.tilesource.TILE_FORMAT_NUMPY, | |
sourceRegion=source_region, | |
targetScale=target_scale, | |
) | |
tissue_arr = tissue_arr.copy() # Make sure it's writable. | |
tissue_arr = tissue_arr[:, :, :3] # remove alpha channel | |
if ( | |
abs(tissue_arr.shape[0] - self.patch_px) > 1 | |
or abs(tissue_arr.shape[1] - self.patch_px) > 1 | |
): | |
# If the coordinates are at the edge of the whole slide image, it is | |
# possible that the patch will not have the expected size. | |
# We also allow a patch to be within one pixel of the expected size. | |
raise ValueError(f"expected {self.patch_px} but got {tissue_arr.shape}") | |
if self.transform is not None: | |
tissue_arr = self.transform(tissue_arr) | |
return tissue_arr |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment