Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save MikeLing/2e1056c873eecbf845c14b318ad4a17f to your computer and use it in GitHub Desktop.
Save MikeLing/2e1056c873eecbf845c14b318ad4a17f to your computer and use it in GitHub Desktop.
Weighted-learning-for-decision-trees.patch
diff --git a/src/mlpack/methods/decision_tree/all_categorical_split.hpp b/src/mlpack/methods/decision_tree/all_categorical_split.hpp
index 23af2c8ca..486320eb1 100644
--- a/src/mlpack/methods/decision_tree/all_categorical_split.hpp
+++ b/src/mlpack/methods/decision_tree/all_categorical_split.hpp
@@ -47,7 +47,7 @@ class AllCategoricalSplit
* @param aux Auxiliary split information, which may be modified on a
* successful split.
*/
- template<typename VecType>
+ template<bool UseWeights, typename VecType, typename WeightVecType>
static double SplitIfBetter(
const double bestGain,
const VecType& data,
@@ -55,6 +55,7 @@ class AllCategoricalSplit
const arma::Row<size_t>& labels,
const size_t numClasses,
const size_t minimumLeafSize,
+ const WeightVecType& weights,
arma::Col<typename VecType::elem_type>& classProbabilities,
AuxiliarySplitInfo<typename VecType::elem_type>& aux);
diff --git a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
index 3b7941bba..deb9ed08b 100644
--- a/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
+++ b/src/mlpack/methods/decision_tree/all_categorical_split_impl.hpp
@@ -11,7 +11,7 @@ namespace mlpack {
namespace tree {
template<typename FitnessFunction>
-template<typename VecType>
+template<bool UseWeights, typename VecType, typename WeightVecType>
double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
const double bestGain,
const VecType& data,
@@ -19,6 +19,7 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
const arma::Row<size_t>& labels,
const size_t numClasses,
const size_t minimumLeafSize,
+ const WeightVecType& weights,
arma::Col<typename VecType::elem_type>& classProbabilities,
AuxiliarySplitInfo<typename VecType::elem_type>& /* aux */)
{
@@ -53,8 +54,8 @@ double AllCategoricalSplit<FitnessFunction>::SplitIfBetter(
{
// Calculate the gain of this child.
const double childPct = double(counts[i]) / double(data.n_elem);
- const double childGain = FitnessFunction::Evaluate(childLabels[i],
- numClasses);
+ const double childGain = FitnessFunction::Evaluate<UseWeights>(childLabels[i],
+ numClasses, weights);
overallGain += childPct * childGain;
}
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp
index 254bdae5b..787d9e14a 100644
--- a/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split.hpp
@@ -45,13 +45,14 @@ class BestBinaryNumericSplit
* @param aux Auxiliary split information, which may be modified on a
* successful split.
*/
- template<typename VecType>
+ template<bool UseWeights, typename VecType, typename WeightVecType>
static double SplitIfBetter(
const double bestGain,
const VecType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const size_t minimumLeafSize,
+ const WeightVecType& weights,
arma::Col<typename VecType::elem_type>& classProbabilities,
AuxiliarySplitInfo<typename VecType::elem_type>& aux);
diff --git a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
index 154a78a4a..0581d5386 100644
--- a/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
+++ b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
@@ -11,13 +11,14 @@ namespace mlpack {
namespace tree {
template<typename FitnessFunction>
-template<typename VecType>
+template<bool UseWeights, typename VecType, typename WeightVecType>
double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
const double bestGain,
const VecType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
const size_t minimumLeafSize,
+ const WeightVecType& weights,
arma::Col<typename VecType::elem_type>& classProbabilities,
AuxiliarySplitInfo<typename VecType::elem_type>& /* aux */)
{
@@ -42,10 +43,10 @@ double BestBinaryNumericSplit<FitnessFunction>::SplitIfBetter(
continue;
// Calculate the gain for the left and right child.
- const double leftGain = FitnessFunction::Evaluate(sortedLabels.subvec(0,
- index - 1), numClasses);
- const double rightGain = FitnessFunction::Evaluate(sortedLabels.subvec(
- index, sortedLabels.n_elem - 1), numClasses);
+ const double leftGain = FitnessFunction::Evaluate<UseWeights>(sortedLabels.subvec(0,
+ index - 1), numClasses, weights);
+ const double rightGain = FitnessFunction::Evaluate<UseWeights>(sortedLabels.subvec(
+ index, sortedLabels.n_elem - 1), numClasses, weights);
// Calculate the fraction of points in the left and right children.
const double leftRatio = double(index) / double(sortedLabels.n_elem);
diff --git a/src/mlpack/methods/decision_tree/decision_tree.hpp b/src/mlpack/methods/decision_tree/decision_tree.hpp
index f42e3b673..285b3ce75 100644
--- a/src/mlpack/methods/decision_tree/decision_tree.hpp
+++ b/src/mlpack/methods/decision_tree/decision_tree.hpp
@@ -129,13 +129,15 @@ class DecisionTree :
* @param datasetInfo Type information for each dimension.
* @param labels Labels for each training point.
* @param numClasses Number of classes in the dataset.
+ * @param weights Weights of all the labels
* @param minimumLeafSize Minimum number of points in each leaf node.
*/
- template<typename MatType>
+ template<bool UseWeights, typename MatType, typename WeightVecType>
void Train(const MatType& data,
const data::DatasetInfo& datasetInfo,
const arma::Row<size_t>& labels,
const size_t numClasses,
+ const WeightVecType& weights,
const size_t minimumLeafSize = 10);
/**
@@ -147,12 +149,14 @@ class DecisionTree :
* @param data Dataset to train on.
* @param labels Labels for each training point.
* @param numClasses Number of classes in the dataset.
+ * @param weights Weights of all the labels
* @param minimumLeafSize Minimum number of points in each leaf node.
*/
- template<typename MatType>
+ template<bool UseWeights, typename MatType, typename WeightVecType>
void Train(const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
+ const WeightVecType& weights,
const size_t minimumLeafSize = 10);
/**
diff --git a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
index 7e6ab1ffa..7b8b8a8cc 100644
--- a/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
+++ b/src/mlpack/methods/decision_tree/decision_tree_impl.hpp
@@ -10,7 +10,7 @@
namespace mlpack {
namespace tree {
-//! Construct and train.
+//! Construct and train without weight
template<typename FitnessFunction,
template<typename> class NumericSplitType,
template<typename> class CategoricalSplitType,
@@ -27,11 +27,13 @@ DecisionTree<FitnessFunction,
const size_t numClasses,
const size_t minimumLeafSize)
{
+ // Pass to unweighted training function.
+ arma::rowvec weights;
// Pass off work to the Train() method.
- Train(data, datasetInfo, labels, numClasses, minimumLeafSize);
+ Train<false>(data, datasetInfo, labels, numClasses, weights, minimumLeafSize);
}
-//! Construct and train.
+//! Construct and train without weight
template<typename FitnessFunction,
template<typename> class NumericSplitType,
template<typename> class CategoricalSplitType,
@@ -47,8 +49,10 @@ DecisionTree<FitnessFunction,
const size_t numClasses,
const size_t minimumLeafSize)
{
+ // Pass to unweighted training function.
+ arma::rowvec weights;
// Pass off work to the Train() method.
- Train(data, labels, numClasses, minimumLeafSize);
+ Train<false>(data, labels, numClasses, weights, minimumLeafSize);
}
//! Construct, don't train.
@@ -211,7 +215,7 @@ template<typename FitnessFunction,
template<typename> class CategoricalSplitType,
typename ElemType,
bool NoRecursion>
-template<typename MatType>
+template<bool UseWeights, typename MatType, typename WeightVecType>
void DecisionTree<FitnessFunction,
NumericSplitType,
CategoricalSplitType,
@@ -220,6 +224,7 @@ void DecisionTree<FitnessFunction,
const data::DatasetInfo& datasetInfo,
const arma::Row<size_t>& labels,
const size_t numClasses,
+ const WeightVecType& weights,
const size_t minimumLeafSize)
{
// Clear children if needed.
@@ -232,7 +237,7 @@ void DecisionTree<FitnessFunction,
// numericAux and categoricalAux (and clear them later if we make not split),
// and use classProbabilities as auxiliary information. Later we'll overwrite
// classProbabilities to the empirical class probabilities if we do not split.
- double bestGain = FitnessFunction::Evaluate(labels, numClasses);
+ double bestGain = FitnessFunction::Evaluate<UseWeights>(labels, numClasses, weights);
size_t bestDim = datasetInfo.Dimensionality(); // This means "no split".
for (size_t i = 0; i < datasetInfo.Dimensionality(); ++i)
{
@@ -332,7 +337,7 @@ template<typename FitnessFunction,
template<typename> class CategoricalSplitType,
typename ElemType,
bool NoRecursion>
-template<typename MatType>
+template<bool UseWeights, typename MatType, typename WeightVecType>
void DecisionTree<FitnessFunction,
NumericSplitType,
CategoricalSplitType,
@@ -340,6 +345,7 @@ void DecisionTree<FitnessFunction,
NoRecursion>::Train(const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
+ const WeightVecType& weights,
const size_t minimumLeafSize)
{
// Clear children if needed.
@@ -355,7 +361,7 @@ void DecisionTree<FitnessFunction,
// later if we don't make a split), and use classProbabilities as auxiliary
// information. Later we'll overwrite classProbabilities to the empirical
// class probabilities if we do not split.
- double bestGain = FitnessFunction::Evaluate(labels, numClasses);
+ double bestGain = FitnessFunction::Evaluate<UseWeights>(labels, numClasses, weights);
size_t bestDim = data.n_rows; // This means "no split".
for (size_t i = 0; i < data.n_rows; ++i)
{
diff --git a/src/mlpack/methods/decision_tree/gini_gain.hpp b/src/mlpack/methods/decision_tree/gini_gain.hpp
index c1f08da78..b60a53f97 100644
--- a/src/mlpack/methods/decision_tree/gini_gain.hpp
+++ b/src/mlpack/methods/decision_tree/gini_gain.hpp
@@ -34,25 +34,48 @@ class GiniGain
* @param labels Set of labels to evaluate Gini impurity on.
* @param numClasses Number of classes in the dataset.
*/
- template<typename RowType>
+ template<bool UseWeights, typename RowType, typename WeightVecType>
static double Evaluate(const RowType& labels,
- const size_t numClasses)
+ const size_t numClasses,
+ const WeightVecType& weights)
{
// Corner case: if there are no elements, the impurity is zero.
if (labels.n_elem == 0)
return 0.0;
+ // Calculate the Gini impurity of the un-split node.
+ double impurity = 0.0;
+
arma::Col<size_t> counts(numClasses);
counts.zeros();
- for (size_t i = 0; i < labels.n_elem; ++i)
- counts[labels[i]]++;
- // Calculate the Gini impurity of the un-split node.
- double impurity = 0.0;
- for (size_t i = 0; i < numClasses; ++i)
+ if (UseWeights)
{
- const double f = ((double) counts[i] / (double) labels.n_elem);
- impurity += f * (1.0 - f);
+ // sum all the weights up
+ double accWeights = 0.0;
+
+ for (size_t i=0; i < labels.n_elem; ++i)
+ {
+ counts[labels[i]] += weights[i];
+ accWeights = weights[i];
+ }
+
+ for (size_t i = 0; i < numClasses; ++i)
+ {
+ const double f = ((double) counts[i] / accWeights);
+ impurity += f * (1.0 - f);
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < labels.n_elem; ++i)
+ counts[labels[i]]++;
+
+ for (size_t i = 0; i < numClasses; ++i)
+ {
+ const double f = ((double) counts[i] / (double) labels.n_elem);
+ impurity += f * (1.0 - f);
+ }
}
return -impurity;
diff --git a/src/mlpack/methods/decision_tree/information_gain.hpp b/src/mlpack/methods/decision_tree/information_gain.hpp
index 2dbf814dd..21817f18b 100644
--- a/src/mlpack/methods/decision_tree/information_gain.hpp
+++ b/src/mlpack/methods/decision_tree/information_gain.hpp
@@ -31,26 +31,51 @@ class InformationGain
* @param labels Labels of the dataset.
* @param numClasses Number of classes in the dataset.
*/
+ template<bool UseWeights>
static double Evaluate(const arma::Row<size_t>& labels,
- const size_t numClasses)
+ const size_t numClasses,
+ const arma::Row<double>& weights)
{
// Edge case: if there are no elements, the gain is zero.
if (labels.n_elem == 0)
return 0.0;
+ // Calculate the information gain.
+ double gain = 0.0;
+
// Count the number of elements in each class.
arma::Col<size_t> counts(numClasses);
counts.zeros();
- for (size_t i = 0; i < labels.n_elem; ++i)
- counts[labels[i]]++;
- // Calculate the information gain.
- double gain = 0.0;
- for (size_t i = 0; i < numClasses; ++i)
+ if (UseWeights)
{
- const double f = ((double) counts[i] / (double) labels.n_elem);
- if (f > 0.0)
+ // sum all the weights up
+ double accWeights = 0.0;
+
+ for (size_t i=0; i < labels.n_elem; ++i)
+ {
+ counts[labels[i]] += weights[i];
+ accWeights = weights[i];
+ }
+
+ for (size_t i = 0; i < numClasses; ++i)
+ {
+ const double f = ((double) counts[i] / accWeights);
+ if (f > 0.0)
+ gain += f * std::log2(f);
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < labels.n_elem; ++i)
+ counts[labels[i]]++;
+
+ for (size_t i = 0; i < numClasses; ++i)
+ {
+ const double f = ((double) counts[i] / (double) labels.n_elem);
+ if (f > 0.0)
gain += f * std::log2(f);
+ }
}
return gain;
--
2.11.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment