Skip to content

Instantly share code, notes, and snippets.

@delta2323
Last active October 17, 2017 18:50
Show Gist options
  • Save delta2323/45e642dbbbc81baef0f686af5a056c5c to your computer and use it in GitHub Desktop.
Save delta2323/45e642dbbbc81baef0f686af5a056c5c to your computer and use it in GitHub Desktop.
class ConvolutionalAutoEncoder(chainer.Chain):
def __init__(self):
super(ConvolutionalAutoEncoder).__init__(
c1=L.Convolution2D(...),
c2=L.Convolution2D(...),
dc1=L.Deconvolution2D(...),
dc2=L.Deconvolution2D(...),
)
def convolve(self, x):
return self.c2(self.c1(x)) # optionally insert F.max_pooling_2d
def deconvolve(self, h):
return self.dc2(self.dc1(h))
def __call__(self, x):
h = self.convolve(x)
x_hat = self.deconvolve(h)
# Instead of using L.Classifier, we calculate the loss value inside of the autoencoder.
# If we want to get reconstructed images, we need to implement a method like thist:
#
# def reconstruct(self, x):
# return self.deconvolve(self.convolve(x))
#
# As we compare with mean_squared_error, x and x_hat must have shapes.
loss = F.mean_squared_error(x, x_hat)
return loss
model = ConvolutionalAutoEncoder()
# Create optimizer, updater, trainer as usual
trainer.run()
# convolve and deconvolve test images
images = get_test_batch()
convolved_images = model.convolve(images) # convolved_images is an instance of chainer.Variable, you can have access to raw data with .data attribute
deconvolved_images = model.deconvolve(convolved_images)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment