-
-
Save solaris33/c5fcce56903a6c6fcf3080285c864210 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from absl import flags | |
from absl import app | |
import os | |
import numpy as np | |
import skimage.io as io | |
from skimage import img_as_ubyte | |
from keras.preprocessing.image import ImageDataGenerator | |
from model import UNET_ISBI_2012 | |
flags.DEFINE_string('checkpoint_path', default='saved_model_isbi_2012/unet_model.ckpt', help='path to a directory to restore checkpoint file') | |
flags.DEFINE_string('test_dir', default='isbi_2012_test_result', help='directory which test prediction result saved') | |
flags.DEFINE_integer('num_classes', default=1, help='number of prediction classes') | |
FLAGS = flags.FLAGS | |
# set configuration value | |
batch_size = 1 | |
total_test_image_num = 30 | |
# normalize ISBI-2012 data | |
def normalize_isbi_2012(input_images): | |
# 0~255 -> 0.0~1.0 | |
input_images = input_images / 255 | |
return input_images | |
# make prediction mask | |
def create_mask(pred_mask): | |
pred_mask = np.where(pred_mask > 0.5, 255, 0) | |
return_mask = pred_mask[0] | |
if return_mask.shape[2] == 1: | |
return_mask = return_mask.reshape(return_mask.shape[:2]) | |
return return_mask | |
# make test data generator | |
def make_test_generator(batch_size): | |
image_gen = ImageDataGenerator() | |
# set image and mask same augmentation using same seed | |
image_generator = image_gen.flow_from_directory( | |
directory='./isbi_2012/preprocessed', | |
classes = ['test_imgs'], | |
class_mode = None, | |
target_size = (512, 512), | |
batch_size = batch_size, | |
color_mode='grayscale', | |
seed=1 | |
) | |
for batch_images in image_generator: | |
batch_images = normalize_isbi_2012(batch_images) | |
yield batch_images | |
def main(_): | |
# check if checkpoint path exists | |
#if not os.path.exists(FLAGS.checkpoint_path): | |
# print('checkpoint file is not exists!') | |
# exit() | |
# create UNET model | |
unet_model = UNET_ISBI_2012(FLAGS.num_classes) | |
# restore latest checkpoint | |
unet_model.load_weights(FLAGS.checkpoint_path) | |
print(f'{FLAGS.checkpoint_path} checkpoint is restored!') | |
# make generator | |
test_generator = make_test_generator(batch_size) | |
# check total image num | |
print('total test image :', total_test_image_num) | |
# save test prediction result to png file | |
if not os.path.exists(os.path.join(os.getcwd(), FLAGS.test_dir)): | |
os.mkdir(os.path.join(os.getcwd(), FLAGS.test_dir)) | |
for image_num, test_image in enumerate(test_generator): | |
if image_num >= total_test_image_num: | |
break | |
pred_mask = unet_model.predict(test_image) | |
output_image_path = os.path.join(os.getcwd(), FLAGS.test_dir, f'{image_num}_result.png') | |
io.imsave(output_image_path, img_as_ubyte(create_mask(pred_mask))) | |
print(output_image_path + ' saved!') | |
if __name__ == '__main__': | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment