Skip to content

Instantly share code, notes, and snippets.

@theogf
Created May 9, 2020 20:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save theogf/4944dea6603a2b48c317dcb163a6b817 to your computer and use it in GitHub Desktop.
Save theogf/4944dea6603a2b48c317dcb163a6b817 to your computer and use it in GitHub Desktop.
Implementation of the sinkhorn algorithm
using Makie
using MakieLayout
using LinearAlgebra, Distances
using Distributions
α = MixtureModel([Normal(0.2, 0.05), Normal(0.6, 0.06), Normal(0.8, 0.04)], [0.3, 0.5, 0.2])
# β = Normal(0.3, 0.09)
β = Laplace(0.5, 0.1)
xmin = 0.0
xmax = 1.0
x = range(xmin, xmax, length = 100)
y = range(xmin, xmax, length = 100)
vmin = 0.05
# Discretisation
a = pdf.(Ref(α), x) .+ vmin |> x -> x/sum(x)
b = pdf.(Ref(β), y) .+ vmin |> x -> x/sum(x)
dK(C, ϵ) = exp.(-C / ϵ)
dK(x, y, ϵ) = exp(-c(x, y) / ϵ)
# Sinkhorn algorithm
function sinkhorn(x, y, a, b, ϵ, T = 100)
u = ones(length(x))
v = ones(length(y))
K = dK.(x, y', ϵ)
for i in 1:T
u .= a ./ (K * v)
v .= b ./ (K' * u)
end
return K, u, v
end
softmin(x, ϵ) = -ϵ * log(sum(exp.(-x/ϵ)))
S!(S, C, f, g) = S .= C .- f .- g'
# Sinkhorn algorithm in log domain
function logsinkhorn(x, y, a, b, ϵ, T = 100)
C = pairwise(SqEuclidean(), x', y', dims = 2)
K = dK(C, ϵ)
f = zeros(length(x))
g = zeros(length(y))
S = similar(K)
for i in 1:T
S!(S, C, f, g)
f .+= softmin.(eachrow(S), ϵ) + ϵ * log.(a)
S!(S, C, f, g)
g .+= softmin.(eachcol(S), ϵ) + ϵ * log.(b)
end
return K, exp.(f / ϵ), exp.(g / ϵ)
end
##
ϵ = Node(0.1) # Regularisation parameter
ϵtitle = lift(ϵ) do ϵ
"π(α, β), log₁₀(ϵ) = $(round(log10(ϵ), digits = 3))"
end
P = lift(ϵ) do ϵ
K, u, v = logsinkhorn(x, y, a, b, ϵ, 1000)
v' .* K .* u
end
K, u, v = logsinkhorn(x, y, a, b, ϵ[], 200)
scene = Scene(resolution = (1200, 675), camera=campixel!)
layout = GridLayout(
scene, 2, 2,
colsizes = [Auto(), Relative(0.25)],
rowsizes = [Relative(0.25), Auto()],
alignmode = Outside(20, 20, 20, 20)
)
palpha = layout[1,1] = LAxis(scene, title = "p(α)")
plot!(palpha, x, a, linewidth = 3.0, color = :blue, resolution = (500, 300))
hidexdecorations!(palpha)
xlims!(palpha, extrema(x))
pbeta = layout[2,2] = LAxis(scene, title = "p(β)")
plot!(pbeta, b, y, linewidth = 3.0, color = :red)
hideydecorations!(pbeta)
ylims!(pbeta, extrema(y))
pbeta.xticks = MakieLayout.LinearTicks(3)
pπ = layout[2,1] = LAxis(scene, title = ϵtitle)
linkxaxes!(pπ, palpha)
linkyaxes!(pπ, pbeta)
update_limits!(scene)
contour!(pπ, x, y, P, fillrange = true, linewidth = 0.0)
save(joinpath(@__DIR__, "evolution regularized_kantorovich.png"), scene)
##
rangeϵ = -3.3:0.1:2
# rangeϵ = vcat(-4:0.1:2, 2:-0.1:-2)
record(scene, joinpath(@__DIR__, "evolution regularized_kantorovich.gif"), framerate = 10, 1:length(rangeϵ)) do i
ϵ[] = 10.0^(rangeϵ[i])
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment