Skip to content

Instantly share code, notes, and snippets.

@NicolasHug
Created September 18, 2020 20:03
Show Gist options
  • Save NicolasHug/2db607b01482988fa549eb2c8770f79f to your computer and use it in GitHub Desktop.
Save NicolasHug/2db607b01482988fa549eb2c8770f79f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.utils.validation import check_random_state\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.utils import shuffle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Changes to Estimators and `clone()`:\n",
"\n",
"- A random seed is drawn in `__init__()`. `set_params()` is updated accordingly\n",
"- `clone()` can now explicitly support strict clones and statistical clones.\n",
"- `fit()` and `get_params()` are unchanged"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def _sample_seed(random_state):\n",
" # sample a random seed to be stored as the random_state attribute\n",
" # ints are passed-through\n",
" if isinstance(random_state, int):\n",
" return random_state\n",
" else:\n",
" return check_random_state(random_state).randint(0, 2**32)\n",
"\n",
" \n",
"class Estimator():\n",
" def __init__(self, random_state=None):\n",
" self.random_state = _sample_seed(random_state)\n",
" \n",
" def fit(self, X=None, y=None):\n",
" # unchanged\n",
" rng = check_random_state(self.random_state)\n",
" print(rng.randint(0, 100))\n",
" return self\n",
" \n",
" def get_params(self):\n",
" # unchanged\n",
" return {'random_state': self.random_state}\n",
" \n",
" def set_params(self, random_state=None):\n",
" self.random_state = _sample_seed(random_state)\n",
" \n",
" def score(self, X, y):\n",
" return 0 # irrelevant\n",
"\n",
" \n",
"def _check_statistical_clone_possible(est):\n",
" if 'random_state' not in est.get_params():\n",
" raise ValueError(\"This estimator isn't random and can only have exact clones\")\n",
" \n",
"\n",
"def clone(est, statistical=False):\n",
" # Return a strict clone or a statistical clone.\n",
" \n",
" # statistical parameter can be:\n",
" # - False: a strict clone is returned\n",
" # - True: a statistical clone is returned. Its RNG is seeded from `est`\n",
" # - None, int, or RandomState instance: a statistical clone is returned.\n",
" # Its RNG is seeded from `statistical`. This is useful to\n",
" # create multiple statistical clones that don't have the same RNG\n",
" \n",
" params = est.get_params()\n",
" \n",
" if statistical is not False:\n",
" # A statistical clone is a clone with a different random_state attribute\n",
" _check_statistical_clone_possible(est)\n",
" rng = params['random_state'] if statistical is True else statistical\n",
" params['random_state'] = _sample_seed(check_random_state(rng))\n",
" \n",
" return est.__class__(**params)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Illustration of estimators behavior"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"25\n",
"25\n",
"95\n",
"95\n"
]
}
],
"source": [
"# Multiple calls to fit on the same instance produce the same rng\n",
"# Also, fit is truely idempotent\n",
"\n",
"a = Estimator(random_state=None).fit().fit()\n",
"b = Estimator(random_state=None).fit().fit()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"exact clones:\n",
"25\n",
"95\n",
"statistical clones (different RNGs):\n",
"24\n",
"30\n",
"statistical clones with random_state=int: Different RNG can still be obtained\n",
"44\n",
"63\n"
]
},
{
"data": {
"text/plain": [
"<__main__.Estimator at 0x7f3d17b9fd90>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Users can explicitly create exact and statistical clones\n",
"# Exact clones can be obtained even if None/instances are passed (this is impossible in master)\n",
"\n",
"print(\"exact clones:\")\n",
"clone(a).fit()\n",
"clone(b).fit()\n",
"\n",
"print(\"statistical clones (different RNGs):\")\n",
"clone(a, statistical=True).fit()\n",
"clone(b, statistical=True).fit()\n",
"\n",
"# Also, statistical clones can be obtained even if ints are passed.\n",
"# In master, None/instances can only give statistical clones, and ints can only give exact clones\n",
"\n",
"print(\"statistical clones with random_state=int: Different RNG can still be obtained\")\n",
"with_int = Estimator(random_state=0)\n",
"with_int.fit()\n",
"clone(with_int, statistical=True).fit()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a's RNG is unchanged\n",
"25\n",
"set a's RNG to that of b\n",
"95\n"
]
},
{
"data": {
"text/plain": [
"<__main__.Estimator at 0x7f3d4821ed60>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Using set_params and get_params allows to get the exact same RNG as another esitmator\n",
"\n",
"print(\"a's RNG is unchanged\")\n",
"a.set_params(random_state=a.get_params()['random_state'])\n",
"a.fit()\n",
"\n",
"print(\"set a's RNG to that of b\")\n",
"a.set_params(random_state=b.get_params()['random_state'])\n",
"a.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CV routines: Users now have explicit control on the CV strategy"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Example of what CV routines would look like.\n",
"# The behaviour of the CV procedure is now explicit, and doesn't depend on the estimator's random_state\n",
"# Use-cases C and D are supported by any estimator.\n",
"\n",
"def cross_val_score(est, X, y, cv, use_exact_clones=True):\n",
" # use_exact_clones:\n",
" # - if True, the same estimator RNG is used on each fold (use-case C) \n",
" # - if False, the different estimator RNG are used on each fold (use-case D) \n",
" # TODO: maybe the default should be 'auto': False if estimato has a random_state, True otherwise\n",
" \n",
" if use_exact_clones:\n",
" statistical = False\n",
" else:\n",
" # need a local RNG so that clones have different random_state attributes\n",
" _check_statistical_clone_possible(est)\n",
" statistical = np.random.RandomState(est.random_state)\n",
" \n",
" return [ # this whole part is unchanged except for the call to clone()\n",
" clone(est, statistical=statistical)\n",
" .fit(X[train], y[train])\n",
" .score(X[test], y[test])\n",
" for train, test in cv.split(X, y)\n",
" ]\n",
"\n",
"X = y = np.arange(10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Contant estimator RNG across folds, different estimator RNG across executions\n",
"19\n",
"19\n",
"19\n",
"19\n",
"19\n"
]
}
],
"source": [
"print(\"Contant estimator RNG across folds, different estimator RNG across executions\")\n",
"_ = cross_val_score(Estimator(random_state=None), X, y, cv=KFold(), use_exact_clones=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Different estimator RNG across folds, different estimator RNG across executions\n",
"76\n",
"5\n",
"1\n",
"49\n",
"72\n"
]
}
],
"source": [
"print(\"Different estimator RNG across folds, different estimator RNG across executions\")\n",
"_ = cross_val_score(Estimator(random_state=None), X, y, cv=KFold(), use_exact_clones=False)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Constant estimator RNG across folds, constant estimator RNG across executions\n",
"44\n",
"44\n",
"44\n",
"44\n",
"44\n"
]
}
],
"source": [
"print(\"Constant estimator RNG across folds, constant estimator RNG across executions\")\n",
"_ = cross_val_score(Estimator(random_state=0), X, y, cv=KFold(), use_exact_clones=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Different estimator RNG across folds, constant estimator RNG across executions\n",
"63\n",
"82\n",
"89\n",
"93\n",
"34\n"
]
}
],
"source": [
"print(\"Different estimator RNG across folds, constant estimator RNG across executions\")\n",
"_ = cross_val_score(Estimator(random_state=0), X, y, cv=KFold(), use_exact_clones=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Changes to CV Splitters\n",
"\n",
"Similar changes as for estimators: a seed is drawn in `__init__`\n",
"\n",
"`split` is unchanged"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class TwoKFold:\n",
" \"\"\"Toy CV class that does shuffled 2-fold CV\"\"\"\n",
" def __init__(self, random_state=None):\n",
" self.random_state = _sample_seed(random_state)\n",
" \n",
" def split(self, X, y=None):\n",
" # Unchanged.\n",
" rng = check_random_state(self.random_state)\n",
" \n",
" indices = shuffle(np.arange(X.shape[0]), random_state=self.random_state)\n",
" mid = X.shape[0] // 2\n",
" \n",
" yield indices[:mid], indices[mid:]\n",
" yield indices[mid:], indices[:mid]\n",
" \n",
"X = np.arange(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Illustration of CV Splitters behavior"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Multiple calls to split yield the same splits:\n",
"[(array([2, 1, 0, 8, 9]), array([4, 5, 3, 7, 6])), (array([4, 5, 3, 7, 6]), array([2, 1, 0, 8, 9]))]\n",
"[(array([2, 1, 0, 8, 9]), array([4, 5, 3, 7, 6])), (array([4, 5, 3, 7, 6]), array([2, 1, 0, 8, 9]))]\n"
]
}
],
"source": [
"print(\"Multiple calls to split yield the same splits:\")\n",
"cv = TwoKFold(random_state=None)\n",
"print(list(cv.split(X)))\n",
"print(list(cv.split(X)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment