Last active
November 15, 2018 20:44
-
-
Save lxuechen/d5208a2a79e99a00e5adaf2504b0c3ef to your computer and use it in GitHub Desktop.
testing vae on fashion mnist with different binarization schemes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tqdm | |
import os | |
import gzip | |
from absl import flags | |
import urllib.request as req | |
import numpy as np | |
np.set_printoptions(threshold=np.nan) | |
import tensorflow as tf | |
k = tf.keras | |
def load_mnist(path, kind='train'): | |
labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) | |
images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) | |
with gzip.open(labels_path, 'rb') as lbpath: | |
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) | |
with gzip.open(images_path, 'rb') as imgpath: | |
images = np.frombuffer( | |
imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784) | |
return images, labels | |
class VAE(k.Model): | |
def __init__(self, latent_dim=50): | |
super(VAE, self).__init__() | |
self.latent_dim = latent_dim | |
self.enc1 = k.layers.Dense(200) | |
self.enc2 = k.layers.Dense(200) | |
self.enc3 = k.layers.Dense(latent_dim * 2) | |
self.dec1 = k.layers.Dense(200) | |
self.dec2 = k.layers.Dense(200) | |
self.dec3 = k.layers.Dense(784) | |
def encode(self, x): | |
net = self.enc1(x) | |
net = tf.nn.relu(net) | |
net = self.enc2(net) | |
net = tf.nn.relu(net) | |
net = self.enc3(net) | |
mean, logvar = net[:, :self.latent_dim], net[:, self.latent_dim:] | |
return mean, logvar | |
def decode(self, z): | |
net = self.dec1(z) | |
net = tf.nn.relu(net) | |
net = self.dec2(net) | |
net = tf.nn.relu(net) | |
logit = self.dec3(net) | |
return logit | |
def reparam(self, mean, logvar): | |
std = tf.exp(.5 * logvar) | |
eps = tf.random_normal(shape=mean.shape) | |
return mean + std * eps | |
def forward(self, x): | |
mean, logvar = self.encode(x) | |
z = self.reparam(mean, logvar) | |
x_logit = self.decode(z) | |
kl = tf.reduce_sum(normal_kl(mean, logvar, 0., 0.), axis=1) | |
logpx = -tf.reduce_sum( | |
tf.nn.sigmoid_cross_entropy_with_logits( | |
logits=x_logit, labels=x), axis=1) | |
elbo = logpx - kl | |
return x_logit, elbo, logpx, kl | |
def normal_kl(mu1, lv1, mu2, lv2): | |
v1, v2 = tf.exp(lv1), tf.exp(lv2) | |
lstd1, lstd2 = lv1 / 2., lv2 / 2. | |
return lstd2 - lstd1 + ((v1 + tf.square(mu1 - mu2)) / (2. * v2)) - .5 | |
def main(_): | |
latent_dim = 50 | |
batch_size = 100 | |
learning_rate = 1e-3 | |
epochs = 100 | |
model = VAE(latent_dim=latent_dim) | |
x = tf.placeholder(shape=(batch_size, 784), dtype=tf.float32) | |
if not FLAGS.static_bin: # dynamic sampling step | |
input_ = tf.nn.relu(tf.sign(x - tf.random_uniform(shape=x.shape))) | |
else: | |
input_ = x | |
x_logit, elbo, logpx, kl = model.forward(input_) | |
elbo = tf.reduce_mean(elbo, axis=0) | |
loss = -elbo | |
optim = tf.train.AdamOptimizer(learning_rate=learning_rate) | |
train_op = optim.minimize(loss) | |
# TB | |
eps = tf.random_normal(shape=(10, latent_dim)) | |
sample = tf.sigmoid(model.decode(eps)) | |
x_recon = tf.sigmoid(x_logit) | |
x_ = tf.reshape(x, (-1, 28, 28, 1)) | |
x_recon_ = tf.reshape(x_recon, (-1, 28, 28, 1)) | |
sample_ = tf.reshape(sample, (-1, 28, 28, 1)) | |
tf.summary.image("original images", x_, max_outputs=10) | |
tf.summary.image("reconstructions", x_recon_, max_outputs=10) | |
tf.summary.image("random samples", sample_, max_outputs=10) | |
summary = tf.summary.merge_all() | |
# read fashion | |
if not os.path.exists("./data/fashion"): | |
os.makedirs("./data/fashion") | |
req.urlretrieve( | |
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz", | |
"./data/fashion/train-images-idx3-ubyte.gz") | |
req.urlretrieve( | |
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz", | |
"./data/fashion/train-labels-idx1-ubyte.gz") | |
req.urlretrieve( | |
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz", | |
"./data/fashion/t10k-images-idx3-ubyte.gz") | |
req.urlretrieve( | |
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz", | |
"./data/fashion/t10k-labels-idx1-ubyte.gz") | |
x_train, _ = load_mnist('./data/fashion', kind='train') | |
x_test, _ = load_mnist('./data/fashion', kind='t10k') | |
x_train = x_train.copy().astype(np.float32) | |
x_test = x_test.copy().astype(np.float32) | |
x_train /= 255. | |
x_test /= 255. | |
if FLAGS.static_bin: # static binarization | |
x_train[x_train >= .5] = 1. | |
x_train[x_train < .5] = 0. | |
x_test[x_test >= .5] = 1. | |
x_test[x_test < .5] = 0. | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
summary_writer = tf.summary.FileWriter("./train_dir", sess.graph) | |
train_its = x_train.shape[0] // batch_size | |
test_its = x_test.shape[0] // batch_size | |
for e in range(epochs): | |
for it in tqdm.tqdm(range(train_its)): | |
data = x_train[batch_size * it: batch_size * (it + 1)] | |
sess.run(train_op, feed_dict={x: data}) | |
train_elbo = [] | |
for it in range(train_its): | |
data = x_train[batch_size * it: batch_size * (it + 1)] | |
train_elbo.append(sess.run(elbo, feed_dict={x: data})) | |
train_elbo = np.mean(train_elbo) | |
test_elbo = [] | |
for it in range(test_its): | |
data = x_test[batch_size * it: batch_size * (it + 1)] | |
test_elbo.append(sess.run(elbo, feed_dict={x: data})) | |
test_elbo = np.mean(test_elbo) | |
print("Epoch {}, train elbo {:.4f}, test elbo {:.4f}".format( | |
e, train_elbo, test_elbo)) | |
summary_writer.add_summary(sess.run(summary, feed_dict={x: data}), e) | |
if __name__ == "__main__": | |
flags.DEFINE_boolean( | |
"static_bin", | |
default=False, | |
help="Use static binarization if True; otherwise use dynamic binarization") | |
FLAGS = flags.FLAGS | |
tf.app.run(main=main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment