Skip to content

Instantly share code, notes, and snippets.

@bhatiaabhinav
Created July 13, 2022 04:48
Show Gist options
  • Save bhatiaabhinav/1b67de77cc5b2551185bcfaf1b4c0148 to your computer and use it in GitHub Desktop.
Save bhatiaabhinav/1b67de77cc5b2551185bcfaf1b4c0148 to your computer and use it in GitHub Desktop.
Julia implementation of WGAN-GP using Flux and Zygote
using Flux
using Flux: update!
using Zygote
using StatsBase
"""
WGAN with gradient penalty. See algorithm 1 in https://proceedings.neurips.cc/paper/2017/file/892c3b1c6dccd52936e27cbd0ff683d6-Paper.pdf. The following code is almost line by line identical.
"""
function train_WGAN_GP(𝐺, 𝐷, 𝐗::Array{Float32, N}, latent_size, num_iters, device_fn; m=32, Ξ»=10f0, ncritic=5, Ξ±=0.0001, β₁=0, Ξ²β‚‚=0.9) where N
n = size(𝐗)[end] # length of dataset
𝐺, 𝐷 = device_fn(deepcopy(𝐺)), device_fn(deepcopy(𝐷))
ΞΈ, 𝑀 = params(𝐺), params(𝐷)
adamΞΈ, adam𝑀 = ADAM(Ξ±, (β₁, Ξ²β‚‚)), ADAM(Ξ±, (β₁, Ξ²β‚‚))
for iter in 1:num_iters
for t in 1:ncritic
𝐱, 𝐳, π›œ = 𝐗[repeat([:], N-1)..., rand(1:n, m)], randn(Float32, latent_size..., m), rand(Float32, repeat([1], N-1)..., m) # Sample batch of real data x, latent variables z, random numbers Ο΅ ∼ U[0, 1].
𝐱, 𝐳, π›œ = device_fn(𝐱), device_fn(𝐳), device_fn(π›œ)
𝐱̃ = 𝐺(𝐳)
𝐱̂ = π›œ .* 𝐱 + (1f0 .- π›œ) .* 𝐱̃
βˆ‡π‘€L = gradient(𝑀) do
βˆ‡π±Μ‚π·, = gradient(𝐱̂ -> sum(𝐷(𝐱̂)), 𝐱̂)
L = mean(𝐷(𝐱̃)) - mean(𝐷(𝐱)) + Ξ» * mean((sqrt.(sum(βˆ‡π±Μ‚π·.^2, dims=1) .+ 1f-12) .- 1f0).^2)
end
update!(adam𝑀, 𝑀, βˆ‡π‘€L)
end
𝐳 = device_fn(randn(Float32, latent_size..., m))
βˆ‡ΞΈπ· = gradient(ΞΈ) do
-mean(𝐷(𝐺(𝐳)))
end
update!(adamΞΈ, ΞΈ, βˆ‡ΞΈπ·)
end
return 𝐺, 𝐷
end
𝐗 = rand(Float32, 50, 10000) # dummy data
z = 16 # latent size
𝐺 = Chain(Dense(z, 32, leakyrelu), Dense(32, 50)) # Generator
𝐷 = Chain(Dense(50, 32, leakyrelu), Dense(32, 1)) # Critic
𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, cpu) # works
# 𝐺, 𝐷 = train_WGAN_GP(𝐺, 𝐷, 𝐗, (z, ), 1, gpu) # fails. Doesn't work on GPU yet.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment