Skip to content

Instantly share code, notes, and snippets.

@delta2323
Last active October 17, 2017 18:50
Show Gist options
  • Save delta2323/5ee462335e5a8676584a8424c8effb27 to your computer and use it in GitHub Desktop.
Save delta2323/5ee462335e5a8676584a8424c8effb27 to your computer and use it in GitHub Desktop.
Chainer example of Tied-weight Autoencoder (Autoencoder with sharing weights)
import chainer
import chainer.functions as F
from chainer import initializers as I
from chainer import reporter
from chainer import training
from chainer.training import extensions as E
import numpy
import scipy.misc
class TiedWeightAutoEncoder(chainer.Chain):
def __init__(self, n_in, n_hidden):
super(TiedWeightAutoEncoder, self).__init__()
self.add_param('W', (n_hidden, n_in), numpy.float32, I.HeNormal())
self.add_param('b_enc', (n_hidden,), numpy.float32, I.Constant(0))
self.add_param('b_dec', (n_in,), numpy.float32, I.Constant(0))
def encode(self, x):
return F.linear(x, self.W, self.b_enc)
def decode(self, h):
return F.linear(h, F.transpose(self.W), self.b_dec)
def reconstruct(self, x):
h = self.encode(x)
return self.decode(h)
def __call__(self, x, t):
x_rec = self.reconstruct(x)
loss = F.mean_squared_error(x, x_rec)
reporter.report({'loss': loss}, self)
return loss
chainer.set_debug(True)
model = TiedWeightAutoEncoder(784, 100)
gpu = -1
if gpu >= 0:
chainer.cuda.get_device(gpu).use()
model.to_gpu()
opt = chainer.optimizers.Adam()
opt.setup(model)
batchsize = 128
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize,
repeat=False, shuffle=False)
epoch = 5
updater = training.StandardUpdater(train_iter, opt, device=gpu)
trainer = training.Trainer(updater, (epoch, 'epoch'))
trainer.extend(E.Evaluator(test_iter, model, device=gpu))
trainer.extend(E.LogReport())
trainer.extend(E.PrintReport(['epoch', 'main/loss',
'validation/main/loss',
'elapsed_time']))
trainer.run()
# Reconstruct
x = train[0][0][None]
x_rec = model.reconstruct(x).data[0]
scipy.misc.imsave('orig.png', x.reshape(28, 28))
scipy.misc.imsave('rec.png', x_rec.reshape(28, 28))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment