Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rhiever/2dbea52780f656fc54e9 to your computer and use it in GitHub Desktop.
Save rhiever/2dbea52780f656fc54e9 to your computer and use it in GitHub Desktop.
This notebook uses a random forest classifier to predict a player's gender based on their trivia question performance.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#Import the trivia data into pandas"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>anon_id</th>\n",
" <th>category</th>\n",
" <th>correct</th>\n",
" <th>total</th>\n",
" <th>ratio</th>\n",
" <th>ratio_relative_to_overall</th>\n",
" <th>gender</th>\n",
" <th>overall_correct_pct</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>39</td>\n",
" <td>amer_hist</td>\n",
" <td>128</td>\n",
" <td>148</td>\n",
" <td>0.864865</td>\n",
" <td>0.100654</td>\n",
" <td>Male</td>\n",
" <td>0.764211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>39</td>\n",
" <td>art</td>\n",
" <td>91</td>\n",
" <td>145</td>\n",
" <td>0.627586</td>\n",
" <td>-0.136624</td>\n",
" <td>Male</td>\n",
" <td>0.764211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>39</td>\n",
" <td>bus_econ</td>\n",
" <td>100</td>\n",
" <td>113</td>\n",
" <td>0.884956</td>\n",
" <td>0.120745</td>\n",
" <td>Male</td>\n",
" <td>0.764211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>39</td>\n",
" <td>class_music</td>\n",
" <td>56</td>\n",
" <td>102</td>\n",
" <td>0.549020</td>\n",
" <td>-0.215191</td>\n",
" <td>Male</td>\n",
" <td>0.764211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>39</td>\n",
" <td>curr_events</td>\n",
" <td>71</td>\n",
" <td>96</td>\n",
" <td>0.739583</td>\n",
" <td>-0.024627</td>\n",
" <td>Male</td>\n",
" <td>0.764211</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" anon_id category correct total ratio ratio_relative_to_overall \\\n",
"0 39 amer_hist 128 148 0.864865 0.100654 \n",
"1 39 art 91 145 0.627586 -0.136624 \n",
"2 39 bus_econ 100 113 0.884956 0.120745 \n",
"3 39 class_music 56 102 0.549020 -0.215191 \n",
"4 39 curr_events 71 96 0.739583 -0.024627 \n",
"\n",
" gender overall_correct_pct \n",
"0 Male 0.764211 \n",
"1 Male 0.764211 \n",
"2 Male 0.764211 \n",
"3 Male 0.764211 \n",
"4 Male 0.764211 "
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"trivia_data = pd.read_csv('https://raw.githubusercontent.com/toddwschneider/learnedleague-analysis/master/learnedleague_category_stats.csv')\n",
"trivia_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#Wrangle the data into a format sklearn can work with"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X = []\n",
"y = []\n",
"\n",
"for anon_id, group_id in trivia_data.groupby('anon_id'):\n",
" group_id = group_id.copy()\n",
" group_id.sort('category', inplace=True)\n",
" X.append(list(group_id.ratio))\n",
" id_class = 1 if group_id.gender.values[0] == 'Male' else 0\n",
" y.append(id_class)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#Train a random forest classifier on all of the data\n",
"\n",
"This achieves 100% accuracy - the random forest classifier can accurately determine whether the player is male or female based off of their trivia category performance."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"rfc = RandomForestClassifier(n_estimators=100)\n",
"rfc.fit(X, y)\n",
"rfc.score(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#Perform cross-validation with the random forest classifier\n",
"\n",
"This shows the model's ability to generalize. It achieves about 86% on the testing set when we perform cross-validation."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.86627043090638933"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
"\n",
"rfc = RandomForestClassifier(n_estimators=100)\n",
"rfc.fit(X_train, y_train)\n",
"rfc.score(X_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.4.3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment