Skip to content

Instantly share code, notes, and snippets.

@devmotion
Created May 24, 2021 19:09
Show Gist options
  • Save devmotion/6b987174bce2bedf18b074358be80198 to your computer and use it in GitHub Desktop.
Save devmotion/6b987174bce2bedf18b074358be80198 to your computer and use it in GitHub Desktop.
Discrete OT
using Distributions
using SparseArrays
using LinearAlgebra
using StatsBase
function _ot_cost_plan(c, μ::DiscreteNonParametric, ν::DiscreteNonParametric; get=:plan)
len_μ = length(μ.p)
len_ν = length(ν.p)
wi = μ.p[1]
wj = ν.p[1]
if get == :plan
γ = spzeros(Base.promote_eltype(μ.p, ν.p), len_μ, len_ν)
elseif get == :cost
cost = c(μ.support[1], ν.support[1]) * min(wi, wj)
end
i, j = 1, 1
while true
if (wi < wj || j == len_ν)
if get == :plan
γ[i, j] = wi
elseif (get == :cost && i + j > 2) # skip the first case, already computed
cost += c(μ.support[i], ν.support[j]) * wi
end
i += 1
if i == len_μ + 1
break
end
wj -= wi
wi = μ.p[i]
else
if get == :plan
γ[i, j] = wj
elseif (get == :cost && i + j > 2) # skip the first case, already computed
cost += c(μ.support[i], ν.support[j]) * wj
end
j += 1
if j == len_ν + 1
break
end
wi -= wj
wj = ν.p[j]
end
end
if get == :plan
return γ
elseif get == :cost
return cost
end
end
struct DiscreteOTIterator{T,M,N}
mu::M
nu::N
end
function DiscreteOTIterator(mu, nu)
T = Base.promote_eltype(mu, nu)
return DiscreteOTIterator{T,typeof(mu),typeof(nu)}(mu, nu)
end
Base.IteratorEltype(::Type{<:DiscreteOTIterator}) = Base.HasEltype()
Base.IteratorSize(::Type{<:DiscreteOTIterator}) = Base.SizeUnknown()
Base.eltype(::Type{<:DiscreteOTIterator{T}}) where {T} = Tuple{Int,Int,T}
function Base.iterate(
d::DiscreteOTIterator{T},
(i, j, mu_next, nu_next)=(1, 1, iterate(d.mu), iterate(d.nu))
) where {T}
if mu_next === nothing || nu_next === nothing
return nothing
end
mu_iter, mu_state = mu_next
nu_iter, nu_state = nu_next
min_iter, max_iter = minmax(mu_iter, nu_iter)
iter = (i, j, min_iter)
diff = max_iter - min_iter
state = if mu_iter < max_iter
(i + 1, j, iterate(d.mu, mu_state), (diff, nu_state))
else
(i, j + 1, (diff, mu_state), iterate(d.nu, nu_state))
end
return iter, state
end
mu_support = randn(200)
nu_support = randn(250)
mu_probs = rand(200)
mu_probs ./= sum(mu_probs)
nu_probs = rand(250)
nu_probs ./= sum(nu_probs)
c(x, y) = abs(x - y)
function ot_plan(_, mu::DiscreteNonParametric, nu::DiscreteNonParametric)
probs_mu = probs(mu)
probs_nu = probs(nu)
iter = DiscreteOTIterator(probs_mu, probs_nu)
I = Int[]
J = Int[]
W = Vector{Base.promote_eltype(probs_mu, probs_nu)}(undef, 0)
m = max(length(probs_mu), length(probs_nu))
sizehint!(I, m)
sizehint!(J, m)
sizehint!(W, m)
for (i, j, w) in iter
push!(I, i)
push!(J, j)
push!(W, w)
end
return sparse(I, J, W, length(probs_mu), length(probs_nu))
end
plan = ot_plan(c, DiscreteNonParametric(mu_support, mu_probs), DiscreteNonParametric(nu_support, nu_probs))
plan2 = _ot_cost_plan(c, DiscreteNonParametric(mu_support, mu_probs), DiscreteNonParametric(nu_support, nu_probs))
plan ≈ plan2
function ot_cost(c, mu::DiscreteNonParametric, nu::DiscreteNonParametric; plan=nothing)
return _ot_cost(c, mu, nu, plan)
end
function _ot_cost(c, mu, nu, ::Nothing)
probs_mu = probs(mu)
probs_nu = probs(nu)
support_mu = support(mu)
support_nu = support(nu)
iter = DiscreteOTIterator(probs_mu, probs_nu)
return sum(c(support_mu[i], support_nu[j]) * w for (i, j, w) in iter)
end
function _ot_cost(c, mu, nu, plan::SparseMatrixCSC)
support_mu = support(mu)
support_nu = support(nu)
I, J, W = findnz(plan)
return sum(c(support_mu[i], support_nu[j]) * w for (i, j, w) in zip(I, J, W))
end
_ot_cost(c, mu, nu, plan) = dot(plan, pairwise(c, support(mu), support(nu)))
cost = ot_cost(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs)
)
cost2 = _ot_cost_plan(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs);
get=:cost
)
cost == cost2
cost ≈ ot_cost(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs);
plan=plan,
)
cost ≈ ot_cost(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs);
plan=Matrix(plan),
)
@code_warntype ot_plan(c, DiscreteNonParametric(mu_support, mu_probs), DiscreteNonParametric(nu_support, nu_probs))
@code_warntype ot_cost(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs)
)
@code_warntype ot_cost(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs);
plan=plan,
)
@code_warntype ot_cost(
c,
DiscreteNonParametric(mu_support, mu_probs),
DiscreteNonParametric(nu_support, nu_probs);
plan=Matrix(plan),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment