Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rcurtin/d8fd23ded2bda55277689265362c224f to your computer and use it in GitHub Desktop.
Save rcurtin/d8fd23ded2bda55277689265362c224f to your computer and use it in GitHub Desktop.
From 56aeb6499b17ffc5e4f82813d640a35ce7d5bb74 Mon Sep 17 00:00:00 2001
From: Ryan Curtin <ryan@ratml.org>
Date: Thu, 25 Jul 2019 20:27:28 -0400
Subject: [PATCH] Add move and copy operators to RectangleTree.
---
.../tree/rectangle_tree/rectangle_tree.hpp | 16 +++-
.../rectangle_tree/rectangle_tree_impl.hpp | 82 +++++++++++++++++++
.../neighbor_search/neighbor_search_impl.hpp | 7 +-
src/mlpack/tests/knn_test.cpp | 56 +++++++++++++
4 files changed, 157 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index a3093df17..014babec3 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -181,10 +181,24 @@ class RectangleTree
/**
* Create a rectangle tree by moving the other tree.
*
- * @param other The tree to be copied.
+ * @param other The tree to be moved.
*/
RectangleTree(RectangleTree&& other);
+ /**
+ * Copy the given rectangle tree.
+ *
+ * @param other The tree to be copied.
+ */
+ RectangleTree& operator=(const RectangleTree& other);
+
+ /**
+ * Take ownership of the given rectangle tree.
+ *
+ * @param other The tree to take ownership of.
+ */
+ RectangleTree& operator=(RectangleTree&& other);
+
/**
* Construct the tree from a boost::serialization archive.
*/
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 59a20e26f..aecc6dc07 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -234,6 +234,88 @@ RectangleTree(RectangleTree&& other) :
other.ownsDataset = false;
}
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType,
+ template<typename> class AuxiliaryInformationType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>&
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>::
+operator=(const RectangleTree& other)
+{
+ for (size_t i = 0; i < numChildren; i++)
+ delete children[i];
+
+ if (ownsDataset)
+ delete dataset;
+
+ maxNumChildren = other.MaxNumChildren();
+ minNumChildren = other.MinNumChildren();
+ numChildren = other.NumChildren();
+ children.resize(maxNumChildren + 1, NULL);
+ parent = NULL;
+ begin = other.Begin();
+ count = other.Count();
+ numDescendants = other.numDescendants;
+ maxLeafSize = other.MaxLeafSize();
+ minLeafSize = other.MinLeafSize();
+ bound = other.bound;
+ parentDistance = other.ParentDistance();
+ dataset = new MatType(*other.dataset);
+ ownsDataset = true;
+ points = other.points;
+ auxiliaryInfo = AuxiliaryInfoType(other.auxiliaryInfo, this, true);
+
+ if (numChildren > 0)
+ {
+ for (size_t i = 0; i < numChildren; i++)
+ children[i] = new RectangleTree(other.Child(i), true, this);
+ }
+
+ return *this;
+}
+
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType,
+ template<typename> class AuxiliaryInformationType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>&
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>::
+operator=(RectangleTree&& other)
+{
+ for (size_t i = 0; i < numChildren; i++)
+ delete children[i];
+
+ if (ownsDataset)
+ delete dataset;
+
+ maxNumChildren = other.MaxNumChildren();
+ minNumChildren = other.MinNumChildren();
+ numChildren = other.NumChildren();
+ children = std::move(other.children);
+ parent = other.Parent();
+ begin = other.Begin();
+ count = other.Count();
+ numDescendants = other.numDescendants;
+ maxLeafSize = other.MaxLeafSize();
+ minLeafSize = other.MinLeafSize();
+ bound = std::move(other.bound);
+ parentDistance = other.ParentDistance();
+ dataset = other.dataset;
+ ownsDataset = other.ownsDataset;
+ points = std::move(other.points);
+ auxiliaryInfo = std::move(other.auxiliaryInfo);
+
+ return *this;
+}
+
/**
* Construct the tree from a boost::serialization archive.
*/
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index d260171b6..ef81bf795 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -129,7 +129,9 @@ SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode,
// Build the tree on the empty dataset, if necessary.
if (mode != NAIVE_MODE)
{
- referenceTree = BuildTree<Tree>(*referenceSet, oldFromNewReferences);
+ referenceTree = BuildTree<Tree>(std::move(*referenceSet),
+ oldFromNewReferences);
+ delete referenceSet;
referenceSet = &referenceTree->Dataset();
}
}
@@ -181,8 +183,7 @@ SingleTreeTraversalType>::NeighborSearch(NeighborSearch&& other) :
treeNeedsReset(other.treeNeedsReset)
{
// Clear the other model.
- other.referenceSet = new MatType();
- other.referenceTree = BuildTree<Tree>(*other.referenceSet,
+ other.referenceTree = BuildTree<Tree>(std::move(MatType()),
other.oldFromNewReferences);
other.referenceSet = &other.referenceTree->Dataset();
other.searchMode = DUAL_TREE_MODE,
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index da961eb13..91c84ba04 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -1300,6 +1300,34 @@ BOOST_AUTO_TEST_CASE(CopyConstructorAndOperatorTest)
CheckMatrices(distances, distances3);
}
+/**
+ * Test the copy constructor and copy operator using the RectangleTree.
+ */
+BOOST_AUTO_TEST_CASE(CopyConstructorAndOperatorRTreeTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(5, 500);
+ typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat,
+ RTree> NeighborSearchType;
+ NeighborSearchType knn(std::move(dataset));
+
+ // Copy constructor and operator.
+ NeighborSearchType knn2(knn);
+ NeighborSearchType knn3 = knn;
+
+ // Get results.
+ arma::mat distances, distances2, distances3;
+ arma::Mat<size_t> neighbors, neighbors2, neighbors3;
+
+ knn.Search(3, neighbors, distances);
+ knn2.Search(3, neighbors2, distances2);
+ knn3.Search(3, neighbors3, distances3);
+
+ CheckMatrices(neighbors, neighbors2);
+ CheckMatrices(neighbors, neighbors3);
+ CheckMatrices(distances, distances2);
+ CheckMatrices(distances, distances3);
+}
+
/**
* Test the move constructor.
*/
@@ -1325,6 +1353,34 @@ BOOST_AUTO_TEST_CASE(MoveConstructorTest)
CheckMatrices(distances, distances2);
}
+/**
+ * Test the move constructor using R trees.
+ */
+BOOST_AUTO_TEST_CASE(MoveConstructorRTreeTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(5, 500);
+ typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat,
+ RTree> NeighborSearchType;
+ NeighborSearchType* knn = new NeighborSearchType(std::move(dataset));
+
+ // Get predictions.
+ arma::mat distances, distances2;
+ arma::Mat<size_t> neighbors, neighbors2;
+
+ knn->Search(3, neighbors, distances);
+
+ // Use move constructor.
+ NeighborSearchType knn2(std::move(*knn));
+
+ delete knn;
+
+ knn2.Search(3, neighbors2, distances2);
+
+ CheckMatrices(neighbors, neighbors2);
+ CheckMatrices(distances, distances2);
+}
+
+
/**
* Test the move operator.
*/
--
2.22.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment