Skip to content

Instantly share code, notes, and snippets.

@Reasat
Last active March 17, 2021 09:06
Show Gist options
  • Save Reasat/76b53d6be24bceff4525e7ab92ca9ffd to your computer and use it in GitHub Desktop.
Save Reasat/76b53d6be24bceff4525e7ab92ca9ffd to your computer and use it in GitHub Desktop.
Implementing a data generator to create batch of images for feature extraction in HIstomicsML2
import os
import json
import h5py
import time
from tqdm import tqdm
import pandas as pd
import numpy as np
import keras
import histomicstk.preprocessing.color_normalization as htk_cnorm
import histomicstk.utils as htk_utils
import large_image
import tensorflow as tf
from keras.models import Model
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
from keras import applications
from ctk_cli import CLIArgumentParser
from scipy.misc import imresize
from sklearn.decomposition import PCA
from sklearn.externals import joblib
from histomicstk.cli import utils as cli_utils
import logging
TIME_STAMP=time.strftime('%Y-%m-%d-%H-%M-%S')
logging.basicConfig(level=logging.CRITICAL)
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
class Logger(object):
def __init__(self,path):
self.terminal = sys.stdout
self.log = open(path, "a+")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
#this flush method is needed for python 3 compatibility.
#this handles the flush command by doing nothing.
#you might want to specify some extra behavior here.
pass
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, tile_info, im_nmzd, x_centroids, y_centroids, patchSize, patchSizeResized = 224, batch_size=32, shuffle=False):
#'Initialization'
self.x_centroids = x_centroids
self.y_centroids = y_centroids
self.patchSize = patchSize
self.patchSizeResized = patchSizeResized
self.left = tile_info['gx']
self.top = tile_info['gy']
self.im_width = tile_info['width']
self.im_height = tile_info['height']
self.im_nmzd = im_nmzd
self.batch_size = batch_size
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
# 'Denotes the number of batches per epoch'
return int(np.ceil(len(self.x_centroids) / float(self.batch_size)))
def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
x_batch, y_batch = zip(*[(self.x_centroids[k], self.y_centroids[k]) for k in indexes])
# Generate data
img_batch = self.__data_generation(x_batch, y_batch)
img_batch = preprocess_input(img_batch)
return img_batch
def on_epoch_end(self):
#'Updates indexes after each epoch'
self.indexes = np.arange(len(self.x_centroids ))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, x_batch, y_batch):
#'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((len(x_batch), self.patchSizeResized, self.patchSizeResized, 3))
for i, (x_centroid, y_centroid) in enumerate(zip(x_batch, y_batch)):
cen_x = (x_centroid - self.left)
cen_y = (y_centroid - self.top)
# get bounds of superpixel region
min_row, max_row, min_col, max_col = \
get_patch_bounds(cen_y, cen_x,
self.patchSize, self.im_height, self.im_width)
# resize superpixel patch
im_patch = \
imresize(self.im_nmzd[min_row:max_row, min_col:max_col, :],
(self.patchSizeResized,
self.patchSizeResized, 3))
X[i,] = im_patch
return X
def compute_superpixel_data_pca(model, pca, slide_path, tile_position,
centroid_path, args, superpixel_kwargs,
src_mu_lab=None, src_sigma_lab=None):
# get slide tile source
ts = large_image.getTileSource(slide_path)
tile_info = \
ts.getSingleTile(tile_position=tile_position,
format=large_image.tilesource.TILE_FORMAT_NUMPY,
**superpixel_kwargs)
left = tile_info['gx']
top = tile_info['gy']
# get width and height
width = tile_info['width']
height = tile_info['height']
f = h5py.File(centroid_path,'r')
x_centroids = f['x_centroid'][:]
y_centroids = f['y_centroid'][:]
n_superpixels = len(x_centroids)
print('total No. of superpixels: {} '.format(n_superpixels))
x_centroids_valid, y_centroids_valid = [],[]
for x,y in zip(x_centroids, y_centroids):
if left < x <= (left + width) \
and top < y <= (top + height):
x_centroids_valid.append(x)
y_centroids_valid.append(y)
print('No. of valid superpixels in tile position {}: {}'.format(tile_position,len(x_centroids_valid)))
if len(x_centroids_valid) == 0:
print('no valid superpixels, returning empty lists')
return [], [], []
# get requested tile
im_tile = tile_info['tile'][:, :, :3]
# perform color normalization
start = time.time()
print('Applying Reinhard color normalization')
im_nmzd = htk_cnorm.reinhard(im_tile, args.reference_mu_lab,
args.reference_std_lab, src_mu=src_mu_lab,
src_sigma=src_sigma_lab)
end = time.time()
print('time taken {}'.format(cli_utils.disp_time_hms(end-start)))
BATCH_SIZE = args.batch_size
datagenerator = DataGenerator(tile_info, im_nmzd, x_centroids_valid, y_centroids_valid,
patchSize= args.patchSize, batch_size=BATCH_SIZE,
shuffle=False)
# get superpixel features
fcn = model.predict_generator(
datagenerator,
steps = np.ceil(len(x_centroids_valid)/float(BATCH_SIZE)),
verbose = True,
workers = args.num_workers_loader,
use_multiprocessing = True,
)
# reduce the fcn features
if pca is not None:
features = pca.transform(fcn)
else:
features = fcn
return features, x_centroids_valid, y_centroids_valid
def get_patch_bounds(cx, cy, patch, m, n):
half_patch = patch / 2.0
min_row = int(round(cx) - half_patch)
max_row = int(round(cx) + half_patch)
min_col = int(round(cy) - half_patch)
max_col = int(round(cy) + half_patch)
if min_row < 0:
max_row = max_row - min_row
min_row = 0
if max_row > m - 1:
min_row = min_row - (max_row - (m - 1))
max_row = m - 1
if min_col < 0:
max_col = max_col - min_col
min_col = 0
if max_col > n - 1:
min_col = min_col - (max_col - (n - 1))
max_col = n - 1
return min_row, max_row, min_col, max_col
def main(args):
TIME_STAMP = time.strftime('%Y-%m-%d-%H-%M-%S')
filepath_log = os.path.join(args.projectName, '{}.log'.format(TIME_STAMP))
sys.stdout = Logger(filepath_log)
print(TIME_STAMP)
print('Selected GPUs', args.gpus)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
total_start_time = time.time()
print('\n>> CLI Parameters ...\n')
if args.inputPCAModel:
print('\n>> Load PCA fitted model ... \n')
pca = joblib.load(args.inputPCAModel)
else:
pca = None
inputSlidePath = '/'+args.projectName+'/svs'
inputCentroidPath = '/'+args.projectName+'/centroid'
outputDataSet = '/'+args.projectName+'/{}_HistomicsML_dataset.h5'.format(TIME_STAMP)
outputPCAsample = '/'+args.projectName+'/{}_pca_model_sample.pkl'.format(TIME_STAMP)
#
# Check whether slide directory exists
#
if os.path.isdir(inputSlidePath):
img_paths = [
os.path.join(inputSlidePath, files)
for files in os.listdir(inputSlidePath)
if os.path.isfile(
os.path.join(inputSlidePath, files))]
else:
raise IOError('Slide path is not directory.')
#
# Check whether centroid directory exists
#
if os.path.isdir(inputCentroidPath):
centroid_paths = [
os.path.join(inputCentroidPath, files)
for files in os.listdir(inputCentroidPath)
if os.path.isfile(
os.path.join(inputCentroidPath, files))]
else:
raise IOError('Centroid path is not directory.')
print('\n>> Reading VGG pre-trained model ... \n')
strategy = tf.contrib.distribute.MirroredStrategy()
#print('Number of devices: {}'.format(strategy.extended.num_replicas_in_sync))
with strategy.scope():
model = applications.VGG16(include_top=True, weights='imagenet')
model = Model(inputs=model.input, outputs=model.get_layer('fc1').output)
print ("Generate train dataset ... ")
slide_superpixel_data = []
slide_x_centroids = []
slide_y_centroids = []
slide_name_list = []
slide_superpixel_index = []
total_n_slides = len(img_paths)
first_superpixel_index = np.zeros((total_n_slides, 1), dtype=np.int32)
slide_wsi_mean = np.zeros((total_n_slides, 3), dtype=np.float32)
slide_wsi_stddev = np.zeros((total_n_slides, 3), dtype=np.float32)
index = 0
for i in tqdm(range(len(img_paths))):
slide_name = img_paths[i].split('/')[-1].split('.')[0]
centroid_path = os.path.join(inputCentroidPath,slide_name+'.h5')
slide_name_list.append(slide_name)
#
# Read Input Image
print('{} is processing ... \n'.format(slide_name))
ts = large_image.getTileSource(img_paths[i])
ts_metadata = ts.getMetadata()
print(json.dumps(ts_metadata, indent=2))
scale = ts_metadata['magnification'] / args.max_mag
superpixel_mag = args.max_mag * scale
superpixel_tile_size = args.max_tile_size * scale
is_wsi = ts_metadata['magnification'] is not None
if is_wsi:
#
# Compute tissue/foreground mask at low-res for whole slide images
#
print('\n>> Computing tissue/foreground mask at low-res ...\n')
start_time = time.time()
im_fgnd_mask_lres, fgnd_seg_scale = \
cli_utils.segment_wsi_foreground_at_low_res(ts)
fgnd_time = time.time() - start_time
print('low-res foreground mask computation time = {}'.format(
cli_utils.disp_time_hms(fgnd_time)))
it_kwargs = {
'tile_size': {'width': superpixel_tile_size},
'scale': {'magnification': superpixel_mag},
}
start_time = time.time()
num_tiles = \
ts.getSingleTile(**it_kwargs)['iterator_range']['position']
print('Number of tiles = {}'.format(num_tiles))
tile_fgnd_frac_list = htk_utils.compute_tile_foreground_fraction(
img_paths[i], im_fgnd_mask_lres, fgnd_seg_scale,
it_kwargs
)
num_fgnd_tiles = np.count_nonzero(
tile_fgnd_frac_list >= args.min_fgnd_frac
)
percent_fgnd_tiles = 100.0 * num_fgnd_tiles / num_tiles
fgnd_frac_comp_time = time.time() - start_time
print('Number of foreground tiles = {0:d} ({1:2f}%%)'.format(
num_fgnd_tiles, percent_fgnd_tiles))
print('Tile foreground fraction computation time = {}'.format(
cli_utils.disp_time_hms(fgnd_frac_comp_time)))
print('\n>> Computing reinhard color normalization stats ...\n')
start_time = time.time()
src_mu_lab, src_sigma_lab = htk_cnorm.reinhard_stats(
img_paths[i], 0.01, magnification=superpixel_mag)
#src_mu_lab = np.array([ 7.64726018, -0.28728364, 0.06082756])
#src_sigma_lab = np.array([0.7931921 , 0.1633065 , 0.03005581])
rstats_time = time.time() - start_time
print('Reinhard stats computation time = {}'.format(
cli_utils.disp_time_hms(rstats_time)))
print('\n>> Detecting superpixel data ...\n')
superpixel_data = []
superpixel_x_centroids = []
superpixel_y_centroids = []
for tile in ts.tileIterator(**it_kwargs):
tile_position = tile['tile_position']['position']
print('tile_position = {}'.format(tile_position))
if tile_fgnd_frac_list[tile_position] <= args.min_fgnd_frac:
continue
print('Extracting features')
# detect superpixel data
tile_features, tile_x_centroids, tile_y_centroids = \
compute_superpixel_data_pca(model, pca, img_paths[i],
tile_position, centroid_path,
args, it_kwargs,
src_mu_lab, src_sigma_lab)
# accumulate single slide data
superpixel_data += list(tile_features)
superpixel_x_centroids+=list(tile_x_centroids)
superpixel_y_centroids+=list(tile_y_centroids)
n_superpixels=len(superpixel_x_centroids) # per slide sp number
n_superpixels_acc = len(slide_superpixel_data) # accumulated sp number
first_superpixel_index[index, 0] = n_superpixels_acc # starting of slide
# accumulating all slide data
slide_superpixel_data += superpixel_data
slide_x_centroids += superpixel_x_centroids
slide_y_centroids += superpixel_y_centroids
slide_wsi_mean[index] = src_mu_lab
slide_wsi_stddev[index] = src_sigma_lab
slide_superpixel_index +=[index]*n_superpixels
index += 1
print('no. of features in slide', len(slide_superpixel_data))
# reshaping and changing data format
slide_superpixel_data = np.asarray(slide_superpixel_data)
slide_x_centroids = np.asarray(slide_x_centroids,
dtype=np.float32).reshape((len(slide_x_centroids), 1))
slide_y_centroids = np.asarray(slide_y_centroids,
dtype=np.float32).reshape((len(slide_y_centroids), 1))
slide_superpixel_index = np.asarray(slide_superpixel_index,
dtype=np.int32).reshape((len(slide_superpixel_index), 1))
if args.inputPCAModel:
superpixel_feature_map = np.asarray(slide_superpixel_data, dtype=np.float32)
else:
print ("Fitting PCA ... ")
df = pd.DataFrame(data=slide_superpixel_data, columns=[_ for _ in range(args.fcn)])
df_sample = df.reindex(np.random.permutation(df.index)).sample(frac=args.pca_sample_scale)
pca = PCA(n_components=args.pca_dim)
pca.fit(df_sample.values)
joblib.dump(pca, outputPCAsample)
superpixel_feature_map = np.asarray(pca.transform(slide_superpixel_data), dtype=np.float32)
slide_x_centroids = np.asarray(slide_x_centroids, dtype=np.float32)
slide_y_centroids = np.asarray(slide_y_centroids, dtype=np.float32)
# get mean and standard deviation for train
slide_feature_mean = np.reshape(
np.mean(superpixel_feature_map[:], axis=0), (superpixel_feature_map.shape[1], 1)
).astype(np.float32)
slide_feature_stddev = np.reshape(
np.std(superpixel_feature_map[:], axis=0), (superpixel_feature_map.shape[1], 1)
).astype(np.float32)
print('total feature shape', len(superpixel_feature_map))
total_time_taken = time.time() - total_start_time
print('Total analysis time = {}'.format(
cli_utils.disp_time_hms(total_time_taken)))
print('>> Writing raw H5 data file')
output = h5py.File(outputDataSet, 'w')
output.create_dataset('slides', data=slide_name_list)
output.create_dataset('slideIdx', data=slide_superpixel_index)
output.create_dataset('dataIdx', data=first_superpixel_index)
output.create_dataset('mean', data=slide_feature_mean)
output.create_dataset('std_dev', data=slide_feature_stddev)
output.create_dataset('features', data=superpixel_feature_map)
output.create_dataset('x_centroid', data=slide_x_centroids)
output.create_dataset('y_centroid', data=slide_y_centroids)
output.create_dataset('wsi_mean', data=slide_wsi_mean)
output.create_dataset('wsi_stddev', data=slide_wsi_stddev)
output.create_dataset('patch_size', data=args.superpixelSize)
output.close()
if __name__ == "__main__":
main(CLIArgumentParser().parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment