Created
January 2, 2019 18:20
-
-
Save johnfink8/1b8554d138b78a31c1d48110c393403b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn.model_selection import GridSearchCV,StratifiedKFold\n", | |
"import numpy as np\n", | |
"from keras.models import *\n", | |
"from keras.layers import *\n", | |
"from keras.callbacks import *\n", | |
"from keras.metrics import *\n", | |
"from keras.datasets import cifar10\n", | |
"import matplotlib.pyplot as plt\n", | |
"from keras.preprocessing.image import ImageDataGenerator\n", | |
"import cv2\n", | |
"from keras.wrappers.scikit_learn import KerasClassifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(50000, 32, 32, 3)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x_train.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.unique(y_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def simple_discrim(kernel_size):\n", | |
" inp=Input(shape=x_train.shape[1:])\n", | |
" x=Dropout(0.1)(inp)\n", | |
" w=64\n", | |
" for i in (1,2,4):\n", | |
" x=Conv2D(w*i, kernel_size, strides=2, padding='same', activation='relu')(x)\n", | |
" if i <4:\n", | |
" x=BatchNormalization()(x) \n", | |
"\n", | |
" x=Flatten()(x)\n", | |
" outp=Dense(10, activation='softmax')(x)\n", | |
" model=Model(inputs=inp, outputs=outp)\n", | |
" model.compile('adam', 'sparse_categorical_crossentropy', metrics=['acc'])\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Conda\\Anaconda3\\envs\\tf12\\lib\\site-packages\\sklearn\\model_selection\\_split.py:2053: FutureWarning: You should specify a value for 'cv' instead of relying on the default value. The default value will change from 3 to 5 in version 0.22.\n", | |
" warnings.warn(CV_WARNING, FutureWarning)\n", | |
"[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Fitting 3 folds for each of 4 candidates, totalling 12 fits\n", | |
"[CV] kernel_size=1 ...................................................\n", | |
"16667/16667 [==============================] - 1s 49us/step\n", | |
"33333/33333 [==============================] - 2s 46us/step\n", | |
"[CV] ......... kernel_size=1, score=0.41597168056996486, total= 24.4s\n", | |
"[CV] kernel_size=1 ...................................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 25.9s remaining: 0.0s\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16667/16667 [==============================] - 1s 57us/step\n", | |
"33333/33333 [==============================] - 1s 44us/step\n", | |
"[CV] ......... kernel_size=1, score=0.41465170694797965, total= 23.7s\n", | |
"[CV] kernel_size=1 ...................................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 51.1s remaining: 0.0s\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16666/16666 [==============================] - 1s 53us/step\n", | |
"33334/33334 [==============================] - 2s 48us/step \n", | |
"[CV] .......... kernel_size=1, score=0.4193567742852765, total= 23.6s\n", | |
"[CV] kernel_size=3 ...................................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.3min remaining: 0.0s\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16667/16667 [==============================] - 1s 64us/step\n", | |
"33333/33333 [==============================] - 2s 57us/step\n", | |
"[CV] .......... kernel_size=3, score=0.6086278274613321, total= 28.4s\n", | |
"[CV] kernel_size=3 ...................................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=1)]: Done 4 out of 4 | elapsed: 1.8min remaining: 0.0s\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16667/16667 [==============================] - 1s 63us/step\n", | |
"33333/33333 [==============================] - 2s 55us/step\n", | |
"[CV] .......... kernel_size=3, score=0.5650686986331799, total= 26.9s\n", | |
"[CV] kernel_size=3 ...................................................\n", | |
"16666/16666 [==============================] - 1s 66us/step\n", | |
"33334/33334 [==============================] - 2s 56us/step\n", | |
"[CV] .......... kernel_size=3, score=0.5270610824468741, total= 27.4s\n", | |
"[CV] kernel_size=5 ...................................................\n", | |
"16667/16667 [==============================] - 1s 82us/step\n", | |
"33333/33333 [==============================] - 2s 72us/step\n", | |
"[CV] .......... kernel_size=5, score=0.6045479090471835, total= 36.1s\n", | |
"[CV] kernel_size=5 ...................................................\n", | |
"16667/16667 [==============================] - 1s 86us/step\n", | |
"33333/33333 [==============================] - 2s 72us/step\n", | |
"[CV] ........... kernel_size=5, score=0.599988000236419, total= 36.4s\n", | |
"[CV] kernel_size=5 ...................................................\n", | |
"16666/16666 [==============================] - 1s 86us/step\n", | |
"33334/33334 [==============================] - 2s 74us/step\n", | |
"[CV] .......... kernel_size=5, score=0.5964238569828896, total= 36.7s\n", | |
"[CV] kernel_size=7 ...................................................\n", | |
"16667/16667 [==============================] - 2s 110us/step\n", | |
"33333/33333 [==============================] - 3s 95us/step\n", | |
"[CV] .......... kernel_size=7, score=0.5818083638649092, total= 51.7s\n", | |
"[CV] kernel_size=7 ...................................................\n", | |
"16667/16667 [==============================] - 2s 111us/step\n", | |
"33333/33333 [==============================] - 3s 96us/step\n", | |
"[CV] .......... kernel_size=7, score=0.6043679126632044, total= 51.9s\n", | |
"[CV] kernel_size=7 ...................................................\n", | |
"16666/16666 [==============================] - 2s 113us/step\n", | |
"33334/33334 [==============================] - 3s 97us/step\n", | |
"[CV] .......... kernel_size=7, score=0.6012240489977227, total= 52.1s\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=1)]: Done 12 out of 12 | elapsed: 7.5min finished\n" | |
] | |
} | |
], | |
"source": [ | |
"model = KerasClassifier(build_fn=simple_discrim)\n", | |
"param_grid = {\n", | |
" 'kernel_size':[1,3,5,7],\n", | |
"}\n", | |
"grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=1, verbose=5)\n", | |
"grid_result = grid.fit(x_train, y_train, epochs=5, verbose=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def report(gridsearch, n_top=3):\n", | |
" results=gridsearch.cv_results_\n", | |
" for i in range(1, n_top + 1):\n", | |
" candidates = np.flatnonzero(results['rank_test_score'] == i)\n", | |
" for candidate in candidates:\n", | |
" print(\"Model with rank: {0}\".format(i))\n", | |
" print(\"Mean validation score: {0:.3f} (std: {1:.3f})\".format(\n", | |
" results['mean_test_score'][candidate],\n", | |
" results['std_test_score'][candidate]))\n", | |
" print(\"Parameters: {0}\".format(results['params'][candidate]))\n", | |
" print(\"\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model with rank: 1\n", | |
"Mean validation score: 0.600 (std: 0.003)\n", | |
"Parameters: {'kernel_size': 5}\n", | |
"\n", | |
"Model with rank: 2\n", | |
"Mean validation score: 0.596 (std: 0.010)\n", | |
"Parameters: {'kernel_size': 7}\n", | |
"\n", | |
"Model with rank: 3\n", | |
"Mean validation score: 0.567 (std: 0.033)\n", | |
"Parameters: {'kernel_size': 3}\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"report(grid)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment