Skip to content

Instantly share code, notes, and snippets.

@shenkev
Created November 19, 2017 02:20
Show Gist options
  • Save shenkev/f8f4d2f2f05d6fc81e20b2e49a35c2ba to your computer and use it in GitHub Desktop.
Save shenkev/f8f4d2f2f05d6fc81e20b2e49a35c2ba to your computer and use it in GitHub Desktop.
from __future__ import print_function
import numpy as np
import tensorflow as tf
import edward as ed
import pickle
from mcmc.util import plot, plot_save
from mcmc.mcmc2 import run_experiment, compare_vae_hmc_loss
sess = ed.get_session() # need to make sure tf and edward share the global session
# Load dataset
mnist_wrapper_object = mnist_data()
mnist_dataset = mnist_wrapper_object.get_data()
# Load model
version = 'v0'
model_attributes = {
'dataset': mnist_wrapper_object,
'latent_dim': 50,
'version': version
}
model = model_class(sess, **model_attributes)
model.build()
model_sample_reconstructions = 0
model.set_defaults(reconstruction={'sampling': model_sample_reconstructions})
checkpoint = 'models/mnist-vae-gan-v0.weights.tfmod'
model.load(checkpoint)
# =============================== INFERENCE ====================================
inference_batch_size = 100
start_ind = 0
f = open("./adversarial_examples_v0.pckl", 'rb')
attack_set, attack_set_labels, adversarial_examples, adversarial_targets = pickle.load(f)
f.close()
attack_set = attack_set[start_ind:]
attack_set_labels = attack_set_labels[start_ind:]
adversarial_examples = adversarial_examples[start_ind:]
adversarial_targets = adversarial_targets[start_ind:]
x_ad = adversarial_examples[0:inference_batch_size]
for i in range(x_ad.shape[0]):
plot_save(attack_set[i].reshape(1, 784), # first number is sample number
'./out/{}_x_gt_label_{}_target{}.png'.format(str(start_ind+i+1).zfill(3), attack_set_labels[i], adversarial_targets[i]))
plot_save(x_ad[i].reshape(1, 784),
'./out/{}_x_adversarial.png'.format(str(start_ind+i+1).zfill(3)))
config = {
'model': 'hmc',
'inference_batch_size': inference_batch_size,
'T': 15000,
'img_dim': 28,
'step_size': None,
'leapfrog_steps': None,
'friction': None,
'z_dim': 50,
'likelihood_variance': 0.48,
'useDiscL': False,
'keep_ratio': 0.05,
'img_num': 0,
'sample_to_vis': 3
}
# Hack this shit
tf.logging.set_verbosity(tf.logging.ERROR)
model._training = tf.constant([False])
qz, qz_kept = run_experiment(model.decode_op, model.encode_op, x_ad, config, model.discriminator_l_op)
num_samples = 40
samples_to_check = qz_kept.sample(num_samples).eval()
f = open('log.txt', 'ab')
for i in range(inference_batch_size):
config['img_num'] = str(start_ind+i+1).zfill(3)
best_recon_loss, average_recon_loss, best_l2_loss, average_l2_loss, best_latent_loss, average_latent_loss, \
vae_recon_loss, vae_l2_loss, vae_latent_loss\
= compare_vae_hmc_loss(model.decode_op, model.encode_op, model.discriminator_l_op,
x_ad[i:i+1], samples_to_check[:, i, :], config)
print ("---------- Summary Image {} ------------".format(start_ind+i+1), file=f)
print("VAE recon loss: " + str(vae_recon_loss), file=f)
print("VAE L2 loss: " + str(vae_l2_loss), file=f)
print("VAE latent loss: " + str(vae_latent_loss), file=f)
print("Best mcmc recon loss: " + str(best_recon_loss), file=f)
print("Best mcmc L2 loss: " + str(best_l2_loss), file=f)
print("Best mcmc latent loss: " + str(best_latent_loss), file=f)
print("Average mcmc recon loss: " + str(average_recon_loss), file=f)
print("Average mcmc l2 loss " + str(average_l2_loss), file=f)
print("Average mcmc latent loss " + str(average_latent_loss), file=f)
f.close()
from edward.models import Empirical, Normal
import edward as ed
import tensorflow as tf
from util import plot, plot_save
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
def run_experiment(P, Q, x_gt, config, DiscL):
hmc_steps = config.get('T') # how many steps to run hmc for, include burn-in steps
keep_ratio = config.get('keep_ratio') # keep only last <keep_ratio> percentage of hmc samples (due to burn-in)
inference, qz = build_experiment(P, Q, x_gt, config, DiscL)
init_uninited_vars()
for _ in range(hmc_steps):
info_dict = inference.update()
inference.print_progress(info_dict)
to_keep_index = int((1 - keep_ratio) * hmc_steps)
qz_kept = Empirical(qz.params[to_keep_index:])
return qz, qz_kept
def compare_vae_hmc_loss(P, Q, DiscL, x_gt, samples_to_check, config):
print ("Starting evaluation...")
x_samples_to_check = trim_32_to_28(P(samples_to_check)).eval()
l_th_layer_samples = DiscL(trim_32_to_28(P(samples_to_check)))
l_th_x_gt = DiscL(x_gt)
x_samples_to_check = np.expand_dims(x_samples_to_check, 1)
num_samples = samples_to_check.shape[0]
img_num = config.get('img_num')
sample_to_vis = config.get('sample_to_vis')
best_recon_sample = x_samples_to_check[0]
best_recon_loss = recon_loss(x_gt, best_recon_sample)
best_l2_sample = x_samples_to_check[0]
best_l2_loss = l2_loss(x_gt, best_l2_sample)
best_latent_sample = x_samples_to_check[0]
best_latent_loss = l_latent_loss(l_th_x_gt, l_th_layer_samples[0:0+1])
total_recon_loss = 0.0
total_l2_loss = 0.0
total_latent_loss = 0.0
for i, sample in enumerate(tqdm(x_samples_to_check)):
for j in range(sample_to_vis):
plot_save(x_samples_to_check[j], './out/{}_mcmc_sample_{}.png'.format(img_num, j + 1))
avg_img = np.mean(x_samples_to_check, axis=0)
plot_save(avg_img, './out/{}_mcmcMean.png'.format(img_num))
r_loss = recon_loss(x_gt, sample)
l_loss = l2_loss(x_gt, sample)
lat_loss = l_latent_loss(l_th_x_gt, l_th_layer_samples[i:i+1])
total_recon_loss += r_loss
total_l2_loss += l_loss
total_latent_loss += lat_loss
if r_loss < best_recon_loss:
best_recon_sample = sample
best_recon_loss = r_loss
if l_loss < best_l2_loss:
best_l2_sample = sample
best_l2_loss = l_loss
if lat_loss < best_latent_loss:
best_latent_sample = sample
best_latent_loss = lat_loss
# print ("Recon loss: " + str(r_loss))
# print ("L2 loss: " + str(l_loss))
average_recon_loss = total_recon_loss/num_samples
average_l2_loss = total_l2_loss/num_samples
average_latent_loss = total_latent_loss /num_samples
vae_recon_loss = recon_loss(x_gt, trim_32_to_28(P(Q(x_gt))))
vae_l2_loss = l2_loss(x_gt, trim_32_to_28(P(Q(x_gt))))
vae_latent_loss = l_latent_loss(l_th_x_gt, P(Q(x_gt)))
print ("---------- Summary Image {} ------------".format(img_num))
print ("VAE recon loss: " + str(vae_recon_loss))
print ("VAE L2 loss: " + str(vae_l2_loss))
print ("VAE latent loss: " + str(vae_latent_loss))
print ("Best mcmc recon loss: " + str(best_recon_loss))
print ("Best mcmc L2 loss: " + str(best_l2_loss))
print ("Best mcmc latent loss: " + str(best_latent_loss))
print ("Average mcmc recon loss: " + str(average_recon_loss))
print ("Average mcmc l2 loss " + str(average_l2_loss))
print ("Average mcmc latent loss " + str(average_latent_loss))
plot_save(tf.reshape(tf.slice(tf.reshape(P(Q(x_gt)), [32, 32]), [2, 2], [28, 28]), [1, 784]).eval(),
'./out/{}_vae_recon.png'.format(img_num))
plot_save(best_recon_sample, './out/{}_best_recon.png'.format(img_num))
plot_save(best_l2_sample, './out/{}_best_l2.png'.format(img_num))
plot_save(best_latent_sample, './out/{}_best_latent.png'.format(img_num))
return best_recon_loss, average_recon_loss, best_l2_loss, average_l2_loss, best_latent_loss, average_latent_loss,\
vae_recon_loss, vae_l2_loss, vae_latent_loss
def l2_loss(x_gt, x_hmc):
return tf.norm(x_gt - x_hmc).eval()
def recon_loss(x_gt, x_hmc):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_hmc, labels=x_gt), 1).eval()
def l_latent_loss(l_th_x_gt, l_th_x_hmc):
return tf.norm(l_th_x_gt - l_th_x_hmc).eval()
def init_uninited_vars():
sess = ed.get_session()
unint_vars = []
for var in tf.global_variables():
if not tf.is_variable_initialized(var).eval():
unint_vars.append(var)
missingVarInit = tf.variables_initializer(unint_vars)
sess.run(missingVarInit)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment