Created
July 31, 2015 01:57
-
-
Save rhiever/2dbea52780f656fc54e9 to your computer and use it in GitHub Desktop.
This notebook uses a random forest classifier to predict a player's gender based on their trivia question performance.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#Import the trivia data into pandas" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>anon_id</th>\n", | |
" <th>category</th>\n", | |
" <th>correct</th>\n", | |
" <th>total</th>\n", | |
" <th>ratio</th>\n", | |
" <th>ratio_relative_to_overall</th>\n", | |
" <th>gender</th>\n", | |
" <th>overall_correct_pct</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>39</td>\n", | |
" <td>amer_hist</td>\n", | |
" <td>128</td>\n", | |
" <td>148</td>\n", | |
" <td>0.864865</td>\n", | |
" <td>0.100654</td>\n", | |
" <td>Male</td>\n", | |
" <td>0.764211</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>39</td>\n", | |
" <td>art</td>\n", | |
" <td>91</td>\n", | |
" <td>145</td>\n", | |
" <td>0.627586</td>\n", | |
" <td>-0.136624</td>\n", | |
" <td>Male</td>\n", | |
" <td>0.764211</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>39</td>\n", | |
" <td>bus_econ</td>\n", | |
" <td>100</td>\n", | |
" <td>113</td>\n", | |
" <td>0.884956</td>\n", | |
" <td>0.120745</td>\n", | |
" <td>Male</td>\n", | |
" <td>0.764211</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>39</td>\n", | |
" <td>class_music</td>\n", | |
" <td>56</td>\n", | |
" <td>102</td>\n", | |
" <td>0.549020</td>\n", | |
" <td>-0.215191</td>\n", | |
" <td>Male</td>\n", | |
" <td>0.764211</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>39</td>\n", | |
" <td>curr_events</td>\n", | |
" <td>71</td>\n", | |
" <td>96</td>\n", | |
" <td>0.739583</td>\n", | |
" <td>-0.024627</td>\n", | |
" <td>Male</td>\n", | |
" <td>0.764211</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" anon_id category correct total ratio ratio_relative_to_overall \\\n", | |
"0 39 amer_hist 128 148 0.864865 0.100654 \n", | |
"1 39 art 91 145 0.627586 -0.136624 \n", | |
"2 39 bus_econ 100 113 0.884956 0.120745 \n", | |
"3 39 class_music 56 102 0.549020 -0.215191 \n", | |
"4 39 curr_events 71 96 0.739583 -0.024627 \n", | |
"\n", | |
" gender overall_correct_pct \n", | |
"0 Male 0.764211 \n", | |
"1 Male 0.764211 \n", | |
"2 Male 0.764211 \n", | |
"3 Male 0.764211 \n", | |
"4 Male 0.764211 " | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"\n", | |
"trivia_data = pd.read_csv('https://raw.githubusercontent.com/toddwschneider/learnedleague-analysis/master/learnedleague_category_stats.csv')\n", | |
"trivia_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#Wrangle the data into a format sklearn can work with" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"X = []\n", | |
"y = []\n", | |
"\n", | |
"for anon_id, group_id in trivia_data.groupby('anon_id'):\n", | |
" group_id = group_id.copy()\n", | |
" group_id.sort('category', inplace=True)\n", | |
" X.append(list(group_id.ratio))\n", | |
" id_class = 1 if group_id.gender.values[0] == 'Male' else 0\n", | |
" y.append(id_class)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#Train a random forest classifier on all of the data\n", | |
"\n", | |
"This achieves 100% accuracy - the random forest classifier can accurately determine whether the player is male or female based off of their trivia category performance." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.ensemble import RandomForestClassifier\n", | |
"\n", | |
"rfc = RandomForestClassifier(n_estimators=100)\n", | |
"rfc.fit(X, y)\n", | |
"rfc.score(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#Perform cross-validation with the random forest classifier\n", | |
"\n", | |
"This shows the model's ability to generalize. It achieves about 86% on the testing set when we perform cross-validation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.86627043090638933" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.ensemble import RandomForestClassifier\n", | |
"from sklearn.cross_validation import train_test_split\n", | |
"\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y)\n", | |
"\n", | |
"rfc = RandomForestClassifier(n_estimators=100)\n", | |
"rfc.fit(X_train, y_train)\n", | |
"rfc.score(X_test, y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment