Skip to content

Instantly share code, notes, and snippets.

@YusukeSuzuki
Created May 30, 2016 11:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YusukeSuzuki/4919c3577347918fa5735dad91d4d168 to your computer and use it in GitHub Desktop.
Save YusukeSuzuki/4919c3577347918fa5735dad91d4d168 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# This is full TensolFlow version of Ryan Dahl's Automatic Colorization.
# The original is here -> http://tinyclouds.org/colorize/
import sys
import tensorflow as tf
INPUT_EDGE = 224
INPUT_FILE = 'shark.jpg'
OUTPUT_FILE = 'shark-color.jpg'
# load image with TensorFlow
# jpeg file only
# png needed, use tf.image.decode.png()
def load_image_jpg_tf(path):
filename_queue = tf.train.string_input_producer([path])
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
read_img = tf.image.decode_jpeg(value, channels=1) # output gray 0 - 255 image
float_img = tf.to_float(read_img)
float_img = float_img / 255 # model need 0. - 1. image
float_shape = tf.shape(float_img)
init_op = tf.initialize_all_variables()
sess = tf.Session()
with sess as default:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
img = sess.run(float_img)
img_shape = sess.run(float_shape)
coord.request_stop()
coord.join(threads)
short_edge = min(img_shape[0], img_shape[1])
img = tf.image.resize_image_with_crop_or_pad(img, short_edge, short_edge)
img = tf.image.resize_images(img, INPUT_EDGE, INPUT_EDGE)
img = tf.reshape(img, (1, INPUT_EDGE, INPUT_EDGE, 1))
return img
shark_gray = load_image_jpg_tf(INPUT_FILE)
with open("colorize.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
tf.import_graph_def(graph_def, input_map={ "grayscale": shark_gray }, name='')
with tf.Session() as sess:
inferred_rgb = sess.graph.get_tensor_by_name("inferred_rgb:0")
inferred_batch = sess.run(inferred_rgb)
inferred_rgb = inferred_rgb * 255
out_uint8_img = tf.cast(inferred_rgb, tf.uint8)
out_shape = tf.shape(out_uint8_img)
jpeg_data_tensor = tf.reshape(out_uint8_img, (INPUT_EDGE,INPUT_EDGE,3))
jpeg_data = tf.image.encode_jpeg(jpeg_data_tensor)
jpeg = sess.run(jpeg_data)
with open(OUTPUT_FILE, mode='wb') as f:
f.write(jpeg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment