Skip to content

Instantly share code, notes, and snippets.

@rcurtin
Created September 14, 2018 21:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rcurtin/ecde14f3f9e459e4fb5d756052738466 to your computer and use it in GitHub Desktop.
Save rcurtin/ecde14f3f9e459e4fb5d756052738466 to your computer and use it in GitHub Desktop.
diff --git a/src/mlpack/methods/lmnn/lmnn_function.hpp b/src/mlpack/methods/lmnn/lmnn_function.hpp
index fab543e2c..ee2a5144b 100644
--- a/src/mlpack/methods/lmnn/lmnn_function.hpp
+++ b/src/mlpack/methods/lmnn/lmnn_function.hpp
@@ -257,7 +257,7 @@ class LMNNFunction
const size_t begin,
const size_t batchSize);
// Recalculate impostors.
- inline void ReCalculateImpostors(const arma::mat& transformedDataset,
+ inline bool ReCalculateImpostors(const arma::mat& transformedDataset,
double transformationDiff);
};
diff --git a/src/mlpack/methods/lmnn/lmnn_function_impl.hpp b/src/mlpack/methods/lmnn/lmnn_function_impl.hpp
index 21c35f532..00e8bc63c 100644
--- a/src/mlpack/methods/lmnn/lmnn_function_impl.hpp
+++ b/src/mlpack/methods/lmnn/lmnn_function_impl.hpp
@@ -89,7 +89,7 @@ LMNNFunction<MetricType>::LMNNFunction(const arma::mat& dataset,
}
constraint.TargetNeighbors(targetNeighbors, dataset, labels, norm);
- constraint.Impostors(impostors, dataset, labels, norm);
+ constraint.Impostors(impostors, distance, dataset, labels, norm);
// Precalculate and save the gradient due to target neighbors.
Precalculate();
@@ -218,7 +218,7 @@ inline void LMNNFunction<MetricType>::TransDiff(
// Recalculate impostors.
template<typename MetricType>
-inline void LMNNFunction<MetricType>::ReCalculateImpostors(
+inline bool LMNNFunction<MetricType>::ReCalculateImpostors(
const arma::mat& transformedDataset,
double transformationDiff)
{
@@ -248,13 +248,17 @@ inline void LMNNFunction<MetricType>::ReCalculateImpostors(
constraint.Impostors(impostors, distance, transformedDataset, labels,
norm);
}
+ return true;
}
else if (iteration++ % range == 0)
{
// Re-calculate impostors on transformed dataset.
constraint.Impostors(impostors, distance, transformedDataset, labels,
norm);
+ return true;
}
+
+ return false;
}
//! Evaluate cost over whole dataset.
@@ -273,7 +277,8 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation)
transformationDiff = arma::norm(transformation - transformationOld);
}
- ReCalculateImpostors(transformedDataset,transformationDiff);
+ bool didRecalculate = ReCalculateImpostors(transformedDataset,
+ transformationDiff);
for (size_t i = 0; i < dataset.n_cols; i++)
{
@@ -309,7 +314,7 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation)
// Calculate exact eval value.
if (eval > -1)
{
- if (iteration - 1 % range == 0)
+ if (didRecalculate)
{
eval = metric.Evaluate(transformedDataset.col(i),
transformedDataset.col(targetNeighbors(j, i))) -
@@ -369,6 +374,7 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
// Apply metric over dataset.
transformedDataset = transformation * dataset;
+ bool didRecalculate = false;
if (recalculate)
{
// Re-calculate impostors on transformed dataset.
@@ -376,6 +382,7 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
norm);
// Set recalculate to false.
recalculate = false;
+ didRecalculate = true;
}
for (size_t i = begin; i < begin + batchSize; i++)
@@ -412,7 +419,7 @@ double LMNNFunction<MetricType>::Evaluate(const arma::mat& transformation,
// Calculate exact eval value.
if (eval > -1)
{
- if (iteration - 1 % range == 0)
+ if (didRecalculate)
{
eval = metric.Evaluate(transformedDataset.col(i),
transformedDataset.col(targetNeighbors(j, i))) -
@@ -475,7 +482,8 @@ void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
transformationDiff = arma::norm(transformation - transformationOld);
}
- ReCalculateImpostors(transformedDataset,transformationDiff);
+ bool didRecalculate = ReCalculateImpostors(transformedDataset,
+ transformationDiff);
gradient.zeros(transformation.n_rows, transformation.n_cols);
@@ -510,7 +518,7 @@ void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
// Calculate exact eval value.
if (eval > -1)
{
- if (iteration - 1 % range == 0)
+ if (didRecalculate)
{
eval = metric.Evaluate(transformedDataset.col(i),
transformedDataset.col(targetNeighbors(j, i))) -
@@ -576,6 +584,7 @@ void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
std::map<size_t, double> transformationDiffs;
TransDiff(transformationDiffs, transformation, begin, batchSize);
+ bool didRecalculate = false;
if (recalculate)
{
// Re-calculate impostors on transformed dataset.
@@ -583,6 +592,7 @@ void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
norm);
// Set recalculate to false.
recalculate = false;
+ didRecalculate = true;
}
gradient.zeros(transformation.n_rows, transformation.n_cols);
@@ -622,7 +632,7 @@ void LMNNFunction<MetricType>::Gradient(const arma::mat& transformation,
// Calculate exact eval value.
if (eval > -1)
{
- if (iteration - 1 % range == 0)
+ if (didRecalculate)
{
eval = metric.Evaluate(transformedDataset.col(i),
transformedDataset.col(targetNeighbors(j, i))) -
@@ -694,7 +704,8 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
transformationDiff = arma::norm(transformation - transformationOld);
}
- ReCalculateImpostors(transformedDataset,transformationDiff);
+ bool didRecalculate = ReCalculateImpostors(transformedDataset,
+ transformationDiff);
gradient.zeros(transformation.n_rows, transformation.n_cols);
@@ -737,7 +748,7 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
// Calculate exact eval value.
if (eval > -1)
{
- if (iteration - 1 % range == 0)
+ if (didRecalculate)
{
eval = metric.Evaluate(transformedDataset.col(i),
transformedDataset.col(targetNeighbors(j, i))) -
@@ -802,6 +813,7 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
// Apply metric over dataset.
transformedDataset = transformation * dataset;
+ bool didRecalculate = false;
if (recalculate)
{
// Re-calculate impostors on transformed dataset.
@@ -809,6 +821,7 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
norm);
// Set recalculate to false.
recalculate = false;
+ didRecalculate = true;
}
gradient.zeros(transformation.n_rows, transformation.n_cols);
@@ -853,7 +866,7 @@ double LMNNFunction<MetricType>::EvaluateWithGradient(
// Calculate exact eval value.
if (eval > -1)
{
- if (iteration - 1 % range == 0)
+ if (didRecalculate)
{
eval = metric.Evaluate(transformedDataset.col(i),
transformedDataset.col(targetNeighbors(j, i))) -
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment