Created
June 2, 2021 16:21
-
-
Save Alexander-Barth/e9d0b19e274b56baa1913dfb6098e23a to your computer and use it in GitHub Desktop.
convolutional varitional autoencoder in Flux.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# adapted from | |
# Keras_code_sample_for_Google_IO_2021 | |
# Modern Keras design patterns | Session | |
# https://www.youtube.com/watch?v=FCz9m4T0DI0 | |
# code: https://goo.gle/3t0rSpo | |
# by Martin Gorner, François Chollet | |
using Flux | |
using MLDatasets | |
using Random | |
using Statistics | |
using Test | |
using ProgressMeter: Progress, next! | |
using PyPlot | |
using BSON | |
# cpu or gpu | |
const device = gpu | |
# helper function to show the output size of intermediate layers | |
@inline function showsz(x) | |
#@show size(x) | |
return x | |
end | |
function sample(z_mean, z_log_var) | |
epsilon = device(randn(Float32,size(z_mean))) | |
return z_mean + exp.(z_log_var/2) .* epsilon | |
end | |
# note: | |
# Keras_code_sample_for_Google_IO_2021 trains over training and test data | |
# i.e. 70 000 samples | |
x = cat( | |
MNIST.traindata(Float32)[1], | |
MNIST.testdata(Float32)[1], | |
dims = 3) | |
digit_size = size(x,1) | |
x_train = reshape(x,(size(x,1),size(x,2),1,size(x,3))) | |
latent_dim = 2 | |
batchsize = 128 | |
# encoder | |
encoder = Chain( | |
showsz, | |
Conv((3,3),1 => 32,relu,stride = 2, pad = 1), | |
showsz, | |
Conv((3,3),32 => 64,relu,stride = 2, pad = 1), | |
showsz, | |
flatten, | |
showsz, | |
Dense(3136,16,relu), | |
showsz, | |
# identity as activation function | |
Dense(16,2*2)) |> device | |
#= | |
input = device(x_train[:,:,:,1:batchsize]) | |
x = encoder(device(input)) | |
z_mean = x[1:2,:] | |
z_log_var = x[3:4,:] | |
z = sample(z_mean, z_log_var) | |
=# | |
# decoder | |
decoder = Chain( | |
showsz, | |
Dense(latent_dim,7 * 7 * 64,relu), | |
showsz, | |
x -> reshape(x,(7,7,64,:)), | |
showsz, | |
# note SamePad() is not possible here | |
#ConvTranspose((3,3),64 => 64,relu,stride=2, pad = SamePad()), | |
ConvTranspose((3,3),64 => 64,relu,stride=2, pad = 0), | |
x -> x[1:end-1,1:end-1,:,:], | |
showsz, | |
#ConvTranspose((3,3),64 => 32,relu,stride=2, pad = SamePad()), | |
ConvTranspose((3,3),64 => 32,relu,stride=2, pad = 0), | |
x -> x[1:end-1,1:end-1,:,:], | |
showsz, | |
ConvTranspose((3,3),32 => 1,sigmoid, pad = 1)) |> device | |
#= | |
output = decoder(z); | |
@show size(output) | |
@show maximum(output) | |
@show maximum(input) | |
=# | |
function kl_divergence(z_mean, z_log_var) | |
kl_loss = -(1 .+ z_log_var - z_mean.^2 - exp.(z_log_var)) / 2 | |
# sum over latent space and | |
# average over batch | |
return mean(sum(kl_loss, dims=1)) | |
end | |
#= | |
# python test code | |
z_mean = tf.Variable(np.array([[1,2],[2,5],[2,5]],dtype = "float")) | |
z_log_var = tf.Variable(np.array([[1,2],[2,5],[-2,5]], dtype = "float")) | |
kl_divergence(z_mean, z_log_var) | |
=# | |
@debug begin | |
z_mean = hcat([1,2],[2,5],[2,5.f0]) | |
z_log_var = hcat([1,2],[2,5],[-2,5.f0]) | |
# test value from Keras_code_sample_for_Google_IO_2021 | |
@test kl_divergence(z_mean, z_log_var) ≈ 59.743007919118355 | |
o = Float32.(reshape(1:120,(3,4,1,10)))/200 | |
i = Float32.(reshape(2:121,(3,4,1,10)))/200 | |
@test Flux.Losses.binarycrossentropy(o,i; agg = sum) / size(o,4) ≈ 6.471607f0 | |
end | |
function model_loss(input) | |
x = encoder(input) | |
z_mean = x[1:2,:] | |
z_log_var = x[3:4,:] | |
z = sample(z_mean, z_log_var) | |
output = decoder(z) | |
batchsize = size(output)[end] | |
kl_loss = kl_divergence(z_mean, z_log_var) | |
reconstruction_loss_val = Flux.Losses.binarycrossentropy(output,input; agg = sum)/batchsize | |
return reconstruction_loss_val + kl_loss | |
end | |
data = Flux.Data.DataLoader(x_train, batchsize=batchsize) | |
input = first(data) | |
@show size(input) | |
# smoke-test | |
@show model_loss(device(input)) | |
@info "number of parameter in encoder " sum(length.(params(encoder))) | |
@info "number of parameter in decoder " sum(length.(params(decoder))) | |
# all parameters to optimize | |
parameters = params(encoder,decoder) | |
opt = ADAM() | |
for e = 1:30 | |
#progress = Progress(length(data)) | |
@time for (i,d) in enumerate(data) | |
loss, back = Flux.pullback(parameters) do | |
model_loss(device(d)) | |
end | |
grads = back(1f0) | |
Flux.Optimise.update!(opt, parameters, grads) | |
#next!(progress; showvalues=[(:loss, loss)]) | |
end | |
end | |
# plot latent space | |
scale = 1 | |
n = 30 | |
grid_x = LinRange(-scale,scale,n) | |
grid_y = LinRange(-scale,scale,n) | |
fig = zeros(Float32,digit_size,n,digit_size,n) | |
z_sample = zeros(Float32,latent_dim,1) | |
for (j,y) in enumerate(grid_y) | |
for (i,x) in enumerate(grid_x) | |
z_sample[:,1] = [x, y] | |
x_decoded = cpu(decoder(device(z_sample))) | |
digit = reshape(x_decoded,(digit_size,digit_size)) | |
fig[:,i,:,j] = digit | |
end | |
end | |
fig = reshape(fig,(digit_size*n,digit_size*n)) | |
clf() | |
PyPlot.imshow(fig', cmap="Greys_r") | |
savefig("latent_space.png") | |
# plot label clusters | |
x,y_train = MNIST.traindata(Float32) | |
x_train = reshape(x,(size(x,1),size(x,2),1,size(x,3))) | |
output = cpu(encoder(device(x_train))); | |
z_mean = output[1:2,:] | |
clf() | |
scatter(z_mean[1,:],z_mean[2,:],c=y_train,cmap="jet") | |
colorbar() | |
xlabel("z[0]") | |
ylabel("z[1]") | |
savefig("label_clusters.png") | |
# save model | |
save_path = "." | |
model_path = joinpath(save_path, "model.bson") | |
let encoder = cpu(encoder), decoder = cpu(decoder) | |
BSON.@save model_path encoder decoder | |
@info "Model saved: $(model_path)" | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment