Skip to content

Instantly share code, notes, and snippets.

@afrendeiro
Last active November 21, 2023 17:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save afrendeiro/54b7e767e45e836227e06c192061507f to your computer and use it in GitHub Desktop.
Save afrendeiro/54b7e767e45e836227e06c192061507f to your computer and use it in GitHub Desktop.
Use torch dataloaders with nuclei coordinates for training.
"""
Use dataloaders with nuclei coordinates for training.
"""
from functools import partial
import requests
import h5py
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from wsi_core import WholeSlideImage
from wsi_core.utils import Path
from wsi_core.utils import collate_features
# set seed to test reproducibility
torch.manual_seed(42)
class ConcatDataset(Dataset):
def __init__(self, *datasets):
self.datasets = datasets
self.n_slides = len(self.datasets)
def __getitem__(self, i: int):
return self.datasets[torch.randint(0, self.n_slides, (1,))][i]
def __len__(self):
return min(len(d) for d in self.datasets)
def download_slide(slide_file: Path, overwrite: bool = False) -> None:
if not slide_file.exists() or overwrite:
url = f"https://brd.nci.nih.gov/brd/imagedownload/{slide_id}"
with open(slide_file, "wb") as handle:
req = requests.get(url)
for block in req.iter_content(1024):
handle.write(block)
def set_nuclei_coordinates(s: WholeSlideImage, tile_width: int = 32) -> None:
slide_id = s.path.stem
# read in nuclei positions
# # here I'm using output from hovernet, you'd need to change this to where you have the stardist files!
# # also, we need to double check the x/y coordinates are in the same order (usual issue of switching dimensions)
nuclei_file = Path(f"{slide_id}_nuclei.csv")
nuclei = pd.read_csv(nuclei_file)[['centroid_x', 'centroid_y']]
# write to h5 converting centroids to tile corners
s.hdf5_file = Path(f"{slide_id}.nuclei.h5")
with h5py.File(s.hdf5_file, "w") as f:
ds = f.create_dataset("coords", data=nuclei.values - (tile_width // 2))
to_add = {
'downsample': np.array([1., 1.]),
'downsampled_level_dim': np.array(s.wsi.dimensions),
'level_dim': np.array(s.wsi.dimensions),
'name': slide_id,
'patch_level': 0,
'patch_size': tile_width,
'save_path': 'data/gtex/svs'
}
for k, v in to_add.items():
ds.attrs[k] = v
# Let's download a couple slides and prepare them (set nuclear coordinates)
slide_ids = ['GTEX-SNMC-0626', 'GTEX-12ZZW-2726']
slides = list()
_coords = list()
for slide_id in slide_ids:
slide_file = Path(f"{slide_id}.svs")
download_slide(slide_file)
s = WholeSlideImage(slide_file)
set_nuclei_coordinates(s)
slides.append(s)
c = pd.DataFrame(s.get_tile_coordinates(), columns=['x', 'y']).assign(slide=slide_id)
_coords.append(c)
# I am going to keep track of the coordinates for each slide to make some checks later
coords = pd.concat(_coords)
coords.groupby('slide').size()
# slide
# GTEX-12ZZW-2726 110836
# GTEX-SNMC-0626 378617
# check no overlap between tiles
assert coords.groupby('slide').apply(lambda x: x.duplicated().sum()).sum() == 0
coords = coords.set_index(['x', 'y']) # just to make it easier to index later
# Now, the only thing we need to do is to create a dataset that concatenates datasets "ConcatDataset":
ds = ConcatDataset(*[s.as_tile_bag() for s in slides])
# Then we create a dataloader that returns the coordinates of tiles across slides:
collate_fn = partial(collate_features, with_coords=True) # just here so we get coordinates now to check, not needed in training
dl = DataLoader(ds, batch_size=64, shuffle=True, collate_fn=collate_fn)
for batch, batch_coords in tqdm(dl):
# check every batch has coordinates from both slides (it does)
n_samples = coords.loc[batch_coords[:, 0], batch_coords[:, 1], :].groupby('slide').size()
assert n_samples.shape[0] == ds.n_slides
# To test the reproducibility of setting a seed, break here (get first batch) and compare the coordinates
# break
# ground_truth = np.load('test_batch_coords.npy')
# assert np.all(batch_coords == ground_truth)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment