{-# OPTIONS_GHC -Wall #-} module OptimalTransport (computeOptimalTransport) where import qualified Data.Vector.Generic as VG import Numeric.LinearAlgebra ((<.>), (#>), (<#), (><)) import qualified Numeric.LinearAlgebra as LA -- | Solve entropy regularized optimal transport problem: -- -- $-- \min_{P\in U(\mathbf{r}, \mathbf{c})} \sum_{i,j} P_{i,j} M_{i,j} - \frac{1}{\lambda}h(P) --$ -- where -- $$U(\mathbf{r}, \mathbf{c}) = \{P \in \mathbb{R}_{>0}^{n\times m}\mid P\mathbf{1}_m = \mathbf{r}, P^\intercal\mathbf{1}_n = \mathbf{c}\}$$ -- and $$h(P) = -\sum_{i,j}P_{i,j}\log P_{i,j}$$ is the information entropy of $$P$$. -- -- See https://michielstock.github.io/OptimalTransport/ for detail. computeOptimalTransport :: LA.Matrix Double -- ^ Cost matrix $$M \in \mathbb{R}^{n\times m}$$ -> LA.Vector Double -- ^ Marginals vector $$\mathbf{r} \in \mathbb{R}^n$$ -> LA.Vector Double -- ^ Marginals vector $$\mathbf{c} \in \mathbb{R}^m$$ -> Double -- ^ Strength of the entropic regularization $$\lambda$$ -> Double -- ^ Convergence parameter $$\epsilon$$ -> (LA.Matrix Double, Double) -- ^ Optimal transport matrix $$P^\star \in U(\mathbf{r}, \mathbf{c})$$ , together with dual-Sinkhorn divergence ( $$\sum_{i,j} P^\star_{i,j} M_{i,j}$$ ) computeOptimalTransport mat r c lam eps = go p0 u0 where (n, m) = LA.size mat onesN = VG.replicate n 1 onesM = VG.replicate m 1 p0 = (\p -> LA.scale (1 / LA.sumElements p) p) $LA.cmap exp$ LA.scale (- lam) mat u0 = VG.replicate n 0 go :: LA.Matrix Double -> LA.Vector Double -> (LA.Matrix Double, Double) go p u | err <= eps = (p, LA.flatten mat <.> LA.flatten p) | otherwise = go p'' r' where err = VG.maximum $VG.zipWith (\x y -> abs (x - y)) u r' r' = p #> onesM p' = LA.assoc (n,m) 0 [((i,j), (p LA.atIndex (i,j)) * (vec VG.! i)) | i <- [0..n-1], j <- [0..m-1]] where vec = VG.zipWith (/) r r' c' = onesN <# p' p'' = LA.assoc (n,m) 0 [((i,j), (p' LA.atIndex (i,j)) * (vec VG.! j)) | i <- [0..n-1], j <- [0..m-1]] where vec = VG.zipWith (/) c c' test = (f 10, f 1) where f lam = computeOptimalTransport mat r c lam eps mat = LA.scale (-1)$ (8 >< 5) [ 2.0, 2, 1, 0, 0 , 0.0, -2, -2, -2, 2 , 1.0, 2, 2, 2, -1 , 2.0, 1, 0, 1, -1 , 0.5, 2, 2, 1, 0 , 0.0, 1, 1, 1, -1 , -2.0, 2, 2, 1, 1 , 2.0, 1, 2, 1, -1 ] r = VG.fromList [3, 3, 3, 4, 2, 2, 2, 1] c = VG.fromList [4, 2, 6, 4, 4] eps = 1e-8