Skip to content

Instantly share code, notes, and snippets.

@sanderbboisen
Created September 20, 2021 12:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sanderbboisen/8e667b78348587169a4ff90f2c5dbe7c to your computer and use it in GitHub Desktop.
Save sanderbboisen/8e667b78348587169a4ff90f2c5dbe7c to your computer and use it in GitHub Desktop.
using Flux;
using Flux:Dense;
using Flux:Chain;
using Flux:glorot_normal;
using Flux:sparse_init;
using Flux:relu;
using Flux:Dropout
using Flux.Optimise:update!
using Flux.Losses: logitbinarycrossentropy
using Statistics:mean;
using GR;
using MLDatasets;
using Random;
using LinearAlgebra;
# Preparing the dataset
iris = MLDatasets.Iris.features()
# May reuse later
function scale_0_1!(data)
for j in 1:size(data,1)
max_row, min_row = maximum(data[j, :]), minimum(data[j,:]);
data[j,:] = (data[j,:].-min_row)./(max_row-min_row)
end # end for-loop
end; # End scale_0_1
scale_0_1!(iris)
# Loss functions
function wasserstein_c_loss(sample, gen)
return mean(gen) - mean(sample)
end; # End function
function wasserstein_g_loss(gen::Matrix{Float64})
return -mean(gen)
end; # End function
# Use gradient normalization to enforce 1-Lipschitz constraint
function gradient_normalization(net_c, x); # Add restriction to net_c later
f, grads = Flux.pullback(net_c, x)
grads = first(grads(1)) # Corresponds to δy/δx
grad_norm = √(sum((grads .^ 2) .+ 1e-8))
f_hat = f ./ (grad_norm .+ abs.(f))
return f_hat
end; # End function
function create_generator(n_features::Int);
return Chain(
Dense(n_features, n_features*2, relu),
Dropout(0.3),
Dense(n_features*2, n_features*4, relu),
Dropout(0.3),
Dense(n_features*4, n_features, relu),
Dense(n_features, n_features, relu)
);
end; # End function
function create_critic(n_features::Int);
return Chain(
Dense(n_features, n_features*4),
x->leakyrelu.(x, 0.2),
Dense(n_features*4, n_features*2),
x->leakyrelu.(x, 0.2),
Dense(n_features*2, n_features),
x->leakyrelu.(x, 0.2),
Dense(n_features, 1, relu)
);
end;
function train_crit!(gen, crit, opt_crit, sample_data);
ps = Flux.params(crit)
x_dim, y_dim = size(sample_data)
noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim))
gen_in = hcat(noise, sample_data)
gen_data = gen(gen_in)
sample = gradient_normalization(crit, sample_data)
generated = gradient_normalization(crit, gen_data)
local loss;
grads = Flux.gradient(ps) do
loss = wasserstein_c_loss(sample, generated)
end; #End gradient
##########################################
##########################################
Flux.Optimise.update!(opt_crit, ps, grads)
# Why won't you work dammit?
##########################################
##########################################
return loss
end; #end function
function train_gen!(gen, crit, opt_gen, sample_data);
ps = Flux.params(gen)
x_dim, y_dim = size(sample_data)
noise = reshape(rand(x_dim * y_dim), (x_dim, y_dim))
gen_in = hcat(noise, sample_data)
gen_data = gen(gen_in)
gen_norm = gradient_normalization(crit, gen_data)
local loss;
grads = Flux.gradient(ps) do
loss = wasserstein_g_loss(gen_data)
end;
##########################################
##########################################
Flux.update!(opt_gen, ps, grads)
# Why won't you work dammit?
##########################################
##########################################
return loss
end; # end function
function train(data,
epochs::Int, batch_size::Int,
sample_interval::Int);
x_dim, _ = size(data)
gen = create_generator(x_dim)
crit = create_critic(x_dim)
Flux.trainmode!(gen, true)
Flux.trainmode!(crit, true)
opt_crit = Flux.Optimise.OADAM(0.004)
opt_gen = Flux.Optimise.OADAM(0.002)
c_loss = 0
for epoch in 0:epochs-1;
ps_bef = Flux.params(gen)
for crit_loop in 1:5;
sample_data = Random.shuffle(data)[:,1:batch_size]
c_loss = train_crit!(gen, crit, opt_crit, sample_data)
end; #end crit_loop
sample_data = Random.shuffle(data)[:,1:batch_size]
g_loss = train_gen!(gen, crit, opt_gen, sample_data)
if ps_bef == Flux.params(gen); print("It is happening again.."); end;
if epoch % sample_interval == 0;
if ps_bef == Flux.params(gen); print("It is happening again..");
else; print("Epoch: $epoch, c_loss: $c_loss, g_loss = $g_loss \n"); end;
end; # end if
end; #end for
return tuple(gen, crit)
end; #End function
# Run the thing!
gen_out, crit_out = train(iris, 5, 50, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment