Skip to content

Instantly share code, notes, and snippets.

@theogf
Last active October 2, 2020 15:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save theogf/354ef8709db81c8fd3806586067e3c59 to your computer and use it in GitHub Desktop.
Save theogf/354ef8709db81c8fd3806586067e3c59 to your computer and use it in GitHub Desktop.
using StatsFuns
function h(x::AbstractVector, v::AbstractVector, y::AbstractVector, ν::AbstractVector, ϵ::Real, c)
dot(v, ν) - ϵ * logsumexp((v - c.(Ref(x), y)) ./ ϵ .+ log.(ν)) - ϵ
end
function h(x::AbstractVector, v::AbstractVector, y::AbstractVector, ν::AbstractVector, ϵ::Int, c)
ϵ == 0 || error("ϵ has to be 0")
dot(v,ν) + mininum(c.(Ref(x), y) .- v)
end
function optim_v(μ, y::AbstractVector, ν::AbstractVector, η::Real, N::Int, ϵ::Real, c)
v = zero(ν); ṽ = zero(ν)
for k in 1:N
xₖ = rand(μ)
ṽ .+= η /√(k) * gradient(ν->h(xₖ, ṽ, y, ν, ϵ, c), ν)[1]
v = ṽ ./ k + (k - 1) / k * v
end
return v
end
function wasserstein_semidiscrete(μ, y, ν, ϵ, c=(x,y)->norm(x-y), η::Real = 0.1, N::Int = 100, N_MC::Int=2000)
v = optim_v(μ, y, ν, η, N, ϵ, c)
return mean(x->h(x, v, y, ν, ϵ, c), eachcol(rand(μ, N_MC)))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment