Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mbeyeler/101cd3d2c5d1e718cc43cfe3dde9461f to your computer and use it in GitHub Desktop.
Save mbeyeler/101cd3d2c5d1e718cc43cfe3dde9461f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to wrap an OpenCV estimator for scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from sklearn import datasets\n",
"iris = datasets.load_iris()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import cv2\n",
"from sklearn.base import BaseEstimator, ClassifierMixin\n",
"class MyKnn(BaseEstimator, ClassifierMixin):\n",
" def __init__(self, k=1):\n",
" \"\"\"An OpenCV-based k-nearest neighbor classifier wrapped for scikit-learn\n",
" \n",
" Parameters\n",
" ----------\n",
" k : int, optional, default: 1\n",
" The number of neighbors to use by default.\n",
" \"\"\"\n",
" self.k = k\n",
" self.knn = cv2.ml.KNearest_create()\n",
" self.knn.setDefaultK(k)\n",
" \n",
" def get_params(self, deep=True):\n",
" \"\"\"Get parameters for this estimator\"\"\"\n",
" return {'k': self.k}\n",
" \n",
" def set_params(self, **params):\n",
" \"\"\"Set parameters for this estimator\"\"\"\n",
" for param, value in params.items():\n",
" setattr(self, param, value)\n",
" return self\n",
" \n",
" def predict(self, X):\n",
" \"\"\"Predict the class labels for the provided data\n",
" \n",
" Parameters\n",
" ----------\n",
" X : array-like, shape (n_query, n_features)\n",
" Test samples.\n",
"\n",
" Returns\n",
" -------\n",
" y : array of shape [n_samples] or [n_samples, n_outputs]\n",
" Class labels for each data sample.\n",
" \"\"\"\n",
" ret, y_pred = self.knn.predict(X.astype(np.float32))\n",
" return y_pred\n",
"\n",
" def fit(self, X, y):\n",
" \"\"\"Fit the model using X as training data and y as target values\n",
"\n",
" Parameters\n",
" ----------\n",
" X : array of shape [n_samples, n_features]\n",
" Training data.\n",
" y : array of shape [n_samples]\n",
" \"\"\"\n",
" self.knn.train(X.astype(np.float32), cv2.ml.ROW_SAMPLE, y)\n",
" return self"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How to call the classifier"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"knn = MyKnn(3)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"MyKnn(k=3)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"knn.fit(iris.data, iris.target)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.95999999999999996"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"knn.score(iris.data, iris.target)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How to use the classifier in cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.96666667, 0.96666667, 0.93333333, 0.96666667, 1. ])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import cross_val_score\n",
"cross_val_score(MyKnn(3), iris.data, iris.target, cv=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How to use the classifier in grid search with cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=None, error_score='raise', estimator=MyKnn(k=1),\n",
" fit_params={}, iid=True, n_jobs=1,\n",
" param_grid={'k': array([1, 2, 3, 4, 5, 6, 7, 8, 9])},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n",
" scoring=None, verbose=0)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"grid_search = GridSearchCV(MyKnn(), {'k': np.arange(1, 10)})\n",
"grid_search.fit(iris.data, iris.target)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'k': 1}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grid_search.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"0.97333333333333338"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grid_search.best_score_"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment