Skip to content

Instantly share code, notes, and snippets.

@muminoff
Created April 22, 2016 02:46
Show Gist options
  • Save muminoff/a19871897498ef46504360afcaab2f68 to your computer and use it in GitHub Desktop.
Save muminoff/a19871897498ef46504360afcaab2f68 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import skimage.transform
from skimage.io import imsave, imread
import os
from os import listdir, path
from os.path import isfile, join
def get_directory(folder):
foundfile = []
for path, subdirs, files in os.walk(folder):
for name in files:
found = os.path.join(path, name)
if name.endswith('.png'):
foundfile.append(found)
break
foundfile.sort()
return foundfile
def load_image(path):
img = imread(path)
# crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
img = skimage.transform.resize(crop_img, (224, 224))
# desaturate image
return (img[:,:,0] + img[:,:,1] + img[:,:,2]) / 3.0
with open("colorize.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
grayscale = tf.placeholder("float", [1, 224, 224, 1])
tf.import_graph_def(graph_def, input_map={ "grayscale": grayscale }, name='')
images = get_directory("input")
for image in images:
print(image)
scene = load_image(image).reshape(1, 224, 224, 1)
with tf.Session() as sess:
inferred_rgb = sess.graph.get_tensor_by_name("inferred_rgb:0")
inferred_batch = sess.run(inferred_rgb, feed_dict={ grayscale: scene })
filename = "output/"+image
imsave(filename, inferred_batch[0])
print("saved " + filename)
sess.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment