Skip to content

Instantly share code, notes, and snippets.

@hirokidaichi
Created August 24, 2021 10:26
Show Gist options
  • Save hirokidaichi/fef380875fb6d28d8b567a4f526c113b to your computer and use it in GitHub Desktop.
Save hirokidaichi/fef380875fb6d28d8b567a4f526c113b to your computer and use it in GitHub Desktop.
04. 各テーマから6問ずつの30問で、偏差値を算出する。
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "04. 各テーマから6問ずつの30問で、偏差値を算出する。",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/hirokidaichi/fef380875fb6d28d8b567a4f526c113b/04-6-30.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kksD11KqWE1t",
"outputId": "9492ede3-18be-4ac6-b36e-a64f1a73a7cd"
},
"source": [
"!pip install numpyro"
],
"execution_count": 256,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: numpyro in /usr/local/lib/python3.7/dist-packages (0.7.2)\n",
"Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.2.19)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.62.0)\n",
"Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.1.70+cuda110)\n",
"Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.19.5)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (3.3.0)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (0.12.0)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.4.1)\n",
"Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.12)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->numpyro) (1.15.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xhWfo56bWGsv",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "847c0d1a-b780-4666-eb00-d445f45f8d8f"
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import numpyro\n",
"import arviz as az\n",
"import jax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from numpyro import distributions as dist\n",
"from numpyro.infer import NUTS,MCMC\n",
"plt.style.use('seaborn-darkgrid')\n",
"\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n"
],
"execution_count": 257,
"outputs": [
{
"output_type": "stream",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZHe2P1PIaLP3"
},
"source": [
"import jax.numpy as jnp\n",
"import jax\n",
"import time\n",
"\n",
"@jax.jit\n",
"def L2P(a, b, x):\n",
" return 1 / (1 + jnp.exp(- a * (x - b)))\n",
"\n"
],
"execution_count": 258,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "VRY-FWLexltz"
},
"source": [
"\n",
"\n",
"可微分なL2P関数をjaxで定義する。\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8e_iwKqgjNAB",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1ca798e6-f082-4f7a-8eb3-598b0f1d0153"
},
"source": [
"\n",
"item_score_all = pd.read_csv(f\"/content/drive/MyDrive/01. DXクライテリアWG/IRT分析/result-item-all.csv\",index_col=0)\n",
"item_score_all[\"name\"] = item_score_all.index\n",
"item_score_all[\"theme\"] = [x.split(\"-\")[0] for x in item_score_all[\"name\"].values]\n",
"item_score_all = item_score_all.sort_values(\"name\")\n",
"\n",
"def load_param(name) :\n",
" item_score = pd.read_csv(f\"/content/drive/MyDrive/01. DXクライテリアWG/IRT分析/result-item-{name}.csv\",index_col=0)\n",
" return item_score.to_dict(orient=\"index\")\n",
"\n",
"ITEM_PARAM = {\n",
" \"all\" : load_param(\"all\"),\n",
" \"team\" : load_param(\"team\"),\n",
" \"system\" : load_param(\"system\"),\n",
" \"data\" : load_param(\"data\"),\n",
" \"design\" : load_param(\"design\"),\n",
" \"corporate\" : load_param(\"corporate\"),\n",
"}\n",
"\n",
"\n",
"def gen_user_data(p,num=6) :\n",
" df = item_score_all\n",
" df = df[(df[\"alpha\"] > 0.2) & (df[\"beta\"]> -2)]\n",
" name = pd.concat([\n",
" df[df[\"theme\"] == \"team\"].sample(n=num),\n",
" df[df[\"theme\"] == \"system\"].sample(n=num),\n",
" df[df[\"theme\"] == \"data\"].sample(n=num),\n",
" df[df[\"theme\"] == \"design\"].sample(n=num),\n",
" df[df[\"theme\"] == \"corporate\"].sample(n=num),\n",
" ]).name.to_list()\n",
" ans = stats.bernoulli.rvs(p=p, size=len(name))\n",
" return [[name,ans] for name,ans in zip(name,ans) ]\n",
"\n",
"\n",
"\n",
"def filter_theme(theme,data):\n",
" return [v for v in data if v[0].split(\"-\")[0] == theme]\n",
"\n",
"gen_user_data(0.5)"
],
"execution_count": 264,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[['team-2-2', 1],\n",
" ['team-4-8', 1],\n",
" ['team-1-5', 0],\n",
" ['team-2-5', 1],\n",
" ['team-5-8', 1],\n",
" ['team-6-4', 1],\n",
" ['system-4-2', 1],\n",
" ['system-3-3', 1],\n",
" ['system-4-4', 0],\n",
" ['system-6-3', 1],\n",
" ['system-3-2', 0],\n",
" ['system-8-5', 1],\n",
" ['data-5-7', 1],\n",
" ['data-3-8', 0],\n",
" ['data-8-4', 0],\n",
" ['data-3-2', 1],\n",
" ['data-8-8', 0],\n",
" ['data-2-7', 0],\n",
" ['design-3-6', 0],\n",
" ['design-7-1', 0],\n",
" ['design-2-1', 0],\n",
" ['design-1-7', 1],\n",
" ['design-1-1', 1],\n",
" ['design-6-4', 0],\n",
" ['corporate-8-3', 1],\n",
" ['corporate-3-7', 1],\n",
" ['corporate-6-6', 1],\n",
" ['corporate-6-1', 0],\n",
" ['corporate-4-8', 1],\n",
" ['corporate-7-8', 0]]"
]
},
"metadata": {},
"execution_count": 264
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nNnR4HYUx88u"
},
"source": [
"pyirtで生成したスコアを元に、データを読み込む。各クライテリアの項目反応曲線のパラメータを取得する。\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "W3jLtCX3k6hm",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "974a31bd-4ead-4e68-9681-62d3c33f3c44"
},
"source": [
"sample_data = gen_user_data(0.5)\n",
"sample_data"
],
"execution_count": 265,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[['team-8-4', 0],\n",
" ['team-7-7', 1],\n",
" ['team-5-7', 1],\n",
" ['team-2-1', 0],\n",
" ['team-4-3', 1],\n",
" ['team-5-2', 0],\n",
" ['system-8-1', 1],\n",
" ['system-3-5', 0],\n",
" ['system-1-2', 1],\n",
" ['system-4-6', 1],\n",
" ['system-3-3', 0],\n",
" ['system-7-7', 0],\n",
" ['data-5-2', 0],\n",
" ['data-4-4', 0],\n",
" ['data-3-5', 1],\n",
" ['data-1-4', 0],\n",
" ['data-5-6', 1],\n",
" ['data-8-4', 0],\n",
" ['design-1-5', 0],\n",
" ['design-2-2', 0],\n",
" ['design-6-4', 0],\n",
" ['design-5-2', 1],\n",
" ['design-7-6', 0],\n",
" ['design-8-4', 0],\n",
" ['corporate-8-8', 1],\n",
" ['corporate-3-1', 1],\n",
" ['corporate-3-3', 0],\n",
" ['corporate-6-4', 0],\n",
" ['corporate-5-4', 1],\n",
" ['corporate-2-3', 0]]"
]
},
"metadata": {},
"execution_count": 265
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wej_NioVk8FM",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "99e84d45-a41b-486e-a7ed-eae780475839"
},
"source": [
"\n",
"\n",
"\n",
"def skill_model_by_theme(theme,data):\n",
" items = np.array([v[0] for v in data])\n",
" answers = np.array([v[1] for v in data])\n",
" alpha = np.array([ITEM_PARAM[theme][e][\"alpha\"] for e in items])\n",
" beta = np.array([ITEM_PARAM[theme][e][\"beta\"] for e in items])\n",
" skill = numpyro.sample(f\"{theme}_skill\",dist.Normal(0,1))\n",
"\n",
" with numpyro.plate(f\"plate_of_result_{theme}\",size= answers.size) as idx :\n",
" r = numpyro.sample(f\"r_{theme}\",\n",
" dist.Bernoulli( L2P(alpha,beta,skill) ),\n",
" obs=answers ) \n",
" return r \n",
"\n",
"\n",
"def skill_model(data):\n",
" r_all =skill_model_by_theme(\"all\",data)\n",
" for theme in \"team system data design corporate\".split(\" \"):\n",
" skill_model_by_theme(theme,filter_theme(theme,data))\n",
" return r_all\n",
"\n",
"\n",
"def hensachi(sample):\n",
" lower= np.percentile(sample,10)\n",
" median = np.median(sample)\n",
" higher= np.percentile(sample,90)\n",
" return np.ceil( np.array([lower,median,higher])*10+50)\n",
"\n",
"\n",
"\n",
"def estimate_rank(data):\n",
" kernel = numpyro.infer.NUTS(skill_model)\n",
" mcmc = numpyro.infer.MCMC(kernel,num_warmup=500,num_samples=1500)\n",
" mcmc.run(jax.random.PRNGKey(0),data)\n",
" sample = mcmc.get_samples()['all_skill']\n",
" for theme in \"all team system data design corporate\".split(\" \"):\n",
" ret = hensachi(mcmc.get_samples()[f\"{theme}_skill\"])\n",
" print(f\"{theme}: {ret[0]} ~ {ret[1]} ~ {ret[2]}\")\n",
"\n",
"\n",
"with numpyro.handlers.seed(rng_seed=0): \n",
" print(skill_model(sample_data))\n"
],
"execution_count": 266,
"outputs": [
{
"output_type": "stream",
"text": [
"[0 1 1 0 1 0 1 0 1 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 1 1 0 0 1 0]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gWdb3u28lPgn",
"outputId": "7fa35f33-2cef-4601-9afe-de542f2e5a8d"
},
"source": [
"estimate_rank(sample_data)"
],
"execution_count": 267,
"outputs": [
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:06<00:00, 310.94it/s, 7 steps of size 7.72e-01. acc. prob=0.89]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 44.0 ~ 49.0 ~ 53.0\n",
"team: 46.0 ~ 55.0 ~ 64.0\n",
"system: 52.0 ~ 61.0 ~ 69.0\n",
"data: 47.0 ~ 57.0 ~ 67.0\n",
"design: 32.0 ~ 42.0 ~ 51.0\n",
"corporate: 54.0 ~ 62.0 ~ 69.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BNzGhSz5o0Rz",
"outputId": "331c3c6e-b533-4f41-95b8-b7e9424cf03e"
},
"source": [
"\n",
"for p in range(6):\n",
" estimate_rank(gen_user_data(p/5))\n"
],
"execution_count": 268,
"outputs": [
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:05<00:00, 343.62it/s, 7 steps of size 6.61e-01. acc. prob=0.92]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 18.0 ~ 26.0 ~ 34.0\n",
"team: 27.0 ~ 36.0 ~ 44.0\n",
"system: 32.0 ~ 44.0 ~ 54.0\n",
"data: 30.0 ~ 40.0 ~ 49.0\n",
"design: 30.0 ~ 40.0 ~ 48.0\n",
"corporate: 29.0 ~ 39.0 ~ 48.0\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:06<00:00, 328.57it/s, 3 steps of size 7.11e-01. acc. prob=0.90]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 27.0 ~ 34.0 ~ 40.0\n",
"team: 28.0 ~ 37.0 ~ 46.0\n",
"system: 30.0 ~ 40.0 ~ 49.0\n",
"data: 34.0 ~ 44.0 ~ 54.0\n",
"design: 35.0 ~ 45.0 ~ 54.0\n",
"corporate: 29.0 ~ 39.0 ~ 48.0\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:06<00:00, 331.76it/s, 7 steps of size 7.52e-01. acc. prob=0.89]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 45.0 ~ 49.0 ~ 53.0\n",
"team: 32.0 ~ 41.0 ~ 50.0\n",
"system: 42.0 ~ 52.0 ~ 62.0\n",
"data: 36.0 ~ 48.0 ~ 60.0\n",
"design: 45.0 ~ 52.0 ~ 60.0\n",
"corporate: 54.0 ~ 63.0 ~ 72.0\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:06<00:00, 321.78it/s, 7 steps of size 7.70e-01. acc. prob=0.88]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 49.0 ~ 53.0 ~ 58.0\n",
"team: 40.0 ~ 47.0 ~ 55.0\n",
"system: 44.0 ~ 53.0 ~ 61.0\n",
"data: 47.0 ~ 56.0 ~ 65.0\n",
"design: 52.0 ~ 61.0 ~ 70.0\n",
"corporate: 47.0 ~ 55.0 ~ 63.0\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:05<00:00, 338.97it/s, 3 steps of size 8.62e-01. acc. prob=0.86]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 53.0 ~ 57.0 ~ 61.0\n",
"team: 56.0 ~ 64.0 ~ 72.0\n",
"system: 48.0 ~ 56.0 ~ 65.0\n",
"data: 40.0 ~ 48.0 ~ 56.0\n",
"design: 45.0 ~ 51.0 ~ 58.0\n",
"corporate: 47.0 ~ 53.0 ~ 60.0\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:06<00:00, 319.62it/s, 7 steps of size 7.35e-01. acc. prob=0.90]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"all: 71.0 ~ 77.0 ~ 85.0\n",
"team: 61.0 ~ 69.0 ~ 77.0\n",
"system: 56.0 ~ 65.0 ~ 75.0\n",
"data: 56.0 ~ 65.0 ~ 74.0\n",
"design: 55.0 ~ 63.0 ~ 73.0\n",
"corporate: 59.0 ~ 67.0 ~ 75.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T3qpYkTIou5L"
},
"source": [
""
],
"execution_count": 263,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment