Skip to content

Instantly share code, notes, and snippets.

@canard0328
Created February 2, 2019 01:37
Show Gist options
  • Save canard0328/9b3fd9f64f13253379e69a8afd984836 to your computer and use it in GitHub Desktop.
Save canard0328/9b3fd9f64f13253379e69a8afd984836 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tpot import TPOTClassifier"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('telecom_train.csv')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Account length</th>\n",
" <th>Area code</th>\n",
" <th>Number vmail messages</th>\n",
" <th>Total day minutes</th>\n",
" <th>Total day calls</th>\n",
" <th>Total day charge</th>\n",
" <th>Total eve minutes</th>\n",
" <th>Total eve calls</th>\n",
" <th>Total eve charge</th>\n",
" <th>Total night minutes</th>\n",
" <th>Total night calls</th>\n",
" <th>Total night charge</th>\n",
" <th>Total intl minutes</th>\n",
" <th>Total intl calls</th>\n",
" <th>Total intl charge</th>\n",
" <th>Customer service calls</th>\n",
" <th>Plan</th>\n",
" <th>Churn</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>70</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>230.3</td>\n",
" <td>110</td>\n",
" <td>39.15</td>\n",
" <td>77.9</td>\n",
" <td>87</td>\n",
" <td>6.62</td>\n",
" <td>247.1</td>\n",
" <td>105</td>\n",
" <td>11.12</td>\n",
" <td>13.2</td>\n",
" <td>4</td>\n",
" <td>3.56</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>69</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>153.7</td>\n",
" <td>109</td>\n",
" <td>26.13</td>\n",
" <td>194.0</td>\n",
" <td>105</td>\n",
" <td>16.49</td>\n",
" <td>256.1</td>\n",
" <td>114</td>\n",
" <td>11.52</td>\n",
" <td>14.1</td>\n",
" <td>6</td>\n",
" <td>3.81</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>45</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>78.2</td>\n",
" <td>127</td>\n",
" <td>13.29</td>\n",
" <td>253.4</td>\n",
" <td>108</td>\n",
" <td>21.54</td>\n",
" <td>255.0</td>\n",
" <td>100</td>\n",
" <td>11.48</td>\n",
" <td>18.0</td>\n",
" <td>3</td>\n",
" <td>4.86</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111</td>\n",
" <td>510</td>\n",
" <td>0</td>\n",
" <td>197.1</td>\n",
" <td>117</td>\n",
" <td>33.51</td>\n",
" <td>227.8</td>\n",
" <td>128</td>\n",
" <td>19.36</td>\n",
" <td>214.0</td>\n",
" <td>101</td>\n",
" <td>9.63</td>\n",
" <td>9.3</td>\n",
" <td>11</td>\n",
" <td>2.51</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>158</td>\n",
" <td>408</td>\n",
" <td>0</td>\n",
" <td>172.4</td>\n",
" <td>114</td>\n",
" <td>29.31</td>\n",
" <td>256.6</td>\n",
" <td>69</td>\n",
" <td>21.81</td>\n",
" <td>235.3</td>\n",
" <td>104</td>\n",
" <td>10.59</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" <td>0.00</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Account length Area code Number vmail messages Total day minutes \\\n",
"0 70 415 0 230.3 \n",
"1 69 415 0 153.7 \n",
"2 45 415 0 78.2 \n",
"3 111 510 0 197.1 \n",
"4 158 408 0 172.4 \n",
"\n",
" Total day calls Total day charge Total eve minutes Total eve calls \\\n",
"0 110 39.15 77.9 87 \n",
"1 109 26.13 194.0 105 \n",
"2 127 13.29 253.4 108 \n",
"3 117 33.51 227.8 128 \n",
"4 114 29.31 256.6 69 \n",
"\n",
" Total eve charge Total night minutes Total night calls \\\n",
"0 6.62 247.1 105 \n",
"1 16.49 256.1 114 \n",
"2 21.54 255.0 100 \n",
"3 19.36 214.0 101 \n",
"4 21.81 235.3 104 \n",
"\n",
" Total night charge Total intl minutes Total intl calls \\\n",
"0 11.12 13.2 4 \n",
"1 11.52 14.1 6 \n",
"2 11.48 18.0 3 \n",
"3 9.63 9.3 11 \n",
"4 10.59 0.0 0 \n",
"\n",
" Total intl charge Customer service calls Plan Churn \n",
"0 3.56 1 0 False \n",
"1 3.81 1 0 False \n",
"2 4.86 1 0 False \n",
"3 2.51 0 0 False \n",
"4 0.00 2 0 False "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X = data.drop('Churn', axis=1)\n",
"y = data['Churn']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tpot = TPOTClassifier(n_jobs=-1, random_state=0, generations=5, population_size=20, verbosity=2)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: xgboost.XGBClassifier is not available and will not be used by TPOT.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Optimization Progress', max=120, style=ProgressStyle(descript…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 1 - Current best internal CV score: 0.9327456962151428\n",
"Generation 2 - Current best internal CV score: 0.9327456962151428\n",
"Generation 3 - Current best internal CV score: 0.9327456962151428\n",
"Generation 4 - Current best internal CV score: 0.9327456962151428\n",
"Generation 5 - Current best internal CV score: 0.9327456962151428\n",
"\n",
"Best pipeline: ExtraTreesClassifier(PolynomialFeatures(input_matrix, degree=2, include_bias=False, interaction_only=False), bootstrap=True, criterion=gini, max_features=0.6000000000000001, min_samples_leaf=1, min_samples_split=3, n_estimators=100)\n"
]
},
{
"data": {
"text/plain": [
"TPOTClassifier(config_dict=None, crossover_rate=0.1, cv=5,\n",
" disable_update_check=False, early_stop=None, generations=5,\n",
" max_eval_time_mins=5, max_time_mins=None, memory=None,\n",
" mutation_rate=0.9, n_jobs=-1, offspring_size=None,\n",
" periodic_checkpoint_folder=None, population_size=20,\n",
" random_state=0, scoring=None, subsample=1.0, use_dask=False,\n",
" verbosity=2, warm_start=False)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot.fit(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"data_test = pd.read_csv('telecom_test_with_y.csv')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Account length</th>\n",
" <th>Area code</th>\n",
" <th>Number vmail messages</th>\n",
" <th>Total day minutes</th>\n",
" <th>Total day calls</th>\n",
" <th>Total day charge</th>\n",
" <th>Total eve minutes</th>\n",
" <th>Total eve calls</th>\n",
" <th>Total eve charge</th>\n",
" <th>Total night minutes</th>\n",
" <th>Total night calls</th>\n",
" <th>Total night charge</th>\n",
" <th>Total intl minutes</th>\n",
" <th>Total intl calls</th>\n",
" <th>Total intl charge</th>\n",
" <th>Customer service calls</th>\n",
" <th>Plan</th>\n",
" <th>Churn</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>133</td>\n",
" <td>510</td>\n",
" <td>0</td>\n",
" <td>295.0</td>\n",
" <td>141</td>\n",
" <td>50.15</td>\n",
" <td>223.6</td>\n",
" <td>101</td>\n",
" <td>19.01</td>\n",
" <td>229.4</td>\n",
" <td>109</td>\n",
" <td>10.32</td>\n",
" <td>12.9</td>\n",
" <td>4</td>\n",
" <td>3.48</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>99</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>200.0</td>\n",
" <td>66</td>\n",
" <td>34.00</td>\n",
" <td>107.9</td>\n",
" <td>104</td>\n",
" <td>9.17</td>\n",
" <td>233.7</td>\n",
" <td>82</td>\n",
" <td>10.52</td>\n",
" <td>11.4</td>\n",
" <td>2</td>\n",
" <td>3.08</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>44</td>\n",
" <td>415</td>\n",
" <td>0</td>\n",
" <td>240.3</td>\n",
" <td>146</td>\n",
" <td>40.85</td>\n",
" <td>164.6</td>\n",
" <td>83</td>\n",
" <td>13.99</td>\n",
" <td>240.7</td>\n",
" <td>106</td>\n",
" <td>10.83</td>\n",
" <td>10.6</td>\n",
" <td>2</td>\n",
" <td>2.86</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>130</td>\n",
" <td>408</td>\n",
" <td>0</td>\n",
" <td>211.2</td>\n",
" <td>119</td>\n",
" <td>35.90</td>\n",
" <td>231.1</td>\n",
" <td>120</td>\n",
" <td>19.64</td>\n",
" <td>220.9</td>\n",
" <td>80</td>\n",
" <td>9.94</td>\n",
" <td>6.3</td>\n",
" <td>9</td>\n",
" <td>1.70</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>90</td>\n",
" <td>408</td>\n",
" <td>0</td>\n",
" <td>222.0</td>\n",
" <td>93</td>\n",
" <td>37.74</td>\n",
" <td>187.0</td>\n",
" <td>103</td>\n",
" <td>15.90</td>\n",
" <td>282.3</td>\n",
" <td>124</td>\n",
" <td>12.70</td>\n",
" <td>12.4</td>\n",
" <td>6</td>\n",
" <td>3.35</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Account length Area code Number vmail messages Total day minutes \\\n",
"0 133 510 0 295.0 \n",
"1 99 415 0 200.0 \n",
"2 44 415 0 240.3 \n",
"3 130 408 0 211.2 \n",
"4 90 408 0 222.0 \n",
"\n",
" Total day calls Total day charge Total eve minutes Total eve calls \\\n",
"0 141 50.15 223.6 101 \n",
"1 66 34.00 107.9 104 \n",
"2 146 40.85 164.6 83 \n",
"3 119 35.90 231.1 120 \n",
"4 93 37.74 187.0 103 \n",
"\n",
" Total eve charge Total night minutes Total night calls \\\n",
"0 19.01 229.4 109 \n",
"1 9.17 233.7 82 \n",
"2 13.99 240.7 106 \n",
"3 19.64 220.9 80 \n",
"4 15.90 282.3 124 \n",
"\n",
" Total night charge Total intl minutes Total intl calls \\\n",
"0 10.32 12.9 4 \n",
"1 10.52 11.4 2 \n",
"2 10.83 10.6 2 \n",
"3 9.94 6.3 9 \n",
"4 12.70 12.4 6 \n",
"\n",
" Total intl charge Customer service calls Plan Churn \n",
"0 3.48 2 0 True \n",
"1 3.08 3 0 False \n",
"2 2.86 1 0 False \n",
"3 1.70 2 0 False \n",
"4 3.35 2 0 False "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_test.head()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"X_test = data_test.drop('Churn', axis=1)\n",
"y_test = data_test['Churn']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9412\n"
]
}
],
"source": [
"print(tpot.score(X_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot.export('tpot_telecom_pipeline.py')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "test",
"language": "python",
"name": "test"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment