Last active
September 11, 2016 23:14
-
-
Save JosephCatrambone/e54ff0fa414070969e6e3190f6fe0796 to your computer and use it in GitHub Desktop.
Experimental bugfixes on Tensorflow's Variational Autoencoder with an image batch loader.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os | |
import math | |
from random import randint, choice | |
from glob import glob | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image, ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
def xavier_init(fan_in, fan_out, constant = 1): | |
low = -constant * np.sqrt(6.0 / (fan_in + fan_out)) | |
high = constant * np.sqrt(6.0 / (fan_in + fan_out)) | |
return tf.random_uniform((fan_in, fan_out), minval = low, maxval = high, dtype = tf.float32) | |
class ImageDirectoryIterator(object): | |
def __init__(self, folder_name, batch_size, target_size, cache=False, patches=False): | |
self.cache = dict() if cache else None | |
self.target_size = target_size | |
self.patches = patches | |
self.batch_size = batch_size | |
self.folder_name = folder_name | |
self.folder_glob = glob(os.path.join(folder_name, "*")) | |
def get_batch(self): | |
batch = np.zeros((self.batch_size, self.target_size[0]*self.target_size[1]), dtype=np.float) | |
for i in range(self.batch_size): | |
try: | |
image_filename = choice(self.folder_glob) | |
batch[i,:] = self.disk_image_to_matrix(image_filename) | |
except (IOError, ValueError) as e: | |
print("Problem loading image filename {}: {}".format(image_filename, e)) | |
return batch | |
def disk_image_to_matrix(self, image_filename): | |
# If caching is enabled, check and load. | |
if not self.cache or image_filename not in self.cache: | |
img = Image.open(image_filename) | |
if self.cache: | |
self.cache[image_filename] = img.copy() | |
else: | |
img = self.cache[image_filename].copy() | |
if self.patches: | |
# Randomly select a piece. Randomize the location AND the scale. | |
left = randint(0, img.size[0]-self.target_size[0]) | |
right = randint(left, img.size[0]-self.target_size[0]) | |
top = randint(0, img.size[1]-self.target_size[1]) | |
bottom = randint(top, img.size[1]-self.target_size[1]) | |
img = img.crop((left, top, right, bottom)) | |
img.thumbnail(self.target_size, Image.ANTIALIAS) # Defaults to highest quality. Works in-place. | |
img = img.convert('L') # Enforce luminance. Does not work in place. | |
arr = np.array(img).reshape((1, -1)) # Flatten. | |
final_example = np.zeros((1, self.target_size[0]*self.target_size[1]), dtype=np.float) | |
if img.size[0] >= self.target_size[0] and img.size[1] >= self.target_size[1]: # If our image is too small, ignore it. | |
final_example[:] = arr/255.0 # Arr is still dtype=uint8 | |
return final_example | |
def matrix_to_disk_image(self, mat, filename): | |
low, high = mat.min(), mat.max() | |
mat = np.asarray(255.0*((mat-low)/(high-low)), dtype=np.uint8) | |
img = Image.fromarray(mat) | |
img.save(filename) | |
class VariationalAutoencoder(object): | |
def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer()): | |
self.n_input = n_input | |
self.n_hidden = n_hidden | |
network_weights = self._initialize_weights() | |
self.weights = network_weights | |
# Model | |
self.x = tf.placeholder(tf.float32, [None, self.n_input]) | |
self.z_mean = tf.add(tf.matmul(self.x, self.weights['w1']), self.weights['b1']) | |
self.z_log_sigma_sq = tf.add(tf.matmul(self.x, self.weights['log_sigma_w1']), self.weights['log_sigma_b1']) | |
# Sampling from gaussian distribution. | |
eps = tf.random_normal(tf.pack([tf.shape(self.x)[0], self.n_hidden]), 0, 1, dtype = tf.float32) | |
self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps)) | |
self.reconstruction = tf.add(tf.matmul(self.z, self.weights['w2']), self.weights['b2']) | |
# Calculating cost. | |
reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.sub(self.reconstruction, self.x), 2.0)) | |
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq - tf.square(self.z_mean) - tf.exp(self.z_log_sigma_sq), 1) | |
self.cost = tf.reduce_mean(reconstr_loss + latent_loss) | |
self.optimizer = optimizer.minimize(self.cost) | |
init = tf.initialize_all_variables() | |
self.sess = tf.Session() | |
self.sess.run(init) | |
self.saver = tf.train.Saver() | |
def _initialize_weights(self): | |
all_weights = dict() | |
all_weights['w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden)) | |
all_weights['log_sigma_w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden)) | |
all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32)) | |
all_weights['log_sigma_b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32)) | |
all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32)) | |
all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32)) | |
return all_weights | |
def partial_fit(self, X): | |
cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X}) | |
return cost | |
def calc_total_cost(self, X): | |
return self.sess.run(self.cost, feed_dict = {self.x: X}) | |
def transform(self, X): | |
return self.sess.run(self.z_mean, feed_dict={self.x: X}) | |
def generate(self, hidden = None): | |
if hidden is None: | |
hidden = np.atleast_2d(np.random.normal(size=self.weights["b1"].get_shape().as_list())) | |
return self.sess.run(self.reconstruction, feed_dict={self.z: hidden}) | |
def reconstruct(self, X): | |
return self.sess.run(self.reconstruction, feed_dict={self.x: X}) | |
def getWeights(self): | |
return self.sess.run(self.weights['w1']) | |
def getBiases(self): | |
return self.sess.run(self.weights['b1']) | |
def save(self, filename): | |
self.saver.save(self.sess, filename) | |
def load(self, filename): | |
self.saver.restore(self.sess, filename) | |
def main(): | |
data = ImageDirectoryIterator("/tmp/data/", 10, (64,64), True, False) | |
vae = VariationalAutoencoder(64*64, 256) | |
for i in range(1, 10000): | |
loss = vae.partial_fit(data.get_batch()) | |
if i % (int(math.floor(math.log(i, 10)+1))*10) == 0: # Decrease steadily how often we dump output. | |
print("Iter {} - Loss {} ".format(i, loss)) | |
if math.isnan(loss): | |
print("Loading...") | |
vae.load("/tmp/test_imgur_scrape") | |
print("Loaded.") | |
# TODO: Decay learning rate. | |
elif i%100 == 0: | |
print("Saving...") | |
vae.save("/tmp/test_imgur_scrape") | |
print("Saved.") | |
""" | |
from matplotlib.pyplot import imshow | |
#arr = data.get_batch()[0,:] | |
arr = vae.generate() | |
arr -= arr.min()+1e-8 | |
arr /= arr.max() | |
im = Image.fromarray(arr.reshape((64,64))*255.0) | |
%matplotlib inline | |
imshow(im) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment