Skip to content

Instantly share code, notes, and snippets.

@John1231983
Last active February 18, 2017 16:12
Show Gist options
  • Save John1231983/63c35d1cb4e770350fc76c5d38722e9c to your computer and use it in GitHub Desktop.
Save John1231983/63c35d1cb4e770350fc76c5d38722e9c to your computer and use it in GitHub Desktop.
from __future__ import print_function
import pandas as pd
import h5py
import nibabel as nib
import numpy as np
def load_nifti(filename, with_affine=False):
img = nib.load(filename)
data = img.get_data()
data = np.squeeze(data)
data = np.copy(data, order="C")
if with_affine:
return data, img.affine
return data
def extract_patch(scalar_img, label_img, mask_img):
slices = [slice(len_ / 2, -len_ / 2) for len_ in shape]
mask_img[slices] *= 2
indices = np.where(mask_img > 1.5)
mask_img[slices] /= 2
j = np.random.choice(len(indices[0]))
slices = [slice(index[j] - len_ / 2, index[j] + len_ / 2) for index, len_ in zip(indices, shape)]
scalar_patch = scalar_img[slices]
label_patch = label_img[slices]
return scalar_patch, label_patch
#Extract patch and save in hdf5
num_sample=1000;
shape=[80, 80, 80]
n_batch=1
n_classes=4
train_df = pd.read_csv('dataset_train.csv')
'''
dataset_train.csv detail
------------------------
index,segTRI,mask,preprocessed
01,/IBSR_02/IBSR_02_segTRI_ana.nii.gz,/IBSR_02/IBSR_02_ana_brainmask.nii.gz,/IBSR_02/IBSR_02_ana_strip.nii.gz
02,/IBSR_03/IBSR_03_segTRI_ana.nii.gz,/IBSR_03/IBSR_03_ana_brainmask.nii.gz,/IBSR_03/IBSR_03_ana_strip.nii.gz
-------------------------
'''
raw_patches = np.zeros([num_sample, 1, shape[0],shape[1],shape[2]], dtype=np.float32)
label_patches = np.zeros([num_sample, shape[0],shape[1],shape[2]], dtype=np.int32)
for index_file in range(len(train_df)):
scalar_file = train_df["preprocessed"][index_file]
label_file = train_df["segTRI"][index_file]
mask_file = train_df["mask"][index_file]
scalar_img = load_nifti(scalar_file).astype(np.float32)
#Pre-processing: Normalize
scalar_img = (scalar_img - np.mean(scalar_img)) / np.std(scalar_img)
#end pre-processing
label_img = load_nifti(label_file).astype(np.int32)
mask_img = load_nifti(mask_file)
for i in range(num_sample):
scalar_patch, label_patch = extract_patch(scalar_img, label_img, mask_img)
assert np.max(label_patch) < n_classes
assert label_patch.shape == tuple(shape), label_patch.shape
raw_patches[i, 0, :, :, :] = scalar_patch # 80*80*80
label_patches[i, :, :, :] = label_patch # 80*80*80
raw_patches = raw_patches[0:num_sample, :, :, :, :]
label_patches = label_patches[0:num_sample, :, :, :]
with h5py.File('trainMS_%s.h5' % index_file, 'w') as f:
f['data'] = raw_patches
f['label'] = label_patches
with open('./trainMS_list.txt', 'a') as f:
f.write('trainMS_%s.h5\n' % index_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment