Skip to content

Instantly share code, notes, and snippets.

@colspan
Created December 10, 2017 02:24
Show Gist options
  • Save colspan/bb029025881ddcdce9f70838aff4aa82 to your computer and use it in GitHub Desktop.
Save colspan/bb029025881ddcdce9f70838aff4aa82 to your computer and use it in GitHub Desktop.
Chainer Implementation of Convolutional Variational AutoEncoder
#!/usr/bin/env python
"""
Copyright Kunihiko Miyoshi, Preferred Networks Inc.
MIT License
Reference
https://github.com/pfnet/chainer/blob/master/examples/vae/net.py
https://github.com/crcrpar/chainer-VAE
https://github.com/mlosch/CVAE-Torch/blob/master/CVAE.lua
"""
import six
import chainer
import chainer.functions as F
from chainer.functions.loss.vae import gaussian_kl_divergence
import chainer.links as L
class CVAE(chainer.Chain):
"""Convolutional Variational AutoEncoder"""
def __init__(self, n_ch, n_latent, n_first, C=1.0, k=1, wscale=0.02):
"""
Args:
args below are for loss function
C (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
train (bool): If true loss_function is used for training.
"""
w = chainer.initializers.Normal(scale=wscale)
super(CVAE, self).__init__(
e_c0=L.Convolution2D(n_ch, n_first, 4, 2, 1, initialW=w),
e_c1=L.Convolution2D(n_first, n_first * 2, 4, 2, 1, initialW=w),
e_c2=L.Convolution2D(
n_first * 2, n_first * 4, 4, 2, 1, initialW=w),
e_c3=L.Convolution2D(
n_first * 4, n_first * 8, 4, 2, 1, initialW=w),
e_bn1=L.BatchNormalization(n_first * 2, use_gamma=False),
e_bn2=L.BatchNormalization(n_first * 4, use_gamma=False),
e_bn3=L.BatchNormalization(n_first * 8, use_gamma=False),
e_mu=L.Linear(n_first * 8 * 4, n_latent),
e_ln_var=L.Linear(n_first * 8 * 4, n_latent),
d_l0=L.Linear(n_latent, n_first * 8 * 4),
d_dc0=L.Deconvolution2D(
n_first * 8, n_first * 4, 4, 2, 1, initialW=w),
d_dc1=L.Deconvolution2D(
n_first * 4, n_first * 2, 4, 2, 1, initialW=w),
d_dc2=L.Deconvolution2D(n_first * 2, n_first, 4, 2, 1, initialW=w),
d_dc3=L.Deconvolution2D(n_first, 1, 4, 2, 1, initialW=w),
d_bn0=L.BatchNormalization(n_first * 4, use_gamma=False),
d_bn1=L.BatchNormalization(n_first * 2, use_gamma=False),
d_bn2=L.BatchNormalization(n_first, use_gamma=False),
d_bn3=L.BatchNormalization(1, use_gamma=False)
)
self.n_first = n_first
self.C = C
self.k = k
def __call__(self, x):
"""AutoEncoder"""
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(self.k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z)) \
/ (self.k * batchsize)
loss = rec_loss + \
self.C * gaussian_kl_divergence(mu, ln_var) / batchsize
chainer.report({'loss': loss}, self)
return loss
def encode(self, x):
# print "=== encoder ==="
# print x.shape
h = F.leaky_relu(self.e_c0(x), slope=0.2)
# print h.shape
h = F.leaky_relu(self.e_bn1(self.e_c1(h)), slope=0.2)
# print h.shape
h = F.leaky_relu(self.e_bn2(self.e_c2(h)), slope=0.2)
# print h.shape
h = F.leaky_relu(self.e_bn3(self.e_c3(h)), slope=0.2)
# print h.shape
h = F.reshape(h, (-1, self.n_first * 8 * 4))
# print h.shape
mu = self.e_mu(h)
ln_var = self.e_ln_var(h)
# print mu.shape
return mu, ln_var
def decode(self, z):
# print "=== decoder ==="
# print z.shape
h = F.relu(self.d_l0(z))
# print h.shape
h = F.reshape(h, (-1, self.n_first * 8, 2, 2))
# print h.shape
h = F.relu(self.d_bn0(self.d_dc0(h)))
# print h.shape
h = F.relu(self.d_bn1(self.d_dc1(h)))
# print h.shape
h = F.relu(self.d_bn2(self.d_dc2(h)))
# print h.shape
h = F.relu(self.d_bn3(self.d_dc3(h)))
# print h.shape
return h
def get_loss_func(self, C=1.0, k=1, train=True):
"""Get loss function of VAE.
The loss value is equal to ELBO (Evidence Lower Bound)
multiplied by -1.
Args:
C (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
train (bool): If true loss_function is used for training.
"""
def lf(self, x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(self.k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z)) \
/ (self.k * batchsize)
self.rec_loss = rec_loss
self.loss = self.rec_loss + \
self.C * gaussian_kl_divergence(mu, ln_var) / batchsize
return self.loss
class VAE(chainer.Chain):
"""Variational AutoEncoder"""
def __init__(self, n_in, n_latent, n_h, C=1.0, k=1):
"""
Args:
args below are for loss function
C (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
train (bool): If true loss_function is used for training.
"""
super(VAE, self).__init__(
# encoder
le1=L.Linear(n_in, n_h),
le2_mu=L.Linear(n_h, n_latent),
le2_ln_var=L.Linear(n_h, n_latent),
# decoder
ld1=L.Linear(n_latent, n_h),
ld2=L.Linear(n_h, n_in),
)
self.C = C
self.k = k
def __call__(self, x, sigmoid=True):
"""AutoEncoder"""
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(self.k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \
/ (self.k * batchsize)
loss = rec_loss + \
self.C * gaussian_kl_divergence(mu, ln_var) / batchsize
chainer.report({'loss': loss}, self)
return loss
def encode(self, x):
h1 = F.tanh(self.le1(x))
mu = self.le2_mu(h1)
ln_var = self.le2_ln_var(h1) # log(sigma**2)
return mu, ln_var
def decode(self, z, sigmoid=True):
h1 = F.tanh(self.ld1(z))
h2 = self.ld2(h1)
if sigmoid:
return F.sigmoid(h2)
else:
return h2
def get_loss_func(self, C=1.0, k=1, train=True):
"""Get loss function of VAE.
The loss value is equal to ELBO (Evidence Lower Bound)
multiplied by -1.
Args:
C (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
train (bool): If true loss_function is used for training.
"""
def lf(self, x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(self.k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \
/ (self.k * batchsize)
self.rec_loss = rec_loss
self.loss = self.rec_loss + \
self.C * gaussian_kl_divergence(mu, ln_var) / batchsize
return self.loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment