Created
November 7, 2017 18:01
-
-
Save mbeyeler/101cd3d2c5d1e718cc43cfe3dde9461f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# How to wrap an OpenCV estimator for scikit-learn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn import datasets\n", | |
"iris = datasets.load_iris()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import cv2\n", | |
"from sklearn.base import BaseEstimator, ClassifierMixin\n", | |
"class MyKnn(BaseEstimator, ClassifierMixin):\n", | |
" def __init__(self, k=1):\n", | |
" \"\"\"An OpenCV-based k-nearest neighbor classifier wrapped for scikit-learn\n", | |
" \n", | |
" Parameters\n", | |
" ----------\n", | |
" k : int, optional, default: 1\n", | |
" The number of neighbors to use by default.\n", | |
" \"\"\"\n", | |
" self.k = k\n", | |
" self.knn = cv2.ml.KNearest_create()\n", | |
" self.knn.setDefaultK(k)\n", | |
" \n", | |
" def get_params(self, deep=True):\n", | |
" \"\"\"Get parameters for this estimator\"\"\"\n", | |
" return {'k': self.k}\n", | |
" \n", | |
" def set_params(self, **params):\n", | |
" \"\"\"Set parameters for this estimator\"\"\"\n", | |
" for param, value in params.items():\n", | |
" setattr(self, param, value)\n", | |
" return self\n", | |
" \n", | |
" def predict(self, X):\n", | |
" \"\"\"Predict the class labels for the provided data\n", | |
" \n", | |
" Parameters\n", | |
" ----------\n", | |
" X : array-like, shape (n_query, n_features)\n", | |
" Test samples.\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" y : array of shape [n_samples] or [n_samples, n_outputs]\n", | |
" Class labels for each data sample.\n", | |
" \"\"\"\n", | |
" ret, y_pred = self.knn.predict(X.astype(np.float32))\n", | |
" return y_pred\n", | |
"\n", | |
" def fit(self, X, y):\n", | |
" \"\"\"Fit the model using X as training data and y as target values\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" X : array of shape [n_samples, n_features]\n", | |
" Training data.\n", | |
" y : array of shape [n_samples]\n", | |
" \"\"\"\n", | |
" self.knn.train(X.astype(np.float32), cv2.ml.ROW_SAMPLE, y)\n", | |
" return self" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## How to call the classifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"knn = MyKnn(3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"MyKnn(k=3)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"knn.fit(iris.data, iris.target)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.95999999999999996" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"knn.score(iris.data, iris.target)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## How to use the classifier in cross-validation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0.96666667, 0.96666667, 0.93333333, 0.96666667, 1. ])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.model_selection import cross_val_score\n", | |
"cross_val_score(MyKnn(3), iris.data, iris.target, cv=5)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## How to use the classifier in grid search with cross-validation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"GridSearchCV(cv=None, error_score='raise', estimator=MyKnn(k=1),\n", | |
" fit_params={}, iid=True, n_jobs=1,\n", | |
" param_grid={'k': array([1, 2, 3, 4, 5, 6, 7, 8, 9])},\n", | |
" pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n", | |
" scoring=None, verbose=0)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.model_selection import GridSearchCV\n", | |
"grid_search = GridSearchCV(MyKnn(), {'k': np.arange(1, 10)})\n", | |
"grid_search.fit(iris.data, iris.target)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'k': 1}" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grid_search.best_params_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.97333333333333338" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grid_search.best_score_" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment