Skip to content

Instantly share code, notes, and snippets.

@previtus
Last active March 30, 2021 19:56
Show Gist options
  • Save previtus/5ec19eb31bbd21e4ff9275999f633e66 to your computer and use it in GitHub Desktop.
Save previtus/5ec19eb31bbd21e4ff9275999f633e66 to your computer and use it in GitHub Desktop.
VAE experiments
# Based on implementations
# - vae core https://github.com/pytorch/examples/blob/master/vae/main.py
# - miwae https://github.com/yoonholee/pytorch-vae
# - notes on VAE from the article at https://iopscience.iop.org/article/10.3847/PSJ/ab9a52 (but can be taken from elsewhere too)
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.normal import Normal
from PIL import Image
import numpy as np
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
help='input batch size for training (default: 20)')
parser.add_argument('--epochs', type=int, default=4000, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=20, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--k', type=int, default=1)
parser.add_argument('--M', type=int, default=1)
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
args.log_interval = 1
torch.manual_seed(args.seed)
device = torch.device("cuda" if args.cuda else "cpu")
print("runnning on", device)
path = "./MNIST"
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
class stochMNIST(datasets.MNIST):
""" Gets a new stochastic binarization of MNIST at each call. """
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img = Image.fromarray(img.numpy(), mode='L')
img = transforms.ToTensor()(img)
img = torch.bernoulli(img) # stochastically binarize
return img, target
def get_mean_img(self):
imgs = self.train_data.type(torch.float) / 255
mean_img = imgs.mean(0).reshape(-1).numpy()
return mean_img
train_loader = torch.utils.data.DataLoader(
stochMNIST(path, train=True, download=True,transform=transforms.ToTensor()),batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
stochMNIST(path, train=False, transform=transforms.ToTensor()),batch_size=args.batch_size, shuffle=True, **kwargs)
def debug_shape(item):
return item.cpu().detach().numpy().shape
class VAE(nn.Module):
def __init__(self, hidden_size = 400, latent_size = 20):
super(VAE, self).__init__()
# encoder layers
self.fc11 = nn.Linear(784, hidden_size)
self.fc12 = nn.Linear(hidden_size, hidden_size)
self.fc21 = nn.Linear(hidden_size, latent_size)
self.fc22 = nn.Linear(hidden_size, latent_size)
# decoder layers
self.fc31 = nn.Linear(latent_size, hidden_size)
self.fc32 = nn.Linear(hidden_size, hidden_size)
self.fc4 = nn.Linear(hidden_size, 784)
self.hidden_size = hidden_size
self.latent_size = latent_size
self.prior_distribution = Normal(torch.zeros([self.latent_size]).to(device), torch.ones([self.latent_size]).to(device))
def encode(self, x):
x = F.tanh(self.fc11(x))
x = F.tanh(self.fc12(x))
mu_enc = self.fc21(x)
std_enc = self.fc22(x)
return Normal(mu_enc, F.softplus(std_enc))
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
x = F.tanh(self.fc31(z))
x = F.tanh(self.fc32(x))
x = self.fc4(x)
return Bernoulli(logits=x)
def forward(self, x, M, k):
input_x = x.view(-1, 784).to(device)
# encoded distribution ~ q(z|x, params) = Normal (real input_x; encoder_into_Mu, encoder_into_Std )
z_distribution = self.encode(input_x)
# sample z values from this distribution
z = z_distribution.rsample(torch.Size([M, k]))
# reconstructions distribution ~ p(x|z, params) = Normal/Bernoulli (sampled z)
x_distribution = self.decode(z)
# priors distribution ~ p(z) = Normal (sampled z; 0s, 1s )
#self.prior_distribution = Normal(torch.zeros([self.latent_size]).to(device), torch.ones([self.latent_size]).to(device))
elbo = self.elbo(input_x, z, x_distribution, z_distribution) # mean_n, imp_n, batch_size
elbo_iwae = self.logmeanexp(elbo, 1).squeeze(1) # mean_n, batch_size
loss = - torch.mean(elbo_iwae, 0) # batch_size
return x_distribution.probs, elbo, loss
def logmeanexp(self, inputs, dim=1): # ***
if inputs.size(dim) == 1:
return inputs
else:
input_max = inputs.max(dim, keepdim=True)[0]
return (inputs - input_max).exp().mean(dim).log() + input_max
def elbo(self, input_x, z, x_distribution, z_distribution):
lpxz = x_distribution.log_prob(input_x).sum(-1)
lpz = self.prior_distribution.log_prob(z).sum(-1)
lqzx = z_distribution.log_prob(z).sum(-1)
kl = -lpz + lqzx
return -kl + lpxz
args.log_interval = 500
M = args.M
k = args.k
#M = 5
#k = 5
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
_, elbo, loss_mk = model(data, M, k)
loss = loss_mk.mean()
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() )) # / len(data)
def test(epoch):
#print_metrics = ((epoch-1) % 10) == 0
print_metrics = True
if print_metrics:
model.eval()
with torch.no_grad():
# Tests:
# IWAE with k, IWAE with 64, IWAE with 5000
elbos = []
for data, _ in test_loader:
_, elbo, _ = model(data, M=1, k=5000)
elbos.append(elbo.squeeze(0))
elbos = np.asarray(elbos)
k_to_run = [k, 64, 5000]
all_losses = []
for k_for_loss in k_to_run:
losses = []
for elbo in elbos[:k_for_loss]:
losses.append(model.logmeanexp(elbo, 0).cpu().numpy().flatten())
loss = np.concatenate(losses).mean()
all_losses.append(- loss)
test_loss_iwae_k, test_loss_iwae64, test_loss_iwae5000 = all_losses
print('====>Test metrics: IWAE M=', M, ',k=',k, ' || epoch', epoch)
print("IWAE-64: ", test_loss_iwae64)
print("logˆp(x) = IWAE-5000: ", test_loss_iwae5000)
print("−KL(Q||P): ", test_loss_iwae64-test_loss_iwae5000)
print("---------------")
if __name__ == "__main__":
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = torch.randn(64, 20).to(device)
sample = model.decode(sample).probs.cpu()
save_image(sample.view(64, 1, 28, 28), 'results/sample_epoch' + str(epoch).zfill(4) + '.png')
Train Epoch: 1 [0/60000 (0%)] Loss: 544.618103
Train Epoch: 1 [10000/60000 (17%)] Loss: 138.734665
Train Epoch: 1 [20000/60000 (33%)] Loss: 115.618584
Train Epoch: 1 [30000/60000 (50%)] Loss: 116.206688
Train Epoch: 1 [40000/60000 (67%)] Loss: 111.551384
Train Epoch: 1 [50000/60000 (83%)] Loss: 120.980362
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 1
IWAE-64: 101.856346
logˆp(x) = IWAE-5000: 101.97784
−KL(Q||P): -0.12149048
---------------
Train Epoch: 2 [0/60000 (0%)] Loss: 101.694847
Train Epoch: 2 [10000/60000 (17%)] Loss: 94.400818
Train Epoch: 2 [20000/60000 (33%)] Loss: 108.299316
Train Epoch: 2 [30000/60000 (50%)] Loss: 101.753235
Train Epoch: 2 [40000/60000 (67%)] Loss: 104.659843
Train Epoch: 2 [50000/60000 (83%)] Loss: 99.216331
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 2
IWAE-64: 97.00398
logˆp(x) = IWAE-5000: 97.345924
−KL(Q||P): -0.34194183
---------------
Train Epoch: 3 [0/60000 (0%)] Loss: 106.367607
Train Epoch: 3 [10000/60000 (17%)] Loss: 102.621948
Train Epoch: 3 [20000/60000 (33%)] Loss: 93.247398
Train Epoch: 3 [30000/60000 (50%)] Loss: 109.849731
Train Epoch: 3 [40000/60000 (67%)] Loss: 105.828445
Train Epoch: 3 [50000/60000 (83%)] Loss: 93.767998
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 3
IWAE-64: 95.24419
logˆp(x) = IWAE-5000: 95.411156
−KL(Q||P): -0.1669693
---------------
Train Epoch: 4 [0/60000 (0%)] Loss: 97.471848
Train Epoch: 4 [10000/60000 (17%)] Loss: 103.686646
Train Epoch: 4 [20000/60000 (33%)] Loss: 102.596367
Train Epoch: 4 [30000/60000 (50%)] Loss: 93.631889
Train Epoch: 4 [40000/60000 (67%)] Loss: 90.186600
Train Epoch: 4 [50000/60000 (83%)] Loss: 100.661491
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 4
IWAE-64: 94.75698
logˆp(x) = IWAE-5000: 94.39016
−KL(Q||P): 0.3668213
---------------
Train Epoch: 5 [0/60000 (0%)] Loss: 109.656487
Train Epoch: 5 [10000/60000 (17%)] Loss: 89.555992
Train Epoch: 5 [20000/60000 (33%)] Loss: 97.195396
Train Epoch: 5 [30000/60000 (50%)] Loss: 100.248428
Train Epoch: 5 [40000/60000 (67%)] Loss: 104.410034
Train Epoch: 5 [50000/60000 (83%)] Loss: 104.687523
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 5
IWAE-64: 94.646225
logˆp(x) = IWAE-5000: 93.69805
−KL(Q||P): 0.9481735
---------------
Train Epoch: 6 [0/60000 (0%)] Loss: 97.765373
Train Epoch: 6 [10000/60000 (17%)] Loss: 107.476028
Train Epoch: 6 [20000/60000 (33%)] Loss: 97.607529
Train Epoch: 6 [30000/60000 (50%)] Loss: 105.302513
Train Epoch: 6 [40000/60000 (67%)] Loss: 109.760330
Train Epoch: 6 [50000/60000 (83%)] Loss: 96.512207
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 6
IWAE-64: 92.74505
logˆp(x) = IWAE-5000: 92.943184
−KL(Q||P): -0.19813538
---------------
Train Epoch: 7 [0/60000 (0%)] Loss: 92.339188
Train Epoch: 7 [10000/60000 (17%)] Loss: 98.588173
Train Epoch: 7 [20000/60000 (33%)] Loss: 89.058235
Train Epoch: 7 [30000/60000 (50%)] Loss: 88.806847
Train Epoch: 7 [40000/60000 (67%)] Loss: 96.309105
Train Epoch: 7 [50000/60000 (83%)] Loss: 94.803154
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 7
IWAE-64: 91.760635
logˆp(x) = IWAE-5000: 92.285355
−KL(Q||P): -0.52471924
---------------
Train Epoch: 8 [0/60000 (0%)] Loss: 87.517845
Train Epoch: 8 [10000/60000 (17%)] Loss: 99.885033
Train Epoch: 8 [20000/60000 (33%)] Loss: 104.214409
Train Epoch: 8 [30000/60000 (50%)] Loss: 97.933716
Train Epoch: 8 [40000/60000 (67%)] Loss: 99.270409
Train Epoch: 8 [50000/60000 (83%)] Loss: 100.278252
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 8
IWAE-64: 90.84507
logˆp(x) = IWAE-5000: 92.03044
−KL(Q||P): -1.1853714
---------------
Train Epoch: 9 [0/60000 (0%)] Loss: 105.112419
Train Epoch: 9 [10000/60000 (17%)] Loss: 91.309120
Train Epoch: 9 [20000/60000 (33%)] Loss: 96.311066
Train Epoch: 9 [30000/60000 (50%)] Loss: 103.695045
Train Epoch: 9 [40000/60000 (67%)] Loss: 102.628288
Train Epoch: 9 [50000/60000 (83%)] Loss: 94.594231
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 9
IWAE-64: 92.28972
logˆp(x) = IWAE-5000: 91.84398
−KL(Q||P): 0.44573975
---------------
Train Epoch: 10 [0/60000 (0%)] Loss: 102.444267
Train Epoch: 10 [10000/60000 (17%)] Loss: 98.669945
Train Epoch: 10 [20000/60000 (33%)] Loss: 91.118675
Train Epoch: 10 [30000/60000 (50%)] Loss: 96.950302
Train Epoch: 10 [40000/60000 (67%)] Loss: 107.136940
Train Epoch: 10 [50000/60000 (83%)] Loss: 97.390648
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 10
IWAE-64: 91.120186
logˆp(x) = IWAE-5000: 91.54648
−KL(Q||P): -0.42629242
---------------
Train Epoch: 11 [0/60000 (0%)] Loss: 109.360893
Train Epoch: 11 [10000/60000 (17%)] Loss: 104.534805
Train Epoch: 11 [20000/60000 (33%)] Loss: 104.689880
Train Epoch: 11 [30000/60000 (50%)] Loss: 103.057434
Train Epoch: 11 [40000/60000 (67%)] Loss: 105.310524
Train Epoch: 11 [50000/60000 (83%)] Loss: 92.356544
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 11
IWAE-64: 91.86742
logˆp(x) = IWAE-5000: 91.24765
−KL(Q||P): 0.61976624
---------------
Train Epoch: 12 [0/60000 (0%)] Loss: 94.402412
Train Epoch: 12 [10000/60000 (17%)] Loss: 102.237709
Train Epoch: 12 [20000/60000 (33%)] Loss: 92.341049
Train Epoch: 12 [30000/60000 (50%)] Loss: 89.968994
Train Epoch: 12 [40000/60000 (67%)] Loss: 91.344337
Train Epoch: 12 [50000/60000 (83%)] Loss: 99.439751
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 12
IWAE-64: 89.93662
logˆp(x) = IWAE-5000: 91.165276
−KL(Q||P): -1.228653
---------------
Train Epoch: 13 [0/60000 (0%)] Loss: 91.478836
Train Epoch: 13 [10000/60000 (17%)] Loss: 94.909088
Train Epoch: 13 [20000/60000 (33%)] Loss: 91.767891
Train Epoch: 13 [30000/60000 (50%)] Loss: 92.367569
Train Epoch: 13 [40000/60000 (67%)] Loss: 107.229668
Train Epoch: 13 [50000/60000 (83%)] Loss: 98.232750
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 13
IWAE-64: 91.38028
logˆp(x) = IWAE-5000: 90.99059
−KL(Q||P): 0.38968658
---------------
Train Epoch: 14 [0/60000 (0%)] Loss: 90.363869
Train Epoch: 14 [10000/60000 (17%)] Loss: 99.742142
Train Epoch: 14 [20000/60000 (33%)] Loss: 91.261124
Train Epoch: 14 [30000/60000 (50%)] Loss: 90.453880
Train Epoch: 14 [40000/60000 (67%)] Loss: 98.580307
Train Epoch: 14 [50000/60000 (83%)] Loss: 99.148628
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 14
IWAE-64: 90.65437
logˆp(x) = IWAE-5000: 90.89895
−KL(Q||P): -0.2445755
---------------
Train Epoch: 15 [0/60000 (0%)] Loss: 108.186623
Train Epoch: 15 [10000/60000 (17%)] Loss: 92.393219
Train Epoch: 15 [20000/60000 (33%)] Loss: 100.103477
Train Epoch: 15 [30000/60000 (50%)] Loss: 85.533005
Train Epoch: 15 [40000/60000 (67%)] Loss: 103.622581
Train Epoch: 15 [50000/60000 (83%)] Loss: 102.047340
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 15
IWAE-64: 90.0911
logˆp(x) = IWAE-5000: 90.891205
−KL(Q||P): -0.80010223
---------------
Train Epoch: 16 [0/60000 (0%)] Loss: 98.122261
Train Epoch: 16 [10000/60000 (17%)] Loss: 92.934647
Train Epoch: 16 [20000/60000 (33%)] Loss: 85.830734
Train Epoch: 16 [30000/60000 (50%)] Loss: 95.870377
Train Epoch: 16 [40000/60000 (67%)] Loss: 93.688805
Train Epoch: 16 [50000/60000 (83%)] Loss: 90.419800
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 16
IWAE-64: 89.59951
logˆp(x) = IWAE-5000: 90.65003
−KL(Q||P): -1.0505219
---------------
Train Epoch: 17 [0/60000 (0%)] Loss: 93.840065
Train Epoch: 17 [10000/60000 (17%)] Loss: 86.847694
Train Epoch: 17 [20000/60000 (33%)] Loss: 98.986687
Train Epoch: 17 [30000/60000 (50%)] Loss: 98.521729
Train Epoch: 17 [40000/60000 (67%)] Loss: 99.243057
Train Epoch: 17 [50000/60000 (83%)] Loss: 91.025291
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 17
IWAE-64: 90.11647
logˆp(x) = IWAE-5000: 90.590324
−KL(Q||P): -0.47385406
---------------
Train Epoch: 18 [0/60000 (0%)] Loss: 94.464935
Train Epoch: 18 [10000/60000 (17%)] Loss: 99.852882
Train Epoch: 18 [20000/60000 (33%)] Loss: 91.386147
Train Epoch: 18 [30000/60000 (50%)] Loss: 90.344818
Train Epoch: 18 [40000/60000 (67%)] Loss: 92.691124
Train Epoch: 18 [50000/60000 (83%)] Loss: 97.712929
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 18
IWAE-64: 90.838585
logˆp(x) = IWAE-5000: 90.65541
−KL(Q||P): 0.18317413
---------------
Train Epoch: 19 [0/60000 (0%)] Loss: 99.118088
Train Epoch: 19 [10000/60000 (17%)] Loss: 105.104935
Train Epoch: 19 [20000/60000 (33%)] Loss: 94.164665
Train Epoch: 19 [30000/60000 (50%)] Loss: 100.436256
Train Epoch: 19 [40000/60000 (67%)] Loss: 90.244896
Train Epoch: 19 [50000/60000 (83%)] Loss: 86.268738
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 19
IWAE-64: 89.80083
logˆp(x) = IWAE-5000: 90.411835
−KL(Q||P): -0.6110077
---------------
Train Epoch: 20 [0/60000 (0%)] Loss: 105.900833
Train Epoch: 20 [10000/60000 (17%)] Loss: 85.296181
Train Epoch: 20 [20000/60000 (33%)] Loss: 102.006134
Train Epoch: 20 [30000/60000 (50%)] Loss: 91.458534
Train Epoch: 20 [40000/60000 (67%)] Loss: 98.606804
Train Epoch: 20 [50000/60000 (83%)] Loss: 92.486732
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 20
IWAE-64: 89.921814
logˆp(x) = IWAE-5000: 90.39562
−KL(Q||P): -0.4738083
---------------
Train Epoch: 21 [0/60000 (0%)] Loss: 86.889793
Train Epoch: 21 [10000/60000 (17%)] Loss: 93.808105
Train Epoch: 21 [20000/60000 (33%)] Loss: 85.814552
Train Epoch: 21 [30000/60000 (50%)] Loss: 97.433723
Train Epoch: 21 [40000/60000 (67%)] Loss: 92.292229
Train Epoch: 21 [50000/60000 (83%)] Loss: 84.512245
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 21
IWAE-64: 90.70811
logˆp(x) = IWAE-5000: 90.39505
−KL(Q||P): 0.31305695
---------------
Train Epoch: 22 [0/60000 (0%)] Loss: 97.888206
Train Epoch: 22 [10000/60000 (17%)] Loss: 95.112480
Train Epoch: 22 [20000/60000 (33%)] Loss: 96.822960
Train Epoch: 22 [30000/60000 (50%)] Loss: 105.579887
Train Epoch: 22 [40000/60000 (67%)] Loss: 88.926628
Train Epoch: 22 [50000/60000 (83%)] Loss: 83.429054
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 22
IWAE-64: 90.45231
logˆp(x) = IWAE-5000: 90.28255
−KL(Q||P): 0.16976166
---------------
Train Epoch: 23 [0/60000 (0%)] Loss: 89.223228
Train Epoch: 23 [10000/60000 (17%)] Loss: 93.890137
Train Epoch: 23 [20000/60000 (33%)] Loss: 93.568741
Train Epoch: 23 [30000/60000 (50%)] Loss: 88.926697
Train Epoch: 23 [40000/60000 (67%)] Loss: 92.509758
Train Epoch: 23 [50000/60000 (83%)] Loss: 99.818192
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 23
IWAE-64: 89.77586
logˆp(x) = IWAE-5000: 90.075615
−KL(Q||P): -0.29975128
---------------
Train Epoch: 24 [0/60000 (0%)] Loss: 82.467995
Train Epoch: 24 [10000/60000 (17%)] Loss: 95.007713
Train Epoch: 24 [20000/60000 (33%)] Loss: 102.897850
Train Epoch: 24 [30000/60000 (50%)] Loss: 103.482498
Train Epoch: 24 [40000/60000 (67%)] Loss: 94.505943
Train Epoch: 24 [50000/60000 (83%)] Loss: 97.068161
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 24
IWAE-64: 90.61882
logˆp(x) = IWAE-5000: 90.10386
−KL(Q||P): 0.51496124
---------------
Train Epoch: 25 [0/60000 (0%)] Loss: 88.563004
Train Epoch: 25 [10000/60000 (17%)] Loss: 96.062202
Train Epoch: 25 [20000/60000 (33%)] Loss: 91.589104
Train Epoch: 25 [30000/60000 (50%)] Loss: 100.115807
Train Epoch: 25 [40000/60000 (67%)] Loss: 97.718956
Train Epoch: 25 [50000/60000 (83%)] Loss: 92.590294
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 25
IWAE-64: 90.77629
logˆp(x) = IWAE-5000: 90.21303
−KL(Q||P): 0.56326294
---------------
Train Epoch: 26 [0/60000 (0%)] Loss: 90.206627
Train Epoch: 26 [10000/60000 (17%)] Loss: 95.104202
Train Epoch: 26 [20000/60000 (33%)] Loss: 99.151428
Train Epoch: 26 [30000/60000 (50%)] Loss: 93.590454
Train Epoch: 26 [40000/60000 (67%)] Loss: 92.422302
Train Epoch: 26 [50000/60000 (83%)] Loss: 103.758888
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 26
IWAE-64: 90.73986
logˆp(x) = IWAE-5000: 90.08644
−KL(Q||P): 0.6534195
---------------
Train Epoch: 27 [0/60000 (0%)] Loss: 98.630524
Train Epoch: 27 [10000/60000 (17%)] Loss: 84.656273
Train Epoch: 27 [20000/60000 (33%)] Loss: 102.395241
Train Epoch: 27 [30000/60000 (50%)] Loss: 103.834000
Train Epoch: 27 [40000/60000 (67%)] Loss: 86.922234
Train Epoch: 27 [50000/60000 (83%)] Loss: 111.384987
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 27
IWAE-64: 89.31232
logˆp(x) = IWAE-5000: 89.94471
−KL(Q||P): -0.6323929
---------------
Train Epoch: 28 [0/60000 (0%)] Loss: 90.933304
Train Epoch: 28 [10000/60000 (17%)] Loss: 99.818108
Train Epoch: 28 [20000/60000 (33%)] Loss: 87.769615
Train Epoch: 28 [30000/60000 (50%)] Loss: 94.958702
Train Epoch: 28 [40000/60000 (67%)] Loss: 93.918137
Train Epoch: 28 [50000/60000 (83%)] Loss: 99.295448
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 28
IWAE-64: 89.74451
logˆp(x) = IWAE-5000: 90.16274
−KL(Q||P): -0.41823578
---------------
Train Epoch: 29 [0/60000 (0%)] Loss: 88.518005
Train Epoch: 29 [10000/60000 (17%)] Loss: 92.207855
Train Epoch: 29 [20000/60000 (33%)] Loss: 100.995888
Train Epoch: 29 [30000/60000 (50%)] Loss: 81.873978
Train Epoch: 29 [40000/60000 (67%)] Loss: 106.279015
Train Epoch: 29 [50000/60000 (83%)] Loss: 98.464409
====>Test metrics: IWAE M= 8 ,k= 8 || epoch 29
IWAE-64: 88.7935
logˆp(x) = IWAE-5000: 90.025635
−KL(Q||P): -1.232132
---------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment