-
-
Save sanderbboisen/8e667b78348587169a4ff90f2c5dbe7c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux; | |
using Flux:Dense; | |
using Flux:Chain; | |
using Flux:glorot_normal; | |
using Flux:sparse_init; | |
using Flux:relu; | |
using Flux:Dropout | |
using Flux.Optimise:update! | |
using Flux.Losses: logitbinarycrossentropy | |
using Statistics:mean; | |
using GR; | |
using MLDatasets; | |
using Random; | |
using LinearAlgebra; | |
# Preparing the dataset | |
iris = MLDatasets.Iris.features() | |
# May reuse later | |
function scale_0_1!(data) | |
for j in 1:size(data,1) | |
max_row, min_row = maximum(data[j, :]), minimum(data[j,:]); | |
data[j,:] = (data[j,:].-min_row)./(max_row-min_row) | |
end # end for-loop | |
end; # End scale_0_1 | |
scale_0_1!(iris) | |
# Loss functions | |
function wasserstein_c_loss(sample, gen) | |
return mean(gen) - mean(sample) | |
end; # End function | |
function wasserstein_g_loss(gen::Matrix{Float64}) | |
return -mean(gen) | |
end; # End function | |
# Use gradient normalization to enforce 1-Lipschitz constraint | |
function gradient_normalization(net_c, x); # Add restriction to net_c later | |
f, grads = Flux.pullback(net_c, x) | |
grads = first(grads(1)) # Corresponds to δy/δx | |
grad_norm = √(sum((grads .^ 2) .+ 1e-8)) | |
f_hat = f ./ (grad_norm .+ abs.(f)) | |
return f_hat | |
end; # End function | |
function create_generator(n_features::Int); | |
return Chain( | |
Dense(n_features, n_features*2, relu), | |
Dropout(0.3), | |
Dense(n_features*2, n_features*4, relu), | |
Dropout(0.3), | |
Dense(n_features*4, n_features, relu), | |
Dense(n_features, n_features, relu) | |
); | |
end; # End function | |
function create_critic(n_features::Int); | |
return Chain( | |
Dense(n_features, n_features*4), | |
x->leakyrelu.(x, 0.2), | |
Dense(n_features*4, n_features*2), | |
x->leakyrelu.(x, 0.2), | |
Dense(n_features*2, n_features), | |
x->leakyrelu.(x, 0.2), | |
Dense(n_features, 1, relu) | |
); | |
end; | |
function train_crit!(gen, crit, opt_crit, sample_data); | |
ps = Flux.params(crit) | |
x_dim, y_dim = size(sample_data) | |
noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim)) | |
gen_in = hcat(noise, sample_data) | |
gen_data = gen(gen_in) | |
sample = gradient_normalization(crit, sample_data) | |
generated = gradient_normalization(crit, gen_data) | |
local loss; | |
grads = Flux.gradient(ps) do | |
loss = wasserstein_c_loss(sample, generated) | |
end; #End gradient | |
########################################## | |
########################################## | |
Flux.Optimise.update!(opt_crit, ps, grads) | |
# Why won't you work dammit? | |
########################################## | |
########################################## | |
return loss | |
end; #end function | |
function train_gen!(gen, crit, opt_gen, sample_data); | |
ps = Flux.params(gen) | |
x_dim, y_dim = size(sample_data) | |
noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim)) | |
gen_in = hcat(noise, sample_data) | |
gen_data = gen(gen_in) | |
gen_norm = gradient_normalization(crit, gen_data) | |
local loss; | |
grads = Flux.gradient(ps) do | |
loss = wasserstein_g_loss(gen_data) | |
end; | |
########################################## | |
########################################## | |
Flux.update!(opt_gen, ps, grads) | |
# Why won't you work dammit? | |
########################################## | |
########################################## | |
return loss | |
end; # end function | |
function train(data, | |
epochs::Int, batch_size::Int, | |
sample_interval::Int); | |
x_dim, _ = size(data) | |
gen = create_generator(x_dim) | |
crit = create_critic(x_dim) | |
Flux.trainmode!(gen, true) | |
Flux.trainmode!(crit, true) | |
opt_crit = Flux.Optimise.OADAM(0.004) | |
opt_gen = Flux.Optimise.OADAM(0.002) | |
c_loss = 0 | |
for epoch in 0:epochs-1; | |
ps_bef = Flux.params(gen) | |
for crit_loop in 1:5; | |
sample_data = Random.shuffle(data)[:,1:batch_size] | |
c_loss = train_crit!(gen, crit, opt_crit, sample_data) | |
end; #end crit_loop | |
sample_data = Random.shuffle(data)[:,1:batch_size] | |
g_loss = train_gen!(gen, crit, opt_gen, sample_data) | |
if ps_bef == Flux.params(gen); print("It is happening again.."); end; | |
if epoch % sample_interval == 0; | |
if ps_bef == Flux.params(gen); print("It is happening again.."); | |
else; print("Epoch: $epoch, c_loss: $c_loss, g_loss = $g_loss \n"); end; | |
end; # end if | |
end; #end for | |
return tuple(gen, crit) | |
end; #End function | |
# Run the thing! | |
gen_out, crit_out = train(iris, 5, 50, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment