Created
December 13, 2020 09:39
-
-
Save msakai/4ded370ca75bbab8f02c8ec22bdadf11 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
module OptimalTransport (computeOptimalTransport) where | |
import qualified Data.Vector.Generic as VG | |
import Numeric.LinearAlgebra ((<.>), (#>), (<#), (><)) | |
import qualified Numeric.LinearAlgebra as LA | |
-- | Solve entropy regularized optimal transport problem: | |
-- | |
-- \[ | |
-- \min_{P\in U(\mathbf{r}, \mathbf{c})} \sum_{i,j} P_{i,j} M_{i,j} - \frac{1}{\lambda}h(P) | |
-- \] | |
-- where | |
-- \(U(\mathbf{r}, \mathbf{c}) = \{P \in \mathbb{R}_{>0}^{n\times m}\mid P\mathbf{1}_m = \mathbf{r}, P^\intercal\mathbf{1}_n = \mathbf{c}\}\) | |
-- and \(h(P) = -\sum_{i,j}P_{i,j}\log P_{i,j}\) is the information entropy of \(P\). | |
-- | |
-- See https://michielstock.github.io/OptimalTransport/ for detail. | |
computeOptimalTransport | |
:: LA.Matrix Double -- ^ Cost matrix \(M \in \mathbb{R}^{n\times m}\) | |
-> LA.Vector Double -- ^ Marginals vector \(\mathbf{r} \in \mathbb{R}^n\) | |
-> LA.Vector Double -- ^ Marginals vector \(\mathbf{c} \in \mathbb{R}^m\) | |
-> Double -- ^ Strength of the entropic regularization \(\lambda\) | |
-> Double -- ^ Convergence parameter \(\epsilon\) | |
-> (LA.Matrix Double, Double) -- ^ Optimal transport matrix \(P^\star \in U(\mathbf{r}, \mathbf{c})\) , together with dual-Sinkhorn divergence ( \(\sum_{i,j} P^\star_{i,j} M_{i,j}\) ) | |
computeOptimalTransport mat r c lam eps = go p0 u0 | |
where | |
(n, m) = LA.size mat | |
onesN = VG.replicate n 1 | |
onesM = VG.replicate m 1 | |
p0 = (\p -> LA.scale (1 / LA.sumElements p) p) $ LA.cmap exp $ LA.scale (- lam) mat | |
u0 = VG.replicate n 0 | |
go :: LA.Matrix Double -> LA.Vector Double -> (LA.Matrix Double, Double) | |
go p u | |
| err <= eps = (p, LA.flatten mat <.> LA.flatten p) | |
| otherwise = go p'' r' | |
where | |
err = VG.maximum $ VG.zipWith (\x y -> abs (x - y)) u r' | |
r' = p #> onesM | |
p' = LA.assoc (n,m) 0 [((i,j), (p `LA.atIndex` (i,j)) * (vec VG.! i)) | i <- [0..n-1], j <- [0..m-1]] | |
where | |
vec = VG.zipWith (/) r r' | |
c' = onesN <# p' | |
p'' = LA.assoc (n,m) 0 [((i,j), (p' `LA.atIndex` (i,j)) * (vec VG.! j)) | i <- [0..n-1], j <- [0..m-1]] | |
where | |
vec = VG.zipWith (/) c c' | |
test = (f 10, f 1) | |
where | |
f lam = computeOptimalTransport mat r c lam eps | |
mat = LA.scale (-1) $ | |
(8 >< 5) | |
[ 2.0, 2, 1, 0, 0 | |
, 0.0, -2, -2, -2, 2 | |
, 1.0, 2, 2, 2, -1 | |
, 2.0, 1, 0, 1, -1 | |
, 0.5, 2, 2, 1, 0 | |
, 0.0, 1, 1, 1, -1 | |
, -2.0, 2, 2, 1, 1 | |
, 2.0, 1, 2, 1, -1 | |
] | |
r = VG.fromList [3, 3, 3, 4, 2, 2, 2, 1] | |
c = VG.fromList [4, 2, 6, 4, 4] | |
eps = 1e-8 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment