Created November 16, 2019 18:33
This is the code provided in the VIB paper for the MNIST example
This produces ...
191: IZY=3.21 IZX=22.02 acc=0.9841 avg_acc=0.9888 err=0.0159 avg_err=0.0112
192: IZY=3.20 IZX=22.57 acc=0.9834 avg_acc=0.9884 err=0.0166 avg_err=0.0116
193: IZY=3.22 IZX=22.54 acc=0.9836 avg_acc=0.9887 err=0.0164 avg_err=0.0113
194: IZY=3.21 IZX=21.95 acc=0.9827 avg_acc=0.9884 err=0.0173 avg_err=0.0116
195: IZY=3.19 IZX=22.25 acc=0.9827 avg_acc=0.9886 err=0.0173 avg_err=0.0114
196: IZY=3.21 IZX=22.34 acc=0.9841 avg_acc=0.9886 err=0.0159 avg_err=0.0114
197: IZY=3.21 IZX=22.54 acc=0.9831 avg_acc=0.9883 err=0.0169 avg_err=0.0117
198: IZY=3.21 IZX=22.21 acc=0.9826 avg_acc=0.9883 err=0.0174 avg_err=0.0117
199: IZY=3.20 IZX=22.24 acc=0.9824 avg_acc=0.9881 err=0.0176 avg_err=0.0119
import numpy as np
import matplotlib.pyplot as plt
# matplotlib inline
import tensorflow as tf
# Turn on xla optimization
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.InteractiveSession(config=config)
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('/tmp/mnistdata', validation_size=0)
images = tf.placeholder(tf.float32, [None, 784], 'images')
labels = tf.placeholder(tf.int64, [None], 'labels')
one_hot_labels = tf.one_hot(labels, 10)
layers = tf.contrib.layers
ds = tf.contrib.distributions
def encoder(images):
net = layers.relu(2*images-1, 1024)
net = layers.relu(net, 1024)
params = layers.linear(net, 512)
mu, rho = params[:, :256], params[:, 256:]
encoding = ds.NormalWithSoftplusScale(mu, rho - 5.0)
return encoding
def decoder(encoding_sample):
net = layers.linear(encoding_sample, 10)
return net
prior = ds.Normal(0.0, 1.0)
import math
with tf.variable_scope('encoder'):
encoding = encoder(images)
with tf.variable_scope('decoder'):
logits = decoder(encoding.sample())
with tf.variable_scope('decoder', reuse=True):
many_logits = decoder(encoding.sample(12))
class_loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=one_hot_labels) / math.log(2)
BETA = 1e-3
info_loss = tf.reduce_sum(tf.reduce_mean(
ds.kl_divergence(encoding, prior), 0)) / math.log(2)
total_loss = class_loss + BETA * info_loss
accuracy = tf.reduce_mean(tf.cast(tf.equal(
tf.argmax(logits, 1), labels), tf.float32))
avg_accuracy = tf.reduce_mean(tf.cast(tf.equal(
tf.argmax(tf.reduce_mean(tf.nn.softmax(many_logits), 0), 1), labels), tf.float32))
IZY_bound = math.log(10, 2) - class_loss
IZX_bound = info_loss
batch_size = 100
steps_per_batch = int(mnist_data.train.num_examples / batch_size)
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(1e-4, global_step,
decay_rate=0.97, staircase=True)
opt = tf.train.AdamOptimizer(learning_rate, 0.5)
ma = tf.train.ExponentialMovingAverage(0.999, zero_debias=True)
ma_update = ma.apply(tf.model_variables())
saver = tf.train.Saver()
saver_polyak = tf.train.Saver(ma.variables_to_restore())
train_tensor =, opt,
def evaluate():
IZY, IZX, acc, avg_acc =[IZY_bound, IZX_bound, accuracy, avg_accuracy],
feed_dict={images: mnist_data.test.images, labels: mnist_data.test.labels})
return IZY, IZX, acc, avg_acc, 1-acc, 1-avg_acc
import sys
for epoch in range(200):
for step in range(steps_per_batch):
im, ls = mnist_data.train.next_batch(batch_size), feed_dict={images: im, labels: ls})
if epoch % 10 == 0:
print("{}: IZY={:.2f}\tIZX={:.2f}\tacc={:.4f}\tavg_acc={:.4f}\terr={:.4f}\tavg_err={:.4f}".format(epoch, *evaluate()))
savepth =, '/tmp/mnistvib', global_step)
saver_polyak.restore(sess, savepth)
saver.restore(sess, savepth)
