Created
March 8, 2017 06:47
-
-
Save MikeLing/2e1056c873eecbf845c14b318ad4a17f to your computer and use it in GitHub Desktop.
Weighted-learning-for-decision-trees.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/mlpack/methods/decision_tree/all_categorical_split.hpp b/src/mlpack/methods/decision_tree/all_categorical_split.hpp | |
index 23af2c8ca..486320eb1 100644 | |
--- a/src/mlpack/methods/decision_tree/all_categorical_split.hpp | |
+++ b/src/mlpack/methods/decision_tree/all_categorical_split.hpp | |
@@ -47,7 +47,7 @@ class AllCategoricalSplit | |
* @param aux Auxiliary split information, which may be modified on a | |
* successful split. | |
*/ | |
- template<typename VecType> | |
+ template<bool UseWeights, typename VecType, typename WeightVecType> | |
static double SplitIfBetter( | |
const double bestGain, | |
const VecType& data, | |
@@ -55,6 +55,7 @@ class AllCategoricalSplit | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
const size_t minimumLeafSize, | |
+ const WeightVecType& weights, | |
arma::Col<typename VecType::elem_type>& classProbabilities, | |
AuxiliarySplitInfo<typename VecType::elem_type>& aux); | |
diff --git a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp | |
index 3b7941bba..deb9ed08b 100644 | |
--- a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp | |
+++ b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp | |
@@ -11,7 +11,7 @@ namespace mlpack { | |
namespace tree { | |
template<typename FitnessFunction> | |
-template<typename VecType> | |
+template<bool UseWeights, typename VecType, typename WeightVecType> | |
double AllCategoricalSplit<FitnessFunction>::SplitIfBetter( | |
const double bestGain, | |
const VecType& data, | |
@@ -19,6 +19,7 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter( | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
const size_t minimumLeafSize, | |
+ const WeightVecType& weights, | |
arma::Col<typename VecType::elem_type>& classProbabilities, | |
AuxiliarySplitInfo<typename VecType::elem_type>& /* aux */) | |
{ | |
@@ -53,8 +54,8 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter( | |
{ | |
// Calculate the gain of this child. | |
const double childPct = double(counts[i]) / double(data.n_elem); | |
- const double childGain = FitnessFunction::Evaluate(childLabels[i], | |
- numClasses); | |
+ const double childGain = FitnessFunction::Evaluate<UseWeights>(childLabels[i], | |
+ numClasses, weights); | |
overallGain += childPct * childGain; | |
} | |
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp | |
index 254bdae5b..787d9e14a 100644 | |
--- a/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp | |
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp | |
@@ -45,13 +45,14 @@ class BestBinaryNumericSplit | |
* @param aux Auxiliary split information, which may be modified on a | |
* successful split. | |
*/ | |
- template<typename VecType> | |
+ template<bool UseWeights, typename VecType, typename WeightVecType> | |
static double SplitIfBetter( | |
const double bestGain, | |
const VecType& data, | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
const size_t minimumLeafSize, | |
+ const WeightVecType& weights, | |
arma::Col<typename VecType::elem_type>& classProbabilities, | |
AuxiliarySplitInfo<typename VecType::elem_type>& aux); | |
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp | |
index 154a78a4a..0581d5386 100644 | |
--- a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp | |
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp | |
@@ -11,13 +11,14 @@ namespace mlpack { | |
namespace tree { | |
template<typename FitnessFunction> | |
-template<typename VecType> | |
+template<bool UseWeights, typename VecType, typename WeightVecType> | |
double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter( | |
const double bestGain, | |
const VecType& data, | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
const size_t minimumLeafSize, | |
+ const WeightVecType& weights, | |
arma::Col<typename VecType::elem_type>& classProbabilities, | |
AuxiliarySplitInfo<typename VecType::elem_type>& /* aux */) | |
{ | |
@@ -42,10 +43,10 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter( | |
continue; | |
// Calculate the gain for the left and right child. | |
- const double leftGain = FitnessFunction::Evaluate(sortedLabels.subvec(0, | |
- index - 1), numClasses); | |
- const double rightGain = FitnessFunction::Evaluate(sortedLabels.subvec( | |
- index, sortedLabels.n_elem - 1), numClasses); | |
+ const double leftGain = FitnessFunction::Evaluate<UseWeights>(sortedLabels.subvec(0, | |
+ index - 1), numClasses, weights); | |
+ const double rightGain = FitnessFunction::Evaluate<UseWeights>(sortedLabels.subvec( | |
+ index, sortedLabels.n_elem - 1), numClasses, weights); | |
// Calculate the fraction of points in the left and right children. | |
const double leftRatio = double(index) / double(sortedLabels.n_elem); | |
diff --git a/src/mlpack/methods/decision_tree/decision_tree.hpp b/src/mlpack/methods/decision_tree/decision_tree.hpp | |
index f42e3b673..285b3ce75 100644 | |
--- a/src/mlpack/methods/decision_tree/decision_tree.hpp | |
+++ b/src/mlpack/methods/decision_tree/decision_tree.hpp | |
@@ -129,13 +129,15 @@ class DecisionTree : | |
* @param datasetInfo Type information for each dimension. | |
* @param labels Labels for each training point. | |
* @param numClasses Number of classes in the dataset. | |
+ * @param weights Weights of all the labels | |
* @param minimumLeafSize Minimum number of points in each leaf node. | |
*/ | |
- template<typename MatType> | |
+ template<bool UseWeights, typename MatType, typename WeightVecType> | |
void Train(const MatType& data, | |
const data::DatasetInfo& datasetInfo, | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
+ const WeightVecType& weights, | |
const size_t minimumLeafSize = 10); | |
/** | |
@@ -147,12 +149,14 @@ class DecisionTree : | |
* @param data Dataset to train on. | |
* @param labels Labels for each training point. | |
* @param numClasses Number of classes in the dataset. | |
+ * @param weights Weights of all the labels | |
* @param minimumLeafSize Minimum number of points in each leaf node. | |
*/ | |
- template<typename MatType> | |
+ template<bool UseWeights, typename MatType, typename WeightVecType> | |
void Train(const MatType& data, | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
+ const WeightVecType& weights, | |
const size_t minimumLeafSize = 10); | |
/** | |
diff --git a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp | |
index 7e6ab1ffa..7b8b8a8cc 100644 | |
--- a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp | |
+++ b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp | |
@@ -10,7 +10,7 @@ | |
namespace mlpack { | |
namespace tree { | |
-//! Construct and train. | |
+//! Construct and train without weight | |
template<typename FitnessFunction, | |
template<typename> class NumericSplitType, | |
template<typename> class CategoricalSplitType, | |
@@ -27,11 +27,13 @@ DecisionTree<FitnessFunction, | |
const size_t numClasses, | |
const size_t minimumLeafSize) | |
{ | |
+ // Pass to unweighted training function. | |
+ arma::rowvec weights; | |
// Pass off work to the Train() method. | |
- Train(data, datasetInfo, labels, numClasses, minimumLeafSize); | |
+ Train<false>(data, datasetInfo, labels, numClasses, weights, minimumLeafSize); | |
} | |
-//! Construct and train. | |
+//! Construct and train without weight | |
template<typename FitnessFunction, | |
template<typename> class NumericSplitType, | |
template<typename> class CategoricalSplitType, | |
@@ -47,8 +49,10 @@ DecisionTree<FitnessFunction, | |
const size_t numClasses, | |
const size_t minimumLeafSize) | |
{ | |
+ // Pass to unweighted training function. | |
+ arma::rowvec weights; | |
// Pass off work to the Train() method. | |
- Train(data, labels, numClasses, minimumLeafSize); | |
+ Train<false>(data, labels, numClasses, weights, minimumLeafSize); | |
} | |
//! Construct, don't train. | |
@@ -211,7 +215,7 @@ template<typename FitnessFunction, | |
template<typename> class CategoricalSplitType, | |
typename ElemType, | |
bool NoRecursion> | |
-template<typename MatType> | |
+template<bool UseWeights, typename MatType, typename WeightVecType> | |
void DecisionTree<FitnessFunction, | |
NumericSplitType, | |
CategoricalSplitType, | |
@@ -220,6 +224,7 @@ void DecisionTree<FitnessFunction, | |
const data::DatasetInfo& datasetInfo, | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
+ const WeightVecType& weights, | |
const size_t minimumLeafSize) | |
{ | |
// Clear children if needed. | |
@@ -232,7 +237,7 @@ void DecisionTree<FitnessFunction, | |
// numericAux and categoricalAux (and clear them later if we make not split), | |
// and use classProbabilities as auxiliary information. Later we'll overwrite | |
// classProbabilities to the empirical class probabilities if we do not split. | |
- double bestGain = FitnessFunction::Evaluate(labels, numClasses); | |
+ double bestGain = FitnessFunction::Evaluate<UseWeights>(labels, numClasses, weights); | |
size_t bestDim = datasetInfo.Dimensionality(); // This means "no split". | |
for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i) | |
{ | |
@@ -332,7 +337,7 @@ template<typename FitnessFunction, | |
template<typename> class CategoricalSplitType, | |
typename ElemType, | |
bool NoRecursion> | |
-template<typename MatType> | |
+template<bool UseWeights, typename MatType, typename WeightVecType> | |
void DecisionTree<FitnessFunction, | |
NumericSplitType, | |
CategoricalSplitType, | |
@@ -340,6 +345,7 @@ void DecisionTree<FitnessFunction, | |
NoRecursion>::Train(const MatType& data, | |
const arma::Row<size_t>& labels, | |
const size_t numClasses, | |
+ const WeightVecType& weights, | |
const size_t minimumLeafSize) | |
{ | |
// Clear children if needed. | |
@@ -355,7 +361,7 @@ void DecisionTree<FitnessFunction, | |
// later if we don't make a split), and use classProbabilities as auxiliary | |
// information. Later we'll overwrite classProbabilities to the empirical | |
// class probabilities if we do not split. | |
- double bestGain = FitnessFunction::Evaluate(labels, numClasses); | |
+ double bestGain = FitnessFunction::Evaluate<UseWeights>(labels, numClasses, weights); | |
size_t bestDim = data.n_rows; // This means "no split". | |
for (size_t i = 0; i < data.n_rows; ++i) | |
{ | |
diff --git a/src/mlpack/methods/decision_tree/gini_gain.hpp b/src/mlpack/methods/decision_tree/gini_gain.hpp | |
index c1f08da78..b60a53f97 100644 | |
--- a/src/mlpack/methods/decision_tree/gini_gain.hpp | |
+++ b/src/mlpack/methods/decision_tree/gini_gain.hpp | |
@@ -34,25 +34,48 @@ class GiniGain | |
* @param labels Set of labels to evaluate Gini impurity on. | |
* @param numClasses Number of classes in the dataset. | |
*/ | |
- template<typename RowType> | |
+ template<bool UseWeights, typename RowType, typename WeightVecType> | |
static double Evaluate(const RowType& labels, | |
- const size_t numClasses) | |
+ const size_t numClasses, | |
+ const WeightVecType& weights) | |
{ | |
// Corner case: if there are no elements, the impurity is zero. | |
if (labels.n_elem == 0) | |
return 0.0; | |
+ // Calculate the Gini impurity of the un-split node. | |
+ double impurity = 0.0; | |
+ | |
arma::Col<size_t> counts(numClasses); | |
counts.zeros(); | |
- for (size_t i = 0; i < labels.n_elem; ++i) | |
- counts[labels[i]]++; | |
- // Calculate the Gini impurity of the un-split node. | |
- double impurity = 0.0; | |
- for (size_t i = 0; i < numClasses; ++i) | |
+ if (UseWeights) | |
{ | |
- const double f = ((double) counts[i] / (double) labels.n_elem); | |
- impurity += f * (1.0 - f); | |
+ // sum all the weights up | |
+ double accWeights = 0.0; | |
+ | |
+ for (size_t i=0; i < labels.n_elem; ++i) | |
+ { | |
+ counts[labels[i]] += weights[i]; | |
+ accWeights = weights[i]; | |
+ } | |
+ | |
+ for (size_t i = 0; i < numClasses; ++i) | |
+ { | |
+ const double f = ((double) counts[i] / accWeights); | |
+ impurity += f * (1.0 - f); | |
+ } | |
+ } | |
+ else | |
+ { | |
+ for (size_t i = 0; i < labels.n_elem; ++i) | |
+ counts[labels[i]]++; | |
+ | |
+ for (size_t i = 0; i < numClasses; ++i) | |
+ { | |
+ const double f = ((double) counts[i] / (double) labels.n_elem); | |
+ impurity += f * (1.0 - f); | |
+ } | |
} | |
return -impurity; | |
diff --git a/src/mlpack/methods/decision_tree/information_gain.hpp b/src/mlpack/methods/decision_tree/information_gain.hpp | |
index 2dbf814dd..21817f18b 100644 | |
--- a/src/mlpack/methods/decision_tree/information_gain.hpp | |
+++ b/src/mlpack/methods/decision_tree/information_gain.hpp | |
@@ -31,26 +31,51 @@ class InformationGain | |
* @param labels Labels of the dataset. | |
* @param numClasses Number of classes in the dataset. | |
*/ | |
+ template<bool UseWeights> | |
static double Evaluate(const arma::Row<size_t>& labels, | |
- const size_t numClasses) | |
+ const size_t numClasses, | |
+ const arma::Row<double>& weights) | |
{ | |
// Edge case: if there are no elements, the gain is zero. | |
if (labels.n_elem == 0) | |
return 0.0; | |
+ // Calculate the information gain. | |
+ double gain = 0.0; | |
+ | |
// Count the number of elements in each class. | |
arma::Col<size_t> counts(numClasses); | |
counts.zeros(); | |
- for (size_t i = 0; i < labels.n_elem; ++i) | |
- counts[labels[i]]++; | |
- // Calculate the information gain. | |
- double gain = 0.0; | |
- for (size_t i = 0; i < numClasses; ++i) | |
+ if (UseWeights) | |
{ | |
- const double f = ((double) counts[i] / (double) labels.n_elem); | |
- if (f > 0.0) | |
+ // sum all the weights up | |
+ double accWeights = 0.0; | |
+ | |
+ for (size_t i=0; i < labels.n_elem; ++i) | |
+ { | |
+ counts[labels[i]] += weights[i]; | |
+ accWeights = weights[i]; | |
+ } | |
+ | |
+ for (size_t i = 0; i < numClasses; ++i) | |
+ { | |
+ const double f = ((double) counts[i] / accWeights); | |
+ if (f > 0.0) | |
+ gain += f * std::log2(f); | |
+ } | |
+ } | |
+ else | |
+ { | |
+ for (size_t i = 0; i < labels.n_elem; ++i) | |
+ counts[labels[i]]++; | |
+ | |
+ for (size_t i = 0; i < numClasses; ++i) | |
+ { | |
+ const double f = ((double) counts[i] / (double) labels.n_elem); | |
+ if (f > 0.0) | |
gain += f * std::log2(f); | |
+ } | |
} | |
return gain; | |
-- | |
2.11.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment