Skip to content

Instantly share code, notes, and snippets.

@knuu
Last active January 8, 2023 14:21
Show Gist options
  • Save knuu/3b978a2d458df5d7910d7e314f9d74f3 to your computer and use it in GitHub Desktop.
Save knuu/3b978a2d458df5d7910d7e314f9d74f3 to your computer and use it in GitHub Desktop.
LightGBM でかんたん Learning to Rank
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "LightGBM でかんたん Learning to Rank",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyN7fLbHJdO8DI3KOUxnQ8y1",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/knuu/3b978a2d458df5d7910d7e314f9d74f3/lightgbm-learning-to-rank.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JPduAh6O7CBH",
"colab_type": "text"
},
"source": [
"# LightGBM でかんたん Learning to Rank\n",
"\n",
"LightGBM には Learning to Rank 用の手法である LambdaRank とサンプルデータが実装されている.ここではそれを用いて実際に Learning to Rank をやってみる.\n",
"\n",
"ここでは以下のことを順に行う.\n",
"\n",
"- データの取得と読み込み\n",
"- LambdaRank の学習\n",
"- 評価値の計算 (NDCG@10)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hutAFPKGzikb",
"colab_type": "code",
"colab": {}
},
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"from sklearn.datasets import load_svmlight_file\n",
"from sklearn.metrics import ndcg_score\n",
"import pandas as pd"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "OblvUvxI68pn",
"colab_type": "text"
},
"source": [
"## データの取得と読み込み\n",
"\n",
"LightGBM の公式のレポジトリにサンプルが用意してあるのでまずはレポジトリを clone する."
]
},
{
"cell_type": "code",
"metadata": {
"id": "X-cJ9cQz3OVe",
"colab_type": "code",
"outputId": "07df9390-d32e-4572-d2ff-c8db4a84ff1c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
}
},
"source": [
"!git clone https://github.com/microsoft/LightGBM.git"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'LightGBM'...\n",
"remote: Enumerating objects: 13, done.\u001b[K\n",
"remote: Counting objects: 7% (1/13)\u001b[K\rremote: Counting objects: 15% (2/13)\u001b[K\rremote: Counting objects: 23% (3/13)\u001b[K\rremote: Counting objects: 30% (4/13)\u001b[K\rremote: Counting objects: 38% (5/13)\u001b[K\rremote: Counting objects: 46% (6/13)\u001b[K\rremote: Counting objects: 53% (7/13)\u001b[K\rremote: Counting objects: 61% (8/13)\u001b[K\rremote: Counting objects: 69% (9/13)\u001b[K\rremote: Counting objects: 76% (10/13)\u001b[K\rremote: Counting objects: 84% (11/13)\u001b[K\rremote: Counting objects: 92% (12/13)\u001b[K\rremote: Counting objects: 100% (13/13)\u001b[K\rremote: Counting objects: 100% (13/13), done.\u001b[K\n",
"remote: Compressing objects: 100% (11/11), done.\u001b[K\n",
"remote: Total 17384 (delta 1), reused 3 (delta 1), pack-reused 17371\u001b[K\n",
"Receiving objects: 100% (17384/17384), 11.84 MiB | 25.35 MiB/s, done.\n",
"Resolving deltas: 100% (12660/12660), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4JngjFXD5CBr",
"colab_type": "code",
"outputId": "710125c8-7adc-4346-9976-645e7b43cdd9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"!ls LightGBM/examples/lambdarank"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"predict.conf rank.test.query rank.train.query train.conf\n",
"rank.test rank.train README.md\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1paNKTxADIKF",
"colab_type": "text"
},
"source": [
"rank.train, test.train にはそれぞれ訓練データとテストデータが入っていて,形式は svmlight の形式である.つまり,`<適合度> <特徴量の番号>:<値> <特徴量の番号>:<値> ...` という形式で文書ごとのデータの特徴量が入っている."
]
},
{
"cell_type": "code",
"metadata": {
"id": "3PGmspfnfnkL",
"colab_type": "code",
"outputId": "05403923-1cb2-4387-9b40-c2195fb9a229",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
}
},
"source": [
"!head -n5 LightGBM/examples/lambdarank/rank.train"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"0 10:0.89 11:0.75 12:0.01 17:0.45 18:0.91 21:0.78 27:0.72 29:0.77 30:0.76 39:0.65 43:0.79 44:0.88 45:0.88 66:0.64 69:0.25 70:0.41 71:0.83 74:0.79 77:0.70 83:0.56 85:0.72 86:0.93 91:0.35 98:0.70 101:0.96 108:0.20 122:0.86 123:0.19 124:0.47 127:0.43 129:0.05 133:0.45 139:0.92 145:0.47 146:0.58 147:0.63 149:0.84 154:0.73 155:0.50 159:0.31 170:0.90 172:0.67 173:0.40 174:0.79 177:0.88 178:0.55 179:0.52 187:0.21 192:0.48 195:0.59 197:0.83 204:0.80 208:0.50 212:0.82 216:0.29 222:0.48 235:0.22 239:0.88 241:0.21 242:0.40 243:0.69 245:0.77 247:0.76 253:0.55 265:0.70 266:0.27 271:0.34 276:0.58 281:0.91 282:0.75 300:0.43\n",
"1 1:0.69 11:0.64 12:0.51 17:0.53 18:0.75 21:0.77 27:0.45 29:0.71 30:0.45 34:0.81 36:0.21 37:0.50 39:0.48 43:0.10 45:0.66 55:0.52 64:0.77 66:0.52 69:0.71 70:0.33 71:0.97 74:0.50 81:0.37 83:0.60 86:0.80 91:0.08 98:0.24 99:0.83 101:0.52 108:0.54 114:0.53 122:0.80 123:0.52 124:0.50 126:0.54 127:0.17 129:0.05 133:0.43 135:0.56 139:0.77 145:0.49 146:0.31 147:0.38 149:0.08 150:0.58 154:0.72 155:0.58 159:0.25 165:0.70 172:0.27 173:0.72 176:0.73 177:0.70 178:0.55 179:0.61 187:0.68 192:0.59 201:0.74 208:0.64 212:0.52 215:0.36 216:0.77 222:0.48 235:0.99 241:0.21 242:0.55 243:0.61 245:0.48 247:0.39 253:0.36 254:0.90 259:0.21 265:0.26 266:0.27 267:0.69 271:0.60 276:0.25 290:0.53 297:0.36 300:0.25\n",
"0 1:0.69 11:0.64 12:0.51 17:0.34 18:0.85 21:0.13 27:0.18 29:0.71 30:0.17 34:0.47 36:0.30 37:0.85 39:0.48 43:0.65 45:0.77 66:0.20 69:0.10 70:0.68 71:0.97 74:0.59 81:0.62 83:0.60 86:0.85 91:0.31 97:0.89 98:0.50 99:0.83 101:0.52 104:0.84 108:0.09 114:0.75 122:0.82 123:0.70 124:0.67 126:0.44 127:0.52 129:0.59 133:0.61 135:0.34 139:0.81 146:0.59 147:0.65 149:0.46 150:0.61 154:0.83 155:0.58 158:0.79 159:0.56 165:0.70 172:0.41 173:0.15 176:0.66 177:0.70 178:0.55 179:0.73 187:0.21 192:0.59 201:0.68 208:0.54 212:0.19 215:0.61 216:0.82 222:0.48 235:0.22 241:0.50 242:0.33 243:0.57 245:0.60 247:0.67 253:0.54 254:0.90 259:0.21 265:0.48 266:0.71 267:0.19 271:0.60 276:0.47 279:0.82 283:0.82 290:0.75 297:0.36 300:0.43\n",
"1 11:0.64 12:0.51 17:0.11 18:0.82 21:0.78 27:0.45 29:0.71 30:0.45 34:0.82 36:0.27 37:0.50 39:0.48 43:0.12 45:0.66 55:0.59 66:0.63 69:0.72 70:0.61 71:0.97 74:0.41 86:0.83 91:0.17 98:0.47 101:0.52 108:0.56 122:0.51 123:0.53 124:0.45 127:0.29 129:0.05 133:0.37 135:0.39 139:0.79 145:0.50 146:0.57 147:0.63 149:0.11 154:0.66 159:0.43 172:0.41 173:0.71 177:0.70 178:0.55 179:0.74 187:0.68 202:0.36 212:0.50 216:0.77 222:0.48 235:0.68 241:0.32 242:0.14 243:0.68 245:0.54 247:0.67 253:0.32 254:0.96 259:0.21 265:0.49 266:0.47 267:0.44 271:0.60 276:0.32 300:0.43\n",
"0 1:0.69 7:0.72 11:0.64 12:0.51 17:0.76 18:0.61 21:0.47 27:0.72 29:0.71 30:0.74 32:0.69 34:0.55 36:0.78 39:0.48 43:0.72 45:0.87 66:0.43 69:0.81 70:0.37 71:0.97 74:0.75 77:0.98 81:0.37 83:0.60 86:0.71 91:0.72 98:0.63 99:0.83 101:0.52 104:0.58 108:0.66 114:0.59 122:0.75 123:0.63 124:0.59 126:0.44 127:0.86 129:0.05 133:0.43 135:0.32 139:0.68 145:1.00 146:0.53 147:0.48 149:0.75 150:0.54 154:0.58 155:0.58 158:0.74 159:0.37 165:0.70 172:0.45 173:0.94 176:0.66 177:0.38 178:0.55 179:0.74 192:0.59 195:0.59 201:0.68 204:0.85 208:0.54 212:0.37 215:0.41 216:0.03 222:0.48 230:0.74 232:0.78 233:0.73 235:0.60 241:0.79 242:0.51 243:0.47 245:0.72 247:0.60 253:0.46 254:0.95 257:0.65 259:0.21 265:0.64 266:0.63 267:0.36 271:0.60 276:0.48 282:0.77 290:0.59 297:0.36 300:0.33\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "smeJ8EBKftRy",
"colab_type": "text"
},
"source": [
"svmlight の形式は sklearn.datasets の load_svmlight_file で容易に読み込める."
]
},
{
"cell_type": "code",
"metadata": {
"id": "zs7ToOaq5Zic",
"colab_type": "code",
"outputId": "a880e28b-ffc5-4866-dd6f-068e4f36d5c1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# データの読み込み\n",
"data_dir_path = \"LightGBM/examples/lambdarank/\"\n",
"X_train_all, y_train_all = load_svmlight_file(data_dir_path + \"rank.train\")\n",
"X_test, y_test = load_svmlight_file(data_dir_path + \"rank.test\")\n",
"X_train_all.shape, y_train_all.shape, X_test.shape, y_test.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((3005, 300), (3005,), (768, 300), (768,))"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "um6a5Z1FEqyg",
"colab_type": "code",
"outputId": "50ae0646-7abc-4aa3-c289-f7a056bc1218",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"# 適合度の分布\n",
"# このデータでは 0~4\n",
"pd.Series(np.concatenate([y_train_all, y_test])).value_counts(sort=False).sort_index()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0 851\n",
"1.0 1467\n",
"2.0 1110\n",
"3.0 266\n",
"4.0 79\n",
"dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8tbzwCH-DYBh",
"colab_type": "text"
},
"source": [
"対応するクエリの情報は rank.train.query と rank.test.query に入っている.rank.train.query の先頭 5 行を見てみると,以下のようになっている."
]
},
{
"cell_type": "code",
"metadata": {
"id": "mbtg3KXqf2_L",
"colab_type": "code",
"outputId": "79b80dcc-4708-4a1f-fae6-b94a28442633",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
}
},
"source": [
"!head -n5 LightGBM/examples/lambdarank/rank.train.query"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"1\n",
"13\n",
"5\n",
"8\n",
"19\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "svDOkqjMf1rg",
"colab_type": "text"
},
"source": [
"これのそれぞれの行は,rank.train における同じクエリのデータ数(=データの行数)を上から順に表している.例えば,1 行目の 1 は rank.train の先頭から 1 行分があるクエリ $q_1$ に対応するデータであり,2 行目の 13 は rank.train の次の 13 行分,つまり 2 行目から 14 行目は次のクエリ $q_2$ に対応するデータであることを表している.\n",
"\n",
"参考: https://lightgbm.readthedocs.io/en/latest/Parameters.html#query-data\n",
"\n",
"これらはデータの行数を表しているので,rank.train.query, rank.test.query にかかれている数を合計すると,rank.train, rank.test の行数と一致する."
]
},
{
"cell_type": "code",
"metadata": {
"id": "zvrUg90oDhtd",
"colab_type": "code",
"outputId": "93caa806-ec02-4b72-8e69-15d11ebb0b62",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"q_train_all = np.loadtxt(data_dir_path + \"rank.train.query\")\n",
"q_test = np.loadtxt(data_dir_path + \"rank.test.query\")\n",
"q_train_all.shape, q_test.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((201,), (50,))"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bKPGbxh-Balm",
"colab_type": "code",
"outputId": "eecd8d45-2071-4b97-e30c-e3778a3d4310",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# q_trian_all, q_test の和は X_train, X_test の行数と一致する\n",
"q_train_all.sum(), q_test.sum()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(3005.0, 768.0)"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-0Mis6QvcMZk",
"colab_type": "text"
},
"source": [
"バリデーション用に訓練データ全体を train と valid に分離する(補足: これをする理由は,lightgbm の early_stopping_rounds を使いたいから).query データがある都合上ランダム分割はできないので,先頭から train:valid=3:1 くらいになるように分割する"
]
},
{
"cell_type": "code",
"metadata": {
"id": "R-0o-OgYI1Do",
"colab_type": "code",
"outputId": "3aa111a5-9409-48e5-eed9-4e3e9bd18487",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# q_train の累積和をとって先頭から 75% となる位置を見つける\n",
"q_train_cumsum = q_train_all.cumsum()\n",
"q_idx = int(np.searchsorted(q_train_cumsum, q_train_all.sum() * 0.75))\n",
"X_idx = int(q_train_cumsum[q_idx])\n",
"# 見つけた位置を使って分割\n",
"X_train, X_valid = X_train_all[:X_idx], X_train_all[X_idx:]\n",
"y_train, y_valid = y_train_all[:X_idx], y_train_all[X_idx:]\n",
"q_train, q_valid = q_train_all[:q_idx+1], q_train_all[q_idx+1:]\n",
"X_train.shape, X_valid.shape, y_train.shape, y_valid.shape, q_train.sum(), q_valid.sum()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((2258, 300), (747, 300), (2258,), (747,), 2258.0, 747.0)"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gAiVy2BlLhKb",
"colab_type": "code",
"outputId": "3f0a477a-df5a-4497-8310-e884a0f246a4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# だいたい train:valid=3:1 になっている\n",
"q_train.sum() / q_train_cumsum[-1], q_valid.sum() / q_train_cumsum[-1]"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.751414309484193, 0.24858569051580698)"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AN1xBy3ZdLeN",
"colab_type": "text"
},
"source": [
"最後に LightGBM 用にデータセットを作成する."
]
},
{
"cell_type": "code",
"metadata": {
"id": "IbwTpn_rz_q0",
"colab_type": "code",
"colab": {}
},
"source": [
"train = lgb.Dataset(X_train, y_train, group=q_train)\n",
"valid = lgb.Dataset(X_valid, y_valid, reference=train, group=q_valid)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uBQm0-oHFDFt",
"colab_type": "text"
},
"source": [
"## LambdaRank の学習\n",
"\n",
"読み込んだデータを用いて LightGBM に実装されている LambdaRank を使ってランキング予測モデルを学習する.これは lgb.train の params で objective に lambdarank を指定し,metric に ndcg を指定,ndcg_eval_at で先頭からいくつ分を評価に加えるかを指定するだけでよい.\n",
"\n",
"さらにテストデータを予測し,結果評価用のテーブルを作り,クエリごとの NDCG@10 の平均値で評価する."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lfgbS8k87sQU",
"colab_type": "code",
"colab": {}
},
"source": [
"params = {\n",
" 'objective': 'lambdarank',\n",
" 'metric': 'ndcg',\n",
" 'lambdarank_truncation_level': 10,\n",
" 'ndcg_eval_at': [10, 5, 20],\n",
" 'n_estimators': 10000,\n",
" 'boosting_type': 'gbdt',\n",
" 'random_state': 0,\n",
"}"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oBl4EIUp7stb",
"colab_type": "code",
"outputId": "1e4a2124-d155-46bd-edc0-101d14d0051b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 340
}
},
"source": [
"model = lgb.train(\n",
" params, train, valid_sets=valid, \n",
" early_stopping_rounds=50,\n",
" verbose_eval=5 # 10 round毎に metric を表示\n",
")"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/lightgbm/engine.py:118: UserWarning: Found `n_estimators` in params. Will use it instead of argument\n",
" warnings.warn(\"Found `{}` in params. Will use it instead of argument\".format(alias))\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Training until validation scores don't improve for 50 rounds.\n",
"[5]\tvalid_0's ndcg@10: 0.76134\tvalid_0's ndcg@5: 0.76134\tvalid_0's ndcg@20: 1.00234\n",
"[10]\tvalid_0's ndcg@10: 0.772455\tvalid_0's ndcg@5: 0.772455\tvalid_0's ndcg@20: 1.00358\n",
"[15]\tvalid_0's ndcg@10: 0.762085\tvalid_0's ndcg@5: 0.762085\tvalid_0's ndcg@20: 1.00114\n",
"[20]\tvalid_0's ndcg@10: 0.778217\tvalid_0's ndcg@5: 0.778217\tvalid_0's ndcg@20: 1.02469\n",
"[25]\tvalid_0's ndcg@10: 0.7761\tvalid_0's ndcg@5: 0.7761\tvalid_0's ndcg@20: 1.01994\n",
"[30]\tvalid_0's ndcg@10: 0.775275\tvalid_0's ndcg@5: 0.775275\tvalid_0's ndcg@20: 1.01337\n",
"[35]\tvalid_0's ndcg@10: 0.772209\tvalid_0's ndcg@5: 0.772209\tvalid_0's ndcg@20: 1.02\n",
"[40]\tvalid_0's ndcg@10: 0.781399\tvalid_0's ndcg@5: 0.781399\tvalid_0's ndcg@20: 1.02128\n",
"[45]\tvalid_0's ndcg@10: 0.776702\tvalid_0's ndcg@5: 0.776702\tvalid_0's ndcg@20: 1.01964\n",
"[50]\tvalid_0's ndcg@10: 0.77714\tvalid_0's ndcg@5: 0.77714\tvalid_0's ndcg@20: 1.01888\n",
"[55]\tvalid_0's ndcg@10: 0.780429\tvalid_0's ndcg@5: 0.780429\tvalid_0's ndcg@20: 1.02396\n",
"[60]\tvalid_0's ndcg@10: 0.773283\tvalid_0's ndcg@5: 0.773283\tvalid_0's ndcg@20: 1.01724\n",
"[65]\tvalid_0's ndcg@10: 0.778825\tvalid_0's ndcg@5: 0.778825\tvalid_0's ndcg@20: 1.02615\n",
"[70]\tvalid_0's ndcg@10: 0.775108\tvalid_0's ndcg@5: 0.775108\tvalid_0's ndcg@20: 1.02012\n",
"Early stopping, best iteration is:\n",
"[22]\tvalid_0's ndcg@10: 0.779497\tvalid_0's ndcg@5: 0.779497\tvalid_0's ndcg@20: 1.02672\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ldFaReReKzU",
"colab_type": "text"
},
"source": [
"学習したモデルを使って予測値を求める."
]
},
{
"cell_type": "code",
"metadata": {
"id": "HUCiR52v2IRY",
"colab_type": "code",
"outputId": "911cd33e-b60f-45fe-df7e-01a7fde539c8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"pred = model.predict(X_test, num_iteration=model.best_iteration)\n",
"pred.shape, y_test.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((768,), (768,))"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qghV65_mdXny",
"colab_type": "text"
},
"source": [
"## 評価値の計算 (NDCG@10)\n",
"\n",
"クエリごとに NDCG@10 を計算し,その平均値を評価値とする."
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Jro_XL2Vw4X",
"colab_type": "code",
"outputId": "71b4e8cf-90f9-47a9-a76f-3a7256cacbda",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 514
}
},
"source": [
"# 予測値にクエリ ID とランキングと正解を付与する\n",
"pred_df = pd.DataFrame({\n",
" \"query_id\": np.repeat(np.arange(q_test.shape[0]), q_test.astype(np.int)),\n",
" \"pred\": pred,\n",
" \"true\": y_test,\n",
"})\n",
"pred_df.head(15)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>query_id</th>\n",
" <th>pred</th>\n",
" <th>true</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>-0.047711</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>0.030261</td>\n",
" <td>3.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>-0.233895</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>-0.025301</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>-0.093171</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0</td>\n",
" <td>-0.088286</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0</td>\n",
" <td>-0.154476</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0</td>\n",
" <td>0.046090</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0</td>\n",
" <td>-0.173321</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>0</td>\n",
" <td>-0.419258</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>0</td>\n",
" <td>-0.237129</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>0</td>\n",
" <td>-0.422494</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>1</td>\n",
" <td>-0.393775</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>1</td>\n",
" <td>-0.522380</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>1</td>\n",
" <td>-0.176059</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" query_id pred true\n",
"0 0 -0.047711 2.0\n",
"1 0 0.030261 3.0\n",
"2 0 -0.233895 2.0\n",
"3 0 -0.025301 0.0\n",
"4 0 -0.093171 2.0\n",
"5 0 -0.088286 1.0\n",
"6 0 -0.154476 2.0\n",
"7 0 0.046090 0.0\n",
"8 0 -0.173321 2.0\n",
"9 0 -0.419258 1.0\n",
"10 0 -0.237129 2.0\n",
"11 0 -0.422494 1.0\n",
"12 1 -0.393775 0.0\n",
"13 1 -0.522380 1.0\n",
"14 1 -0.176059 1.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "e1fuUng5WPeT",
"colab_type": "code",
"outputId": "5cd3d5a6-9b65-4b23-b0b2-876f444fc096",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"# クエリ ID ごとに NDCG@10 を計算し,その平均値を算出\n",
"pred_df.groupby(\"query_id\").apply(\n",
" lambda d: ndcg_score([d[\"true\"]], [d[\"pred\"]], k=10)).mean()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.7723474171927661"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment