Last active
February 7, 2023 04:50
-
-
Save TianyiFranklinWang/2c4cc0103edb30c15790bb462351e3c4 to your computer and use it in GitHub Desktop.
Hierarchically generate and store patches from WSIs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This code is for hierarchically patching large whole slide images. | |
Copyright (c) 2022 Tianyi Wang & Mengkang Lu | |
All rights reserved. | |
Released under MIT License. | |
""" | |
import gc | |
import json | |
import os | |
import time | |
import typing | |
import cv2 | |
import h5py | |
import numpy as np | |
import skimage | |
import tifffile as tiff | |
import torch | |
from models.resnet_custom import resnet50_baseline | |
class Config: | |
"""Parameters used for patching. | |
Attributes: | |
input_dir (str): Input directory. | |
output_dir (str): Output directory. | |
largest_scale (int): Largest scale level in wsi. | |
pyramid_level (int): Level of scales in wsi pyramid. | |
relevant_scale (int): Relevant scale factor between neighbourhoods in wsi pyramid. | |
patch_size (int): Size of patches. | |
pad_value_wsi (int): Border pad value for wsi. | |
pad_value_thresh (int): Border pad value for threshed wsi. | |
thresh_method (str): Threshold method either 'otsu' or 'adaptive'. | |
area_threshold (int): Area_threshold argument for remove_small_holes function. | |
min_size (int): Min_size argument for remove_small_objects function. | |
connectivity (int): Connectivity argument for remove_small_holes and remove_small_objects functions. | |
device_type (str): Define inference on cpu or gpu. | |
save (bool): Whether to save algorithm results. | |
""" | |
# input_dir = "../../TCGA_LUNG/LUAD" | |
# output_dir = "../../TCGA_LUNG_processed/LUAD" | |
input_dir = "../../TCGA_LUNG/LUSC" | |
output_dir = "../../TCGA_LUNG_processed/LUSC" | |
largest_scale = 20 | |
pyramid_level = 4 | |
relevant_scale = 2 | |
patch_size = 256 | |
pad_value_wsi = 255 | |
pad_value_thresh = 0 | |
thresh_method = 'otsu' # or 'adaptive' | |
area_threshold = 16384 | |
min_size = 16384 | |
connectivity = 8 | |
no_thresh_images = ["TCGA-05-4415-01Z-00-DX1.55E0C429-B308-4962-8DA9-41D7D3F7764E.svs", | |
"TCGA-05-4418-01Z-00-DX1.f3863ea5-564f-482f-9878-cc104cf69401.svs", | |
"TCGA-05-4398-01Z-00-DX1.269bc75f-492e-48b1-87ee-85924aa80e74.svs"] | |
device_type = 'cuda' if torch.cuda.is_available() else 'cpu' | |
save = True | |
compression_opts_level = 5 | |
def create_none_exist_folder(path: str) -> None: | |
"""Create folders that don't exist. | |
Args: | |
path (str): Folder path. | |
""" | |
if not os.path.exists(path): | |
os.makedirs(path) | |
def save_config(config: Config, path: str) -> dict[str, str]: | |
"""Saves a config as a json file. | |
Args: | |
config (Config): Config. | |
path (str): Path to save at. | |
Returns: | |
dict: Config as a dictionary. | |
""" | |
dic = config.__dict__.copy() | |
del dic["__doc__"], dic["__module__"], dic["__dict__"], dic["__weakref__"] | |
with open(path, "w") as f: | |
json.dump(dic, f) | |
return dic | |
def resize_wsi(wsi: np.ndarray, scale: int) -> np.ndarray: | |
"""Resizes the wsi to the given size. | |
Args: | |
wsi (np.ndarray): wsi. | |
scale (int): Relevant scale factor. | |
Returns: | |
np.ndarray: Resized wsi. | |
""" | |
dsize = (wsi.shape[1] // scale, wsi.shape[0] // scale) | |
wsi = cv2.resize(wsi, dsize=dsize, fx=0, fy=0, interpolation=cv2.INTER_AREA) | |
return wsi | |
def pad_wsi(config: Config, wsi: np.ndarray, pad_value: int) -> np.ndarray: | |
"""Pad the wsi in order to be dividable. | |
Args: | |
config (Config): Configurations. | |
wsi (np.ndarray): Wsi to be pad. | |
pad_value (int): Padding value. | |
Returns: | |
np.ndarray: Padded wsi. | |
""" | |
scaled_shape = wsi.shape | |
pad_size = config.patch_size * (2 * config.pyramid_level) | |
pad0, pad1 = (int(pad_size - (scaled_shape[0] % pad_size)), | |
int(pad_size - (scaled_shape[1] % pad_size))) | |
if len(scaled_shape) == 3: | |
wsi = np.pad(wsi, [[pad0 // 2, pad0 - pad0 // 2], [pad1 // 2, pad1 - pad1 // 2], [0, 0]], | |
constant_values=pad_value) | |
elif len(scaled_shape) == 2: | |
wsi = np.pad(wsi, [[pad0 // 2, pad0 - pad0 // 2], [pad1 // 2, pad1 - pad1 // 2]], | |
constant_values=pad_value) | |
return wsi | |
def thresh_wsi(config: Config, wsi: np.ndarray) -> np.ndarray: | |
"""Apply thresholding to the wsi. | |
Args: | |
config (Config): Configurations. | |
wsi (np.ndarray): Wsi to be threshed. | |
Returns: | |
np.ndarray: Threshed wsi. | |
""" | |
gray_scaled_wsi = cv2.cvtColor(wsi, cv2.COLOR_RGB2GRAY) | |
blured_scaled_wsi = cv2.medianBlur(gray_scaled_wsi, 3) | |
if config.thresh_method == "adaptive": | |
threshed_wsi = cv2.adaptiveThreshold(blured_scaled_wsi, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
cv2.THRESH_BINARY_INV, 21, 8) | |
elif config.thresh_method == "otsu": | |
_, threshed_wsi = cv2.threshold(blured_scaled_wsi, 0, 255, cv2.THRESH_OTSU + cv2.THRESH_BINARY_INV) | |
else: | |
raise AttributeError(f"No thresh method named {config.thresh_method}") | |
threshed_wsi = skimage.morphology.remove_small_holes(threshed_wsi > 0, area_threshold=config.area_threshold, | |
connectivity=config.connectivity) | |
threshed_wsi = skimage.morphology.remove_small_objects(threshed_wsi, min_size=config.min_size, | |
connectivity=config.connectivity) | |
return threshed_wsi.astype(np.uint8) * 255 | |
def gen_patch(wsi: np.ndarray, patch_size: int) -> np.ndarray: | |
"""Generate Patches from wsi of given size. | |
Args: | |
wsi (np.ndarray): Wsi to be processed. | |
patch_size (int): Size of a patch. | |
Returns: | |
np.ndarray: Patches of given size. | |
""" | |
shape = wsi.shape | |
if len(shape) == 2: | |
patches = wsi.reshape(shape[0] // patch_size, patch_size, | |
shape[1] // patch_size, patch_size) | |
patches = patches.transpose(0, 2, 1, 3) | |
patches = patches.reshape(-1, patch_size, patch_size) | |
elif len(shape) == 3: | |
patches = wsi.reshape(shape[0] // patch_size, patch_size, | |
shape[1] // patch_size, patch_size, 3) | |
patches = patches.transpose(0, 2, 1, 3, 4) | |
patches = patches.reshape(-1, patch_size, patch_size, 3) | |
return patches | |
def gen_features(config: Config, patches: np.ndarray, selected_idx: list[int]) -> np.ndarray: | |
"""Generate features from patches. | |
Args: | |
config (Config): Configurations. | |
patches (np.ndarray): Patches generated from wsi. | |
selected_idx (list[int]): List of selected indices. | |
Returns: | |
np.ndarray: Features generated from patches. | |
""" | |
device = torch.device(config.device_type) | |
model = resnet50_baseline(pretrained=True).to(device) | |
model.eval() | |
features = None | |
with torch.no_grad(): | |
for idx, patch in enumerate(patches): | |
if idx in selected_idx: | |
patch = patch.transpose(2, 0, 1) | |
patch = np.expand_dims(patch, axis=0) | |
patch = torch.from_numpy(patch).float().to(device) | |
feature = model(patch) | |
if features is None: | |
features = feature.cpu().numpy() | |
else: | |
features = np.concatenate([features, feature.cpu().numpy()], axis=0) | |
return features | |
def select_patch(thresh_patches: np.ndarray) -> list[int]: | |
"""Select out patches that contain information. | |
Args: | |
thresh_patches (np.ndarray): Patches from threshed wsi. | |
Returns: | |
list[int]: List of selected indices. | |
""" | |
selected_idx = list() | |
for idx, thresh_patch in enumerate(thresh_patches): | |
if thresh_patch.sum() > 0: | |
selected_idx.append(idx) | |
return selected_idx | |
def gen_index(global_index_mat: typing.Union[np.ndarray, None], selected_idx: list[int], num_patches: list[int], | |
patches_per_row: list[int], level: int) -> np.ndarray: | |
"""Generate index for selected patches. | |
Args: | |
global_index_mat (np.ndarray or None): Previously built indices or None for no previous. | |
selected_idx (list[int]): Indices of selected patches. | |
num_patches (list[int]): Number of patches of each level. | |
patches_per_row (list[int]): Number of patches in a row of each level. | |
level (int): Current level. | |
Returns: | |
np.ndarray: An array of all known indices. | |
""" | |
local_index_mat = None | |
if level == 0: | |
local_index_mat = np.expand_dims(np.asarray(selected_idx), axis=0).T | |
else: | |
local_index_tree = dict() | |
index_mat = np.arange(0, num_patches[level - 1]).reshape(patches_per_row[level - 1], -1) | |
shape = index_mat.shape | |
index_patches = index_mat.reshape(shape[0] // config.relevant_scale, config.relevant_scale, | |
shape[1] // config.relevant_scale, config.relevant_scale) | |
index_patches = index_patches.transpose(0, 2, 1, 3) | |
index_patches = index_patches.reshape(-1, config.relevant_scale, config.relevant_scale) | |
for father_index, index_patch in enumerate(index_patches): | |
local_index_tree[father_index] = list() | |
for index in index_patch.flatten(): | |
local_index_tree[father_index].append(index) | |
for key, val in local_index_tree.items(): | |
if key in selected_idx: | |
for index in val: | |
rows = np.where(global_index_mat[:, 0] == index)[0] | |
if not len(rows) == 0: | |
for row in rows: | |
row = global_index_mat[row, :].tolist() | |
row.insert(0, key) | |
if local_index_mat is None: | |
local_index_mat = np.asarray(row) | |
else: | |
local_index_mat = np.vstack([local_index_mat, row]) | |
return local_index_mat | |
def convert_to_relevant_index(global_index_mat: np.ndarray, patches_per_row: list[int]) -> np.ndarray: | |
"""Convert the global index matrix to the relevant index matrix. | |
Args: | |
global_index_mat (np.ndarray): The global index matrix. | |
patches_per_row (list[int]): Number of patches in a row of each level. | |
Returns: | |
np.ndarray: The relevant index matrix. | |
""" | |
relevant_global_index_mat = None | |
patches_per_row.reverse() | |
for num_row, row in enumerate(global_index_mat): | |
row_arr = list() | |
for level, index in enumerate(row): | |
coords = (index % patches_per_row[level], index // patches_per_row[level]) | |
row_arr.append(coords[0]) | |
row_arr.append(coords[1]) | |
if relevant_global_index_mat is None: | |
relevant_global_index_mat = np.asarray(row_arr) | |
else: | |
relevant_global_index_mat = np.vstack([relevant_global_index_mat, row_arr]) | |
return relevant_global_index_mat | |
def save_current_level_as_h5(config: Config, patches: np.ndarray, features: np.ndarray, num_patches_per_row: int, | |
wsi_name: str, scale_factor: float) -> None: | |
"""Saves processing results from current level as h5 files. | |
Args: | |
config (Config): Configurations. | |
patches (np.ndarray): The patches of the current level. | |
features (np.ndarray): The features of the current level. | |
num_patches_per_row (int): The number of patches per row of the current level. | |
wsi_name (str): Name of the wsi. | |
scale_factor (float): The absolute scale factor of the wsi. | |
""" | |
selected_coords = None | |
patch_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "patches", f"{scale_factor}x.h5") | |
feature_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "features", f"{scale_factor}x.h5") | |
with h5py.File(patch_h5_file_name, 'w') as hf: | |
with h5py.File(feature_h5_file_name, 'w') as feature_hf: | |
patch_index_group = hf.create_group("Indices") | |
patch_group = hf.create_group("Patches") | |
feature_index_group = feature_hf.create_group("Indices") | |
feature_group = feature_hf.create_group("Feature") | |
for idx, patch, in enumerate(patches): | |
if idx in selected_idx: | |
coords = (idx % num_patches_per_row, idx // num_patches_per_row) | |
patch_dataset = patch_group.create_dataset(f"{scale_factor}x_{coords[0]}_{coords[1]}", data=patch, | |
compression='gzip', | |
compression_opts=config.compression_opts_level) | |
patch_dataset.attrs["wsi_name"] = wsi_name | |
patch_dataset.attrs["scale_factor"] = scale_factor | |
patch_dataset.attrs["id"] = idx | |
patch_dataset.attrs["coords"] = coords | |
np_coords = np.asarray([idx, coords[0], coords[1]], dtype=np.uintc) | |
if selected_coords is None: | |
selected_coords = np_coords | |
else: | |
selected_coords = np.vstack([selected_coords, np_coords]) | |
for i, feature in enumerate(features): | |
idx = selected_idx[i] | |
coords = (idx % num_patches_per_row, idx // num_patches_per_row) | |
feature_dataset = feature_group.create_dataset(f"{scale_factor}x_{coords[0]}_{coords[1]}", data=feature, | |
compression='gzip', | |
compression_opts=config.compression_opts_level) | |
feature_dataset.attrs["wsi_name"] = wsi_name | |
feature_dataset.attrs["scale_factor"] = scale_factor | |
feature_dataset.attrs["id"] = idx | |
feature_dataset.attrs["coords"] = coords | |
patch_index_dataset = patch_index_group.create_dataset("local_index", data=selected_coords, | |
compression='gzip', | |
compression_opts=config.compression_opts_level) | |
patch_index_dataset.attrs["wsi_name"] = wsi_name | |
patch_index_dataset.attrs["scale_factor"] = scale_factor | |
feature_index_dataset = feature_index_group.create_dataset("local_index", data=selected_coords, | |
compression='gzip', | |
compression_opts=config.compression_opts_level) | |
feature_index_dataset.attrs["wsi_name"] = wsi_name | |
feature_index_dataset.attrs["scale_factor"] = scale_factor | |
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5") | |
with h5py.File(global_index_h5_file_name, 'a') as hf: | |
group = hf.get("Indices") | |
dataset = group.create_dataset(f"{scale_factor}x_index", data=selected_coords, compression='gzip', | |
compression_opts=config.compression_opts_level) | |
dataset.attrs["wsi_name"] = wsi_name | |
dataset.attrs["scale_factor"] = scale_factor | |
def save_current_level_as_jpeg_pt(config: Config, patches: np.ndarray, features: np.ndarray, num_patches_per_row: int, | |
wsi_name: str, scale_factor: float) -> None: | |
"""Saves processing results from current level as jpeg and pt files. | |
Args: | |
config (Config): Configurations. | |
patches (np.ndarray): The patches of the current level. | |
features (np.ndarray): The features of the current level. | |
num_patches_per_row (int): The number of patches per row of the current level. | |
wsi_name (str): Name of the wsi. | |
scale_factor (float): The absolute scale factor of the wsi. | |
""" | |
selected_coords = None | |
patch_folder = os.path.join(config.output_dir, wsi_name.split(".")[0], "patches", f"{scale_factor}x") | |
feature_folder = os.path.join(config.output_dir, wsi_name.split(".")[0], "features", f"{scale_factor}x") | |
create_none_exist_folder(patch_folder) | |
create_none_exist_folder(feature_folder) | |
for idx, patch, in enumerate(patches): | |
if idx in selected_idx: | |
coords = (idx % num_patches_per_row, idx // num_patches_per_row) | |
cv2.imwrite(os.path.join(patch_folder, f"{coords[0]}_{coords[1]}.jpeg"), | |
cv2.cvtColor(patch, cv2.COLOR_RGB2BGR)) | |
np_coords = np.asarray([idx, coords[0], coords[1]], dtype=np.uintc) | |
if selected_coords is None: | |
selected_coords = np_coords | |
else: | |
selected_coords = np.vstack([selected_coords, np_coords]) | |
for i, feature in enumerate(features): | |
idx = selected_idx[i] | |
coords = (idx % num_patches_per_row, idx // num_patches_per_row) | |
torch.save(torch.from_numpy(feature), os.path.join(feature_folder, f"{coords[0]}_{coords[1]}.pt")) | |
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5") | |
with h5py.File(global_index_h5_file_name, 'a') as hf: | |
group = hf.get("Indices") | |
dataset = group.create_dataset(f"{scale_factor}x_index", data=selected_coords, compression='gzip', | |
compression_opts=config.compression_opts_level) | |
dataset.attrs["wsi_name"] = wsi_name | |
dataset.attrs["scale_factor"] = scale_factor | |
def create_global_index_h5_file(config: Config, wsi_name: str) -> None: | |
"""Create an hdf5 file for storing the global index. | |
Args: | |
config (Config): Configurations. | |
wsi_name (str): Name of the wsi. | |
""" | |
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5") | |
with h5py.File(global_index_h5_file_name, 'w') as hf: | |
hf.create_group("Indices") | |
def save_index(config: Config, global_index_mat: np.ndarray, relevant_global_index_mat: np.ndarray, | |
wsi_name: str) -> None: | |
"""Save the relevant and absolute indices. | |
Args: | |
config (Config): Configurations. | |
global_index_mat (np.ndarray): Absolute index mat. | |
relevant_global_index_mat (np.ndarray): Relevant index mat. | |
wsi_name (str): Name of the wsi. | |
""" | |
global_index_h5_file_name = os.path.join(config.output_dir, wsi_name.split(".")[0], "global_index.h5") | |
with h5py.File(global_index_h5_file_name, 'a') as hf: | |
group = hf.get("Indices") | |
dataset = group.create_dataset(f"global_index", data=global_index_mat, compression='gzip', | |
compression_opts=config.compression_opts_level) | |
dataset.attrs["wsi_name"] = wsi_name | |
dataset = group.create_dataset(f"relevant_global_index", data=relevant_global_index_mat, compression='gzip', | |
compression_opts=config.compression_opts_level) | |
dataset.attrs["wsi_name"] = wsi_name | |
if __name__ == "__main__": | |
config = Config() | |
create_none_exist_folder(config.output_dir) | |
save_config(Config, os.path.join(config.output_dir, "config.json")) | |
wsi_names = [file_name for file_name in os.listdir(config.input_dir) | |
if os.path.isfile(os.path.join(config.input_dir, file_name))] | |
if len(wsi_names) == 0: | |
raise FileNotFoundError(f"No file found under {config.input_dir}") | |
print(f" -> Processing {len(wsi_names)} WSI(s)...") | |
print(f" -> File names: {wsi_names}") | |
start_time = time.time() | |
for index, wsi_name in enumerate(wsi_names): | |
if not os.path.exists(os.path.join(os.path.join(config.output_dir, wsi_name.split(".")[0]))): | |
print(f"\n - Processing on {wsi_name} [{index + 1}/{len(wsi_names)}]") | |
create_none_exist_folder(os.path.join(config.output_dir, wsi_name.split(".")[0])) | |
create_none_exist_folder(os.path.join(config.output_dir, wsi_name.split(".")[0], "patches")) | |
create_none_exist_folder(os.path.join(config.output_dir, wsi_name.split(".")[0], "features")) | |
if config.save: | |
create_global_index_h5_file(config, wsi_name) | |
global_index_mat = None | |
num_patches = list() | |
patches_per_row = list() | |
for level in range(config.pyramid_level): | |
scale_factor = config.largest_scale / (config.relevant_scale ** level) | |
print(f" - Scale {scale_factor}x") | |
if level == 0: | |
wsi = tiff.imread(os.path.join(config.input_dir, wsi_name)) | |
wsi = resize_wsi(wsi, 40 // config.largest_scale) | |
wsi = pad_wsi(config, wsi, config.pad_value_wsi) | |
if wsi_name not in config.no_thresh_images: | |
threshed_wsi = thresh_wsi(config, wsi) | |
else: | |
threshed_wsi = np.ones(shape=(wsi.shape[0], wsi.shape[1]), dtype=np.uint8) * 255 | |
else: | |
wsi = resize_wsi(wsi, config.relevant_scale) | |
threshed_wsi = resize_wsi(threshed_wsi, config.relevant_scale) | |
patches = gen_patch(wsi, config.patch_size) | |
thresh_patches = gen_patch(threshed_wsi, config.patch_size) | |
selected_idx = select_patch(thresh_patches) | |
features = gen_features(config, patches, selected_idx) | |
num_patches.append(len(thresh_patches)) | |
patches_per_row.append(wsi.shape[0] // config.patch_size) | |
global_index_mat = gen_index(global_index_mat, selected_idx, num_patches, patches_per_row, level) | |
if config.save: | |
save_current_level_as_jpeg_pt(config, patches, features, patches_per_row[-1], wsi_name, | |
scale_factor) | |
gc.collect() | |
relevant_global_index_mat = convert_to_relevant_index(global_index_mat, patches_per_row) | |
if config.save: | |
save_index(config, global_index_mat, relevant_global_index_mat, wsi_name) | |
gc.collect() | |
else: | |
print(f"\n - Processed {wsi_name} [{index + 1}/{len(wsi_names)}]") | |
print(" - Skip") | |
print(f"\nComplete in {time.time() - start_time}s.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment