Last active
October 28, 2021 23:03
-
-
Save sethaxen/6408d7f104ba44abfbb74981f8a8deb8 to your computer and use it in GitHub Desktop.
empirical p-wasserstein distance for multivariate samples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distances, Distributions, OptimalTransport, LinearAlgebra, Random | |
struct PowEuclidean{T} <: Distances.PreMetric | |
p::T | |
end | |
(m::PowEuclidean)(a, b) = Distances.Euclidean()(a, b)^m.p | |
# for measures μ and ν with support ℝᵈ, approximate empirical p-wasserstein distance | |
# between matrices x and y of random points whose columns are respectively drawn from | |
# μ and ν. | |
# this is an approximate method that can scale to large numbers of columns. | |
# to improve scaling, increase ϵ; this reduces the accuracy of the approximation | |
function wasserstein(x, y, p::Int=1; ϵ=1.0) | |
# uniformly weight histograms | |
nx = size(x, 2) | |
μ = fill(inv(nx), nx) | |
ny = size(y, 2) | |
ν = fill(inv(ny), ny) | |
# compute cost matrix | |
C = Distances.pairwise(PowEuclidean(p), x, y; dims=2) | |
return OptimalTransport.sinkhorn2(μ, ν, C, ϵ)^(1//p) | |
end | |
# exact 2-Wasserstein distance for MvNormals | |
function wasserstein2(μ::Distributions.MvNormal, ν::Distributions.MvNormal) | |
Σ1 = μ.Σ | |
Σ2 = ν.Σ | |
sqrtΣ1 = sqrt(Symmetric(Σ1)) | |
B2 = tr(Σ1) + tr(Σ2) - 2tr(sqrt(Symmetric(sqrtΣ1 * Σ2 * sqrtΣ1))) | |
return sqrt(Distances.SqEuclidean()(μ.μ, ν.μ) + B2) | |
end | |
d = 10 | |
σ1 = randexp(d) | |
σ2 = randexp(d) | |
lkj = LKJ(d, 2.0) | |
μ = MvNormal(σ1 .* σ1' .* rand(lkj)) | |
ν = MvNormal(σ2 .* σ2' .* rand(lkj)) | |
x = rand(μ, 1_000) | |
y = rand(ν, 1_000) | |
wasserstein2(μ, ν) | |
wasserstein(x, y, 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment