Skip to content

Instantly share code, notes, and snippets.

@kevindavenport
Last active April 24, 2018 17:05
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save kevindavenport/c4b377f9c0626c9dd856 to your computer and use it in GitHub Desktop.
Save kevindavenport/c4b377f9c0626c9dd856 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Blog bost at: http://kldavenport.com/pure-python-decision-trees/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"###Our Data:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First let's define our data, in this case a list of lists. Let's imagine that this data is for a SaaS product we are selling. We offer users a trial 14-days and at the end of the trial offer the users the ability to sign up for a basic or premium offering. We collect the following info from them from beginning to end:\n",
"\n",
"1. Where the customer was referred from when they signed up for the trial (google, slashdot, etc.) [Domain name string or (direct)]\n",
"2. Country of orgin (resolved by IP) [Country string]\n",
"3. Clicked on our FAQ link during the trail? [boolean]\n",
"4. How many application pages they viewed during the trial. [int]\n",
"5. What service they choose at the end of the trial. [None, Basic, Premium strings]"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"my_data=[['slashdot','USA','yes',18,'None'],\n",
" ['google','France','yes',23,'Premium'],\n",
" ['reddit','USA','yes',24,'Basic'],\n",
" ['kiwitobes','France','yes',23,'Basic'],\n",
" ['google','UK','no',21,'Premium'],\n",
" ['(direct)','New Zealand','no',12,'None'],\n",
" ['(direct)','UK','no',21,'Basic'],\n",
" ['google','USA','no',24,'Premium'],\n",
" ['slashdot','France','yes',19,'None'],\n",
" ['reddit','USA','no',18,'None'],\n",
" ['google','UK','no',18,'None'],\n",
" ['kiwitobes','UK','no',19,'None'],\n",
" ['reddit','New Zealand','yes',12,'Basic'],\n",
" ['slashdot','UK','no',21,'None'],\n",
" ['google','UK','yes',18,'Basic'],\n",
" ['kiwitobes','France','yes',19,'Basic']]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating our tree\n",
"Below we define a class to represent each node a tree.\n",
"\n",
"Note: I'm not assuming a certain python level for this blog post, as such I will go over some programming fundamentals. A class is a user-defined prototype (guide, template, etc.) for an object that defines a set of attributes that characterize any object of the class. The attributes are data members (class variables and instance variables) and methods, accessed via dot notation. If you've been using something like scikit-learn up to this point I'm sure you're used to model.fit(), model.score(), etc. Fit and score are all instance methods of whatever model you've instantiated. For example:\n",
"\n",
"```python\n",
"clf = linear_model.SGDRegressor() # Instantiating SGDRegressor as clf \n",
"clf.fit(X, y) # using clf's (an instance of SGDRegressor) fit method on some data\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class decisionnode:\n",
" def __init__(self,col=-1,value=None,results=None,tb=None,fb=None):\n",
" self.col=col # column index of criteria being tested\n",
" self.value=value # vlaue necessary to get a true result\n",
" self.results=results # dict of results for a branch, None for everything except endpoints\n",
" self.tb=tb # true decision nodes \n",
" self.fb=fb # false decision nodes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's work on the intelligence behind contructing the tree. One of the more popular methods for consructing trees, CART (Classification And Regression Trees), was developed by Leo Breiman https://www.stat.berkeley.edu/~breiman/papers.html. CART is a recursive partitioning method that builds classification and regression trees for predicting continuous and categorical variables.\n",
"\n",
"The first step is to construct a root node by considering all the observations in our dataset and determine which variable or feature would subset the data the most. If we were looking at a lung cancer outcomes dataset for example and one of the variables was smoker (y/n), it would be intuitive that this would split the dataset up substantially.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Divides a set on a specific column. Can handle numeric or nominal values\n",
"\n",
"def divideset(rows,column,value):\n",
" # Make a function that tells us if a row is in the first group \n",
" # (true) or the second group (false)\n",
" split_function=None\n",
" # for numerical values\n",
" if isinstance(value,int) or isinstance(value,float):\n",
" split_function=lambda row:row[column]>=value\n",
" # for nominal values\n",
" else:\n",
" split_function=lambda row:row[column]==value\n",
" \n",
" # Divide the rows into two sets and return them\n",
" set1=[row for row in rows if split_function(row)] # if split_function(row) \n",
" set2=[row for row in rows if not split_function(row)]\n",
" return (set1,set2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the previously defined `my_data` let's split our data by users in the USA."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"([['slashdot', 'USA', 'yes', 18, 'None'],\n",
" ['reddit', 'USA', 'yes', 24, 'Basic'],\n",
" ['google', 'USA', 'no', 24, 'Premium'],\n",
" ['reddit', 'USA', 'no', 18, 'None']],\n",
" [['google', 'France', 'yes', 23, 'Premium'],\n",
" ['kiwitobes', 'France', 'yes', 23, 'Basic'],\n",
" ['google', 'UK', 'no', 21, 'Premium'],\n",
" ['(direct)', 'New Zealand', 'no', 12, 'None'],\n",
" ['(direct)', 'UK', 'no', 21, 'Basic'],\n",
" ['slashdot', 'France', 'yes', 19, 'None'],\n",
" ['google', 'UK', 'no', 18, 'None'],\n",
" ['kiwitobes', 'UK', 'no', 19, 'None'],\n",
" ['reddit', 'New Zealand', 'yes', 12, 'Basic'],\n",
" ['slashdot', 'UK', 'no', 21, 'None'],\n",
" ['google', 'UK', 'yes', 18, 'Basic'],\n",
" ['kiwitobes', 'France', 'yes', 19, 'Basic']])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"divideset(my_data,1,'USA')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Granted this is a small samples of data, country of orgin doesn't seem to be a good variable to split on at this point as we still have a good representation/mix of subscription outcomes in both sets above (None, Basic, Premium).\n",
"\n",
"We need a formalized manner to assess how mixed a result set is in order to properly check the outcome of spliting on each variable. When constructing our root node we should chose a variable that creates two sets with the least possible amount of mixing. To start let's create a function to count the occurences of the outcomes in each set. We'll use this function later on inside other functions to measure how mixed a set is."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Create counts of possible results (last column of each row is the result)\n",
"def uniquecounts(rows):\n",
" results={}\n",
" for row in rows:\n",
" # The result is the last column\n",
" r=row[len(row)-1]\n",
" if r not in results: results[r]=0\n",
" results[r]+=1\n",
" return results"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sorry I couldn't help but look for an excuse to use `defaultdict` from `collections`!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"def uniquecounts_dd(rows):\n",
" results = defaultdict(lambda: 0)\n",
" for row in rows:\n",
" r = row[len(row)-1]\n",
" results[r]+=1\n",
" return dict(results) "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"({'Basic': 6, 'None': 7, 'Premium': 3},\n",
" 'Same output',\n",
" {'Basic': 6, 'None': 7, 'Premium': 3})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"uniquecounts(my_data),'Same output', uniquecounts_dd(my_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"###Measures of mixture\n",
"\n",
"In a previous post I took an in-depth look at Entropy or the measure of surprisal: http://kldavenport.com/a-real-world-introduction-to-information-entropy/ I'll cover the basics here, but please refer to my other post for more detail. At least watch this creative video on information entropy:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"400\"\n",
" height=\"300\"\n",
" src=\"https://www.youtube.com/embed/R4OlXb9aTvQ\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.YouTubeVideo at 0x1043d5510>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import YouTubeVideo\n",
"YouTubeVideo('R4OlXb9aTvQ')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our entropy function will calculate how many times a class appears and divide it by the number of observations in our data set:\n",
"\n",
"$$ p(i) = frequency(outcome) = count(outcome) \\thinspace/ \\thinspace count(total rows) $$\n",
"\n",
"It then does the following for all outcomes, $p(i)$:\n",
"\n",
"$$ Entropy = sum \\thinspace of \\thinspace p(i) * log(p(i)) $$\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Entropy is the sum of p(x)log(p(x)) across all the different possible results\n",
"def entropy(rows):\n",
" from math import log\n",
" log2=lambda x:log(x)/log(2) \n",
" results=uniquecounts(rows)\n",
" # Now calculate the entropy\n",
" ent=0.0\n",
" for r in results.keys():\n",
" # current probability of class\n",
" p=float(results[r])/len(rows) \n",
" ent=ent-p*log2(p)\n",
" return ent"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"1.5052408149441479"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"entropy(my_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Essentially entropy is higher the more mixed up the groups or outcomes of subscription is. Trying the function on a data set where the outcomes either `None` or `Basic` should result in a smaller number:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.9852281360342516"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_data2=[['slashdot','USA','yes',18,'None'],\n",
" ['google','France','yes',23,'None'],\n",
" ['reddit','USA','yes',24,'Basic'],\n",
" ['kiwitobes','France','yes',23,'Basic'],\n",
" ['google','UK','no',21,'None'],\n",
" ['(direct)','New Zealand','no',12,'None'],\n",
" ['(direct)','UK','no',21,'Basic']]\n",
"\n",
"entropy(my_data2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have a method of assessing entropy. The next step in building our tree will involve assessing the success of each variable's ability to split the dataset. In other words we're attempting to identify the feature that best splits the target class into the purest children nodes. These pure nodes would not contain a mix of output classes, in this case subscription level (None, Basic, Premium).\n",
"\n",
"We'll start by calculating the entropy of the entire data set then dividing the group by all the possible outcomes for each attribute. We determine the best attribute to divide on by calculating information gain (Entropy before - Entropy after). Again more info in my more detailed post on entropy here: http://kldavenport.com/a-real-world-introduction-to-information-entropy/\n",
"\n",
"**Caveats:**\n",
"Information gain is generally a good measure for deciding the relevance of an attribute, but there are some distinct shortcomings. One case is when information gain is applied to variabless that take on a large number of unique values. This is a concern not necessarily from a pure variance perspective, rather that the variable is too descriptive of the current observations.\n",
"\n",
"**High mutual information** indicates a large reduction in uncertainty, credit card numbers or street addresss variables in a dataset uniquely identify a customer. These variables provide a great deal of identifying information if we are trying to predict a customer, but will not generalize well to unobserved/trained-on instances (overfitting)."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def buildtree(rows, scorefun=entropy):\n",
" if len(rows) == 0: return decisionnode()\n",
" current_score = scorefun(rows)\n",
"\n",
" best_gain = 0.0\n",
" best_criteria = None\n",
" best_sets = None\n",
"\n",
" column_count = len(rows[0]) - 1\t# last column is result\n",
" for col in range(0, column_count):\n",
" # find different values in this column\n",
" column_values = set([row[col] for row in rows])\n",
"\n",
" # for each possible value, try to divide on that value\n",
" for value in column_values:\n",
" set1, set2 = divideset(rows, col, value)\n",
"\n",
" # Information gain\n",
" p = float(len(set1)) / len(rows)\n",
" gain = current_score - p*scorefun(set1) - (1-p)*scorefun(set2)\n",
" if gain > best_gain and len(set1) > 0 and len(set2) > 0:\n",
" best_gain = gain\n",
" best_criteria = (col, value)\n",
" best_sets = (set1, set2)\n",
"\n",
" if best_gain > 0:\n",
" trueBranch = buildtree(best_sets[0])\n",
" falseBranch = buildtree(best_sets[1])\n",
" return decisionnode(col=best_criteria[0], value=best_criteria[1],\n",
" tb=trueBranch, fb=falseBranch)\n",
" else:\n",
" return decisionnode(results=uniquecounts(rows))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have a function that returns a trained decision tree. We can print a rudimentary tree."
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def printtree(tree,indent=''):\n",
" # Is this a leaf node?\n",
" if tree.results!=None:\n",
" print str(tree.results)\n",
" else:\n",
" # Print the criteria\n",
" print 'Column ' + str(tree.col)+' : '+str(tree.value)+'? '\n",
"\n",
" # Print the branches\n",
" print indent+'True->',\n",
" printtree(tree.tb,indent+' ')\n",
" print indent+'False->',\n",
" printtree(tree.fb,indent+' ')"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"[['slashdot', 'USA', 'yes', 18, 'None'],\n",
" ['google', 'France', 'yes', 23, 'Premium'],\n",
" ['reddit', 'USA', 'yes', 24, 'Basic']]"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Printing a few rows of our dataset for context\n",
"my_data[0:3]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When printing the tree we see that the root node checks if column 0 contains 'google'. If the condition is met (condition is True) we then move on to see that anyone that was referred from Google will purchase a subscription (Basic or Premium) if they view 21 pages or more and so on."
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Column 0 : google? \n",
"True-> Column 3 : 21? \n",
" True-> {'Premium': 3}\n",
" False-> Column 2 : yes? \n",
" True-> {'Basic': 1}\n",
" False-> {'None': 1}\n",
"False-> Column 0 : slashdot? \n",
" True-> {'None': 3}\n",
" False-> Column 2 : yes? \n",
" True-> {'Basic': 4}\n",
" False-> Column 3 : 21? \n",
" True-> {'Basic': 1}\n",
" False-> {'None': 3}\n"
]
}
],
"source": [
"printtree(buildtree(my_data))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next step might be to classify new observations by building a function that traverses the tree. We could even make our implementation more advanced by implementing pruning: https://en.wikipedia.org/wiki/Pruning_(decision_trees). If you're interested in pure python implementations of analytics implementations check out the aforementioned book by Toby Segaran [Programming Collective Intelligence](https://amzn.to/DJp4uz)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
@cgoliver
Copy link

very nice!

@nicholas-leonard
Copy link

👍

@MiracleGithub
Copy link

please write a function to predict

@SwatiChaudhuri
Copy link

Kevin this is awesome - really helped me get a good understanding of entropy vs gini and how decision trees fundamentally work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment