Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Experimental bugfixes on Tensorflow's Variational Autoencoder with an image batch loader.
import sys, os
import math
from random import randint, choice
from glob import glob
import tensorflow as tf
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def xavier_init(fan_in, fan_out, constant = 1):
low = -constant * np.sqrt(6.0 / (fan_in + fan_out))
high = constant * np.sqrt(6.0 / (fan_in + fan_out))
return tf.random_uniform((fan_in, fan_out), minval = low, maxval = high, dtype = tf.float32)
class ImageDirectoryIterator(object):
def __init__(self, folder_name, batch_size, target_size, cache=False, patches=False):
self.cache = dict() if cache else None
self.target_size = target_size
self.patches = patches
self.batch_size = batch_size
self.folder_name = folder_name
self.folder_glob = glob(os.path.join(folder_name, "*"))
def get_batch(self):
batch = np.zeros((self.batch_size, self.target_size[0]*self.target_size[1]), dtype=np.float)
for i in range(self.batch_size):
try:
image_filename = choice(self.folder_glob)
batch[i,:] = self.disk_image_to_matrix(image_filename)
except (IOError, ValueError) as e:
print("Problem loading image filename {}: {}".format(image_filename, e))
return batch
def disk_image_to_matrix(self, image_filename):
# If caching is enabled, check and load.
if not self.cache or image_filename not in self.cache:
img = Image.open(image_filename)
if self.cache:
self.cache[image_filename] = img.copy()
else:
img = self.cache[image_filename].copy()
if self.patches:
# Randomly select a piece. Randomize the location AND the scale.
left = randint(0, img.size[0]-self.target_size[0])
right = randint(left, img.size[0]-self.target_size[0])
top = randint(0, img.size[1]-self.target_size[1])
bottom = randint(top, img.size[1]-self.target_size[1])
img = img.crop((left, top, right, bottom))
img.thumbnail(self.target_size, Image.ANTIALIAS) # Defaults to highest quality. Works in-place.
img = img.convert('L') # Enforce luminance. Does not work in place.
arr = np.array(img).reshape((1, -1)) # Flatten.
final_example = np.zeros((1, self.target_size[0]*self.target_size[1]), dtype=np.float)
if img.size[0] >= self.target_size[0] and img.size[1] >= self.target_size[1]: # If our image is too small, ignore it.
final_example[:] = arr/255.0 # Arr is still dtype=uint8
return final_example
def matrix_to_disk_image(self, mat, filename):
low, high = mat.min(), mat.max()
mat = np.asarray(255.0*((mat-low)/(high-low)), dtype=np.uint8)
img = Image.fromarray(mat)
img.save(filename)
class VariationalAutoencoder(object):
def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer()):
self.n_input = n_input
self.n_hidden = n_hidden
network_weights = self._initialize_weights()
self.weights = network_weights
# Model
self.x = tf.placeholder(tf.float32, [None, self.n_input])
self.z_mean = tf.add(tf.matmul(self.x, self.weights['w1']), self.weights['b1'])
self.z_log_sigma_sq = tf.add(tf.matmul(self.x, self.weights['log_sigma_w1']), self.weights['log_sigma_b1'])
# Sampling from gaussian distribution.
eps = tf.random_normal(tf.pack([tf.shape(self.x)[0], self.n_hidden]), 0, 1, dtype = tf.float32)
self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))
self.reconstruction = tf.add(tf.matmul(self.z, self.weights['w2']), self.weights['b2'])
# Calculating cost.
reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.sub(self.reconstruction, self.x), 2.0))
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq - tf.square(self.z_mean) - tf.exp(self.z_log_sigma_sq), 1)
self.cost = tf.reduce_mean(reconstr_loss + latent_loss)
self.optimizer = optimizer.minimize(self.cost)
init = tf.initialize_all_variables()
self.sess = tf.Session()
self.sess.run(init)
self.saver = tf.train.Saver()
def _initialize_weights(self):
all_weights = dict()
all_weights['w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden))
all_weights['log_sigma_w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden))
all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
all_weights['log_sigma_b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))
all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32))
return all_weights
def partial_fit(self, X):
cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X})
return cost
def calc_total_cost(self, X):
return self.sess.run(self.cost, feed_dict = {self.x: X})
def transform(self, X):
return self.sess.run(self.z_mean, feed_dict={self.x: X})
def generate(self, hidden = None):
if hidden is None:
hidden = np.atleast_2d(np.random.normal(size=self.weights["b1"].get_shape().as_list()))
return self.sess.run(self.reconstruction, feed_dict={self.z: hidden})
def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict={self.x: X})
def getWeights(self):
return self.sess.run(self.weights['w1'])
def getBiases(self):
return self.sess.run(self.weights['b1'])
def save(self, filename):
self.saver.save(self.sess, filename)
def load(self, filename):
self.saver.restore(self.sess, filename)
def main():
data = ImageDirectoryIterator("/tmp/data/", 10, (64,64), True, False)
vae = VariationalAutoencoder(64*64, 256)
for i in range(1, 10000):
loss = vae.partial_fit(data.get_batch())
if i % (int(math.floor(math.log(i, 10)+1))*10) == 0: # Decrease steadily how often we dump output.
print("Iter {} - Loss {} ".format(i, loss))
if math.isnan(loss):
print("Loading...")
vae.load("/tmp/test_imgur_scrape")
print("Loaded.")
# TODO: Decay learning rate.
elif i%100 == 0:
print("Saving...")
vae.save("/tmp/test_imgur_scrape")
print("Saved.")
"""
from matplotlib.pyplot import imshow
#arr = data.get_batch()[0,:]
arr = vae.generate()
arr -= arr.min()+1e-8
arr /= arr.max()
im = Image.fromarray(arr.reshape((64,64))*255.0)
%matplotlib inline
imshow(im)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.