Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Last active October 28, 2021 23:03
Show Gist options
  • Save sethaxen/6408d7f104ba44abfbb74981f8a8deb8 to your computer and use it in GitHub Desktop.
Save sethaxen/6408d7f104ba44abfbb74981f8a8deb8 to your computer and use it in GitHub Desktop.
empirical p-wasserstein distance for multivariate samples
using Distances, Distributions, OptimalTransport, LinearAlgebra, Random
struct PowEuclidean{T} <: Distances.PreMetric
p::T
end
(m::PowEuclidean)(a, b) = Distances.Euclidean()(a, b)^m.p
# for measures μ and ν with support ℝᵈ, approximate empirical p-wasserstein distance
# between matrices x and y of random points whose columns are respectively drawn from
# μ and ν.
# this is an approximate method that can scale to large numbers of columns.
# to improve scaling, increase ϵ; this reduces the accuracy of the approximation
function wasserstein(x, y, p::Int=1; ϵ=1.0)
# uniformly weight histograms
nx = size(x, 2)
μ = fill(inv(nx), nx)
ny = size(y, 2)
ν = fill(inv(ny), ny)
# compute cost matrix
C = Distances.pairwise(PowEuclidean(p), x, y; dims=2)
return OptimalTransport.sinkhorn2(μ, ν, C, ϵ)^(1//p)
end
# exact 2-Wasserstein distance for MvNormals
function wasserstein2(μ::Distributions.MvNormal, ν::Distributions.MvNormal)
Σ1 = μ.Σ
Σ2 = ν.Σ
sqrtΣ1 = sqrt(Symmetric(Σ1))
B2 = tr(Σ1) + tr(Σ2) - 2tr(sqrt(Symmetric(sqrtΣ1 * Σ2 * sqrtΣ1)))
return sqrt(Distances.SqEuclidean()(μ.μ, ν.μ) + B2)
end
d = 10
σ1 = randexp(d)
σ2 = randexp(d)
lkj = LKJ(d, 2.0)
μ = MvNormal(σ1 .* σ1' .* rand(lkj))
ν = MvNormal(σ2 .* σ2' .* rand(lkj))
x = rand(μ, 1_000)
y = rand(ν, 1_000)
wasserstein2(μ, ν)
wasserstein(x, y, 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment