Skip to content

Instantly share code, notes, and snippets.

@fangzhou-xie
Last active September 22, 2021 01:36
Show Gist options
  • Save fangzhou-xie/0918123ba17897399ceb0608b921f645 to your computer and use it in GitHub Desktop.
Save fangzhou-xie/0918123ba17897399ceb0608b921f645 to your computer and use it in GitHub Desktop.
Eigen Sinkhorn
// write the sinkhorn algorithm and compare with the POT result
#include "Eigen/Eigen"
// #include "unsupported/Eigen/MatrixFunctions"
#include <iostream>
void sinkhorn(Eigen::VectorXd a, Eigen::VectorXd b, Eigen::MatrixXd M,
const double reg, const int numItermax = 1000,
const double stopThr = 1e-9) {
// only compute 1d to 1d case
// init u, v
Eigen::VectorXd u(a.rows());
Eigen::VectorXd v(b.rows());
// Eigen::MatrixXd K(M.rows(), M.cols());
// Eigen::VectorXd uprev;
// Eigen::VectorXd vprev;
u.setOnes();
v.setOnes();
u = u / u.rows();
v = v / b.rows();
Eigen::MatrixXd K = (M / (-reg)).array().exp();
Eigen::MatrixXd Kp = K.array().colwise() * a.array().cwiseInverse();
unsigned int cpt = 0;
double err = 1.0;
Eigen::VectorXd temp1(v.rows());
temp1.setOnes();
while (err > stopThr & cpt < numItermax) {
Eigen::VectorXd uprev = u;
Eigen::VectorXd vprev = v;
Eigen::MatrixXd KtransposeU = K.transpose() * u; // n*1
// std::cout << "shape of KtransposeU:" << KtransposeU.rows() << " "
// << KtransposeU.cols() << std::endl;
v = b.cwiseQuotient(K.transpose() * u);
// std::cout << "shape of v:" << v.rows() << " " << v.cols() << std::endl;
// std::cout << a.cwiseInverse() << std::endl;
// v = b / KtransposeU;
// v = KtransposeU.cwiseInverse() * b; // need broadcast here
// TODO: this is a broadcast multiply
// Eigen::MatrixXd Kp = K * a.cwiseInverse();
// Kp.array().colwise() /
// Kp.colwise() *= a.cwiseInverse();
// = a.cwiseInverse() * K.colwise();
u = (Kp * v).cwiseInverse();
if (u.array().isNaN().any() | v.array().isNaN().any() |
u.array().isInf().any() | v.array().isInf().any() |
(KtransposeU.array() == 0.0).any()) {
std::cout << "numerical error" << std::endl;
u = uprev;
v = vprev;
break;
}
// check error
if (cpt % 10 == 0) {
Eigen::VectorXd temp =
(u.asDiagonal() * K * v.asDiagonal()).transpose() * temp1;
err = (temp - b).norm();
}
cpt += 1;
}
// u.resize(u.rows(), 1);
// v.resize(1, v.rows());
Eigen::MatrixXd mat =
(K.array().colwise() * u.array()).rowwise() * v.array().transpose();
// Eigen::MatrixXd mat = (u.array() * K.array().colwise()).array().rowwise() *
// v.array().transpose();
std::cout << mat << std::endl;
// Eigen::MatrixXd mat = u.resize(u.rows(), 1) * K * v.resize(1, v.cols());
}
int main() {
Eigen::VectorXd a(2);
a << 0.5, 0.5;
Eigen::VectorXd b(2);
b << 0.5, 0.5;
Eigen::MatrixXd M(2, 2);
M << 0.0, 1.0, 1.0, 0.0;
double reg = 1.0;
// sinkhorn(a, b, M, reg);
// try out broadcasting
// std::cout << M.array().colwise() * a.array() << std::endl;
// std::cout << a.array().transpose() * M.array().colwise() << std::endl;
// try the elementwise power
// Eigen::VectorXd c = a.array().pow(b.array());
// std::cout << c << std::endl;
// Eigen::MatrixXd K = (M / (-reg)).array().exp();
// std::cout << K << std::endl;
// std::cout << a << std::endl;
// a.resize(1, 2);
// std::cout << a << std::endl;
// Eigen::MatrixXd b = a.cwiseInverse();
// std::cout << b << std::endl;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment