Skip to content

Instantly share code, notes, and snippets.

@rcurtin
Created July 10, 2018 15:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rcurtin/c29a7a513ac070aa407a000aa5a7a80b to your computer and use it in GitHub Desktop.
Save rcurtin/c29a7a513ac070aa407a000aa5a7a80b to your computer and use it in GitHub Desktop.
diff --git a/src/mlpack/methods/lmnn/lmnn_function_impl.hpp b/src/mlpack/methods/lmnn/lmnn_function_impl.hpp
index b9fb88ad6..5de8e8279 100644
--- a/src/mlpack/methods/lmnn/lmnn_function_impl.hpp
+++ b/src/mlpack/methods/lmnn/lmnn_function_impl.hpp
@@ -193,6 +193,8 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
begin, batchSize);
}
+ size_t prunes = 0;
+ size_t checks = 0;
for (size_t i = begin; i < begin + batchSize; i++)
{
for (size_t j = 0; j < k ; j++)
@@ -211,6 +213,7 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
// Calculate cost due to {data point, target neighbors, impostors}
// triplets.
double eval = 0;
+ ++checks;
if (transformationOld.n_elem != 0)
{
eval = evalOld[i] + transformationDiff * (norm(targetNeighbors(j, i)) +
@@ -232,6 +235,10 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
transformedDataset.col(impostors(l, i)));
}
}
+ else
+ {
+ ++prunes;
+ }
}
else
{
@@ -256,6 +263,9 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
}
}
+ std::cout << "Evaluate(): " << prunes << " pruned of " << checks << ", "
+ << "transformation diff " << transformationDiff << ".\n";
+
// Update cache transformation matrix.
transformationOld = transformation;
@@ -373,6 +383,7 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
const arma::mat& transformation,
GradType& gradient)
{
+ Timer::Start("lmnn_function_evaluate_with_gradient");
double cost = 0;
// Apply metric over dataset.
@@ -399,6 +410,12 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
// Calculate gradient due to impostors.
arma::mat cil = arma::zeros(dataset.n_rows, dataset.n_rows);
+ size_t prunes = 0;
+ size_t checks = 0;
+ size_t active = 0;
+ size_t close = 0;
+ size_t inactive = 0;
+ Timer::Start("lmnn_function_compute_evals");
for (size_t i = 0; i < dataset.n_cols; i++)
{
for (size_t j = 0; j < k ; j++)
@@ -417,10 +434,15 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
// Calculate cost due to {data point, target neighbors, impostors}
// triplets.
double eval = 0;
+ ++checks;
if (transformationOld.n_elem != 0)
{
eval = evalOld[i] + transformationDiff * (norm(targetNeighbors(j, i)) +
norm(impostors(l, i)) + 2 * norm(i));
+// std::cout << "eval bound of point " << i << " with l " << l << ": " <<
+//eval << "; tdiff " << transformationDiff << " norm target " <<
+//norm(targetNeighbors(j, i)) << ", norm impostors " << norm(impostors(l, i)) <<
+//", norm i " << norm(i) << "; evalOld " << evalOld[i] << "\n";
if (eval > -1)
{
// Calculate exact eval.
@@ -438,6 +460,10 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
transformedDataset.col(impostors(l, i)));
}
}
+ else
+ {
+ ++prunes;
+ }
}
else
{
@@ -446,16 +472,21 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
distance(l, i);
}
+ // Update cache eval value.
+ evalOld[i] = eval;
+
// Check bounding condition.
if (eval <= -1)
{
+ if (eval < -2)
+ ++inactive;
+ else
+ ++close;
// update bound.
bp = l;
break;
}
-
- // Update cache eval value.
- evalOld[i] = eval;
+ ++active;
cost += regularization * (1 + eval);
@@ -468,12 +499,18 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
}
}
}
+ Timer::Stop("lmnn_function_compute_evals");
gradient = 2 * transformation * ((1 - regularization) * cij +
regularization * cil);
+ std::cout << "Evaluate(): " << prunes << " pruned of " << checks << ", "
+ << "transformation diff " << transformationDiff << ". Active: " << active
+ << ", close " << close << ", inactive " << inactive << ".\n";
+
// Update cache transformation matrix.
transformationOld = transformation;
+ Timer::Stop("lmnn_function_evaluate_with_gradient");
return cost;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment