Skip to content

Instantly share code, notes, and snippets.

@lynxnathan
Last active June 7, 2020 02:41
Show Gist options
  • Save lynxnathan/9834926bc0e42fac5d840e6f4e9261ee to your computer and use it in GitHub Desktop.
Save lynxnathan/9834926bc0e42fac5d840e6f4e9261ee to your computer and use it in GitHub Desktop.
Naive single decision tree implementation - wrote this on the context of fast.ai ml1 course
const csv = require('csv/lib/sync');
const fs = require('fs');
// Naive single decision tree implementation in javascript
// Train.csv can be acquired in kaggle's blue book for bulldozers competition
class DecisionTree {
constructor(csvData, independentVariables, dependentVariable, options) {
this.options = options;
this.depth = 0;
this.columns = csvData.shift();
this.data = csvData;
this.dependentVariableIndex = DecisionTree.getVariablesIndexes(this.columns, [dependentVariable]);
this.independentVariablesIndexes = DecisionTree.getVariablesIndexes(this.columns, independentVariables);
if (options.sampleSize > -1 && this.data.length > options.sampleSize) {
this.data = this.data.slice(0, options.sampleSize);
}
}
fit() {
let depth = 0;
const getBestSplitForBranch = (subtreeIndexes) => {
let bestSplit = {branchDiffusion: Infinity};
depth += 1;
this.independentVariablesIndexes.forEach(independentVariableIndex => {
const split = this.getBestSplitForVariable(subtreeIndexes, independentVariableIndex);
if (split.branchDiffusion < bestSplit.branchDiffusion) {
bestSplit = split;
bestSplit.splitVariable = this.columns[independentVariableIndex];
if (this.options.depth && depth <= this.options.depth) {
bestSplit.leftBranch = getBestSplitForBranch(split.leftBranchIndexes);
bestSplit.rightBranch = getBestSplitForBranch(split.rightBranchIndexes);
}
bestSplit.sampleSize = subtreeIndexes.length;
delete bestSplit.leftBranchIndexes;
delete bestSplit.rightBranchIndexes;
}
});
return bestSplit;
};
return getBestSplitForBranch(Array(this.data.length).fill().map((_, i) => i));
};
getBestSplitForVariable(subtreeIndexes, independentVariableIndex) {
const possibleValues = this.getValuesFromColumn(subtreeIndexes, independentVariableIndex);
let split = {branchDiffusion: Infinity, splitValue: null, leftBranchIndexes: [], rightBranchIndexes: []};
possibleValues.forEach(splitValue => {
const leftBranchIndexes = [];
const rightBranchIndexes = [];
subtreeIndexes.forEach(index => {
parseFloat(this.data[index][independentVariableIndex]) <= parseFloat(splitValue) ? leftBranchIndexes.push(index) : rightBranchIndexes.push(index);
});
let branchDiffusion = leftBranchIndexes.length * DecisionTree.standardDeviation(
leftBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex])));
branchDiffusion += rightBranchIndexes.length * DecisionTree.standardDeviation(
rightBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex])));
if (branchDiffusion < split.branchDiffusion) {
split = {branchDiffusion, leftBranchIndexes, rightBranchIndexes, splitValue};
}
});
return split;
};
static getVariablesIndexes(columns, variables) {
const indexes = [];
variables.forEach(column => {
const columnIndex = columns.indexOf(column);
if (columnIndex !== -1) {
indexes.push(columnIndex)
}
});
return indexes;
};
getValuesFromColumn(subtreeIndexes, columnIndex) {
return subtreeIndexes.map(index => this.data[index][columnIndex]);
};
static standardDeviation(values) {
const mean = values.reduce((sum, value) => sum + parseFloat(value), 0) / values.length;
const squaredMean = values.reduce((sum, value) => sum + Math.pow(value - mean, 2), 0) / values.length;
return Math.sqrt(squaredMean);
};
}
const csvData = csv.parse(fs.readFileSync('Train.csv'));
const INDEPENDENT_VARIABLES = ['MachineHoursCurrentMeter', 'YearMade'];
const DEPENDENT_VARIABLE = 'SalePrice';
console.log((new DecisionTree(csvData, INDEPENDENT_VARIABLES, DEPENDENT_VARIABLE, {sampleSize: 1000, depth: 1})).fit());
// Result:
//
// { branchDiffusion: 672.0239238184622,
// splitValue: '2178',
// splitVariable: 'MachineHoursCurrentMeter',
// leftBranch:
// { branchDiffusion: 299.86590963795163,
// splitValue: '2003',
// splitVariable: 'YearMade',
// sampleSize: 470 },
// rightBranch:
// { branchDiffusion: 349.137024345939,
// splitValue: '1997',
// splitVariable: 'YearMade',
// sampleSize: 530 },
// sampleSize: 1000 }
// Jeremy's (fast.ai) Python implementation's result:
// ml1 course – lesson3 notebook
// n: 1000; val:10.160352993311724; score:672.0239238184623; split:2178.0; var:MachineHoursCurrentMeter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment