Skip to content

Instantly share code, notes, and snippets.

@TianyiFranklinWang
Last active February 7, 2023 04:50
Show Gist options
  • Save TianyiFranklinWang/2c4cc0103edb30c15790bb462351e3c4 to your computer and use it in GitHub Desktop.
Save TianyiFranklinWang/2c4cc0103edb30c15790bb462351e3c4 to your computer and use it in GitHub Desktop.
Hierarchically generate and store patches from WSIs.
"""This code is for hierarchically patching large whole slide images.
Copyright (c) 2022 Tianyi Wang & Mengkang Lu
All rights reserved.
Released under MIT License.
"""
import gc
import json
import os
import time
import typing
import cv2
import h5py
import numpy as np
import skimage
import tifffile as tiff
import torch
from models.resnet_custom import resnet50_baseline
class Config:
"""Parameters used for patching.
Attributes:
input_dir (str): Input directory.
output_dir (str): Output directory.
largest_scale (int): Largest scale level in wsi.
pyramid_level (int): Level of scales in wsi pyramid.
relevant_scale (int): Relevant scale factor between neighbourhoods in wsi pyramid.
patch_size (int): Size of patches.
pad_value_wsi (int): Border pad value for wsi.
pad_value_thresh (int): Border pad value for threshed wsi.
thresh_method (str): Threshold method either 'otsu' or 'adaptive'.
area_threshold (int): Area_threshold argument for remove_small_holes function.
min_size (int): Min_size argument for remove_small_objects function.
connectivity (int): Connectivity argument for remove_small_holes and remove_small_objects functions.
device_type (str): Define inference on cpu or gpu.
save (bool): Whether to save algorithm results.
"""
# input_dir = "../../TCGA_LUNG/LUAD"
# output_dir = "../../TCGA_LUNG_processed/LUAD"
input_dir = "../../TCGA_LUNG/LUSC"
output_dir = "../../TCGA_LUNG_processed/LUSC"
largest_scale = 20
pyramid_level = 4
relevant_scale = 2
patch_size = 256
pad_value_wsi = 255
pad_value_thresh = 0
thresh_method = 'otsu' # or 'adaptive'
area_threshold = 16384
min_size = 16384
connectivity = 8
no_thresh_images = ["TCGA-05-4415-01Z-00-DX1.55E0C429-B308-4962-8DA9-41D7D3F7764E.svs",
"TCGA-05-4418-01Z-00-DX1.f3863ea5-564f-482f-9878-cc104cf69401.svs",
"TCGA-05-4398-01Z-00-DX1.269bc75f-492e-48b1-87ee-85924aa80e74.svs"]
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
save = True
compression_opts_level = 5
def create_none_exist_folder(path: str) -> None:
"""Create folders that don't exist.
Args:
path (str): Folder path.
"""
if not os.path.exists(path):
os.makedirs(path)
def save_config(config: Config, path: str) -> dict[str, str]:
"""Saves a config as a json file.
Args:
config (Config): Config.
path (str): Path to save at.
Returns:
dict: Config as a dictionary.
"""
dic = config.__dict__.copy()
del dic["__doc__"], dic["__module__"], dic["__dict__"], dic["__weakref__"]
with open(path, "w") as f:
json.dump(dic, f)
return dic
def resize_wsi(wsi: np.ndarray, scale: int) -> np.ndarray:
"""Resizes the wsi to the given size.
Args:
wsi (np.ndarray): wsi.
scale (int): Relevant scale factor.
Returns:
np.ndarray: Resized wsi.
"""
dsize = (wsi.shape[1] // scale, wsi.shape[0] // scale)
wsi = cv2.resize(wsi, dsize=dsize, fx=0, fy=0, interpolation=cv2.INTER_AREA)
return wsi
def pad_wsi(config: Config, wsi: np.ndarray, pad_value: int) -> np.ndarray:
"""Pad the wsi in order to be dividable.
Args:
config (Config): Configurations.
wsi (np.ndarray): Wsi to be pad.
pad_value (int): Padding value.
Returns:
np.ndarray: Padded wsi.
"""
scaled_shape = wsi.shape
pad_size = config.patch_size * (2 * config.pyramid_level)
pad0, pad1 = (int(pad_size - (scaled_shape[0] % pad_size)),
int(pad_size - (scaled_shape[1] % pad_size)))
if len(scaled_shape) == 3:
wsi = np.pad(wsi, [[pad0 // 2, pad0 - pad0 // 2], [pad1 // 2, pad1 - pad1 // 2], [0, 0]],
constant_values=pad_value)
elif len(scaled_shape) == 2:
wsi = np.pad(wsi, [[pad0 // 2, pad0 - pad0 // 2], [pad1 // 2, pad1 - pad1 // 2]],
constant_values=pad_value)
return wsi
def thresh_wsi(config: Config, wsi: np.ndarray) -> np.ndarray:
"""Apply thresholding to the wsi.
Args:
config (Config): Configurations.
wsi (np.ndarray): Wsi to be threshed.
Returns:
np.ndarray: Threshed wsi.
"""
gray_scaled_wsi = cv2.cvtColor(wsi, cv2.COLOR_RGB2GRAY)
blured_scaled_wsi = cv2.medianBlur(gray_scaled_wsi, 3)
if config.thresh_method == "adaptive":
threshed_wsi = cv2.adaptiveThreshold(blured_scaled_wsi, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 21, 8)
elif config.thresh_method == "otsu":
_, threshed_wsi = cv2.threshold(blured_scaled_wsi, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY_INV)
else:
raise AttributeError(f"No thresh method named {config.thresh_method}")
threshed_wsi = skimage.morphology.remove_small_holes(threshed_wsi > 0, area_threshold=config.area_threshold,
connectivity=config.connectivity)
threshed_wsi = skimage.morphology.remove_small_objects(threshed_wsi, min_size=config.min_size,
connectivity=config.connectivity)
return threshed_wsi.astype(np.uint8) * 255
def gen_patch(wsi: np.ndarray, patch_size: int) -> np.ndarray:
"""Generate Patches from wsi of given size.
Args:
wsi (np.ndarray): Wsi to be processed.
patch_size (int): Size of a patch.
Returns:
np.ndarray: Patches of given size.
"""
shape = wsi.shape
if len(shape) == 2:
patches = wsi.reshape(shape[0] // patch_size, patch_size,
shape[1] // patch_size, patch_size)
patches = patches.transpose(0, 2, 1, 3)
patches = patches.reshape(-1, patch_size, patch_size)
elif len(shape) == 3:
patches = wsi.reshape(shape[0] // patch_size, patch_size,
shape[1] // patch_size, patch_size, 3)
patches = patches.transpose(0, 2, 1, 3, 4)
patches = patches.reshape(-1, patch_size, patch_size, 3)
return patches
def gen_features(config: Config, patches: np.ndarray, selected_idx: list[int]) -> np.ndarray:
"""Generate features from patches.
Args:
config (Config): Configurations.
patches (np.ndarray): Patches generated from wsi.
selected_idx (list[int]): List of selected indices.
Returns:
np.ndarray: Features generated from patches.
"""
device = torch.device(config.device_type)
model = resnet50_baseline(pretrained=True).to(device)
model.eval()
features = None
with torch.no_grad():
for idx, patch in enumerate(patches):
if idx in selected_idx:
patch = patch.transpose(2, 0, 1)
patch = np.expand_dims(patch, axis=0)
patch = torch.from_numpy(patch).float().to(device)
feature = model(patch)
if features is None:
features = feature.cpu().numpy()
else:
features = np.concatenate([features, feature.cpu().numpy()], axis=0)
return features
def select_patch(thresh_patches: np.ndarray) -> list[int]:
"""Select out patches that contain information.
Args:
thresh_patches (np.ndarray): Patches from threshed wsi.
Returns:
list[int]: List of selected indices.
"""
selected_idx = list()
for idx, thresh_patch in enumerate(thresh_patches):
if thresh_patch.sum() > 0:
selected_idx.append(idx)
return selected_idx
def gen_index(global_index_mat: typing.Union[np.ndarray, None], selected_idx: list[int], num_patches: list[int],
patches_per_row: list[int], level: int) -> np.ndarray:
"""Generate index for selected patches.
Args:
global_index_mat (np.ndarray or None): Previously built indices or None for no previous.
selected_idx (list[int]): Indices of selected patches.
num_patches (list[int]): Number of patches of each level.
patches_per_row (list[int]): Number of patches in a row of each level.
level (int): Current level.
Returns:
np.ndarray: An array of all known indices.
"""
local_index_mat = None
if level == 0:
local_index_mat = np.expand_dims(np.asarray(selected_idx), axis=0).T
else:
local_index_tree = dict()
index_mat = np.arange(0, num_patches[level - 1]).reshape(patches_per_row[level - 1], -1)
shape = index_mat.shape
index_patches = index_mat.reshape(shape[0] // config.relevant_scale, config.relevant_scale,
shape[1] // config.relevant_scale, config.relevant_scale)
index_patches = index_patches.transpose(0, 2, 1, 3)
index_patches = index_patches.reshape(-1, config.relevant_scale, config.relevant_scale)
for father_index, index_patch in enumerate(index_patches):
local_index_tree[father_index] = list()
for index in index_patch.flatten():
local_index_tree[father_index].append(index)
for key, val in local_index_tree.items():
if key in selected_idx:
for index in val:
rows = np.where(global_index_mat[:, 0] == index)[0]
if not len(rows) == 0:
for row in rows:
row = global_index_mat[row, :].tolist()
row.insert(0, key)
if local_index_mat is None:
local_index_mat = np.asarray(row)
else:
local_index_mat = np.vstack([local_index_mat, row])
return local_index_mat
def convert_to_relevant_index(global_index_mat: np.ndarray, patches_per_row: list[int]) -> np.ndarray:
"""Convert the global index matrix to the relevant index matrix.
Args:
global_index_mat (np.ndarray): The global index matrix.
patches_per_row (list[int]): Number of patches in a row of each level.
Returns:
np.ndarray: The relevant index matrix.
"""
relevant_global_index_mat = None
patches_per_row.reverse()
for num_row, row in enumerate(global_index_mat):
row_arr = list()
for level, index in enumerate(row):
coords = (index % patches_per_row[level], index // patches_per_row[level])
row_arr.append(coords[0])
row_arr.append(coords[1])
if relevant_global_index_mat is None:
relevant_global_index_mat = np.asarray(row_arr)
else:
relevant_global_index_mat = np.vstack([relevant_global_index_mat, row_arr])
return relevant_global_index_mat
def save_current_level_as_h5(config: Config, patches: np.ndarray, features: np.ndarray, num_patches_per_row: int,
wsi_name: str, scale_factor: float) -> None:
"""Saves processing results from current level as h5 files.
Args:
config (Config): Configurations.
patches (np.ndarray): The patches of the current level.
features (np.ndarray): The features of the current level.
num_patches_per_row (int): The number of patches per row of the current level.
wsi_name (str): Name of the wsi.
scale_factor (float): The absolute scale factor of the wsi.
"""
selected_coords = None
patch_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "patches", f"{scale_factor}x.h5")
feature_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "features", f"{scale_factor}x.h5")
with h5py.File(patch_h5_file_name, 'w') as hf:
with h5py.File(feature_h5_file_name, 'w') as feature_hf:
patch_index_group = hf.create_group("Indices")
patch_group = hf.create_group("Patches")
feature_index_group = feature_hf.create_group("Indices")
feature_group = feature_hf.create_group("Feature")
for idx, patch, in enumerate(patches):
if idx in selected_idx:
coords = (idx % num_patches_per_row, idx // num_patches_per_row)
patch_dataset = patch_group.create_dataset(f"{scale_factor}x_{coords[0]}_{coords[1]}", data=patch,
compression='gzip',
compression_opts=config.compression_opts_level)
patch_dataset.attrs["wsi_name"] = wsi_name
patch_dataset.attrs["scale_factor"] = scale_factor
patch_dataset.attrs["id"] = idx
patch_dataset.attrs["coords"] = coords
np_coords = np.asarray([idx, coords[0], coords[1]], dtype=np.uintc)
if selected_coords is None:
selected_coords = np_coords
else:
selected_coords = np.vstack([selected_coords, np_coords])
for i, feature in enumerate(features):
idx = selected_idx[i]
coords = (idx % num_patches_per_row, idx // num_patches_per_row)
feature_dataset = feature_group.create_dataset(f"{scale_factor}x_{coords[0]}_{coords[1]}", data=feature,
compression='gzip',
compression_opts=config.compression_opts_level)
feature_dataset.attrs["wsi_name"] = wsi_name
feature_dataset.attrs["scale_factor"] = scale_factor
feature_dataset.attrs["id"] = idx
feature_dataset.attrs["coords"] = coords
patch_index_dataset = patch_index_group.create_dataset("local_index", data=selected_coords,
compression='gzip',
compression_opts=config.compression_opts_level)
patch_index_dataset.attrs["wsi_name"] = wsi_name
patch_index_dataset.attrs["scale_factor"] = scale_factor
feature_index_dataset = feature_index_group.create_dataset("local_index", data=selected_coords,
compression='gzip',
compression_opts=config.compression_opts_level)
feature_index_dataset.attrs["wsi_name"] = wsi_name
feature_index_dataset.attrs["scale_factor"] = scale_factor
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5")
with h5py.File(global_index_h5_file_name, 'a') as hf:
group = hf.get("Indices")
dataset = group.create_dataset(f"{scale_factor}x_index", data=selected_coords, compression='gzip',
compression_opts=config.compression_opts_level)
dataset.attrs["wsi_name"] = wsi_name
dataset.attrs["scale_factor"] = scale_factor
def save_current_level_as_jpeg_pt(config: Config, patches: np.ndarray, features: np.ndarray, num_patches_per_row: int,
wsi_name: str, scale_factor: float) -> None:
"""Saves processing results from current level as jpeg and pt files.
Args:
config (Config): Configurations.
patches (np.ndarray): The patches of the current level.
features (np.ndarray): The features of the current level.
num_patches_per_row (int): The number of patches per row of the current level.
wsi_name (str): Name of the wsi.
scale_factor (float): The absolute scale factor of the wsi.
"""
selected_coords = None
patch_folder = os.path.join(config.output_dir, wsi_name.split(".")[0], "patches", f"{scale_factor}x")
feature_folder = os.path.join(config.output_dir, wsi_name.split(".")[0], "features", f"{scale_factor}x")
create_none_exist_folder(patch_folder)
create_none_exist_folder(feature_folder)
for idx, patch, in enumerate(patches):
if idx in selected_idx:
coords = (idx % num_patches_per_row, idx // num_patches_per_row)
cv2.imwrite(os.path.join(patch_folder, f"{coords[0]}_{coords[1]}.jpeg"),
cv2.cvtColor(patch, cv2.COLOR_RGB2BGR))
np_coords = np.asarray([idx, coords[0], coords[1]], dtype=np.uintc)
if selected_coords is None:
selected_coords = np_coords
else:
selected_coords = np.vstack([selected_coords, np_coords])
for i, feature in enumerate(features):
idx = selected_idx[i]
coords = (idx % num_patches_per_row, idx // num_patches_per_row)
torch.save(torch.from_numpy(feature), os.path.join(feature_folder, f"{coords[0]}_{coords[1]}.pt"))
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5")
with h5py.File(global_index_h5_file_name, 'a') as hf:
group = hf.get("Indices")
dataset = group.create_dataset(f"{scale_factor}x_index", data=selected_coords, compression='gzip',
compression_opts=config.compression_opts_level)
dataset.attrs["wsi_name"] = wsi_name
dataset.attrs["scale_factor"] = scale_factor
def create_global_index_h5_file(config: Config, wsi_name: str) -> None:
"""Create an hdf5 file for storing the global index.
Args:
config (Config): Configurations.
wsi_name (str): Name of the wsi.
"""
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5")
with h5py.File(global_index_h5_file_name, 'w') as hf:
hf.create_group("Indices")
def save_index(config: Config, global_index_mat: np.ndarray, relevant_global_index_mat: np.ndarray,
wsi_name: str) -> None:
"""Save the relevant and absolute indices.
Args:
config (Config): Configurations.
global_index_mat (np.ndarray): Absolute index mat.
relevant_global_index_mat (np.ndarray): Relevant index mat.
wsi_name (str): Name of the wsi.
"""
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5")
with h5py.File(global_index_h5_file_name, 'a') as hf:
group = hf.get("Indices")
dataset = group.create_dataset(f"global_index", data=global_index_mat, compression='gzip',
compression_opts=config.compression_opts_level)
dataset.attrs["wsi_name"] = wsi_name
dataset = group.create_dataset(f"relevant_global_index", data=relevant_global_index_mat, compression='gzip',
compression_opts=config.compression_opts_level)
dataset.attrs["wsi_name"] = wsi_name
if __name__ == "__main__":
config = Config()
create_none_exist_folder(config.output_dir)
save_config(Config, os.path.join(config.output_dir, "config.json"))
wsi_names = [file_name for file_name in os.listdir(config.input_dir)
if os.path.isfile(os.path.join(config.input_dir, file_name))]
if len(wsi_names) == 0:
raise FileNotFoundError(f"No file found under {config.input_dir}")
print(f" -> Processing {len(wsi_names)} WSI(s)...")
print(f" -> File names: {wsi_names}")
start_time = time.time()
for index, wsi_name in enumerate(wsi_names):
if not os.path.exists(os.path.join(os.path.join(config.output_dir, wsi_name.split(".")[0]))):
print(f"\n - Processing on {wsi_name} [{index + 1}/{len(wsi_names)}]")
create_none_exist_folder(os.path.join(config.output_dir, wsi_name.split(".")[0]))
create_none_exist_folder(os.path.join(config.output_dir, wsi_name.split(".")[0], "patches"))
create_none_exist_folder(os.path.join(config.output_dir, wsi_name.split(".")[0], "features"))
if config.save:
create_global_index_h5_file(config, wsi_name)
global_index_mat = None
num_patches = list()
patches_per_row = list()
for level in range(config.pyramid_level):
scale_factor = config.largest_scale / (config.relevant_scale ** level)
print(f" - Scale {scale_factor}x")
if level == 0:
wsi = tiff.imread(os.path.join(config.input_dir, wsi_name))
wsi = resize_wsi(wsi, 40 // config.largest_scale)
wsi = pad_wsi(config, wsi, config.pad_value_wsi)
if wsi_name not in config.no_thresh_images:
threshed_wsi = thresh_wsi(config, wsi)
else:
threshed_wsi = np.ones(shape=(wsi.shape[0], wsi.shape[1]), dtype=np.uint8) * 255
else:
wsi = resize_wsi(wsi, config.relevant_scale)
threshed_wsi = resize_wsi(threshed_wsi, config.relevant_scale)
patches = gen_patch(wsi, config.patch_size)
thresh_patches = gen_patch(threshed_wsi, config.patch_size)
selected_idx = select_patch(thresh_patches)
features = gen_features(config, patches, selected_idx)
num_patches.append(len(thresh_patches))
patches_per_row.append(wsi.shape[0] // config.patch_size)
global_index_mat = gen_index(global_index_mat, selected_idx, num_patches, patches_per_row, level)
if config.save:
save_current_level_as_jpeg_pt(config, patches, features, patches_per_row[-1], wsi_name,
scale_factor)
gc.collect()
relevant_global_index_mat = convert_to_relevant_index(global_index_mat, patches_per_row)
if config.save:
save_index(config, global_index_mat, relevant_global_index_mat, wsi_name)
gc.collect()
else:
print(f"\n - Processed {wsi_name} [{index + 1}/{len(wsi_names)}]")
print(" - Skip")
print(f"\nComplete in {time.time() - start_time}s.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment