Skip to content

Instantly share code, notes, and snippets.

@raghavrv
Last active December 28, 2015 16:11
Show Gist options
  • Save raghavrv/a4de1d848a5ea6caafd4 to your computer and use it in GitHub Desktop.
Save raghavrv/a4de1d848a5ea6caafd4 to your computer and use it in GitHub Desktop.
Decorator Function for fitted parameter
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:ee93ac627aa6027de5b43cee374610a87019d885d5c329c87c258c7f9d29ce3a"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin\n",
"from sklearn.svm import LinearSVC\n",
"from sklearn.utils import check_consistent_length\n",
"from sklearn.externals.joblib import Parallel, delayed\n",
"from sklearn.multiclass import _fit_binary, _fit_ovo_binary, OneVsOneClassifier\n",
"from sklearn import datasets\n",
"import numpy as np\n",
"import functools"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def set_is_fitted(fit_method):\n",
" @functools.wraps(fit_method)\n",
" def wrapped_fit_method(self, *args, **kwargs):\n",
" self.is_fitted_ = False\n",
" return_val = fit_method(self, *args, **kwargs)\n",
" # If the fit was successful\n",
" self.is_fitted_ = True\n",
" return return_val\n",
" return wrapped_fit_method\n",
"\n",
"@property\n",
"def _is_fitted(self):\n",
" if hasattr(self, 'is_fitted_'):\n",
" return self.is_fitted_\n",
" return None # May be True or False\n",
"\n",
"BaseEstimator.is_fitted = _is_fitted"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# @set_is_fitted to be used over all the desired fit methods\n",
"\n",
"@set_is_fitted\n",
"def _fit(self, X, y):\n",
" \"\"\"Fit underlying estimators.\n",
"\n",
" Parameters\n",
" ----------\n",
" X : (sparse) array-like, shape = [n_samples, n_features]\n",
" Data.\n",
"\n",
" y : array-like, shape = [n_samples]\n",
" Multi-class targets.\n",
"\n",
" Returns\n",
" -------\n",
" self\n",
" \"\"\"\n",
" y = np.asarray(y)\n",
" check_consistent_length(X, y)\n",
"\n",
" self.classes_ = np.unique(y)\n",
" n_classes = self.classes_.shape[0]\n",
" self.estimators_ = Parallel(n_jobs=self.n_jobs)(\n",
" delayed(_fit_ovo_binary)(\n",
" self.estimator, X, y, self.classes_[i], self.classes_[j])\n",
" for i in range(n_classes) for j in range(i + 1, n_classes))\n",
"\n",
" return self\n",
"\n",
"OneVsOneClassifier.fit = _fit"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"iris = datasets.load_iris()\n",
"rng = np.random.RandomState(0)\n",
"perm = rng.permutation(iris.target.size)\n",
"iris.data = iris.data[perm]\n",
"iris.target = iris.target[perm]\n",
"n_classes = 3\n",
"\n",
"ovo = OneVsOneClassifier(LinearSVC())\n",
"ovo.fit(iris.data, iris.target)\n",
"ovo.is_fitted"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 4,
"text": [
"True"
]
}
],
"prompt_number": 4
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment