Created
May 30, 2016 11:20
-
-
Save YusukeSuzuki/4919c3577347918fa5735dad91d4d168 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# This is full TensolFlow version of Ryan Dahl's Automatic Colorization. | |
# The original is here -> http://tinyclouds.org/colorize/ | |
import sys | |
import tensorflow as tf | |
INPUT_EDGE = 224 | |
INPUT_FILE = 'shark.jpg' | |
OUTPUT_FILE = 'shark-color.jpg' | |
# load image with TensorFlow | |
# jpeg file only | |
# png needed, use tf.image.decode.png() | |
def load_image_jpg_tf(path): | |
filename_queue = tf.train.string_input_producer([path]) | |
reader = tf.WholeFileReader() | |
key, value = reader.read(filename_queue) | |
read_img = tf.image.decode_jpeg(value, channels=1) # output gray 0 - 255 image | |
float_img = tf.to_float(read_img) | |
float_img = float_img / 255 # model need 0. - 1. image | |
float_shape = tf.shape(float_img) | |
init_op = tf.initialize_all_variables() | |
sess = tf.Session() | |
with sess as default: | |
sess.run(init_op) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord) | |
img = sess.run(float_img) | |
img_shape = sess.run(float_shape) | |
coord.request_stop() | |
coord.join(threads) | |
short_edge = min(img_shape[0], img_shape[1]) | |
img = tf.image.resize_image_with_crop_or_pad(img, short_edge, short_edge) | |
img = tf.image.resize_images(img, INPUT_EDGE, INPUT_EDGE) | |
img = tf.reshape(img, (1, INPUT_EDGE, INPUT_EDGE, 1)) | |
return img | |
shark_gray = load_image_jpg_tf(INPUT_FILE) | |
with open("colorize.tfmodel", mode='rb') as f: | |
fileContent = f.read() | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(fileContent) | |
tf.import_graph_def(graph_def, input_map={ "grayscale": shark_gray }, name='') | |
with tf.Session() as sess: | |
inferred_rgb = sess.graph.get_tensor_by_name("inferred_rgb:0") | |
inferred_batch = sess.run(inferred_rgb) | |
inferred_rgb = inferred_rgb * 255 | |
out_uint8_img = tf.cast(inferred_rgb, tf.uint8) | |
out_shape = tf.shape(out_uint8_img) | |
jpeg_data_tensor = tf.reshape(out_uint8_img, (INPUT_EDGE,INPUT_EDGE,3)) | |
jpeg_data = tf.image.encode_jpeg(jpeg_data_tensor) | |
jpeg = sess.run(jpeg_data) | |
with open(OUTPUT_FILE, mode='wb') as f: | |
f.write(jpeg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment