Created
July 10, 2018 15:57
-
-
Save rcurtin/c29a7a513ac070aa407a000aa5a7a80b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/mlpack/methods/lmnn/lmnn_function_impl.hpp b/src/mlpack/methods/lmnn/lmnn_function_impl.hpp | |
index b9fb88ad6..5de8e8279 100644 | |
--- a/src/mlpack/methods/lmnn/lmnn_function_impl.hpp | |
+++ b/src/mlpack/methods/lmnn/lmnn_function_impl.hpp | |
@@ -193,6 +193,8 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation, | |
begin, batchSize); | |
} | |
+ size_t prunes = 0; | |
+ size_t checks = 0; | |
for (size_t i = begin; i < begin + batchSize; i++) | |
{ | |
for (size_t j = 0; j < k ; j++) | |
@@ -211,6 +213,7 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation, | |
// Calculate cost due to {data point, target neighbors, impostors} | |
// triplets. | |
double eval = 0; | |
+ ++checks; | |
if (transformationOld.n_elem != 0) | |
{ | |
eval = evalOld[i] + transformationDiff * (norm(targetNeighbors(j, i)) + | |
@@ -232,6 +235,10 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation, | |
transformedDataset.col(impostors(l, i))); | |
} | |
} | |
+ else | |
+ { | |
+ ++prunes; | |
+ } | |
} | |
else | |
{ | |
@@ -256,6 +263,9 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation, | |
} | |
} | |
+ std::cout << "Evaluate(): " << prunes << " pruned of " << checks << ", " | |
+ << "transformation diff " << transformationDiff << ".\n"; | |
+ | |
// Update cache transformation matrix. | |
transformationOld = transformation; | |
@@ -373,6 +383,7 @@ double LMNNFunction<MetricType>::EvaluateWithGradient( | |
const arma::mat& transformation, | |
GradType& gradient) | |
{ | |
+ Timer::Start("lmnn_function_evaluate_with_gradient"); | |
double cost = 0; | |
// Apply metric over dataset. | |
@@ -399,6 +410,12 @@ double LMNNFunction<MetricType>::EvaluateWithGradient( | |
// Calculate gradient due to impostors. | |
arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows); | |
+ size_t prunes = 0; | |
+ size_t checks = 0; | |
+ size_t active = 0; | |
+ size_t close = 0; | |
+ size_t inactive = 0; | |
+ Timer::Start("lmnn_function_compute_evals"); | |
for (size_t i = 0; i < dataset.n_cols; i++) | |
{ | |
for (size_t j = 0; j < k ; j++) | |
@@ -417,10 +434,15 @@ double LMNNFunction<MetricType>::EvaluateWithGradient( | |
// Calculate cost due to {data point, target neighbors, impostors} | |
// triplets. | |
double eval = 0; | |
+ ++checks; | |
if (transformationOld.n_elem != 0) | |
{ | |
eval = evalOld[i] + transformationDiff * (norm(targetNeighbors(j, i)) + | |
norm(impostors(l, i)) + 2 * norm(i)); | |
+// std::cout << "eval bound of point " << i << " with l " << l << ": " << | |
+//eval << "; tdiff " << transformationDiff << " norm target " << | |
+//norm(targetNeighbors(j, i)) << ", norm impostors " << norm(impostors(l, i)) << | |
+//", norm i " << norm(i) << "; evalOld " << evalOld[i] << "\n"; | |
if (eval > -1) | |
{ | |
// Calculate exact eval. | |
@@ -438,6 +460,10 @@ double LMNNFunction<MetricType>::EvaluateWithGradient( | |
transformedDataset.col(impostors(l, i))); | |
} | |
} | |
+ else | |
+ { | |
+ ++prunes; | |
+ } | |
} | |
else | |
{ | |
@@ -446,16 +472,21 @@ double LMNNFunction<MetricType>::EvaluateWithGradient( | |
distance(l, i); | |
} | |
+ // Update cache eval value. | |
+ evalOld[i] = eval; | |
+ | |
// Check bounding condition. | |
if (eval <= -1) | |
{ | |
+ if (eval < -2) | |
+ ++inactive; | |
+ else | |
+ ++close; | |
// update bound. | |
bp = l; | |
break; | |
} | |
- | |
- // Update cache eval value. | |
- evalOld[i] = eval; | |
+ ++active; | |
cost += regularization * (1 + eval); | |
@@ -468,12 +499,18 @@ double LMNNFunction<MetricType>::EvaluateWithGradient( | |
} | |
} | |
} | |
+ Timer::Stop("lmnn_function_compute_evals"); | |
gradient = 2 * transformation * ((1 - regularization) * cij + | |
regularization * cil); | |
+ std::cout << "Evaluate(): " << prunes << " pruned of " << checks << ", " | |
+ << "transformation diff " << transformationDiff << ". Active: " << active | |
+ << ", close " << close << ", inactive " << inactive << ".\n"; | |
+ | |
// Update cache transformation matrix. | |
transformationOld = transformation; | |
+ Timer::Stop("lmnn_function_evaluate_with_gradient"); | |
return cost; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment