Skip to content

Instantly share code, notes, and snippets.

@gideonite
Created February 9, 2019 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gideonite/bb756207223e63e79aeea36d74f1723b to your computer and use it in GitHub Desktop.
Save gideonite/bb756207223e63e79aeea36d74f1723b to your computer and use it in GitHub Desktop.
#!/usr/bin/python
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import tensorflow.contrib.eager as tfe
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import urllib
from absl import app
from absl import flags
from absl import logging
FLAGS = flags.FLAGS
flags.DEFINE_integer('seed', 0, 'random seed')
flags.DEFINE_integer('batch_size', 28, 'size of batch for training')
flags.DEFINE_integer('n_samples', 8, 'n of samples from q(z|x) during training')
flags.DEFINE_integer('z_size', 4, 'size of continuous hidden variable, z_n')
flags.DEFINE_string('outdir', '/tmp/', '')
ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
IMAGE_SHAPE = [28, 28, 1]
def download(directory, filename):
"""Downloads a file."""
filepath = os.path.expanduser(os.path.join(directory, filename))
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
url = os.path.join(ROOT_PATH, filename)
print("Downloading %s to %s" % (url, filepath))
urllib.request.urlretrieve(url, filepath)
return filepath
def mnist_dataset(directory, split_name):
"""Returns binary static MNIST tf.data.Dataset."""
FILE_TEMPLATE = "binarized_mnist_{split}.amat"
amat_file = download(directory, FILE_TEMPLATE.format(split=split_name))
dataset = tf.data.TextLineDataset(amat_file)
str_to_arr = lambda string: np.array([c == b"1" for c in string.split()])
def _parser(s):
booltensor = tf.py_func(str_to_arr, [s], tf.bool)
#reshaped = tf.reshape(booltensor, [28, 28, 1])
#return tf.to_float(reshaped), tf.constant(0, tf.int32)
# TODO remove
return tf.to_float(booltensor), tf.constant(0, tf.int32)
return dataset.map(_parser)
def make_nn(out_size, hidden_size=(128,64)):
layers = []
for h in hidden_size:
layers.append(tf.keras.layers.Dense(h,
activation=tf.nn.relu))
layers.append(tf.keras.layers.Dense(out_size))
return tf.keras.Sequential(layers)
def encode(x, net):
def _compute_z_size(shape):
z_size = int(shape)
assert z_size % 2 == 0
return z_size // 2
mapped = net(x)
z_size = _compute_z_size(mapped.shape[-1])
return tfd.MultivariateNormalDiag(
loc=mapped[..., :z_size],
scale_diag=tf.nn.softplus(mapped[..., z_size:]))
def decode(z, net, batch_size, n_samples):
mapped = net(z)
return tfd.Independent(tfd.Bernoulli(mapped), reinterpreted_batch_ndims=1)
def make_prior(z_size, dtype=tf.float32):
return tfd.MultivariateNormalDiag(
loc=tf.zeros(z_size, dtype),
scale_diag=tf.ones(z_size, dtype))
def mnist_train_data(batch_size):
# TODO .shuffle().repeat()
return mnist_dataset("~/data/mnist", "train").batch(batch_size)
def main(argv):
tf.enable_eager_execution()
np.random.seed(FLAGS.seed)
tf.set_random_seed(FLAGS.seed)
del argv # unused
train_data = mnist_train_data(FLAGS.batch_size)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
encoding_net = make_nn(FLAGS.z_size*2)
decoding_net = make_nn(28*28) # TODO
ckpt = tf.train.Checkpoint(encoding_net=encoding_net, decoding_net=decoding_net)
elbos = []
prior = make_prior(FLAGS.z_size)
for (x,y) in train_data:
step = tf.train.get_or_create_global_step().numpy()
if step % 100 == 0:
ckpt.save(os.path.join(FLAGS.outdir, 'ckpt/ckpt'))
np.savetxt(os.path.join(FLAGS.outdir, 'elbos.txt'), elbos)
with tf.GradientTape() as tape:
z = encode(x, encoding_net)
xhat = decode(z.sample(FLAGS.n_samples),
decoding_net,
FLAGS.batch_size,
FLAGS.n_samples)
if step % 100 == 0:
xhat_logits = xhat.parameters['distribution'].logits
np.savez(os.path.join(FLAGS.outdir, "xhat_logits_{}.npz".format(step)), xhat_logits.numpy())
negloglik = -xhat.log_prob(x)
assert negloglik.shape == (FLAGS.n_samples, FLAGS.batch_size)
loss = tf.reduce_mean(negloglik)
kl = tfd.kl_divergence(z, prior)
assert kl.shape == (FLAGS.batch_size)
kl = tf.reduce_mean(kl)
elbo = loss + kl
elbos.append(elbo)
assert elbo.shape == ()
print("step", step, "elbo", elbo.numpy())
vars = encoding_net.variables + decoding_net.variables
grads = tape.gradient(elbo, vars)
optimizer.apply_gradients(zip(grads, vars), global_step=tf.train.get_or_create_global_step())
# TRICKS
# - IWAE
# TODO
# - Create ourdir if it doesn't exist.
# - what other metrics other than elbo?
# - refactor make_nn, encoder, and decoder. n.b. you have to expose the
# variables for later gradient computations.
if __name__ == "__main__":
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment