Last active
November 27, 2019 12:48
-
-
Save NicolasHug/1169ee253a4669ff993c947507ae2cb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from sklearn.utils.validation import check_random_state\n", | |
"from sklearn.base import clone" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Estimator():\n", | |
" def __init__(self, random_state=None):\n", | |
" self.random_state = _get_state(random_state)\n", | |
" \n", | |
" def fit(self):\n", | |
" rng = np.random.RandomState()\n", | |
" rng.set_state(self.random_state)\n", | |
" \n", | |
" print(rng.randint(0, 100))\n", | |
" \n", | |
" def get_params(self):\n", | |
" # return a RandomState instance whose state is\n", | |
" # self.random_state\n", | |
" random_state = np.random.RandomState()\n", | |
" random_state.set_state(self.random_state)\n", | |
" return {'random_state': random_state}\n", | |
" \n", | |
" def set_params(self, random_state=None):\n", | |
" # same as in __init__\n", | |
" self.random_state = _get_state(random_state)\n", | |
" \n", | |
"def clone(est): # just like base.clone\n", | |
" params = est.get_params()\n", | |
" return Estimator(**params)\n", | |
"\n", | |
"def _get_state(random_state):\n", | |
" rng = check_random_state(random_state)\n", | |
" \n", | |
" if random_state is None:\n", | |
" # rng is numpy's singleton. We don't want to use its state directly,\n", | |
" # because the singleton may not always be consumed between\n", | |
" # estimator instanciations.\n", | |
" BIG_INT = 1000000\n", | |
" rng = np.random.RandomState(rng.randint(0, BIG_INT))\n", | |
" \n", | |
" return rng.get_state()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"27\n", | |
"27\n", | |
"8\n", | |
"8\n" | |
] | |
} | |
], | |
"source": [ | |
"# Multiple calls to fit/split on the same instances use the same rng\n", | |
"a = Estimator(random_state=None)\n", | |
"b = Estimator(random_state=None)\n", | |
"\n", | |
"a.fit()\n", | |
"a.fit()\n", | |
"\n", | |
"b.fit()\n", | |
"b.fit()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"27\n", | |
"27\n", | |
"8\n", | |
"8\n" | |
] | |
} | |
], | |
"source": [ | |
"# Clone has a natural behaviour\n", | |
"# c and d have the same rng as a and b, respectively\n", | |
"c = clone(a)\n", | |
"d = clone(b)\n", | |
"\n", | |
"c.fit()\n", | |
"c.fit()\n", | |
"\n", | |
"d.fit()\n", | |
"d.fit()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"27\n", | |
"8\n" | |
] | |
} | |
], | |
"source": [ | |
"# Does not change a's rng\n", | |
"a.set_params(random_state=a.get_params()['random_state'])\n", | |
"a.fit()\n", | |
"\n", | |
"# Now a and b have the same rng state\n", | |
"a.set_params(random_state=b.get_params()['random_state'])\n", | |
"a.fit()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment