Created
July 26, 2019 00:27
-
-
Save rcurtin/d8fd23ded2bda55277689265362c224f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From 56aeb6499b17ffc5e4f82813d640a35ce7d5bb74 Mon Sep 17 00:00:00 2001 | |
From: Ryan Curtin <ryan@ratml.org> | |
Date: Thu, 25 Jul 2019 20:27:28 -0400 | |
Subject: [PATCH] Add move and copy operators to RectangleTree. | |
--- | |
.../tree/rectangle_tree/rectangle_tree.hpp | 16 +++- | |
.../rectangle_tree/rectangle_tree_impl.hpp | 82 +++++++++++++++++++ | |
.../neighbor_search/neighbor_search_impl.hpp | 7 +- | |
src/mlpack/tests/knn_test.cpp | 56 +++++++++++++ | |
4 files changed, 157 insertions(+), 4 deletions(-) | |
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp | |
index a3093df17..014babec3 100644 | |
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp | |
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp | |
@@ -181,10 +181,24 @@ class RectangleTree | |
/** | |
* Create a rectangle tree by moving the other tree. | |
* | |
- * @param other The tree to be copied. | |
+ * @param other The tree to be moved. | |
*/ | |
RectangleTree(RectangleTree&& other); | |
+ /** | |
+ * Copy the given rectangle tree. | |
+ * | |
+ * @param other The tree to be copied. | |
+ */ | |
+ RectangleTree& operator=(const RectangleTree& other); | |
+ | |
+ /** | |
+ * Take ownership of the given rectangle tree. | |
+ * | |
+ * @param other The tree to take ownership of. | |
+ */ | |
+ RectangleTree& operator=(RectangleTree&& other); | |
+ | |
/** | |
* Construct the tree from a boost::serialization archive. | |
*/ | |
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp | |
index 59a20e26f..aecc6dc07 100644 | |
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp | |
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp | |
@@ -234,6 +234,88 @@ RectangleTree(RectangleTree&& other) : | |
other.ownsDataset = false; | |
} | |
+template<typename MetricType, | |
+ typename StatisticType, | |
+ typename MatType, | |
+ typename SplitType, | |
+ typename DescentType, | |
+ template<typename> class AuxiliaryInformationType> | |
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, | |
+ AuxiliaryInformationType>& | |
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, | |
+ AuxiliaryInformationType>:: | |
+operator=(const RectangleTree& other) | |
+{ | |
+ for (size_t i = 0; i < numChildren; i++) | |
+ delete children[i]; | |
+ | |
+ if (ownsDataset) | |
+ delete dataset; | |
+ | |
+ maxNumChildren = other.MaxNumChildren(); | |
+ minNumChildren = other.MinNumChildren(); | |
+ numChildren = other.NumChildren(); | |
+ children.resize(maxNumChildren + 1, NULL); | |
+ parent = NULL; | |
+ begin = other.Begin(); | |
+ count = other.Count(); | |
+ numDescendants = other.numDescendants; | |
+ maxLeafSize = other.MaxLeafSize(); | |
+ minLeafSize = other.MinLeafSize(); | |
+ bound = other.bound; | |
+ parentDistance = other.ParentDistance(); | |
+ dataset = new MatType(*other.dataset); | |
+ ownsDataset = true; | |
+ points = other.points; | |
+ auxiliaryInfo = AuxiliaryInfoType(other.auxiliaryInfo, this, true); | |
+ | |
+ if (numChildren > 0) | |
+ { | |
+ for (size_t i = 0; i < numChildren; i++) | |
+ children[i] = new RectangleTree(other.Child(i), true, this); | |
+ } | |
+ | |
+ return *this; | |
+} | |
+ | |
+template<typename MetricType, | |
+ typename StatisticType, | |
+ typename MatType, | |
+ typename SplitType, | |
+ typename DescentType, | |
+ template<typename> class AuxiliaryInformationType> | |
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, | |
+ AuxiliaryInformationType>& | |
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, | |
+ AuxiliaryInformationType>:: | |
+operator=(RectangleTree&& other) | |
+{ | |
+ for (size_t i = 0; i < numChildren; i++) | |
+ delete children[i]; | |
+ | |
+ if (ownsDataset) | |
+ delete dataset; | |
+ | |
+ maxNumChildren = other.MaxNumChildren(); | |
+ minNumChildren = other.MinNumChildren(); | |
+ numChildren = other.NumChildren(); | |
+ children = std::move(other.children); | |
+ parent = other.Parent(); | |
+ begin = other.Begin(); | |
+ count = other.Count(); | |
+ numDescendants = other.numDescendants; | |
+ maxLeafSize = other.MaxLeafSize(); | |
+ minLeafSize = other.MinLeafSize(); | |
+ bound = std::move(other.bound); | |
+ parentDistance = other.ParentDistance(); | |
+ dataset = other.dataset; | |
+ ownsDataset = other.ownsDataset; | |
+ points = std::move(other.points); | |
+ auxiliaryInfo = std::move(other.auxiliaryInfo); | |
+ | |
+ return *this; | |
+} | |
+ | |
/** | |
* Construct the tree from a boost::serialization archive. | |
*/ | |
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp | |
index d260171b6..ef81bf795 100644 | |
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp | |
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp | |
@@ -129,7 +129,9 @@ SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode, | |
// Build the tree on the empty dataset, if necessary. | |
if (mode != NAIVE_MODE) | |
{ | |
- referenceTree = BuildTree<Tree>(*referenceSet, oldFromNewReferences); | |
+ referenceTree = BuildTree<Tree>(std::move(*referenceSet), | |
+ oldFromNewReferences); | |
+ delete referenceSet; | |
referenceSet = &referenceTree->Dataset(); | |
} | |
} | |
@@ -181,8 +183,7 @@ SingleTreeTraversalType>::NeighborSearch(NeighborSearch&& other) : | |
treeNeedsReset(other.treeNeedsReset) | |
{ | |
// Clear the other model. | |
- other.referenceSet = new MatType(); | |
- other.referenceTree = BuildTree<Tree>(*other.referenceSet, | |
+ other.referenceTree = BuildTree<Tree>(std::move(MatType()), | |
other.oldFromNewReferences); | |
other.referenceSet = &other.referenceTree->Dataset(); | |
other.searchMode = DUAL_TREE_MODE, | |
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp | |
index da961eb13..91c84ba04 100644 | |
--- a/src/mlpack/tests/knn_test.cpp | |
+++ b/src/mlpack/tests/knn_test.cpp | |
@@ -1300,6 +1300,34 @@ BOOST_AUTO_TEST_CASE(CopyConstructorAndOperatorTest) | |
CheckMatrices(distances, distances3); | |
} | |
+/** | |
+ * Test the copy constructor and copy operator using the RectangleTree. | |
+ */ | |
+BOOST_AUTO_TEST_CASE(CopyConstructorAndOperatorRTreeTest) | |
+{ | |
+ arma::mat dataset = arma::randu<arma::mat>(5, 500); | |
+ typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat, | |
+ RTree> NeighborSearchType; | |
+ NeighborSearchType knn(std::move(dataset)); | |
+ | |
+ // Copy constructor and operator. | |
+ NeighborSearchType knn2(knn); | |
+ NeighborSearchType knn3 = knn; | |
+ | |
+ // Get results. | |
+ arma::mat distances, distances2, distances3; | |
+ arma::Mat<size_t> neighbors, neighbors2, neighbors3; | |
+ | |
+ knn.Search(3, neighbors, distances); | |
+ knn2.Search(3, neighbors2, distances2); | |
+ knn3.Search(3, neighbors3, distances3); | |
+ | |
+ CheckMatrices(neighbors, neighbors2); | |
+ CheckMatrices(neighbors, neighbors3); | |
+ CheckMatrices(distances, distances2); | |
+ CheckMatrices(distances, distances3); | |
+} | |
+ | |
/** | |
* Test the move constructor. | |
*/ | |
@@ -1325,6 +1353,34 @@ BOOST_AUTO_TEST_CASE(MoveConstructorTest) | |
CheckMatrices(distances, distances2); | |
} | |
+/** | |
+ * Test the move constructor using R trees. | |
+ */ | |
+BOOST_AUTO_TEST_CASE(MoveConstructorRTreeTest) | |
+{ | |
+ arma::mat dataset = arma::randu<arma::mat>(5, 500); | |
+ typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat, | |
+ RTree> NeighborSearchType; | |
+ NeighborSearchType* knn = new NeighborSearchType(std::move(dataset)); | |
+ | |
+ // Get predictions. | |
+ arma::mat distances, distances2; | |
+ arma::Mat<size_t> neighbors, neighbors2; | |
+ | |
+ knn->Search(3, neighbors, distances); | |
+ | |
+ // Use move constructor. | |
+ NeighborSearchType knn2(std::move(*knn)); | |
+ | |
+ delete knn; | |
+ | |
+ knn2.Search(3, neighbors2, distances2); | |
+ | |
+ CheckMatrices(neighbors, neighbors2); | |
+ CheckMatrices(distances, distances2); | |
+} | |
+ | |
+ | |
/** | |
* Test the move operator. | |
*/ | |
-- | |
2.22.0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment