Skip to content

Instantly share code, notes, and snippets.

@tai2
Created May 29, 2017 02:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tai2/7a2e6c356b76125ff1dec64b19bdee63 to your computer and use it in GitHub Desktop.
Save tai2/7a2e6c356b76125ff1dec64b19bdee63 to your computer and use it in GitHub Desktop.
import os
import sys
import tensorflow as tf
import skimage.io as io
import skimage.morphology
import numpy as np
#sys.path.append('tf-image-segmentation')
sys.path.append('tf-image-segmentation-tf1.0')
sys.path.append('models/slim')
#sys.path.append('models-custom/slim')
fcn_16s_checkpoint_path = 'fcn_8s_checkpoint/model_fcn8s_final.ckpt'
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
slim = tf.contrib.slim
def save(image_np, pred_np):
prediction_mask = (pred_np.squeeze() == 15)
# Let's apply some morphological operations to
# create the contour for our sticker
cropped_object = image_np * np.dstack((prediction_mask,) * 3)
square = skimage.morphology.square(5)
temp = skimage.morphology.binary_erosion(prediction_mask, square)
negative_mask = (temp != True)
eroding_countour = negative_mask * prediction_mask
eroding_countour_img = np.dstack((eroding_countour, ) * 3)
cropped_object[eroding_countour_img] = 248
png_transparancy_mask = np.uint8(prediction_mask * 255)
image_shape = cropped_object.shape
png_array = np.zeros(shape=[image_shape[0], image_shape[1], 4], dtype=np.uint8)
png_array[:, :, :3] = cropped_object
png_array[:, :, 3] = png_transparancy_mask
io.imsave('sticker_me.png', png_array)
def main():
from tf_image_segmentation.models.fcn_8s import FCN_8s
from tf_image_segmentation.utils.inference import adapt_network_for_any_size_input
from tf_image_segmentation.utils.pascal_voc import pascal_segmentation_lut
number_of_classes = 21
image_filename = 'me.jpg'
image_filename_placeholder = tf.placeholder(tf.string)
feed_dict_to_use = {image_filename_placeholder: image_filename}
image_tensor = tf.read_file(image_filename_placeholder)
image_tensor = tf.image.decode_jpeg(image_tensor, channels=3)
image_batch_tensor = tf.expand_dims(image_tensor, axis=0)
FCN_8s = adapt_network_for_any_size_input(FCN_8s, 32)
pred, fcn_16s_variables_mapping = FCN_8s(image_batch_tensor=image_batch_tensor,
number_of_classes=number_of_classes,
is_training=False)
initializer = tf.local_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(initializer)
saver.restore(sess, fcn_16s_checkpoint_path)
image_np, pred_np = sess.run([image_tensor, pred], feed_dict=feed_dict_to_use)
io.imsave('image_np.png', image_np);
io.imsave('pred_np.png', pred_np.squeeze())
save(image_np, pred_np)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment