Last active
July 18, 2018 06:22
-
-
Save manish7294/3d97be37919658b96bba0125f2f3de84 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
double innerProduct(arma::mat& Ar, arma::mat& Z) | |
{ | |
double sum = 0.0; | |
for (size_t i = 0; i < Z.n_elem; i++) | |
sum += Ar(i) * Z(i); | |
return sum; | |
} | |
template<typename MetricType> | |
void BOOSTMETRIC<MetricType>::LearnDistance(arma::mat& outputMatrix) | |
{ | |
Timer::Start("boostmetric"); | |
// Temporary variables. Values taken from implementation. | |
size_t maxIter = 500; | |
double tolerance = 1e-7; | |
double wTolerance = 1e-5; | |
arma::Mat<size_t> triplets; | |
constraint.Triplets(triplets, dataset, labels); | |
size_t N = triplets.n_cols; | |
size_t dim = dataset.n_rows; | |
arma::vec ui; | |
ui.ones(N); | |
// Normalize ui so that sum(ui) = 1 | |
ui = ui / N; | |
arma::cube Ar(dim, dim, N); | |
// Initialize X as zero matrix. | |
arma::mat X; | |
X.zeros(dim, dim); | |
// Initialize Ar. See Equation 2. | |
for (size_t i = 0; i < N; i++) | |
{ | |
Ar.slice(i) = (dataset.col(triplets(0, i)) - dataset.col(triplets(2, i))) * | |
arma::trans(dataset.col(triplets(0, i)) - dataset.col(triplets(2, i))) | |
- | |
(dataset.col(triplets(0, i)) - dataset.col(triplets(1, i))) * | |
arma::trans(dataset.col(triplets(0, i)) - dataset.col(triplets(1, i))); | |
} | |
// Used in equation 10. | |
arma::mat Acap; | |
for (size_t i = 0; i < maxIter; i++) | |
{ | |
// Initialize Acap. | |
Acap.zeros(dim, dim); | |
for (size_t j = 0; j < N; j++) | |
Acap += ui(j) * Ar.slice(j); | |
arma::vec eigval; | |
arma::mat eigvec; | |
arma::eig_sym(eigval, eigvec, (Acap + trans(Acap) / 2)); | |
// Get maximum eigvalue. | |
double maxEig = arma::max(eigval); | |
if (maxEig < tolerance) | |
break; | |
// Get index of maximum element. | |
arma::uvec maxIndex = arma::find(eigval == maxEig); | |
arma::mat Z = eigvec.col(maxIndex(0)) * arma::trans(eigvec.col(maxIndex(0))); | |
// Values taken from matlab implementation. | |
double wHi = 10; | |
double wLo = 0; | |
// Initialize H vector by computing inner product. Hr = <Ar, Z> | |
arma::vec H(N); | |
for (size_t j = 0; j < N; j++) | |
H(j) = innerProduct(Ar.slice(j), Z); | |
// Binary Search for w calculation. | |
double w; | |
while (true) | |
{ | |
w = (wHi + wLo) * 0.5; | |
double lhs = 0; | |
for (size_t j = 0; j < N; j++) | |
lhs += (H(j) - tolerance) * std::exp(-w * H(j)) * ui(j); | |
if (lhs > 0) | |
wLo = w; | |
else | |
wHi = w; | |
if (wHi - wLo < wTolerance || std::abs(lhs) < wTolerance) | |
break; | |
} | |
// Update u from equation 9. | |
for (size_t j = 0; j < N; j++) | |
ui(j) = ui(j) * std::exp(-H(j) * w); | |
// Normalize ui so that sum(ui) = 1. | |
ui = ui / arma::sum(ui); | |
// Update p.s.d matrix. | |
X += w * Z; | |
} | |
// Generate distance from p.s.d matrix. | |
arma::vec eigval; | |
arma::mat eigvec; | |
arma::eig_sym(eigval, eigvec, (X + arma::trans(X)) / 2); | |
eigval.transform( [](double val) { return (val > 0) ? std::sqrt(val) : double(0); } ); | |
outputMatrix = arma::trans(eigvec * diagmat(eigval)); | |
Timer::Stop("boostmetric"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment