Skip to content

Instantly share code, notes, and snippets.

@manish7294
Last active July 18, 2018 06:22
Show Gist options
  • Save manish7294/3d97be37919658b96bba0125f2f3de84 to your computer and use it in GitHub Desktop.
Save manish7294/3d97be37919658b96bba0125f2f3de84 to your computer and use it in GitHub Desktop.
double innerProduct(arma::mat& Ar, arma::mat& Z)
{
double sum = 0.0;
for (size_t i = 0; i < Z.n_elem; i++)
sum += Ar(i) * Z(i);
return sum;
}
template<typename MetricType>
void BOOSTMETRIC<MetricType>::LearnDistance(arma::mat& outputMatrix)
{
Timer::Start("boostmetric");
// Temporary variables. Values taken from implementation.
size_t maxIter = 500;
double tolerance = 1e-7;
double wTolerance = 1e-5;
arma::Mat<size_t> triplets;
constraint.Triplets(triplets, dataset, labels);
size_t N = triplets.n_cols;
size_t dim = dataset.n_rows;
arma::vec ui;
ui.ones(N);
// Normalize ui so that sum(ui) = 1
ui = ui / N;
arma::cube Ar(dim, dim, N);
// Initialize X as zero matrix.
arma::mat X;
X.zeros(dim, dim);
// Initialize Ar. See Equation 2.
for (size_t i = 0; i < N; i++)
{
Ar.slice(i) = (dataset.col(triplets(0, i)) - dataset.col(triplets(2, i))) *
arma::trans(dataset.col(triplets(0, i)) - dataset.col(triplets(2, i)))
-
(dataset.col(triplets(0, i)) - dataset.col(triplets(1, i))) *
arma::trans(dataset.col(triplets(0, i)) - dataset.col(triplets(1, i)));
}
// Used in equation 10.
arma::mat Acap;
for (size_t i = 0; i < maxIter; i++)
{
// Initialize Acap.
Acap.zeros(dim, dim);
for (size_t j = 0; j < N; j++)
Acap += ui(j) * Ar.slice(j);
arma::vec eigval;
arma::mat eigvec;
arma::eig_sym(eigval, eigvec, (Acap + trans(Acap) / 2));
// Get maximum eigvalue.
double maxEig = arma::max(eigval);
if (maxEig < tolerance)
break;
// Get index of maximum element.
arma::uvec maxIndex = arma::find(eigval == maxEig);
arma::mat Z = eigvec.col(maxIndex(0)) * arma::trans(eigvec.col(maxIndex(0)));
// Values taken from matlab implementation.
double wHi = 10;
double wLo = 0;
// Initialize H vector by computing inner product. Hr = <Ar, Z>
arma::vec H(N);
for (size_t j = 0; j < N; j++)
H(j) = innerProduct(Ar.slice(j), Z);
// Binary Search for w calculation.
double w;
while (true)
{
w = (wHi + wLo) * 0.5;
double lhs = 0;
for (size_t j = 0; j < N; j++)
lhs += (H(j) - tolerance) * std::exp(-w * H(j)) * ui(j);
if (lhs > 0)
wLo = w;
else
wHi = w;
if (wHi - wLo < wTolerance || std::abs(lhs) < wTolerance)
break;
}
// Update u from equation 9.
for (size_t j = 0; j < N; j++)
ui(j) = ui(j) * std::exp(-H(j) * w);
// Normalize ui so that sum(ui) = 1.
ui = ui / arma::sum(ui);
// Update p.s.d matrix.
X += w * Z;
}
// Generate distance from p.s.d matrix.
arma::vec eigval;
arma::mat eigvec;
arma::eig_sym(eigval, eigvec, (X + arma::trans(X)) / 2);
eigval.transform( [](double val) { return (val > 0) ? std::sqrt(val) : double(0); } );
outputMatrix = arma::trans(eigvec * diagmat(eigval));
Timer::Stop("boostmetric");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment