Last active
March 17, 2021 09:06
-
-
Save Reasat/76b53d6be24bceff4525e7ab92ca9ffd to your computer and use it in GitHub Desktop.
Implementing a data generator to create batch of images for feature extraction in HIstomicsML2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
import h5py | |
import time | |
from tqdm import tqdm | |
import pandas as pd | |
import numpy as np | |
import keras | |
import histomicstk.preprocessing.color_normalization as htk_cnorm | |
import histomicstk.utils as htk_utils | |
import large_image | |
import tensorflow as tf | |
from keras.models import Model | |
from keras.preprocessing import image | |
from keras.applications.vgg16 import preprocess_input | |
from keras import applications | |
from ctk_cli import CLIArgumentParser | |
from scipy.misc import imresize | |
from sklearn.decomposition import PCA | |
from sklearn.externals import joblib | |
from histomicstk.cli import utils as cli_utils | |
import logging | |
TIME_STAMP=time.strftime('%Y-%m-%d-%H-%M-%S') | |
logging.basicConfig(level=logging.CRITICAL) | |
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
#os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |
import sys | |
class Logger(object): | |
def __init__(self,path): | |
self.terminal = sys.stdout | |
self.log = open(path, "a+") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
def flush(self): | |
#this flush method is needed for python 3 compatibility. | |
#this handles the flush command by doing nothing. | |
#you might want to specify some extra behavior here. | |
pass | |
class DataGenerator(keras.utils.Sequence): | |
'Generates data for Keras' | |
def __init__(self, tile_info, im_nmzd, x_centroids, y_centroids, patchSize, patchSizeResized = 224, batch_size=32, shuffle=False): | |
#'Initialization' | |
self.x_centroids = x_centroids | |
self.y_centroids = y_centroids | |
self.patchSize = patchSize | |
self.patchSizeResized = patchSizeResized | |
self.left = tile_info['gx'] | |
self.top = tile_info['gy'] | |
self.im_width = tile_info['width'] | |
self.im_height = tile_info['height'] | |
self.im_nmzd = im_nmzd | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.on_epoch_end() | |
def __len__(self): | |
# 'Denotes the number of batches per epoch' | |
return int(np.ceil(len(self.x_centroids) / float(self.batch_size))) | |
def __getitem__(self, index): | |
# Generate indexes of the batch | |
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] | |
# Find list of IDs | |
x_batch, y_batch = zip(*[(self.x_centroids[k], self.y_centroids[k]) for k in indexes]) | |
# Generate data | |
img_batch = self.__data_generation(x_batch, y_batch) | |
img_batch = preprocess_input(img_batch) | |
return img_batch | |
def on_epoch_end(self): | |
#'Updates indexes after each epoch' | |
self.indexes = np.arange(len(self.x_centroids )) | |
if self.shuffle == True: | |
np.random.shuffle(self.indexes) | |
def __data_generation(self, x_batch, y_batch): | |
#'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) | |
# Initialization | |
X = np.empty((len(x_batch), self.patchSizeResized, self.patchSizeResized, 3)) | |
for i, (x_centroid, y_centroid) in enumerate(zip(x_batch, y_batch)): | |
cen_x = (x_centroid - self.left) | |
cen_y = (y_centroid - self.top) | |
# get bounds of superpixel region | |
min_row, max_row, min_col, max_col = \ | |
get_patch_bounds(cen_y, cen_x, | |
self.patchSize, self.im_height, self.im_width) | |
# resize superpixel patch | |
im_patch = \ | |
imresize(self.im_nmzd[min_row:max_row, min_col:max_col, :], | |
(self.patchSizeResized, | |
self.patchSizeResized, 3)) | |
X[i,] = im_patch | |
return X | |
def compute_superpixel_data_pca(model, pca, slide_path, tile_position, | |
centroid_path, args, superpixel_kwargs, | |
src_mu_lab=None, src_sigma_lab=None): | |
# get slide tile source | |
ts = large_image.getTileSource(slide_path) | |
tile_info = \ | |
ts.getSingleTile(tile_position=tile_position, | |
format=large_image.tilesource.TILE_FORMAT_NUMPY, | |
**superpixel_kwargs) | |
left = tile_info['gx'] | |
top = tile_info['gy'] | |
# get width and height | |
width = tile_info['width'] | |
height = tile_info['height'] | |
f = h5py.File(centroid_path,'r') | |
x_centroids = f['x_centroid'][:] | |
y_centroids = f['y_centroid'][:] | |
n_superpixels = len(x_centroids) | |
print('total No. of superpixels: {} '.format(n_superpixels)) | |
x_centroids_valid, y_centroids_valid = [],[] | |
for x,y in zip(x_centroids, y_centroids): | |
if left < x <= (left + width) \ | |
and top < y <= (top + height): | |
x_centroids_valid.append(x) | |
y_centroids_valid.append(y) | |
print('No. of valid superpixels in tile position {}: {}'.format(tile_position,len(x_centroids_valid))) | |
if len(x_centroids_valid) == 0: | |
print('no valid superpixels, returning empty lists') | |
return [], [], [] | |
# get requested tile | |
im_tile = tile_info['tile'][:, :, :3] | |
# perform color normalization | |
start = time.time() | |
print('Applying Reinhard color normalization') | |
im_nmzd = htk_cnorm.reinhard(im_tile, args.reference_mu_lab, | |
args.reference_std_lab, src_mu=src_mu_lab, | |
src_sigma=src_sigma_lab) | |
end = time.time() | |
print('time taken {}'.format(cli_utils.disp_time_hms(end-start))) | |
BATCH_SIZE = args.batch_size | |
datagenerator = DataGenerator(tile_info, im_nmzd, x_centroids_valid, y_centroids_valid, | |
patchSize= args.patchSize, batch_size=BATCH_SIZE, | |
shuffle=False) | |
# get superpixel features | |
fcn = model.predict_generator( | |
datagenerator, | |
steps = np.ceil(len(x_centroids_valid)/float(BATCH_SIZE)), | |
verbose = True, | |
workers = args.num_workers_loader, | |
use_multiprocessing = True, | |
) | |
# reduce the fcn features | |
if pca is not None: | |
features = pca.transform(fcn) | |
else: | |
features = fcn | |
return features, x_centroids_valid, y_centroids_valid | |
def get_patch_bounds(cx, cy, patch, m, n): | |
half_patch = patch / 2.0 | |
min_row = int(round(cx) - half_patch) | |
max_row = int(round(cx) + half_patch) | |
min_col = int(round(cy) - half_patch) | |
max_col = int(round(cy) + half_patch) | |
if min_row < 0: | |
max_row = max_row - min_row | |
min_row = 0 | |
if max_row > m - 1: | |
min_row = min_row - (max_row - (m - 1)) | |
max_row = m - 1 | |
if min_col < 0: | |
max_col = max_col - min_col | |
min_col = 0 | |
if max_col > n - 1: | |
min_col = min_col - (max_col - (n - 1)) | |
max_col = n - 1 | |
return min_row, max_row, min_col, max_col | |
def main(args): | |
TIME_STAMP = time.strftime('%Y-%m-%d-%H-%M-%S') | |
filepath_log = os.path.join(args.projectName, '{}.log'.format(TIME_STAMP)) | |
sys.stdout = Logger(filepath_log) | |
print(TIME_STAMP) | |
print('Selected GPUs', args.gpus) | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus | |
total_start_time = time.time() | |
print('\n>> CLI Parameters ...\n') | |
if args.inputPCAModel: | |
print('\n>> Load PCA fitted model ... \n') | |
pca = joblib.load(args.inputPCAModel) | |
else: | |
pca = None | |
inputSlidePath = '/'+args.projectName+'/svs' | |
inputCentroidPath = '/'+args.projectName+'/centroid' | |
outputDataSet = '/'+args.projectName+'/{}_HistomicsML_dataset.h5'.format(TIME_STAMP) | |
outputPCAsample = '/'+args.projectName+'/{}_pca_model_sample.pkl'.format(TIME_STAMP) | |
# | |
# Check whether slide directory exists | |
# | |
if os.path.isdir(inputSlidePath): | |
img_paths = [ | |
os.path.join(inputSlidePath, files) | |
for files in os.listdir(inputSlidePath) | |
if os.path.isfile( | |
os.path.join(inputSlidePath, files))] | |
else: | |
raise IOError('Slide path is not directory.') | |
# | |
# Check whether centroid directory exists | |
# | |
if os.path.isdir(inputCentroidPath): | |
centroid_paths = [ | |
os.path.join(inputCentroidPath, files) | |
for files in os.listdir(inputCentroidPath) | |
if os.path.isfile( | |
os.path.join(inputCentroidPath, files))] | |
else: | |
raise IOError('Centroid path is not directory.') | |
print('\n>> Reading VGG pre-trained model ... \n') | |
strategy = tf.contrib.distribute.MirroredStrategy() | |
#print('Number of devices: {}'.format(strategy.extended.num_replicas_in_sync)) | |
with strategy.scope(): | |
model = applications.VGG16(include_top=True, weights='imagenet') | |
model = Model(inputs=model.input, outputs=model.get_layer('fc1').output) | |
print ("Generate train dataset ... ") | |
slide_superpixel_data = [] | |
slide_x_centroids = [] | |
slide_y_centroids = [] | |
slide_name_list = [] | |
slide_superpixel_index = [] | |
total_n_slides = len(img_paths) | |
first_superpixel_index = np.zeros((total_n_slides, 1), dtype=np.int32) | |
slide_wsi_mean = np.zeros((total_n_slides, 3), dtype=np.float32) | |
slide_wsi_stddev = np.zeros((total_n_slides, 3), dtype=np.float32) | |
index = 0 | |
for i in tqdm(range(len(img_paths))): | |
slide_name = img_paths[i].split('/')[-1].split('.')[0] | |
centroid_path = os.path.join(inputCentroidPath,slide_name+'.h5') | |
slide_name_list.append(slide_name) | |
# | |
# Read Input Image | |
print('{} is processing ... \n'.format(slide_name)) | |
ts = large_image.getTileSource(img_paths[i]) | |
ts_metadata = ts.getMetadata() | |
print(json.dumps(ts_metadata, indent=2)) | |
scale = ts_metadata['magnification'] / args.max_mag | |
superpixel_mag = args.max_mag * scale | |
superpixel_tile_size = args.max_tile_size * scale | |
is_wsi = ts_metadata['magnification'] is not None | |
if is_wsi: | |
# | |
# Compute tissue/foreground mask at low-res for whole slide images | |
# | |
print('\n>> Computing tissue/foreground mask at low-res ...\n') | |
start_time = time.time() | |
im_fgnd_mask_lres, fgnd_seg_scale = \ | |
cli_utils.segment_wsi_foreground_at_low_res(ts) | |
fgnd_time = time.time() - start_time | |
print('low-res foreground mask computation time = {}'.format( | |
cli_utils.disp_time_hms(fgnd_time))) | |
it_kwargs = { | |
'tile_size': {'width': superpixel_tile_size}, | |
'scale': {'magnification': superpixel_mag}, | |
} | |
start_time = time.time() | |
num_tiles = \ | |
ts.getSingleTile(**it_kwargs)['iterator_range']['position'] | |
print('Number of tiles = {}'.format(num_tiles)) | |
tile_fgnd_frac_list = htk_utils.compute_tile_foreground_fraction( | |
img_paths[i], im_fgnd_mask_lres, fgnd_seg_scale, | |
it_kwargs | |
) | |
num_fgnd_tiles = np.count_nonzero( | |
tile_fgnd_frac_list >= args.min_fgnd_frac | |
) | |
percent_fgnd_tiles = 100.0 * num_fgnd_tiles / num_tiles | |
fgnd_frac_comp_time = time.time() - start_time | |
print('Number of foreground tiles = {0:d} ({1:2f}%%)'.format( | |
num_fgnd_tiles, percent_fgnd_tiles)) | |
print('Tile foreground fraction computation time = {}'.format( | |
cli_utils.disp_time_hms(fgnd_frac_comp_time))) | |
print('\n>> Computing reinhard color normalization stats ...\n') | |
start_time = time.time() | |
src_mu_lab, src_sigma_lab = htk_cnorm.reinhard_stats( | |
img_paths[i], 0.01, magnification=superpixel_mag) | |
#src_mu_lab = np.array([ 7.64726018, -0.28728364, 0.06082756]) | |
#src_sigma_lab = np.array([0.7931921 , 0.1633065 , 0.03005581]) | |
rstats_time = time.time() - start_time | |
print('Reinhard stats computation time = {}'.format( | |
cli_utils.disp_time_hms(rstats_time))) | |
print('\n>> Detecting superpixel data ...\n') | |
superpixel_data = [] | |
superpixel_x_centroids = [] | |
superpixel_y_centroids = [] | |
for tile in ts.tileIterator(**it_kwargs): | |
tile_position = tile['tile_position']['position'] | |
print('tile_position = {}'.format(tile_position)) | |
if tile_fgnd_frac_list[tile_position] <= args.min_fgnd_frac: | |
continue | |
print('Extracting features') | |
# detect superpixel data | |
tile_features, tile_x_centroids, tile_y_centroids = \ | |
compute_superpixel_data_pca(model, pca, img_paths[i], | |
tile_position, centroid_path, | |
args, it_kwargs, | |
src_mu_lab, src_sigma_lab) | |
# accumulate single slide data | |
superpixel_data += list(tile_features) | |
superpixel_x_centroids+=list(tile_x_centroids) | |
superpixel_y_centroids+=list(tile_y_centroids) | |
n_superpixels=len(superpixel_x_centroids) # per slide sp number | |
n_superpixels_acc = len(slide_superpixel_data) # accumulated sp number | |
first_superpixel_index[index, 0] = n_superpixels_acc # starting of slide | |
# accumulating all slide data | |
slide_superpixel_data += superpixel_data | |
slide_x_centroids += superpixel_x_centroids | |
slide_y_centroids += superpixel_y_centroids | |
slide_wsi_mean[index] = src_mu_lab | |
slide_wsi_stddev[index] = src_sigma_lab | |
slide_superpixel_index +=[index]*n_superpixels | |
index += 1 | |
print('no. of features in slide', len(slide_superpixel_data)) | |
# reshaping and changing data format | |
slide_superpixel_data = np.asarray(slide_superpixel_data) | |
slide_x_centroids = np.asarray(slide_x_centroids, | |
dtype=np.float32).reshape((len(slide_x_centroids), 1)) | |
slide_y_centroids = np.asarray(slide_y_centroids, | |
dtype=np.float32).reshape((len(slide_y_centroids), 1)) | |
slide_superpixel_index = np.asarray(slide_superpixel_index, | |
dtype=np.int32).reshape((len(slide_superpixel_index), 1)) | |
if args.inputPCAModel: | |
superpixel_feature_map = np.asarray(slide_superpixel_data, dtype=np.float32) | |
else: | |
print ("Fitting PCA ... ") | |
df = pd.DataFrame(data=slide_superpixel_data, columns=[_ for _ in range(args.fcn)]) | |
df_sample = df.reindex(np.random.permutation(df.index)).sample(frac=args.pca_sample_scale) | |
pca = PCA(n_components=args.pca_dim) | |
pca.fit(df_sample.values) | |
joblib.dump(pca, outputPCAsample) | |
superpixel_feature_map = np.asarray(pca.transform(slide_superpixel_data), dtype=np.float32) | |
slide_x_centroids = np.asarray(slide_x_centroids, dtype=np.float32) | |
slide_y_centroids = np.asarray(slide_y_centroids, dtype=np.float32) | |
# get mean and standard deviation for train | |
slide_feature_mean = np.reshape( | |
np.mean(superpixel_feature_map[:], axis=0), (superpixel_feature_map.shape[1], 1) | |
).astype(np.float32) | |
slide_feature_stddev = np.reshape( | |
np.std(superpixel_feature_map[:], axis=0), (superpixel_feature_map.shape[1], 1) | |
).astype(np.float32) | |
print('total feature shape', len(superpixel_feature_map)) | |
total_time_taken = time.time() - total_start_time | |
print('Total analysis time = {}'.format( | |
cli_utils.disp_time_hms(total_time_taken))) | |
print('>> Writing raw H5 data file') | |
output = h5py.File(outputDataSet, 'w') | |
output.create_dataset('slides', data=slide_name_list) | |
output.create_dataset('slideIdx', data=slide_superpixel_index) | |
output.create_dataset('dataIdx', data=first_superpixel_index) | |
output.create_dataset('mean', data=slide_feature_mean) | |
output.create_dataset('std_dev', data=slide_feature_stddev) | |
output.create_dataset('features', data=superpixel_feature_map) | |
output.create_dataset('x_centroid', data=slide_x_centroids) | |
output.create_dataset('y_centroid', data=slide_y_centroids) | |
output.create_dataset('wsi_mean', data=slide_wsi_mean) | |
output.create_dataset('wsi_stddev', data=slide_wsi_stddev) | |
output.create_dataset('patch_size', data=args.superpixelSize) | |
output.close() | |
if __name__ == "__main__": | |
main(CLIArgumentParser().parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment