Skip to content

Instantly share code, notes, and snippets.

@andrewssobral
Created July 24, 2017 17:27
Show Gist options
  • Save andrewssobral/0448f87258125d10459cf5b6f4bdd996 to your computer and use it in GitHub Desktop.
Save andrewssobral/0448f87258125d10459cf5b6f4bdd996 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
import os
import sys
import cv2
from PIL import Image
from keras.preprocessing.image import *
from keras.models import load_model
import keras.backend as K
from keras.applications.imagenet_utils import preprocess_input
from models import *
def inference(model_name, weight_file, image_size, image_list, data_dir, label_dir, return_results=True, save_dir=None,
label_suffix='.png',
data_suffix='.jpg'):
current_dir = os.path.dirname(os.path.realpath(__file__))
# mean_value = np.array([104.00699, 116.66877, 122.67892])
batch_shape = (1, ) + image_size + (3, )
#print("\n\n\n\n\n")
#print('==>batch_shape: ', batch_shape)
#print("\n\n\n\n\n")
save_path = os.path.join(current_dir, 'Models/'+model_name)
model_path = os.path.join(save_path, "model.json")
checkpoint_path = os.path.join(save_path, weight_file)
# model_path = os.path.join(current_dir, 'model_weights/fcn_atrous/model_change.hdf5')
# model = FCN_Resnet50_32s((480,480,3))
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
K.set_session(session)
#model = globals()[model_name](batch_shape=batch_shape, input_shape=(512, 512, 3))
model = globals()[model_name](batch_shape=batch_shape, input_shape=(320, 320, 3),classes=2)
#print("\n\n\n\n\n")
#print('==>checkpoint_path: ', checkpoint_path)
#print("\n\n\n\n\n")
model.load_weights(checkpoint_path, by_name=True)
model.summary()
results = []
total = 0
for img_num in image_list:
img_num = img_num.strip('\n')
total += 1
print('#%d: %s' % (total,img_num))
image = Image.open('%s/%s%s' % (data_dir, img_num, data_suffix)).convert('RGB')
print('==>img.getpalette: ', image.getpalette())
print('==>img.palette: ', image.palette)
image = img_to_array(image) # , data_format='default')
#print("\n\n\n\n\n")
print('==>image.shape: ', image.shape) # ('==>image.shape: ', (500, 332, 3))
#print("\n\n\n\n\n")
label = Image.open('%s/%s%s' % (label_dir, img_num, label_suffix)).convert('L')
label_size = label.size
#print("\n\n\n\n\n")
print('==>label_size: ', label.size) # ('==>label_size: ', (332, 500))
#print("\n\n\n\n\n")
img_h, img_w = image.shape[0:2]
print('(%d,%d)' % (img_h,img_w)) # (500,332)
# long_side = max(img_h, img_w, image_size[0], image_size[1])
pad_w = max(image_size[1] - img_w, 0)
pad_h = max(image_size[0] - img_h, 0)
image = np.lib.pad(image, ((pad_h/2, pad_h - pad_h/2), (pad_w/2, pad_w - pad_w/2), (0, 0)), 'constant', constant_values=0.)
print('==>image.shape: ', image.shape) # (500, 332, 3)
# image -= mean_value
#img = array_to_img(image, 'channels_last', scale=False)
#img.show()
#exit()
image = cv2.resize(image, image_size)
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
print('==>image.shape: ', image.shape) # (1, 320, 320, 3)
result = model.predict(image, batch_size=1)
result = np.argmax(np.squeeze(result), axis=-1).astype(np.uint8)
result_img = Image.fromarray(result, mode='P')
result_img.palette = label.palette
print('==>label.palette: ', label.palette) # ('==>label.palette: ', None)
# result_img = result_img.resize(label_size, resample=Image.BILINEAR)
result_img = result_img.crop((pad_w/2, pad_h/2, pad_w/2+img_w, pad_h/2+img_h))
#result_img.show(title='result')
if return_results:
results.append(result_img)
if save_dir:
result_img.save(os.path.join(save_dir, img_num + '.png'))
return results
if __name__ == '__main__':
model_name = 'AtrousFCN_Resnet50_16s'
#model_name = 'Atrous_DenseNet'
#model_name = 'DenseNet_FCN'
weight_file = 'checkpoint_weights.hdf5'
#image_size = (512, 512)
image_size = (320, 320)
data_dir = os.path.expanduser('~/.keras/datasets/thales/JPEGImages')
label_dir = os.path.expanduser('~/.keras/datasets/thales/SegmentationClass')
#data_dir = os.path.expanduser('~/.keras/datasets/VOC2012/VOCdevkit/VOC2012/JPEGImages')
#label_dir = os.path.expanduser('~/.keras/datasets/VOC2012/VOCdevkit/VOC2012/SegmentationClass')
image_list = sys.argv[1:]#'2007_000491'
#results = inference(model_name, weight_file, image_size, image_list, data_dir, label_dir)
results = inference(model_name, weight_file, image_size, image_list, data_dir, label_dir, False, '/home/ubuntu/Keras-FCN/output',label_suffix='.jpg',data_suffix='.jpg')
#for result in results:
# result.show(title='result', command=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment