Skip to content

Instantly share code, notes, and snippets.

@Alexander-Barth
Created June 2, 2021 16:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Alexander-Barth/e9d0b19e274b56baa1913dfb6098e23a to your computer and use it in GitHub Desktop.
Save Alexander-Barth/e9d0b19e274b56baa1913dfb6098e23a to your computer and use it in GitHub Desktop.
convolutional varitional autoencoder in Flux.jl
# adapted from
# Keras_code_sample_for_Google_IO_2021
# Modern Keras design patterns | Session
# https://www.youtube.com/watch?v=FCz9m4T0DI0
# code: https://goo.gle/3t0rSpo
# by Martin Gorner, François Chollet
using Flux
using MLDatasets
using Random
using Statistics
using Test
using ProgressMeter: Progress, next!
using PyPlot
using BSON
# cpu or gpu
const device = gpu
# helper function to show the output size of intermediate layers
@inline function showsz(x)
#@show size(x)
return x
end
function sample(z_mean, z_log_var)
epsilon = device(randn(Float32,size(z_mean)))
return z_mean + exp.(z_log_var/2) .* epsilon
end
# note:
# Keras_code_sample_for_Google_IO_2021 trains over training and test data
# i.e. 70 000 samples
x = cat(
MNIST.traindata(Float32)[1],
MNIST.testdata(Float32)[1],
dims = 3)
digit_size = size(x,1)
x_train = reshape(x,(size(x,1),size(x,2),1,size(x,3)))
latent_dim = 2
batchsize = 128
# encoder
encoder = Chain(
showsz,
Conv((3,3),1 => 32,relu,stride = 2, pad = 1),
showsz,
Conv((3,3),32 => 64,relu,stride = 2, pad = 1),
showsz,
flatten,
showsz,
Dense(3136,16,relu),
showsz,
# identity as activation function
Dense(16,2*2)) |> device
#=
input = device(x_train[:,:,:,1:batchsize])
x = encoder(device(input))
z_mean = x[1:2,:]
z_log_var = x[3:4,:]
z = sample(z_mean, z_log_var)
=#
# decoder
decoder = Chain(
showsz,
Dense(latent_dim,7 * 7 * 64,relu),
showsz,
x -> reshape(x,(7,7,64,:)),
showsz,
# note SamePad() is not possible here
#ConvTranspose((3,3),64 => 64,relu,stride=2, pad = SamePad()),
ConvTranspose((3,3),64 => 64,relu,stride=2, pad = 0),
x -> x[1:end-1,1:end-1,:,:],
showsz,
#ConvTranspose((3,3),64 => 32,relu,stride=2, pad = SamePad()),
ConvTranspose((3,3),64 => 32,relu,stride=2, pad = 0),
x -> x[1:end-1,1:end-1,:,:],
showsz,
ConvTranspose((3,3),32 => 1,sigmoid, pad = 1)) |> device
#=
output = decoder(z);
@show size(output)
@show maximum(output)
@show maximum(input)
=#
function kl_divergence(z_mean, z_log_var)
kl_loss = -(1 .+ z_log_var - z_mean.^2 - exp.(z_log_var)) / 2
# sum over latent space and
# average over batch
return mean(sum(kl_loss, dims=1))
end
#=
# python test code
z_mean = tf.Variable(np.array([[1,2],[2,5],[2,5]],dtype = "float"))
z_log_var = tf.Variable(np.array([[1,2],[2,5],[-2,5]], dtype = "float"))
kl_divergence(z_mean, z_log_var)
=#
@debug begin
z_mean = hcat([1,2],[2,5],[2,5.f0])
z_log_var = hcat([1,2],[2,5],[-2,5.f0])
# test value from Keras_code_sample_for_Google_IO_2021
@test kl_divergence(z_mean, z_log_var) ≈ 59.743007919118355
o = Float32.(reshape(1:120,(3,4,1,10)))/200
i = Float32.(reshape(2:121,(3,4,1,10)))/200
@test Flux.Losses.binarycrossentropy(o,i; agg = sum) / size(o,4) ≈ 6.471607f0
end
function model_loss(input)
x = encoder(input)
z_mean = x[1:2,:]
z_log_var = x[3:4,:]
z = sample(z_mean, z_log_var)
output = decoder(z)
batchsize = size(output)[end]
kl_loss = kl_divergence(z_mean, z_log_var)
reconstruction_loss_val = Flux.Losses.binarycrossentropy(output,input; agg = sum)/batchsize
return reconstruction_loss_val + kl_loss
end
data = Flux.Data.DataLoader(x_train, batchsize=batchsize)
input = first(data)
@show size(input)
# smoke-test
@show model_loss(device(input))
@info "number of parameter in encoder " sum(length.(params(encoder)))
@info "number of parameter in decoder " sum(length.(params(decoder)))
# all parameters to optimize
parameters = params(encoder,decoder)
opt = ADAM()
for e = 1:30
#progress = Progress(length(data))
@time for (i,d) in enumerate(data)
loss, back = Flux.pullback(parameters) do
model_loss(device(d))
end
grads = back(1f0)
Flux.Optimise.update!(opt, parameters, grads)
#next!(progress; showvalues=[(:loss, loss)])
end
end
# plot latent space
scale = 1
n = 30
grid_x = LinRange(-scale,scale,n)
grid_y = LinRange(-scale,scale,n)
fig = zeros(Float32,digit_size,n,digit_size,n)
z_sample = zeros(Float32,latent_dim,1)
for (j,y) in enumerate(grid_y)
for (i,x) in enumerate(grid_x)
z_sample[:,1] = [x, y]
x_decoded = cpu(decoder(device(z_sample)))
digit = reshape(x_decoded,(digit_size,digit_size))
fig[:,i,:,j] = digit
end
end
fig = reshape(fig,(digit_size*n,digit_size*n))
clf()
PyPlot.imshow(fig', cmap="Greys_r")
savefig("latent_space.png")
# plot label clusters
x,y_train = MNIST.traindata(Float32)
x_train = reshape(x,(size(x,1),size(x,2),1,size(x,3)))
output = cpu(encoder(device(x_train)));
z_mean = output[1:2,:]
clf()
scatter(z_mean[1,:],z_mean[2,:],c=y_train,cmap="jet")
colorbar()
xlabel("z[0]")
ylabel("z[1]")
savefig("label_clusters.png")
# save model
save_path = "."
model_path = joinpath(save_path, "model.bson")
let encoder = cpu(encoder), decoder = cpu(decoder)
BSON.@save model_path encoder decoder
@info "Model saved: $(model_path)"
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment