Skip to content

Instantly share code, notes, and snippets.

@metasyn
Created June 9, 2016 23:42
Show Gist options
  • Save metasyn/5505453cb505cc205f69091f7d91859e to your computer and use it in GitHub Desktop.
Save metasyn/5505453cb505cc205f69091f7d91859e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.ensemble import RandomForestClassifier as RFC\n",
"from sklearn.cross_validation import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"df = pd.read_csv(\"/Users/aljohnson/data/churn.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Account Length</th>\n",
" <th>Area Code</th>\n",
" <th>Churn?</th>\n",
" <th>CustServ Calls</th>\n",
" <th>Day Calls</th>\n",
" <th>Day Charge</th>\n",
" <th>Day Mins</th>\n",
" <th>Eve Calls</th>\n",
" <th>Eve Charge</th>\n",
" <th>Eve Mins</th>\n",
" <th>...</th>\n",
" <th>Intl Charge</th>\n",
" <th>Intl Mins</th>\n",
" <th>Night Calls</th>\n",
" <th>Night Charge</th>\n",
" <th>Night Mins</th>\n",
" <th>Phone</th>\n",
" <th>State</th>\n",
" <th>VMail Message</th>\n",
" <th>VMail Plan</th>\n",
" <th>predicted(Churn?)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>128</td>\n",
" <td>415</td>\n",
" <td>False.</td>\n",
" <td>1</td>\n",
" <td>110</td>\n",
" <td>45.07</td>\n",
" <td>265.1</td>\n",
" <td>99</td>\n",
" <td>16.78</td>\n",
" <td>197.4</td>\n",
" <td>...</td>\n",
" <td>2.70</td>\n",
" <td>10.0</td>\n",
" <td>91</td>\n",
" <td>11.01</td>\n",
" <td>244.7</td>\n",
" <td>382-4657</td>\n",
" <td>KS</td>\n",
" <td>25</td>\n",
" <td>yes</td>\n",
" <td>False.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>107</td>\n",
" <td>415</td>\n",
" <td>False.</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>27.47</td>\n",
" <td>161.6</td>\n",
" <td>103</td>\n",
" <td>16.62</td>\n",
" <td>195.5</td>\n",
" <td>...</td>\n",
" <td>3.70</td>\n",
" <td>13.7</td>\n",
" <td>103</td>\n",
" <td>11.45</td>\n",
" <td>254.4</td>\n",
" <td>371-7191</td>\n",
" <td>OH</td>\n",
" <td>26</td>\n",
" <td>yes</td>\n",
" <td>False.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>137</td>\n",
" <td>415</td>\n",
" <td>False.</td>\n",
" <td>0</td>\n",
" <td>114</td>\n",
" <td>41.38</td>\n",
" <td>243.4</td>\n",
" <td>110</td>\n",
" <td>10.30</td>\n",
" <td>121.2</td>\n",
" <td>...</td>\n",
" <td>3.29</td>\n",
" <td>12.2</td>\n",
" <td>104</td>\n",
" <td>7.32</td>\n",
" <td>162.6</td>\n",
" <td>358-1921</td>\n",
" <td>NJ</td>\n",
" <td>0</td>\n",
" <td>no</td>\n",
" <td>False.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>84</td>\n",
" <td>408</td>\n",
" <td>False.</td>\n",
" <td>2</td>\n",
" <td>71</td>\n",
" <td>50.90</td>\n",
" <td>299.4</td>\n",
" <td>88</td>\n",
" <td>5.26</td>\n",
" <td>61.9</td>\n",
" <td>...</td>\n",
" <td>1.78</td>\n",
" <td>6.6</td>\n",
" <td>89</td>\n",
" <td>8.86</td>\n",
" <td>196.9</td>\n",
" <td>375-9999</td>\n",
" <td>OH</td>\n",
" <td>0</td>\n",
" <td>no</td>\n",
" <td>True.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>75</td>\n",
" <td>415</td>\n",
" <td>False.</td>\n",
" <td>3</td>\n",
" <td>113</td>\n",
" <td>28.34</td>\n",
" <td>166.7</td>\n",
" <td>122</td>\n",
" <td>12.61</td>\n",
" <td>148.3</td>\n",
" <td>...</td>\n",
" <td>2.73</td>\n",
" <td>10.1</td>\n",
" <td>121</td>\n",
" <td>8.41</td>\n",
" <td>186.9</td>\n",
" <td>330-6626</td>\n",
" <td>OK</td>\n",
" <td>0</td>\n",
" <td>no</td>\n",
" <td>True.</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 22 columns</p>\n",
"</div>"
],
"text/plain": [
" Account Length Area Code Churn? CustServ Calls Day Calls Day Charge \\\n",
"0 128 415 False. 1 110 45.07 \n",
"1 107 415 False. 1 123 27.47 \n",
"2 137 415 False. 0 114 41.38 \n",
"3 84 408 False. 2 71 50.90 \n",
"4 75 415 False. 3 113 28.34 \n",
"\n",
" Day Mins Eve Calls Eve Charge Eve Mins ... Intl Charge \\\n",
"0 265.1 99 16.78 197.4 ... 2.70 \n",
"1 161.6 103 16.62 195.5 ... 3.70 \n",
"2 243.4 110 10.30 121.2 ... 3.29 \n",
"3 299.4 88 5.26 61.9 ... 1.78 \n",
"4 166.7 122 12.61 148.3 ... 2.73 \n",
"\n",
" Intl Mins Night Calls Night Charge Night Mins Phone State \\\n",
"0 10.0 91 11.01 244.7 382-4657 KS \n",
"1 13.7 103 11.45 254.4 371-7191 OH \n",
"2 12.2 104 7.32 162.6 358-1921 NJ \n",
"3 6.6 89 8.86 196.9 375-9999 OH \n",
"4 10.1 121 8.41 186.9 330-6626 OK \n",
"\n",
" VMail Message VMail Plan predicted(Churn?) \n",
"0 25 yes False. \n",
"1 26 yes False. \n",
"2 0 no False. \n",
"3 0 no True. \n",
"4 0 no True. \n",
"\n",
"[5 rows x 22 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"target = df['Churn?'].map(lambda x: (x[:-1]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"df.drop('Churn?', axis=1, inplace=True)\n",
"df.drop('predicted(Churn?)', axis=1, inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(df, target, test_size=0.33, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"clf = RFC(n_estimators=10, max_depth=5)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"enc = LabelEncoder()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"LabelEncoder()"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"enc.fit(y_train)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array(['False', 'True'], dtype=object)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"enc.classes_"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"target = enc.transform(y_train)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"enc2 = LabelEncoder()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(2233, 2)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.shape(np.array(X_train[['Day Calls', 'Day Charge']]))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
" max_depth=5, max_features='auto', max_leaf_nodes=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n",
" oob_score=False, random_state=None, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.fit(np.array(X_train[['Day Calls', 'Day Charge']]), target)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"preds = clf.predict(np.array(X_test[['Day Calls', 'Day Charge']]))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 1, ..., 0, 1, 0])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"enc.transform(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.86636363636363634"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.score(np.array(X_test[['Day Calls', 'Day Charge']]), pd.DataFrame(enc.transform(y_test)))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
" max_features=None, max_leaf_nodes=None, min_samples_leaf=1,\n",
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
" presort=False, random_state=None, splitter='best')"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.base_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"('criterion',\n",
" 'max_depth',\n",
" 'min_samples_split',\n",
" 'min_samples_leaf',\n",
" 'min_weight_fraction_leaf',\n",
" 'max_features',\n",
" 'max_leaf_nodes',\n",
" 'random_state')"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.estimator_params"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.16866398, 0.83133602])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.feature_importances_"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=5,\n",
" max_features='auto', max_leaf_nodes=None, min_samples_leaf=1,\n",
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
" presort=False, random_state=438965221, splitter='best')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.__reduce__()[2]['estimators_'][0]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.tree import export_graphviz\n",
"import StringIO"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sio = StringIO.StringIO()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"j = export_graphviz(clf.__reduce__()[2]['estimators_'][0], sio)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<StringIO.StringIO instance at 0x10f1bdf80>"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sio"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"'digraph Tree {\\nnode [shape=box] ;\\n0 [label=\"X[1] <= 48.375\\\\ngini = 0.2267\\\\nsamples = 1425\\\\nvalue = [1942, 291]\"] ;\\n1 [label=\"X[0] <= 102.5\\\\ngini = 0.1943\\\\nsamples = 1375\\\\nvalue = [1920, 235]\"] ;\\n0 -> 1 [labeldistance=2.5, labelangle=45, headlabel=\"True\"] ;\\n2 [label=\"X[0] <= 95.5\\\\ngini = 0.1585\\\\nsamples = 737\\\\nvalue = [1052, 100]\"] ;\\n1 -> 2 ;\\n3 [label=\"X[1] <= 19.645\\\\ngini = 0.1864\\\\nsamples = 524\\\\nvalue = [715, 83]\"] ;\\n2 -> 3 ;\\n4 [label=\"X[0] <= 93.5\\\\ngini = 0.4193\\\\nsamples = 63\\\\nvalue = [75, 32]\"] ;\\n3 -> 4 ;\\n5 [label=\"gini = 0.3723\\\\nsamples = 54\\\\nvalue = [70, 23]\"] ;\\n4 -> 5 ;\\n6 [label=\"gini = 0.4592\\\\nsamples = 9\\\\nvalue = [5, 9]\"] ;\\n4 -> 6 ;\\n7 [label=\"X[0] <= 91.5\\\\ngini = 0.1367\\\\nsamples = 461\\\\nvalue = [640, 51]\"] ;\\n3 -> 7 ;\\n8 [label=\"gini = 0.1172\\\\nsamples = 367\\\\nvalue = [510, 34]\"] ;\\n7 -> 8 ;\\n9 [label=\"gini = 0.2045\\\\nsamples = 94\\\\nvalue = [130, 17]\"] ;\\n7 -> 9 ;\\n10 [label=\"X[0] <= 101.5\\\\ngini = 0.0914\\\\nsamples = 213\\\\nvalue = [337, 17]\"] ;\\n2 -> 10 ;\\n11 [label=\"X[1] <= 22.09\\\\ngini = 0.11\\\\nsamples = 177\\\\nvalue = [274, 17]\"] ;\\n10 -> 11 ;\\n12 [label=\"gini = 0.2486\\\\nsamples = 32\\\\nvalue = [47, 8]\"] ;\\n11 -> 12 ;\\n13 [label=\"gini = 0.0734\\\\nsamples = 145\\\\nvalue = [227, 9]\"] ;\\n11 -> 13 ;\\n14 [label=\"gini = 0.0\\\\nsamples = 36\\\\nvalue = [63, 0]\"] ;\\n10 -> 14 ;\\n15 [label=\"X[0] <= 162.5\\\\ngini = 0.233\\\\nsamples = 638\\\\nvalue = [868, 135]\"] ;\\n1 -> 15 ;\\n16 [label=\"X[1] <= 38.255\\\\ngini = 0.2317\\\\nsamples = 637\\\\nvalue = [868, 134]\"] ;\\n15 -> 16 ;\\n17 [label=\"X[0] <= 113.5\\\\ngini = 0.1671\\\\nsamples = 511\\\\nvalue = [730, 74]\"] ;\\n16 -> 17 ;\\n18 [label=\"gini = 0.2009\\\\nsamples = 247\\\\nvalue = [360, 46]\"] ;\\n17 -> 18 ;\\n19 [label=\"gini = 0.1308\\\\nsamples = 264\\\\nvalue = [370, 28]\"] ;\\n17 -> 19 ;\\n20 [label=\"X[0] <= 103.5\\\\ngini = 0.4224\\\\nsamples = 126\\\\nvalue = [138, 60]\"] ;\\n16 -> 20 ;\\n21 [label=\"gini = 0.0\\\\nsamples = 2\\\\nvalue = [0, 4]\"] ;\\n20 -> 21 ;\\n22 [label=\"gini = 0.4107\\\\nsamples = 124\\\\nvalue = [138, 56]\"] ;\\n20 -> 22 ;\\n23 [label=\"gini = 0.0\\\\nsamples = 1\\\\nvalue = [0, 1]\"] ;\\n15 -> 23 ;\\n24 [label=\"X[0] <= 105.5\\\\ngini = 0.405\\\\nsamples = 50\\\\nvalue = [22, 56]\"] ;\\n0 -> 24 [labeldistance=2.5, labelangle=-45, headlabel=\"False\"] ;\\n25 [label=\"X[0] <= 86.5\\\\ngini = 0.4918\\\\nsamples = 26\\\\nvalue = [17, 22]\"] ;\\n24 -> 25 ;\\n26 [label=\"X[0] <= 69.0\\\\ngini = 0.1327\\\\nsamples = 11\\\\nvalue = [1, 13]\"] ;\\n25 -> 26 ;\\n27 [label=\"gini = 0.0\\\\nsamples = 4\\\\nvalue = [0, 6]\"] ;\\n26 -> 27 ;\\n28 [label=\"X[1] <= 50.815\\\\ngini = 0.2188\\\\nsamples = 7\\\\nvalue = [1, 7]\"] ;\\n26 -> 28 ;\\n29 [label=\"gini = 0.0\\\\nsamples = 4\\\\nvalue = [0, 5]\"] ;\\n28 -> 29 ;\\n30 [label=\"gini = 0.4444\\\\nsamples = 3\\\\nvalue = [1, 2]\"] ;\\n28 -> 30 ;\\n31 [label=\"X[1] <= 50.73\\\\ngini = 0.4608\\\\nsamples = 15\\\\nvalue = [16, 9]\"] ;\\n25 -> 31 ;\\n32 [label=\"X[1] <= 49.175\\\\ngini = 0.4444\\\\nsamples = 7\\\\nvalue = [4, 8]\"] ;\\n31 -> 32 ;\\n33 [label=\"gini = 0.0\\\\nsamples = 2\\\\nvalue = [4, 0]\"] ;\\n32 -> 33 ;\\n34 [label=\"gini = 0.0\\\\nsamples = 5\\\\nvalue = [0, 8]\"] ;\\n32 -> 34 ;\\n35 [label=\"X[1] <= 53.415\\\\ngini = 0.142\\\\nsamples = 8\\\\nvalue = [12, 1]\"] ;\\n31 -> 35 ;\\n36 [label=\"gini = 0.0\\\\nsamples = 6\\\\nvalue = [11, 0]\"] ;\\n35 -> 36 ;\\n37 [label=\"gini = 0.5\\\\nsamples = 2\\\\nvalue = [1, 1]\"] ;\\n35 -> 37 ;\\n38 [label=\"X[1] <= 49.43\\\\ngini = 0.2235\\\\nsamples = 24\\\\nvalue = [5, 34]\"] ;\\n24 -> 38 ;\\n39 [label=\"gini = 0.0\\\\nsamples = 9\\\\nvalue = [0, 13]\"] ;\\n38 -> 39 ;\\n40 [label=\"X[0] <= 146.5\\\\ngini = 0.3107\\\\nsamples = 15\\\\nvalue = [5, 21]\"] ;\\n38 -> 40 ;\\n41 [label=\"X[0] <= 130.0\\\\ngini = 0.2688\\\\nsamples = 14\\\\nvalue = [4, 21]\"] ;\\n40 -> 41 ;\\n42 [label=\"gini = 0.375\\\\nsamples = 10\\\\nvalue = [4, 12]\"] ;\\n41 -> 42 ;\\n43 [label=\"gini = 0.0\\\\nsamples = 4\\\\nvalue = [0, 9]\"] ;\\n41 -> 43 ;\\n44 [label=\"gini = 0.0\\\\nsamples = 1\\\\nvalue = [1, 0]\"] ;\\n40 -> 44 ;\\n}'"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sio.getvalue()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment