Created
November 19, 2017 02:20
-
-
Save shenkev/f8f4d2f2f05d6fc81e20b2e49a35c2ba to your computer and use it in GitHub Desktop.
Gist affiliated with Stackoverflow https://stackoverflow.com/questions/47372815/tensorflow-every-iteration-of-for-loop-gets-slower-and-slower
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import numpy as np | |
import tensorflow as tf | |
import edward as ed | |
import pickle | |
from mcmc.util import plot, plot_save | |
from mcmc.mcmc2 import run_experiment, compare_vae_hmc_loss | |
sess = ed.get_session() # need to make sure tf and edward share the global session | |
# Load dataset | |
mnist_wrapper_object = mnist_data() | |
mnist_dataset = mnist_wrapper_object.get_data() | |
# Load model | |
version = 'v0' | |
model_attributes = { | |
'dataset': mnist_wrapper_object, | |
'latent_dim': 50, | |
'version': version | |
} | |
model = model_class(sess, **model_attributes) | |
model.build() | |
model_sample_reconstructions = 0 | |
model.set_defaults(reconstruction={'sampling': model_sample_reconstructions}) | |
checkpoint = 'models/mnist-vae-gan-v0.weights.tfmod' | |
model.load(checkpoint) | |
# =============================== INFERENCE ==================================== | |
inference_batch_size = 100 | |
start_ind = 0 | |
f = open("./adversarial_examples_v0.pckl", 'rb') | |
attack_set, attack_set_labels, adversarial_examples, adversarial_targets = pickle.load(f) | |
f.close() | |
attack_set = attack_set[start_ind:] | |
attack_set_labels = attack_set_labels[start_ind:] | |
adversarial_examples = adversarial_examples[start_ind:] | |
adversarial_targets = adversarial_targets[start_ind:] | |
x_ad = adversarial_examples[0:inference_batch_size] | |
for i in range(x_ad.shape[0]): | |
plot_save(attack_set[i].reshape(1, 784), # first number is sample number | |
'./out/{}_x_gt_label_{}_target{}.png'.format(str(start_ind+i+1).zfill(3), attack_set_labels[i], adversarial_targets[i])) | |
plot_save(x_ad[i].reshape(1, 784), | |
'./out/{}_x_adversarial.png'.format(str(start_ind+i+1).zfill(3))) | |
config = { | |
'model': 'hmc', | |
'inference_batch_size': inference_batch_size, | |
'T': 15000, | |
'img_dim': 28, | |
'step_size': None, | |
'leapfrog_steps': None, | |
'friction': None, | |
'z_dim': 50, | |
'likelihood_variance': 0.48, | |
'useDiscL': False, | |
'keep_ratio': 0.05, | |
'img_num': 0, | |
'sample_to_vis': 3 | |
} | |
# Hack this shit | |
tf.logging.set_verbosity(tf.logging.ERROR) | |
model._training = tf.constant([False]) | |
qz, qz_kept = run_experiment(model.decode_op, model.encode_op, x_ad, config, model.discriminator_l_op) | |
num_samples = 40 | |
samples_to_check = qz_kept.sample(num_samples).eval() | |
f = open('log.txt', 'ab') | |
for i in range(inference_batch_size): | |
config['img_num'] = str(start_ind+i+1).zfill(3) | |
best_recon_loss, average_recon_loss, best_l2_loss, average_l2_loss, best_latent_loss, average_latent_loss, \ | |
vae_recon_loss, vae_l2_loss, vae_latent_loss\ | |
= compare_vae_hmc_loss(model.decode_op, model.encode_op, model.discriminator_l_op, | |
x_ad[i:i+1], samples_to_check[:, i, :], config) | |
print ("---------- Summary Image {} ------------".format(start_ind+i+1), file=f) | |
print("VAE recon loss: " + str(vae_recon_loss), file=f) | |
print("VAE L2 loss: " + str(vae_l2_loss), file=f) | |
print("VAE latent loss: " + str(vae_latent_loss), file=f) | |
print("Best mcmc recon loss: " + str(best_recon_loss), file=f) | |
print("Best mcmc L2 loss: " + str(best_l2_loss), file=f) | |
print("Best mcmc latent loss: " + str(best_latent_loss), file=f) | |
print("Average mcmc recon loss: " + str(average_recon_loss), file=f) | |
print("Average mcmc l2 loss " + str(average_l2_loss), file=f) | |
print("Average mcmc latent loss " + str(average_latent_loss), file=f) | |
f.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from edward.models import Empirical, Normal | |
import edward as ed | |
import tensorflow as tf | |
from util import plot, plot_save | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import numpy as np | |
def run_experiment(P, Q, x_gt, config, DiscL): | |
hmc_steps = config.get('T') # how many steps to run hmc for, include burn-in steps | |
keep_ratio = config.get('keep_ratio') # keep only last <keep_ratio> percentage of hmc samples (due to burn-in) | |
inference, qz = build_experiment(P, Q, x_gt, config, DiscL) | |
init_uninited_vars() | |
for _ in range(hmc_steps): | |
info_dict = inference.update() | |
inference.print_progress(info_dict) | |
to_keep_index = int((1 - keep_ratio) * hmc_steps) | |
qz_kept = Empirical(qz.params[to_keep_index:]) | |
return qz, qz_kept | |
def compare_vae_hmc_loss(P, Q, DiscL, x_gt, samples_to_check, config): | |
print ("Starting evaluation...") | |
x_samples_to_check = trim_32_to_28(P(samples_to_check)).eval() | |
l_th_layer_samples = DiscL(trim_32_to_28(P(samples_to_check))) | |
l_th_x_gt = DiscL(x_gt) | |
x_samples_to_check = np.expand_dims(x_samples_to_check, 1) | |
num_samples = samples_to_check.shape[0] | |
img_num = config.get('img_num') | |
sample_to_vis = config.get('sample_to_vis') | |
best_recon_sample = x_samples_to_check[0] | |
best_recon_loss = recon_loss(x_gt, best_recon_sample) | |
best_l2_sample = x_samples_to_check[0] | |
best_l2_loss = l2_loss(x_gt, best_l2_sample) | |
best_latent_sample = x_samples_to_check[0] | |
best_latent_loss = l_latent_loss(l_th_x_gt, l_th_layer_samples[0:0+1]) | |
total_recon_loss = 0.0 | |
total_l2_loss = 0.0 | |
total_latent_loss = 0.0 | |
for i, sample in enumerate(tqdm(x_samples_to_check)): | |
for j in range(sample_to_vis): | |
plot_save(x_samples_to_check[j], './out/{}_mcmc_sample_{}.png'.format(img_num, j + 1)) | |
avg_img = np.mean(x_samples_to_check, axis=0) | |
plot_save(avg_img, './out/{}_mcmcMean.png'.format(img_num)) | |
r_loss = recon_loss(x_gt, sample) | |
l_loss = l2_loss(x_gt, sample) | |
lat_loss = l_latent_loss(l_th_x_gt, l_th_layer_samples[i:i+1]) | |
total_recon_loss += r_loss | |
total_l2_loss += l_loss | |
total_latent_loss += lat_loss | |
if r_loss < best_recon_loss: | |
best_recon_sample = sample | |
best_recon_loss = r_loss | |
if l_loss < best_l2_loss: | |
best_l2_sample = sample | |
best_l2_loss = l_loss | |
if lat_loss < best_latent_loss: | |
best_latent_sample = sample | |
best_latent_loss = lat_loss | |
# print ("Recon loss: " + str(r_loss)) | |
# print ("L2 loss: " + str(l_loss)) | |
average_recon_loss = total_recon_loss/num_samples | |
average_l2_loss = total_l2_loss/num_samples | |
average_latent_loss = total_latent_loss /num_samples | |
vae_recon_loss = recon_loss(x_gt, trim_32_to_28(P(Q(x_gt)))) | |
vae_l2_loss = l2_loss(x_gt, trim_32_to_28(P(Q(x_gt)))) | |
vae_latent_loss = l_latent_loss(l_th_x_gt, P(Q(x_gt))) | |
print ("---------- Summary Image {} ------------".format(img_num)) | |
print ("VAE recon loss: " + str(vae_recon_loss)) | |
print ("VAE L2 loss: " + str(vae_l2_loss)) | |
print ("VAE latent loss: " + str(vae_latent_loss)) | |
print ("Best mcmc recon loss: " + str(best_recon_loss)) | |
print ("Best mcmc L2 loss: " + str(best_l2_loss)) | |
print ("Best mcmc latent loss: " + str(best_latent_loss)) | |
print ("Average mcmc recon loss: " + str(average_recon_loss)) | |
print ("Average mcmc l2 loss " + str(average_l2_loss)) | |
print ("Average mcmc latent loss " + str(average_latent_loss)) | |
plot_save(tf.reshape(tf.slice(tf.reshape(P(Q(x_gt)), [32, 32]), [2, 2], [28, 28]), [1, 784]).eval(), | |
'./out/{}_vae_recon.png'.format(img_num)) | |
plot_save(best_recon_sample, './out/{}_best_recon.png'.format(img_num)) | |
plot_save(best_l2_sample, './out/{}_best_l2.png'.format(img_num)) | |
plot_save(best_latent_sample, './out/{}_best_latent.png'.format(img_num)) | |
return best_recon_loss, average_recon_loss, best_l2_loss, average_l2_loss, best_latent_loss, average_latent_loss,\ | |
vae_recon_loss, vae_l2_loss, vae_latent_loss | |
def l2_loss(x_gt, x_hmc): | |
return tf.norm(x_gt - x_hmc).eval() | |
def recon_loss(x_gt, x_hmc): | |
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_hmc, labels=x_gt), 1).eval() | |
def l_latent_loss(l_th_x_gt, l_th_x_hmc): | |
return tf.norm(l_th_x_gt - l_th_x_hmc).eval() | |
def init_uninited_vars(): | |
sess = ed.get_session() | |
unint_vars = [] | |
for var in tf.global_variables(): | |
if not tf.is_variable_initialized(var).eval(): | |
unint_vars.append(var) | |
missingVarInit = tf.variables_initializer(unint_vars) | |
sess.run(missingVarInit) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment