Skip to content

Instantly share code, notes, and snippets.

@nyk510
Created January 4, 2021 07:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nyk510/cb565d01a834b6d86f8d582d20efdffb to your computer and use it in GitHub Desktop.
Save nyk510/cb565d01a834b6d86f8d582d20efdffb to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: scikit-learn==0.24.0 in /home/penguin/.conda/lib/python3.7/site-packages (0.24.0)\n",
"Requirement already satisfied: joblib>=0.11 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (0.16.0)\n",
"Requirement already satisfied: numpy>=1.13.3 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (1.18.5)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (2.1.0)\n",
"Requirement already satisfied: scipy>=0.19.1 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (1.5.0)\n"
]
}
],
"source": [
"!pip install -U scikit-learn==0.24.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.model_selection import KFold\n",
"\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### BaseSearchCVと sample weight の問題点\n",
"\n",
"* [`BaseSearchCV`](https://github.com/scikit-learn/scikit-learn/blob/0.20.2/sklearn/model_selection/_search.py#L409) を継承した CV class において, 各 CV ごとの学習自体は `sample_weight` が適用される\n",
"* しかし validation set に対する score の計算では `sample_weight` が適用されない\n",
" * 具体的には https://github.com/scikit-learn/scikit-learn/blob/0.24.0/sklearn/model_selection/_validation.py#L620 このあたり\n",
" * 例えば pos / negative にそれぞれ 1 / 1000 の重みを与えると本来的の重み付きスコアは 0.999 になっていてほしい。\n",
" * だが実際には重みが適用されないので 0.5 のままになる。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"cv = np.array([0, 0, 1, 1])\n",
"X = np.ones(shape=(4, 1))\n",
"y = np.array([1, 0, 1, 0])\n",
"\n",
"fold = np.array([\n",
" [[0, 1], [2, 3]],\n",
" [[2, 3], [0, 1]],\n",
"])\n",
"\n",
"# negative に対して weight を 999 / 1000 で与える\n",
"sample_weight_for_zeros = [\n",
" 1, 999, 1, 999\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=array([[[0, 1],\n",
" [2, 3]],\n",
"\n",
" [[2, 3],\n",
" [0, 1]]]),\n",
" estimator=LogisticRegression(), param_grid={'random_state': [42]},\n",
" return_train_score=True, scoring='accuracy')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grid = GridSearchCV(\n",
" estimator=LogisticRegression(), \n",
" param_grid={ 'random_state': [42] }, \n",
" scoring='accuracy', \n",
" cv=fold, \n",
" return_train_score=True\n",
")\n",
"grid.fit(X, y, sample_weight=sample_weight_for_zeros)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 本当は 0.999 になってほしい\n",
"grid.best_score_"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0.99900002, 0.00099998],\n",
" [0.99900002, 0.00099998],\n",
" [0.99900002, 0.00099998],\n",
" [0.99900002, 0.00099998]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 予測値は negative が 0.999 が出力される (学習自体には sample_weight が適用されているため.)\n",
"grid.predict_proba(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### まとめ\n",
"\n",
"* 学習自体は sample_weight つきで期待通り実行されるが, 算出される score は sample_weight を無視して計算されている.\n",
"* GridSearchCV や RandomSaerchCV など, BaseSearchCV を継承したパラメータサーチでは score の意味で最も良いパラメータが選ばれる. \n",
"* したがって, sample_weight を考慮して最も良い parameter が知りたい場合でも, sample_weight を無視したスコアの意味でもっとも良いモデルが選ばれるため、問題."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment