Last active
December 28, 2015 16:11
-
-
Save raghavrv/a4de1d848a5ea6caafd4 to your computer and use it in GitHub Desktop.
Decorator Function for fitted parameter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"metadata": { | |
"name": "", | |
"signature": "sha256:ee93ac627aa6027de5b43cee374610a87019d885d5c329c87c258c7f9d29ce3a" | |
}, | |
"nbformat": 3, | |
"nbformat_minor": 0, | |
"worksheets": [ | |
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"from sklearn.base import BaseEstimator, ClassifierMixin, MetaEstimatorMixin\n", | |
"from sklearn.svm import LinearSVC\n", | |
"from sklearn.utils import check_consistent_length\n", | |
"from sklearn.externals.joblib import Parallel, delayed\n", | |
"from sklearn.multiclass import _fit_binary, _fit_ovo_binary, OneVsOneClassifier\n", | |
"from sklearn import datasets\n", | |
"import numpy as np\n", | |
"import functools" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 1 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"def set_is_fitted(fit_method):\n", | |
" @functools.wraps(fit_method)\n", | |
" def wrapped_fit_method(self, *args, **kwargs):\n", | |
" self.is_fitted_ = False\n", | |
" return_val = fit_method(self, *args, **kwargs)\n", | |
" # If the fit was successful\n", | |
" self.is_fitted_ = True\n", | |
" return return_val\n", | |
" return wrapped_fit_method\n", | |
"\n", | |
"@property\n", | |
"def _is_fitted(self):\n", | |
" if hasattr(self, 'is_fitted_'):\n", | |
" return self.is_fitted_\n", | |
" return None # May be True or False\n", | |
"\n", | |
"BaseEstimator.is_fitted = _is_fitted" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 2 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"# @set_is_fitted to be used over all the desired fit methods\n", | |
"\n", | |
"@set_is_fitted\n", | |
"def _fit(self, X, y):\n", | |
" \"\"\"Fit underlying estimators.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" X : (sparse) array-like, shape = [n_samples, n_features]\n", | |
" Data.\n", | |
"\n", | |
" y : array-like, shape = [n_samples]\n", | |
" Multi-class targets.\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" self\n", | |
" \"\"\"\n", | |
" y = np.asarray(y)\n", | |
" check_consistent_length(X, y)\n", | |
"\n", | |
" self.classes_ = np.unique(y)\n", | |
" n_classes = self.classes_.shape[0]\n", | |
" self.estimators_ = Parallel(n_jobs=self.n_jobs)(\n", | |
" delayed(_fit_ovo_binary)(\n", | |
" self.estimator, X, y, self.classes_[i], self.classes_[j])\n", | |
" for i in range(n_classes) for j in range(i + 1, n_classes))\n", | |
"\n", | |
" return self\n", | |
"\n", | |
"OneVsOneClassifier.fit = _fit" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [], | |
"prompt_number": 3 | |
}, | |
{ | |
"cell_type": "code", | |
"collapsed": false, | |
"input": [ | |
"iris = datasets.load_iris()\n", | |
"rng = np.random.RandomState(0)\n", | |
"perm = rng.permutation(iris.target.size)\n", | |
"iris.data = iris.data[perm]\n", | |
"iris.target = iris.target[perm]\n", | |
"n_classes = 3\n", | |
"\n", | |
"ovo = OneVsOneClassifier(LinearSVC())\n", | |
"ovo.fit(iris.data, iris.target)\n", | |
"ovo.is_fitted" | |
], | |
"language": "python", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"metadata": {}, | |
"output_type": "pyout", | |
"prompt_number": 4, | |
"text": [ | |
"True" | |
] | |
} | |
], | |
"prompt_number": 4 | |
} | |
], | |
"metadata": {} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment