Skip to content

Instantly share code, notes, and snippets.

@Joshuaek
Created February 25, 2018 07:31
Show Gist options
  • Save Joshuaek/c96c777ed9db32c245e3a013e19374e9 to your computer and use it in GitHub Desktop.
Save Joshuaek/c96c777ed9db32c245e3a013e19374e9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense\n",
"from keras.layers import Dropout\n",
"from keras.wrappers.scikit_learn import KerasClassifier\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"seed = 7\n",
"np.random.seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>V1</th>\n",
" <th>V2</th>\n",
" <th>V3</th>\n",
" <th>V4</th>\n",
" <th>V5</th>\n",
" <th>V6</th>\n",
" <th>V7</th>\n",
" <th>V8</th>\n",
" <th>V9</th>\n",
" <th>V10</th>\n",
" <th>V11</th>\n",
" <th>V12</th>\n",
" <th>V13</th>\n",
" <th>V14</th>\n",
" <th>V15</th>\n",
" <th>V16</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>58</td>\n",
" <td>management</td>\n",
" <td>married</td>\n",
" <td>tertiary</td>\n",
" <td>no</td>\n",
" <td>2143</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>unknown</td>\n",
" <td>5</td>\n",
" <td>may</td>\n",
" <td>261</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>44</td>\n",
" <td>technician</td>\n",
" <td>single</td>\n",
" <td>secondary</td>\n",
" <td>no</td>\n",
" <td>29</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>unknown</td>\n",
" <td>5</td>\n",
" <td>may</td>\n",
" <td>151</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>33</td>\n",
" <td>entrepreneur</td>\n",
" <td>married</td>\n",
" <td>secondary</td>\n",
" <td>no</td>\n",
" <td>2</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>unknown</td>\n",
" <td>5</td>\n",
" <td>may</td>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>47</td>\n",
" <td>blue-collar</td>\n",
" <td>married</td>\n",
" <td>unknown</td>\n",
" <td>no</td>\n",
" <td>1506</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>unknown</td>\n",
" <td>5</td>\n",
" <td>may</td>\n",
" <td>92</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>33</td>\n",
" <td>unknown</td>\n",
" <td>single</td>\n",
" <td>unknown</td>\n",
" <td>no</td>\n",
" <td>1</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>unknown</td>\n",
" <td>5</td>\n",
" <td>may</td>\n",
" <td>198</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 \\\n",
"0 58 management married tertiary no 2143 yes no unknown 5 \n",
"1 44 technician single secondary no 29 yes no unknown 5 \n",
"2 33 entrepreneur married secondary no 2 yes yes unknown 5 \n",
"3 47 blue-collar married unknown no 1506 yes no unknown 5 \n",
"4 33 unknown single unknown no 1 no no unknown 5 \n",
"\n",
" V11 V12 V13 V14 V15 V16 Class \n",
"0 may 261 1 -1 0 unknown 1 \n",
"1 may 151 1 -1 0 unknown 1 \n",
"2 may 76 1 -1 0 unknown 1 \n",
"3 may 92 1 -1 0 unknown 1 \n",
"4 may 198 1 -1 0 unknown 1 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('https://www.openml.org/data/get_csv/1586218/phpkIxskf')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 2])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['Class'].unique()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>V1</th>\n",
" <th>V6</th>\n",
" <th>V10</th>\n",
" <th>V12</th>\n",
" <th>V13</th>\n",
" <th>V14</th>\n",
" <th>V15</th>\n",
" <th>V2_admin.</th>\n",
" <th>V2_blue-collar</th>\n",
" <th>V2_entrepreneur</th>\n",
" <th>...</th>\n",
" <th>V11_mar</th>\n",
" <th>V11_may</th>\n",
" <th>V11_nov</th>\n",
" <th>V11_oct</th>\n",
" <th>V11_sep</th>\n",
" <th>V16_failure</th>\n",
" <th>V16_other</th>\n",
" <th>V16_success</th>\n",
" <th>V16_unknown</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>58</td>\n",
" <td>2143</td>\n",
" <td>5</td>\n",
" <td>261</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>44</td>\n",
" <td>29</td>\n",
" <td>5</td>\n",
" <td>151</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>33</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>76</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>47</td>\n",
" <td>1506</td>\n",
" <td>5</td>\n",
" <td>92</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>33</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>198</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>35</td>\n",
" <td>231</td>\n",
" <td>5</td>\n",
" <td>139</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>28</td>\n",
" <td>447</td>\n",
" <td>5</td>\n",
" <td>217</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>42</td>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>380</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>58</td>\n",
" <td>121</td>\n",
" <td>5</td>\n",
" <td>50</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>43</td>\n",
" <td>593</td>\n",
" <td>5</td>\n",
" <td>55</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>10 rows × 52 columns</p>\n",
"</div>"
],
"text/plain": [
" V1 V6 V10 V12 V13 V14 V15 V2_admin. V2_blue-collar \\\n",
"0 58 2143 5 261 1 -1 0 0 0 \n",
"1 44 29 5 151 1 -1 0 0 0 \n",
"2 33 2 5 76 1 -1 0 0 0 \n",
"3 47 1506 5 92 1 -1 0 0 1 \n",
"4 33 1 5 198 1 -1 0 0 0 \n",
"5 35 231 5 139 1 -1 0 0 0 \n",
"6 28 447 5 217 1 -1 0 0 0 \n",
"7 42 2 5 380 1 -1 0 0 0 \n",
"8 58 121 5 50 1 -1 0 0 0 \n",
"9 43 593 5 55 1 -1 0 0 0 \n",
"\n",
" V2_entrepreneur ... V11_mar V11_may V11_nov V11_oct V11_sep \\\n",
"0 0 ... 0 1 0 0 0 \n",
"1 0 ... 0 1 0 0 0 \n",
"2 1 ... 0 1 0 0 0 \n",
"3 0 ... 0 1 0 0 0 \n",
"4 0 ... 0 1 0 0 0 \n",
"5 0 ... 0 1 0 0 0 \n",
"6 0 ... 0 1 0 0 0 \n",
"7 1 ... 0 1 0 0 0 \n",
"8 0 ... 0 1 0 0 0 \n",
"9 0 ... 0 1 0 0 0 \n",
"\n",
" V16_failure V16_other V16_success V16_unknown Class \n",
"0 0 0 0 1 1 \n",
"1 0 0 0 1 1 \n",
"2 0 0 0 1 1 \n",
"3 0 0 0 1 1 \n",
"4 0 0 0 1 1 \n",
"5 0 0 0 1 1 \n",
"6 0 0 0 1 1 \n",
"7 0 0 0 1 1 \n",
"8 0 0 0 1 1 \n",
"9 0 0 0 1 1 \n",
"\n",
"[10 rows x 52 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dummy_df = pd.get_dummies(df)\n",
"cols = list(dummy_df.columns)\n",
"cols.remove('Class')\n",
"cols.append('Class')\n",
"dummy_df = dummy_df[cols]\n",
"dummy_df.head(n=10)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100\n"
]
}
],
"source": [
"dataset = dummy_df.values\n",
"# split into input (X) and output (Y) variables\n",
"X = dataset[:,0:51].astype(float)\n",
"Y = dataset[:,51]\n",
"Y = [[y-1] for y in Y]\n",
"\n",
"training_examples = 100\n",
"\n",
"x_test = np.asarray(X[training_examples:])\n",
"x_train = np.asarray(X[:training_examples])\n",
"\n",
"y_test = np.asarray(Y[training_examples:])\n",
"y_train = np.asarray(Y[:training_examples])\n",
"\n",
"\n",
"print(len(x_train))\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"51\n"
]
}
],
"source": [
"print(len(X[23]))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"100/100 [==============================] - 0s 3ms/step - loss: 5.0533 - acc: 0.6000\n",
"Epoch 2/30\n",
"100/100 [==============================] - 0s 77us/step - loss: 3.1687 - acc: 0.7900\n",
"Epoch 3/30\n",
"100/100 [==============================] - 0s 77us/step - loss: 2.8346 - acc: 0.7400\n",
"Epoch 4/30\n",
"100/100 [==============================] - 0s 83us/step - loss: 1.8216 - acc: 0.8500\n",
"Epoch 5/30\n",
"100/100 [==============================] - 0s 95us/step - loss: 1.7830 - acc: 0.8600\n",
"Epoch 6/30\n",
"100/100 [==============================] - 0s 97us/step - loss: 1.3023 - acc: 0.9000\n",
"Epoch 7/30\n",
"100/100 [==============================] - 0s 95us/step - loss: 2.3336 - acc: 0.8100\n",
"Epoch 8/30\n",
"100/100 [==============================] - 0s 80us/step - loss: 1.2987 - acc: 0.9000\n",
"Epoch 9/30\n",
"100/100 [==============================] - 0s 107us/step - loss: 1.1307 - acc: 0.9100\n",
"Epoch 10/30\n",
"100/100 [==============================] - 0s 76us/step - loss: 1.5822 - acc: 0.8900\n",
"Epoch 11/30\n",
"100/100 [==============================] - 0s 77us/step - loss: 0.8198 - acc: 0.9300\n",
"Epoch 12/30\n",
"100/100 [==============================] - 0s 73us/step - loss: 1.2554 - acc: 0.9200\n",
"Epoch 13/30\n",
"100/100 [==============================] - 0s 95us/step - loss: 1.4007 - acc: 0.9000\n",
"Epoch 14/30\n",
"100/100 [==============================] - 0s 76us/step - loss: 0.8098 - acc: 0.9500\n",
"Epoch 15/30\n",
"100/100 [==============================] - 0s 94us/step - loss: 0.8629 - acc: 0.9200\n",
"Epoch 16/30\n",
"100/100 [==============================] - 0s 72us/step - loss: 1.2508 - acc: 0.9100\n",
"Epoch 17/30\n",
"100/100 [==============================] - 0s 96us/step - loss: 1.3948 - acc: 0.8900\n",
"Epoch 18/30\n",
"100/100 [==============================] - 0s 89us/step - loss: 1.6759 - acc: 0.8600\n",
"Epoch 19/30\n",
"100/100 [==============================] - 0s 79us/step - loss: 0.6739 - acc: 0.9500\n",
"Epoch 20/30\n",
"100/100 [==============================] - 0s 71us/step - loss: 0.8056 - acc: 0.9500\n",
"Epoch 21/30\n",
"100/100 [==============================] - 0s 62us/step - loss: 0.6576 - acc: 0.9400\n",
"Epoch 22/30\n",
"100/100 [==============================] - 0s 66us/step - loss: 0.8816 - acc: 0.9300\n",
"Epoch 23/30\n",
"100/100 [==============================] - 0s 90us/step - loss: 0.6433 - acc: 0.9600\n",
"Epoch 24/30\n",
"100/100 [==============================] - 0s 58us/step - loss: 0.5001 - acc: 0.9600\n",
"Epoch 25/30\n",
"100/100 [==============================] - 0s 73us/step - loss: 0.9618 - acc: 0.9400\n",
"Epoch 26/30\n",
"100/100 [==============================] - 0s 69us/step - loss: 1.2810 - acc: 0.9200\n",
"Epoch 27/30\n",
"100/100 [==============================] - 0s 70us/step - loss: 0.6553 - acc: 0.9500\n",
"Epoch 28/30\n",
"100/100 [==============================] - 0s 84us/step - loss: 0.4838 - acc: 0.9700\n",
"Epoch 29/30\n",
"100/100 [==============================] - 0s 69us/step - loss: 0.5863 - acc: 0.9500\n",
"Epoch 30/30\n",
"100/100 [==============================] - 0s 74us/step - loss: 0.6430 - acc: 0.9600\n",
"45111/45111 [==============================] - 1s 22us/step\n"
]
}
],
"source": [
"model = Sequential()\n",
"model.add(Dense(51, input_dim=51, activation='relu', name=\"input\"))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(51, activation='relu', name=\"hidden\"))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(1, activation='sigmoid', name=\"output\"))\n",
"\n",
"model.compile(loss='binary_crossentropy',\n",
" optimizer='rmsprop',\n",
" metrics=['accuracy'])\n",
"\n",
"model.fit(x_train, y_train,\n",
" epochs=30,\n",
" batch_size=64)\n",
"score = model.evaluate(x_test, y_test, batch_size=64)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "Error when checking : expected input_input to have shape (51,) but got array with shape (1,)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-e3fab4dcbcc1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_classes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/models.py\u001b[0m in \u001b[0;36mpredict_classes\u001b[0;34m(self, x, batch_size, verbose, steps)\u001b[0m\n\u001b[1;32m 1136\u001b[0m \"\"\"\n\u001b[1;32m 1137\u001b[0m proba = self.predict(x, batch_size=batch_size, verbose=verbose,\n\u001b[0;32m-> 1138\u001b[0;31m steps=steps)\n\u001b[0m\u001b[1;32m 1139\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mproba\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1140\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mproba\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/models.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1024\u001b[0m return self.model.predict(x, batch_size=batch_size, verbose=verbose,\n\u001b[0;32m-> 1025\u001b[0;31m steps=steps)\n\u001b[0m\u001b[1;32m 1026\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1027\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_on_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps)\u001b[0m\n\u001b[1;32m 1822\u001b[0m x = _standardize_input_data(x, self._feed_input_names,\n\u001b[1;32m 1823\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_feed_input_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1824\u001b[0;31m check_batch_axis=False)\n\u001b[0m\u001b[1;32m 1825\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstateful\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1826\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_standardize_input_data\u001b[0;34m(data, names, shapes, check_batch_axis, exception_prefix)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;34m': expected '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnames\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' to have shape '\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m' but got array with shape '\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 123\u001b[0;31m str(data_shape))\n\u001b[0m\u001b[1;32m 124\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Error when checking : expected input_input to have shape (51,) but got array with shape (1,)"
]
}
],
"source": [
"model.predict_classes(x_train[3])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(51,)\n"
]
}
],
"source": [
"print(x_train[3].shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment