Skip to content

Instantly share code, notes, and snippets.

@danielhaim1
Last active March 29, 2023 09:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielhaim1/288ec403ed3baf58e02f2aba6cfc1c36 to your computer and use it in GitHub Desktop.
Save danielhaim1/288ec403ed3baf58e02f2aba6cfc1c36 to your computer and use it in GitHub Desktop.
Decision Tree Classification in JavaScript
class DecisionTreeNode {
constructor(featureIndex, threshold, leftNode, rightNode, prediction) {
this.featureIndex = featureIndex;
this.threshold = threshold;
this.leftNode = leftNode;
this.rightNode = rightNode;
this.prediction = prediction;
}
isLeaf() {
return this.prediction !== undefined;
}
}
class DecisionTreeClassifier {
constructor() {
this.rootNode = null;
}
train(features, labels) {
this.rootNode = this._buildTree(features, labels);
}
predict(features) {
let currentNode = this.rootNode;
while (!currentNode.isLeaf()) {
if (features[currentNode.featureIndex] < currentNode.threshold) {
currentNode = currentNode.leftNode;
} else {
currentNode = currentNode.rightNode;
}
}
return currentNode.prediction;
}
_buildTree(features, labels) {
const predictions = this._countLabels(labels);
if (Object.keys(predictions).length === 1) {
return new DecisionTreeNode(null, null, null, null, Object.keys(predictions)[0]);
}
if (features.length === 0) {
return new DecisionTreeNode(null, null, null, null, this._mostCommonLabel(labels));
}
const bestSplit = this._findBestSplit(features, labels);
if (bestSplit.leftLabels.length === 0 || bestSplit.rightLabels.length === 0) {
return new DecisionTreeNode(null, null, null, null, this._mostCommonLabel(labels));
}
const leftNode = this._buildTree(bestSplit.leftFeatures, bestSplit.leftLabels);
const rightNode = this._buildTree(bestSplit.rightFeatures, bestSplit.rightLabels);
return new DecisionTreeNode(bestSplit.featureIndex, bestSplit.threshold, leftNode, rightNode, null);
}
_findBestSplit(features, labels) {
let bestGain = 0;
let bestSplit = null;
for (let featureIndex = 0; featureIndex < features[0].length; featureIndex++) {
const values = features.map(row => row[featureIndex]);
const uniqueValues = [...new Set(values)];
for (let threshold of uniqueValues) {
const split = this._splitData(features, labels, featureIndex, threshold);
if (split.leftLabels.length === 0 || split.rightLabels.length === 0) {
continue;
}
const gain = this._informationGain(labels, split.leftLabels, split.rightLabels);
if (gain > bestGain) {
bestGain = gain;
bestSplit = split;
bestSplit.featureIndex = featureIndex;
bestSplit.threshold = threshold;
}
}
}
return bestSplit;
}
_splitData(features, labels, featureIndex, threshold) {
const leftFeatures = [];
const leftLabels = [];
const rightFeatures = [];
const rightLabels = [];
for (let i = 0; i < features.length; i++) {
if (features[i][featureIndex] < threshold) {
leftFeatures.push(features[i]);
leftLabels.push(labels[i]);
} else {
rightFeatures.push(features[i]);
rightLabels.push(labels[i]);
}
}
return {
leftFeatures,
leftLabels,
rightFeatures,
rightLabels,
};
}
_countLabels(labels) {
let counts = {};
labels.forEach(label => {
counts[label] = (counts[label] || 0) + 1;
});
return counts;
}
_predict(node, example) {
if (node.type === 'leaf') {
return node.output;
}
const attr = node.attribute;
const val = example[attr];
if (val === undefined) {
return node.defaultOutput;
}
const childNode = node.children.find(child => child.rule(val));
if (!childNode) {
return node.defaultOutput;
}
return this._predict(childNode, example);
}
predict(example) {
return this._predict(this.tree, example);
}
}
// the data point [10, 1, 0] represents a person who prefers their coffee to be very
// sweet (a sweetness level of 10), slightly bitter (a bitterness level of 1), and
// with no creaminess (a creaminess level of 0).
// This array of data points is used to train and test the K-means clustering algorithm
// in order to group similar preferences for coffee together.
const coffeeData = [
[1, 0, 0],
[2, 1, 0],
[3, 1, 1],
[4, 1, 1],
[5, 0, 0],
[6, 1, 0],
[7, 1, 1],
[8, 0, 0],
[9, 1, 1],
[10, 1, 0],
];
const coffeeLabels = [
"latte",
"cappuccino",
"espresso",
"espresso",
"latte",
"cappuccino",
"espresso",
"latte",
"espresso",
"cappuccino",
];
const coffeeTree = new DecisionTreeClassifier();
coffeeTree.train(coffeeData, coffeeLabels);
const customerPreferences = [3, 1, 0];
const predictedCoffee = coffeeTree.predict(customerPreferences);
console.log(predictedCoffee); // Output: "espresso"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment