Skip to content

Instantly share code, notes, and snippets.

@lxuechen
Last active November 15, 2018 20:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lxuechen/d5208a2a79e99a00e5adaf2504b0c3ef to your computer and use it in GitHub Desktop.
Save lxuechen/d5208a2a79e99a00e5adaf2504b0c3ef to your computer and use it in GitHub Desktop.
testing vae on fashion mnist with different binarization schemes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tqdm
import os
import gzip
from absl import flags
import urllib.request as req
import numpy as np
np.set_printoptions(threshold=np.nan)
import tensorflow as tf
k = tf.keras
def load_mnist(path, kind='train'):
labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)
with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)
with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(
imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784)
return images, labels
class VAE(k.Model):
def __init__(self, latent_dim=50):
super(VAE, self).__init__()
self.latent_dim = latent_dim
self.enc1 = k.layers.Dense(200)
self.enc2 = k.layers.Dense(200)
self.enc3 = k.layers.Dense(latent_dim * 2)
self.dec1 = k.layers.Dense(200)
self.dec2 = k.layers.Dense(200)
self.dec3 = k.layers.Dense(784)
def encode(self, x):
net = self.enc1(x)
net = tf.nn.relu(net)
net = self.enc2(net)
net = tf.nn.relu(net)
net = self.enc3(net)
mean, logvar = net[:, :self.latent_dim], net[:, self.latent_dim:]
return mean, logvar
def decode(self, z):
net = self.dec1(z)
net = tf.nn.relu(net)
net = self.dec2(net)
net = tf.nn.relu(net)
logit = self.dec3(net)
return logit
def reparam(self, mean, logvar):
std = tf.exp(.5 * logvar)
eps = tf.random_normal(shape=mean.shape)
return mean + std * eps
def forward(self, x):
mean, logvar = self.encode(x)
z = self.reparam(mean, logvar)
x_logit = self.decode(z)
kl = tf.reduce_sum(normal_kl(mean, logvar, 0., 0.), axis=1)
logpx = -tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=x_logit, labels=x), axis=1)
elbo = logpx - kl
return x_logit, elbo, logpx, kl
def normal_kl(mu1, lv1, mu2, lv2):
v1, v2 = tf.exp(lv1), tf.exp(lv2)
lstd1, lstd2 = lv1 / 2., lv2 / 2.
return lstd2 - lstd1 + ((v1 + tf.square(mu1 - mu2)) / (2. * v2)) - .5
def main(_):
latent_dim = 50
batch_size = 100
learning_rate = 1e-3
epochs = 100
model = VAE(latent_dim=latent_dim)
x = tf.placeholder(shape=(batch_size, 784), dtype=tf.float32)
if not FLAGS.static_bin: # dynamic sampling step
input_ = tf.nn.relu(tf.sign(x - tf.random_uniform(shape=x.shape)))
else:
input_ = x
x_logit, elbo, logpx, kl = model.forward(input_)
elbo = tf.reduce_mean(elbo, axis=0)
loss = -elbo
optim = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optim.minimize(loss)
# TB
eps = tf.random_normal(shape=(10, latent_dim))
sample = tf.sigmoid(model.decode(eps))
x_recon = tf.sigmoid(x_logit)
x_ = tf.reshape(x, (-1, 28, 28, 1))
x_recon_ = tf.reshape(x_recon, (-1, 28, 28, 1))
sample_ = tf.reshape(sample, (-1, 28, 28, 1))
tf.summary.image("original images", x_, max_outputs=10)
tf.summary.image("reconstructions", x_recon_, max_outputs=10)
tf.summary.image("random samples", sample_, max_outputs=10)
summary = tf.summary.merge_all()
# read fashion
if not os.path.exists("./data/fashion"):
os.makedirs("./data/fashion")
req.urlretrieve(
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
"./data/fashion/train-images-idx3-ubyte.gz")
req.urlretrieve(
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
"./data/fashion/train-labels-idx1-ubyte.gz")
req.urlretrieve(
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
"./data/fashion/t10k-images-idx3-ubyte.gz")
req.urlretrieve(
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
"./data/fashion/t10k-labels-idx1-ubyte.gz")
x_train, _ = load_mnist('./data/fashion', kind='train')
x_test, _ = load_mnist('./data/fashion', kind='t10k')
x_train = x_train.copy().astype(np.float32)
x_test = x_test.copy().astype(np.float32)
x_train /= 255.
x_test /= 255.
if FLAGS.static_bin: # static binarization
x_train[x_train >= .5] = 1.
x_train[x_train < .5] = 0.
x_test[x_test >= .5] = 1.
x_test[x_test < .5] = 0.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
summary_writer = tf.summary.FileWriter("./train_dir", sess.graph)
train_its = x_train.shape[0] // batch_size
test_its = x_test.shape[0] // batch_size
for e in range(epochs):
for it in tqdm.tqdm(range(train_its)):
data = x_train[batch_size * it: batch_size * (it + 1)]
sess.run(train_op, feed_dict={x: data})
train_elbo = []
for it in range(train_its):
data = x_train[batch_size * it: batch_size * (it + 1)]
train_elbo.append(sess.run(elbo, feed_dict={x: data}))
train_elbo = np.mean(train_elbo)
test_elbo = []
for it in range(test_its):
data = x_test[batch_size * it: batch_size * (it + 1)]
test_elbo.append(sess.run(elbo, feed_dict={x: data}))
test_elbo = np.mean(test_elbo)
print("Epoch {}, train elbo {:.4f}, test elbo {:.4f}".format(
e, train_elbo, test_elbo))
summary_writer.add_summary(sess.run(summary, feed_dict={x: data}), e)
if __name__ == "__main__":
flags.DEFINE_boolean(
"static_bin",
default=False,
help="Use static binarization if True; otherwise use dynamic binarization")
FLAGS = flags.FLAGS
tf.app.run(main=main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment