Skip to content

Instantly share code, notes, and snippets.

@NicolasHug
Last active November 27, 2019 12:48
Show Gist options
  • Save NicolasHug/1169ee253a4669ff993c947507ae2cb5 to your computer and use it in GitHub Desktop.
Save NicolasHug/1169ee253a4669ff993c947507ae2cb5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.utils.validation import check_random_state\n",
"from sklearn.base import clone"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class Estimator():\n",
" def __init__(self, random_state=None):\n",
" self.random_state = _get_state(random_state)\n",
" \n",
" def fit(self):\n",
" rng = np.random.RandomState()\n",
" rng.set_state(self.random_state)\n",
" \n",
" print(rng.randint(0, 100))\n",
" \n",
" def get_params(self):\n",
" # return a RandomState instance whose state is\n",
" # self.random_state\n",
" random_state = np.random.RandomState()\n",
" random_state.set_state(self.random_state)\n",
" return {'random_state': random_state}\n",
" \n",
" def set_params(self, random_state=None):\n",
" # same as in __init__\n",
" self.random_state = _get_state(random_state)\n",
" \n",
"def clone(est): # just like base.clone\n",
" params = est.get_params()\n",
" return Estimator(**params)\n",
"\n",
"def _get_state(random_state):\n",
" rng = check_random_state(random_state)\n",
" \n",
" if random_state is None:\n",
" # rng is numpy's singleton. We don't want to use its state directly,\n",
" # because the singleton may not always be consumed between\n",
" # estimator instanciations.\n",
" BIG_INT = 1000000\n",
" rng = np.random.RandomState(rng.randint(0, BIG_INT))\n",
" \n",
" return rng.get_state()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27\n",
"27\n",
"8\n",
"8\n"
]
}
],
"source": [
"# Multiple calls to fit/split on the same instances use the same rng\n",
"a = Estimator(random_state=None)\n",
"b = Estimator(random_state=None)\n",
"\n",
"a.fit()\n",
"a.fit()\n",
"\n",
"b.fit()\n",
"b.fit()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27\n",
"27\n",
"8\n",
"8\n"
]
}
],
"source": [
"# Clone has a natural behaviour\n",
"# c and d have the same rng as a and b, respectively\n",
"c = clone(a)\n",
"d = clone(b)\n",
"\n",
"c.fit()\n",
"c.fit()\n",
"\n",
"d.fit()\n",
"d.fit()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27\n",
"8\n"
]
}
],
"source": [
"# Does not change a's rng\n",
"a.set_params(random_state=a.get_params()['random_state'])\n",
"a.fit()\n",
"\n",
"# Now a and b have the same rng state\n",
"a.set_params(random_state=b.get_params()['random_state'])\n",
"a.fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment