Last active
March 29, 2023 09:18
-
-
Save danielhaim1/288ec403ed3baf58e02f2aba6cfc1c36 to your computer and use it in GitHub Desktop.
Decision Tree Classification in JavaScript
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DecisionTreeNode { | |
constructor(featureIndex, threshold, leftNode, rightNode, prediction) { | |
this.featureIndex = featureIndex; | |
this.threshold = threshold; | |
this.leftNode = leftNode; | |
this.rightNode = rightNode; | |
this.prediction = prediction; | |
} | |
isLeaf() { | |
return this.prediction !== undefined; | |
} | |
} | |
class DecisionTreeClassifier { | |
constructor() { | |
this.rootNode = null; | |
} | |
train(features, labels) { | |
this.rootNode = this._buildTree(features, labels); | |
} | |
predict(features) { | |
let currentNode = this.rootNode; | |
while (!currentNode.isLeaf()) { | |
if (features[currentNode.featureIndex] < currentNode.threshold) { | |
currentNode = currentNode.leftNode; | |
} else { | |
currentNode = currentNode.rightNode; | |
} | |
} | |
return currentNode.prediction; | |
} | |
_buildTree(features, labels) { | |
const predictions = this._countLabels(labels); | |
if (Object.keys(predictions).length === 1) { | |
return new DecisionTreeNode(null, null, null, null, Object.keys(predictions)[0]); | |
} | |
if (features.length === 0) { | |
return new DecisionTreeNode(null, null, null, null, this._mostCommonLabel(labels)); | |
} | |
const bestSplit = this._findBestSplit(features, labels); | |
if (bestSplit.leftLabels.length === 0 || bestSplit.rightLabels.length === 0) { | |
return new DecisionTreeNode(null, null, null, null, this._mostCommonLabel(labels)); | |
} | |
const leftNode = this._buildTree(bestSplit.leftFeatures, bestSplit.leftLabels); | |
const rightNode = this._buildTree(bestSplit.rightFeatures, bestSplit.rightLabels); | |
return new DecisionTreeNode(bestSplit.featureIndex, bestSplit.threshold, leftNode, rightNode, null); | |
} | |
_findBestSplit(features, labels) { | |
let bestGain = 0; | |
let bestSplit = null; | |
for (let featureIndex = 0; featureIndex < features[0].length; featureIndex++) { | |
const values = features.map(row => row[featureIndex]); | |
const uniqueValues = [...new Set(values)]; | |
for (let threshold of uniqueValues) { | |
const split = this._splitData(features, labels, featureIndex, threshold); | |
if (split.leftLabels.length === 0 || split.rightLabels.length === 0) { | |
continue; | |
} | |
const gain = this._informationGain(labels, split.leftLabels, split.rightLabels); | |
if (gain > bestGain) { | |
bestGain = gain; | |
bestSplit = split; | |
bestSplit.featureIndex = featureIndex; | |
bestSplit.threshold = threshold; | |
} | |
} | |
} | |
return bestSplit; | |
} | |
_splitData(features, labels, featureIndex, threshold) { | |
const leftFeatures = []; | |
const leftLabels = []; | |
const rightFeatures = []; | |
const rightLabels = []; | |
for (let i = 0; i < features.length; i++) { | |
if (features[i][featureIndex] < threshold) { | |
leftFeatures.push(features[i]); | |
leftLabels.push(labels[i]); | |
} else { | |
rightFeatures.push(features[i]); | |
rightLabels.push(labels[i]); | |
} | |
} | |
return { | |
leftFeatures, | |
leftLabels, | |
rightFeatures, | |
rightLabels, | |
}; | |
} | |
_countLabels(labels) { | |
let counts = {}; | |
labels.forEach(label => { | |
counts[label] = (counts[label] || 0) + 1; | |
}); | |
return counts; | |
} | |
_predict(node, example) { | |
if (node.type === 'leaf') { | |
return node.output; | |
} | |
const attr = node.attribute; | |
const val = example[attr]; | |
if (val === undefined) { | |
return node.defaultOutput; | |
} | |
const childNode = node.children.find(child => child.rule(val)); | |
if (!childNode) { | |
return node.defaultOutput; | |
} | |
return this._predict(childNode, example); | |
} | |
predict(example) { | |
return this._predict(this.tree, example); | |
} | |
} | |
// the data point [10, 1, 0] represents a person who prefers their coffee to be very | |
// sweet (a sweetness level of 10), slightly bitter (a bitterness level of 1), and | |
// with no creaminess (a creaminess level of 0). | |
// This array of data points is used to train and test the K-means clustering algorithm | |
// in order to group similar preferences for coffee together. | |
const coffeeData = [ | |
[1, 0, 0], | |
[2, 1, 0], | |
[3, 1, 1], | |
[4, 1, 1], | |
[5, 0, 0], | |
[6, 1, 0], | |
[7, 1, 1], | |
[8, 0, 0], | |
[9, 1, 1], | |
[10, 1, 0], | |
]; | |
const coffeeLabels = [ | |
"latte", | |
"cappuccino", | |
"espresso", | |
"espresso", | |
"latte", | |
"cappuccino", | |
"espresso", | |
"latte", | |
"espresso", | |
"cappuccino", | |
]; | |
const coffeeTree = new DecisionTreeClassifier(); | |
coffeeTree.train(coffeeData, coffeeLabels); | |
const customerPreferences = [3, 1, 0]; | |
const predictedCoffee = coffeeTree.predict(customerPreferences); | |
console.log(predictedCoffee); // Output: "espresso" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment