Created
February 9, 2019 14:08
-
-
Save gideonite/bb756207223e63e79aeea36d74f1723b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
from tensorflow_probability import distributions as tfd | |
import tensorflow.contrib.eager as tfe | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import sys | |
import os | |
import urllib | |
from absl import app | |
from absl import flags | |
from absl import logging | |
FLAGS = flags.FLAGS | |
flags.DEFINE_integer('seed', 0, 'random seed') | |
flags.DEFINE_integer('batch_size', 28, 'size of batch for training') | |
flags.DEFINE_integer('n_samples', 8, 'n of samples from q(z|x) during training') | |
flags.DEFINE_integer('z_size', 4, 'size of continuous hidden variable, z_n') | |
flags.DEFINE_string('outdir', '/tmp/', '') | |
ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/" | |
IMAGE_SHAPE = [28, 28, 1] | |
def download(directory, filename): | |
"""Downloads a file.""" | |
filepath = os.path.expanduser(os.path.join(directory, filename)) | |
if tf.gfile.Exists(filepath): | |
return filepath | |
if not tf.gfile.Exists(directory): | |
tf.gfile.MakeDirs(directory) | |
url = os.path.join(ROOT_PATH, filename) | |
print("Downloading %s to %s" % (url, filepath)) | |
urllib.request.urlretrieve(url, filepath) | |
return filepath | |
def mnist_dataset(directory, split_name): | |
"""Returns binary static MNIST tf.data.Dataset.""" | |
FILE_TEMPLATE = "binarized_mnist_{split}.amat" | |
amat_file = download(directory, FILE_TEMPLATE.format(split=split_name)) | |
dataset = tf.data.TextLineDataset(amat_file) | |
str_to_arr = lambda string: np.array([c == b"1" for c in string.split()]) | |
def _parser(s): | |
booltensor = tf.py_func(str_to_arr, [s], tf.bool) | |
#reshaped = tf.reshape(booltensor, [28, 28, 1]) | |
#return tf.to_float(reshaped), tf.constant(0, tf.int32) | |
# TODO remove | |
return tf.to_float(booltensor), tf.constant(0, tf.int32) | |
return dataset.map(_parser) | |
def make_nn(out_size, hidden_size=(128,64)): | |
layers = [] | |
for h in hidden_size: | |
layers.append(tf.keras.layers.Dense(h, | |
activation=tf.nn.relu)) | |
layers.append(tf.keras.layers.Dense(out_size)) | |
return tf.keras.Sequential(layers) | |
def encode(x, net): | |
def _compute_z_size(shape): | |
z_size = int(shape) | |
assert z_size % 2 == 0 | |
return z_size // 2 | |
mapped = net(x) | |
z_size = _compute_z_size(mapped.shape[-1]) | |
return tfd.MultivariateNormalDiag( | |
loc=mapped[..., :z_size], | |
scale_diag=tf.nn.softplus(mapped[..., z_size:])) | |
def decode(z, net, batch_size, n_samples): | |
mapped = net(z) | |
return tfd.Independent(tfd.Bernoulli(mapped), reinterpreted_batch_ndims=1) | |
def make_prior(z_size, dtype=tf.float32): | |
return tfd.MultivariateNormalDiag( | |
loc=tf.zeros(z_size, dtype), | |
scale_diag=tf.ones(z_size, dtype)) | |
def mnist_train_data(batch_size): | |
# TODO .shuffle().repeat() | |
return mnist_dataset("~/data/mnist", "train").batch(batch_size) | |
def main(argv): | |
tf.enable_eager_execution() | |
np.random.seed(FLAGS.seed) | |
tf.set_random_seed(FLAGS.seed) | |
del argv # unused | |
train_data = mnist_train_data(FLAGS.batch_size) | |
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) | |
encoding_net = make_nn(FLAGS.z_size*2) | |
decoding_net = make_nn(28*28) # TODO | |
ckpt = tf.train.Checkpoint(encoding_net=encoding_net, decoding_net=decoding_net) | |
elbos = [] | |
prior = make_prior(FLAGS.z_size) | |
for (x,y) in train_data: | |
step = tf.train.get_or_create_global_step().numpy() | |
if step % 100 == 0: | |
ckpt.save(os.path.join(FLAGS.outdir, 'ckpt/ckpt')) | |
np.savetxt(os.path.join(FLAGS.outdir, 'elbos.txt'), elbos) | |
with tf.GradientTape() as tape: | |
z = encode(x, encoding_net) | |
xhat = decode(z.sample(FLAGS.n_samples), | |
decoding_net, | |
FLAGS.batch_size, | |
FLAGS.n_samples) | |
if step % 100 == 0: | |
xhat_logits = xhat.parameters['distribution'].logits | |
np.savez(os.path.join(FLAGS.outdir, "xhat_logits_{}.npz".format(step)), xhat_logits.numpy()) | |
negloglik = -xhat.log_prob(x) | |
assert negloglik.shape == (FLAGS.n_samples, FLAGS.batch_size) | |
loss = tf.reduce_mean(negloglik) | |
kl = tfd.kl_divergence(z, prior) | |
assert kl.shape == (FLAGS.batch_size) | |
kl = tf.reduce_mean(kl) | |
elbo = loss + kl | |
elbos.append(elbo) | |
assert elbo.shape == () | |
print("step", step, "elbo", elbo.numpy()) | |
vars = encoding_net.variables + decoding_net.variables | |
grads = tape.gradient(elbo, vars) | |
optimizer.apply_gradients(zip(grads, vars), global_step=tf.train.get_or_create_global_step()) | |
# TRICKS | |
# - IWAE | |
# TODO | |
# - Create ourdir if it doesn't exist. | |
# - what other metrics other than elbo? | |
# - refactor make_nn, encoder, and decoder. n.b. you have to expose the | |
# variables for later gradient computations. | |
if __name__ == "__main__": | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment