Created
May 29, 2017 02:59
-
-
Save tai2/7a2e6c356b76125ff1dec64b19bdee63 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import tensorflow as tf | |
import skimage.io as io | |
import skimage.morphology | |
import numpy as np | |
#sys.path.append('tf-image-segmentation') | |
sys.path.append('tf-image-segmentation-tf1.0') | |
sys.path.append('models/slim') | |
#sys.path.append('models-custom/slim') | |
fcn_16s_checkpoint_path = 'fcn_8s_checkpoint/model_fcn8s_final.ckpt' | |
# os.environ["CUDA_VISIBLE_DEVICES"] = '1' | |
slim = tf.contrib.slim | |
def save(image_np, pred_np): | |
prediction_mask = (pred_np.squeeze() == 15) | |
# Let's apply some morphological operations to | |
# create the contour for our sticker | |
cropped_object = image_np * np.dstack((prediction_mask,) * 3) | |
square = skimage.morphology.square(5) | |
temp = skimage.morphology.binary_erosion(prediction_mask, square) | |
negative_mask = (temp != True) | |
eroding_countour = negative_mask * prediction_mask | |
eroding_countour_img = np.dstack((eroding_countour, ) * 3) | |
cropped_object[eroding_countour_img] = 248 | |
png_transparancy_mask = np.uint8(prediction_mask * 255) | |
image_shape = cropped_object.shape | |
png_array = np.zeros(shape=[image_shape[0], image_shape[1], 4], dtype=np.uint8) | |
png_array[:, :, :3] = cropped_object | |
png_array[:, :, 3] = png_transparancy_mask | |
io.imsave('sticker_me.png', png_array) | |
def main(): | |
from tf_image_segmentation.models.fcn_8s import FCN_8s | |
from tf_image_segmentation.utils.inference import adapt_network_for_any_size_input | |
from tf_image_segmentation.utils.pascal_voc import pascal_segmentation_lut | |
number_of_classes = 21 | |
image_filename = 'me.jpg' | |
image_filename_placeholder = tf.placeholder(tf.string) | |
feed_dict_to_use = {image_filename_placeholder: image_filename} | |
image_tensor = tf.read_file(image_filename_placeholder) | |
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3) | |
image_batch_tensor = tf.expand_dims(image_tensor, axis=0) | |
FCN_8s = adapt_network_for_any_size_input(FCN_8s, 32) | |
pred, fcn_16s_variables_mapping = FCN_8s(image_batch_tensor=image_batch_tensor, | |
number_of_classes=number_of_classes, | |
is_training=False) | |
initializer = tf.local_variables_initializer() | |
saver = tf.train.Saver() | |
with tf.Session() as sess: | |
sess.run(initializer) | |
saver.restore(sess, fcn_16s_checkpoint_path) | |
image_np, pred_np = sess.run([image_tensor, pred], feed_dict=feed_dict_to_use) | |
io.imsave('image_np.png', image_np); | |
io.imsave('pred_np.png', pred_np.squeeze()) | |
save(image_np, pred_np) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment