Skip to content

Instantly share code, notes, and snippets.

@hirokidaichi
Created August 20, 2021 09:16
Show Gist options
  • Save hirokidaichi/3013ceffeb7f43bbb0b4b85b96b41af4 to your computer and use it in GitHub Desktop.
Save hirokidaichi/3013ceffeb7f43bbb0b4b85b96b41af4 to your computer and use it in GitHub Desktop.
IRT少ない回答数から、DXCの総合点数を推論する。
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "IRT少ない回答数から、DXCの総合点数を推論する。",
"provenance": [],
"mount_file_id": "1vkuam8m25efJHhWgyv8dgHgL6FraduIx",
"authorship_tag": "ABX9TyPEico2Dtu2nY6fwq4R2L70",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/hirokidaichi/3013ceffeb7f43bbb0b4b85b96b41af4/irt-dxc.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kksD11KqWE1t",
"outputId": "17eed0db-3815-4e1e-93ff-d3bdad9fa788"
},
"source": [
"!pip install numpyro"
],
"execution_count": 172,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: numpyro in /usr/local/lib/python3.7/dist-packages (0.7.2)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro) (4.62.0)\n",
"Requirement already satisfied: jax>=0.2.13 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.2.19)\n",
"Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro) (0.1.70+cuda110)\n",
"Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (0.12.0)\n",
"Requirement already satisfied: numpy>=1.18 in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (1.19.5)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.2.13->numpyro) (3.3.0)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.4.1)\n",
"Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro) (1.12)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.2.13->numpyro) (1.15.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xhWfo56bWGsv"
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import numpyro\n",
"import jax\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from numpyro import distributions as dist\n",
"from numpyro.infer import NUTS,MCMC\n",
"\n"
],
"execution_count": 173,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZHe2P1PIaLP3"
},
"source": [
"import jax.numpy as jnp\n",
"def L2P(a, b, x):\n",
" return 1 / (1 + jnp.exp(- a * (x - b)))"
],
"execution_count": 174,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "VRY-FWLexltz"
},
"source": [
"可微分なL2P関数をjaxで定義する。\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8e_iwKqgjNAB"
},
"source": [
"item_score = pd.read_csv(\"/content/drive/MyDrive/01. DXクライテリアWG/IRT分析/item.csv\")\n",
"item_score = item_score.rename(columns={\"Unnamed: 0\":\"name\"}).set_index(\"name\")\n",
"item_score_dict = item_score.to_dict(orient=\"index\")"
],
"execution_count": 175,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nNnR4HYUx88u"
},
"source": [
"pyirtで生成したスコアを元に、データを読み込む。各クライテリアの項目反応曲線のパラメータを取得する。\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "W3jLtCX3k6hm",
"outputId": "7bd38772-8c7b-4eee-9118-308e09dab089"
},
"source": [
"sample_data = [\n",
" [\"team-1-1\" , 1 ] ,\n",
" [\"team-2-1\" , 1 ] ,\n",
" [\"team-3-1\" , 1 ] ,\n",
" [\"team-4-1\" , 1 ] ,\n",
" [\"team-5-5\" , 1 ] ,\n",
" [\"team-3-1\" , 1 ] ,\n",
" [\"team-2-3\" , 0.5 ] ,\n",
" [\"team-8-2\" , 1] ,\n",
" [\"system-1-1\" , 1 ] ,\n",
" [\"system-2-1\" , 1 ] ,\n",
" [\"system-3-1\" , 0 ] ,\n",
" [\"system-4-1\" , 0.5 ] ,\n",
" [\"system-5-5\" , 1 ] ,\n",
" [\"system-3-1\" , 1 ] ,\n",
" [\"system-2-3\" , 0.5 ] ,\n",
" [\"system-8-2\" , 1] ,\n",
"]\n",
"\n",
"(name,score) = sample_data[1]\n",
"name"
],
"execution_count": 176,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'team-2-1'"
]
},
"metadata": {},
"execution_count": 176
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wej_NioVk8FM",
"outputId": "255bb13b-ca04-4a92-89c1-0f142c4e00a2"
},
"source": [
"\n",
"def skill_model(test_data):\n",
" # -4 ~ 4の間に能力値がplotされる。\n",
" skill = numpyro.sample(\"skill\",dist.Uniform(-4,4))\n",
"\n",
" # すべてのサンプリングに対して、回答ansが確率 p= L2P(alpha,beta,skill )のbernoulli(p)分布に従う。\n",
" for idx,value in enumerate(test_data):\n",
" name,answer = value\n",
" item_score = item_score_dict[name]\n",
" a = item_score[\"alpha\"]\n",
" b = item_score[\"beta\"]\n",
" r = numpyro.sample(f\"r_{idx}\",\n",
" dist.Bernoulli( L2P(a,b,skill) ),\n",
" obs=answer )\n",
" return r\n",
"\n",
"\n",
"from numpyro import handlers\n",
"\n",
"with handlers.seed(rng_seed=0): # random.PRNGKey(0) is used\n",
" print(skill_model(sample_data))\n",
" \n",
"## mcmcサンプリング\n",
"kernel = pyro.infer.NUTS(skill_model)\n",
"mcmc = pyro.infer.MCMC(kernel,num_warmup=1000,num_samples=1000)\n",
"\n"
],
"execution_count": 177,
"outputs": [
{
"output_type": "stream",
"text": [
"1\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Mqir06gXooYv",
"outputId": "b2005c9b-2c97-407d-ca48-c052d4ba82b6"
},
"source": [
"mcmc.run(jax.random.PRNGKey(1),sample_data)"
],
"execution_count": 178,
"outputs": [
{
"output_type": "stream",
"text": [
"sample: 100%|██████████| 2000/2000 [00:04<00:00, 423.12it/s, 3 steps of size 5.62e-01. acc. prob=0.87]\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BNzGhSz5o0Rz",
"outputId": "f25241a3-e6bf-437c-95f0-7ad1f0d5ffb6"
},
"source": [
"mcmc.print_summary()"
],
"execution_count": 179,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
" mean std median 5.0% 95.0% n_eff r_hat\n",
" skill 1.76 0.69 1.75 0.63 2.86 238.20 1.00\n",
"\n",
"Number of divergences: 0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 282
},
"id": "1lBacH1OumDJ",
"outputId": "15b17e6c-d4fb-4daf-a60f-8391ee4c9da4"
},
"source": [
"import seaborn as sns\n",
"sns.histplot(mcmc.get_samples()[\"skill\"],kde=True)"
],
"execution_count": 180,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x7fd90b2f1fd0>"
]
},
"metadata": {},
"execution_count": 180
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cRB-uOAtvv-7"
},
"source": [
"skill_sampling = mcmc.get_samples()[\"skill\"]"
],
"execution_count": 181,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "M9GkFG53zDXK"
},
"source": [
""
],
"execution_count": 181,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment