Skip to content

Instantly share code, notes, and snippets.

@msakai
Created December 13, 2020 09:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save msakai/4ded370ca75bbab8f02c8ec22bdadf11 to your computer and use it in GitHub Desktop.
Save msakai/4ded370ca75bbab8f02c8ec22bdadf11 to your computer and use it in GitHub Desktop.
{-# OPTIONS_GHC -Wall #-}
module OptimalTransport (computeOptimalTransport) where
import qualified Data.Vector.Generic as VG
import Numeric.LinearAlgebra ((<.>), (#>), (<#), (><))
import qualified Numeric.LinearAlgebra as LA
-- | Solve entropy regularized optimal transport problem:
--
-- \[
-- \min_{P\in U(\mathbf{r}, \mathbf{c})} \sum_{i,j} P_{i,j} M_{i,j} - \frac{1}{\lambda}h(P)
-- \]
-- where
-- \(U(\mathbf{r}, \mathbf{c}) = \{P \in \mathbb{R}_{>0}^{n\times m}\mid P\mathbf{1}_m = \mathbf{r}, P^\intercal\mathbf{1}_n = \mathbf{c}\}\)
-- and \(h(P) = -\sum_{i,j}P_{i,j}\log P_{i,j}\) is the information entropy of \(P\).
--
-- See https://michielstock.github.io/OptimalTransport/ for detail.
computeOptimalTransport
:: LA.Matrix Double -- ^ Cost matrix \(M \in \mathbb{R}^{n\times m}\)
-> LA.Vector Double -- ^ Marginals vector \(\mathbf{r} \in \mathbb{R}^n\)
-> LA.Vector Double -- ^ Marginals vector \(\mathbf{c} \in \mathbb{R}^m\)
-> Double -- ^ Strength of the entropic regularization \(\lambda\)
-> Double -- ^ Convergence parameter \(\epsilon\)
-> (LA.Matrix Double, Double) -- ^ Optimal transport matrix \(P^\star \in U(\mathbf{r}, \mathbf{c})\) , together with dual-Sinkhorn divergence ( \(\sum_{i,j} P^\star_{i,j} M_{i,j}\) )
computeOptimalTransport mat r c lam eps = go p0 u0
where
(n, m) = LA.size mat
onesN = VG.replicate n 1
onesM = VG.replicate m 1
p0 = (\p -> LA.scale (1 / LA.sumElements p) p) $ LA.cmap exp $ LA.scale (- lam) mat
u0 = VG.replicate n 0
go :: LA.Matrix Double -> LA.Vector Double -> (LA.Matrix Double, Double)
go p u
| err <= eps = (p, LA.flatten mat <.> LA.flatten p)
| otherwise = go p'' r'
where
err = VG.maximum $ VG.zipWith (\x y -> abs (x - y)) u r'
r' = p #> onesM
p' = LA.assoc (n,m) 0 [((i,j), (p `LA.atIndex` (i,j)) * (vec VG.! i)) | i <- [0..n-1], j <- [0..m-1]]
where
vec = VG.zipWith (/) r r'
c' = onesN <# p'
p'' = LA.assoc (n,m) 0 [((i,j), (p' `LA.atIndex` (i,j)) * (vec VG.! j)) | i <- [0..n-1], j <- [0..m-1]]
where
vec = VG.zipWith (/) c c'
test = (f 10, f 1)
where
f lam = computeOptimalTransport mat r c lam eps
mat = LA.scale (-1) $
(8 >< 5)
[ 2.0, 2, 1, 0, 0
, 0.0, -2, -2, -2, 2
, 1.0, 2, 2, 2, -1
, 2.0, 1, 0, 1, -1
, 0.5, 2, 2, 1, 0
, 0.0, 1, 1, 1, -1
, -2.0, 2, 2, 1, 1
, 2.0, 1, 2, 1, -1
]
r = VG.fromList [3, 3, 3, 4, 2, 2, 2, 1]
c = VG.fromList [4, 2, 6, 4, 4]
eps = 1e-8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment