Skip to content

Instantly share code, notes, and snippets.

@enakai00
Created December 1, 2021 13:07
Show Gist options
  • Save enakai00/f31aa0e553ca729359768c740487a7d4 to your computer and use it in GitHub Desktop.
Save enakai00/f31aa0e553ca729359768c740487a7d4 to your computer and use it in GitHub Desktop.
Logistic regression example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Logistic regression example.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/enakai00/f31aa0e553ca729359768c740487a7d4/logistic-regression-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJO3PPzqsq8d"
},
"source": [
"実行に必要なモジュールをインポートします。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gB5UUoAXIVmC"
},
"source": [
"import numpy as np\n",
"from numpy.random import multivariate_normal, permutation\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers, models\n",
"\n",
"np.random.seed(20190220)\n",
"tf.random.set_seed(20190220)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "yz2h7_8St1wi"
},
"source": [
"学習データを乱数で生成します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ASgzWK5AjWvn"
},
"source": [
"n0, mu0, variance0 = 20, [10, 11], 20\n",
"data0 = multivariate_normal(mu0, np.eye(2)*variance0 ,n0)\n",
"df0 = DataFrame(data0, columns=['x1', 'x2'])\n",
"df0['t'] = 0\n",
"\n",
"n1, mu1, variance1 = 15, [18, 20], 22\n",
"data1 = multivariate_normal(mu1, np.eye(2)*variance1, n1)\n",
"df1 = DataFrame(data1, columns=['x1', 'x2'])\n",
"df1['t'] = 1\n",
"\n",
"df = pd.concat([df0, df1], ignore_index=True)\n",
"train_set = df.reindex(permutation(df.index)).reset_index(drop=True)\n",
"train_x = train_set[['x1', 'x2']].values\n",
"train_t = train_set['t'].values"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "txozw2MaFclz"
},
"source": [
"生成したデータを表示します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Hp4EnlqvToYN",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "9fddf951-a299-4c04-f9da-af154f67c8c3"
},
"source": [
"train_set"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x1</th>\n",
" <th>x2</th>\n",
" <th>t</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>11.148678</td>\n",
" <td>12.178698</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>8.628574</td>\n",
" <td>16.936525</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6.751810</td>\n",
" <td>2.686665</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>14.613345</td>\n",
" <td>22.415744</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>-0.582185</td>\n",
" <td>9.712311</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>8.720424</td>\n",
" <td>20.263025</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>14.689335</td>\n",
" <td>11.718604</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>15.174583</td>\n",
" <td>18.703856</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>18.932923</td>\n",
" <td>20.026993</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10.199965</td>\n",
" <td>19.306527</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>8.047290</td>\n",
" <td>9.257321</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>7.973561</td>\n",
" <td>1.842595</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>9.367123</td>\n",
" <td>12.547001</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>13.379836</td>\n",
" <td>17.101564</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>10.050234</td>\n",
" <td>15.911195</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>21.531288</td>\n",
" <td>15.107301</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>9.636055</td>\n",
" <td>10.316380</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>5.415042</td>\n",
" <td>10.557410</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>15.189524</td>\n",
" <td>10.026291</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>8.816570</td>\n",
" <td>23.696075</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>17.037535</td>\n",
" <td>12.352113</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>11.378429</td>\n",
" <td>7.172675</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>23.073613</td>\n",
" <td>9.808894</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>14.247538</td>\n",
" <td>19.111286</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>16.194011</td>\n",
" <td>19.591581</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>21.138444</td>\n",
" <td>27.280888</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>17.428121</td>\n",
" <td>13.953785</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>16.275732</td>\n",
" <td>25.689385</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>13.915948</td>\n",
" <td>24.620724</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>11.609024</td>\n",
" <td>11.623596</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>13.652538</td>\n",
" <td>3.678494</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>3.424919</td>\n",
" <td>26.121595</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>19.663299</td>\n",
" <td>21.072943</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33</th>\n",
" <td>3.977605</td>\n",
" <td>7.202972</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34</th>\n",
" <td>3.360409</td>\n",
" <td>17.789434</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" x1 x2 t\n",
"0 11.148678 12.178698 0\n",
"1 8.628574 16.936525 0\n",
"2 6.751810 2.686665 0\n",
"3 14.613345 22.415744 1\n",
"4 -0.582185 9.712311 0\n",
"5 8.720424 20.263025 0\n",
"6 14.689335 11.718604 0\n",
"7 15.174583 18.703856 1\n",
"8 18.932923 20.026993 1\n",
"9 10.199965 19.306527 1\n",
"10 8.047290 9.257321 0\n",
"11 7.973561 1.842595 0\n",
"12 9.367123 12.547001 0\n",
"13 13.379836 17.101564 1\n",
"14 10.050234 15.911195 0\n",
"15 21.531288 15.107301 1\n",
"16 9.636055 10.316380 0\n",
"17 5.415042 10.557410 0\n",
"18 15.189524 10.026291 1\n",
"19 8.816570 23.696075 1\n",
"20 17.037535 12.352113 1\n",
"21 11.378429 7.172675 0\n",
"22 23.073613 9.808894 0\n",
"23 14.247538 19.111286 0\n",
"24 16.194011 19.591581 1\n",
"25 21.138444 27.280888 1\n",
"26 17.428121 13.953785 0\n",
"27 16.275732 25.689385 1\n",
"28 13.915948 24.620724 1\n",
"29 11.609024 11.623596 0\n",
"30 13.652538 3.678494 0\n",
"31 3.424919 26.121595 1\n",
"32 19.663299 21.072943 1\n",
"33 3.977605 7.202972 0\n",
"34 3.360409 17.789434 0"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 450
},
"id": "hD-HJuti3rqe",
"outputId": "3e9467e7-fa89-45a3-a32f-cba5cee4a6b6"
},
"source": [
"train_set0 = train_set[train_set['t']==0]\n",
"train_set1 = train_set[train_set['t']==1]\n",
"\n",
"fig = plt.figure(figsize=(7, 7))\n",
"subplot = fig.add_subplot(1, 1, 1)\n",
"subplot.set_ylim([0, 30])\n",
"subplot.set_xlim([0, 30])\n",
"subplot.scatter(train_set1.x1, train_set1.x2, marker='o')\n",
"subplot.scatter(train_set0.x1, train_set0.x2, marker='x')"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x7fcbec84c690>"
]
},
"metadata": {},
"execution_count": 14
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fmnjQdqAvQRw"
},
"source": [
"確率 $P(x, y)$ を計算するモデルを定義します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BakcuKxdQoSL",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d4604fd1-ab36-484d-e67a-3ee9be45532d"
},
"source": [
"model = models.Sequential()\n",
"model.add(layers.Dense(1, activation='sigmoid', input_shape=(2,),\n",
" name='logistic_regression'))\n",
"\n",
"model.summary()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" logistic_regression (Dense) (None, 1) 3 \n",
" \n",
"=================================================================\n",
"Total params: 3\n",
"Trainable params: 3\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fBltXsSRvZn0"
},
"source": [
"バイナリー・クロスエントロピーを最小化するように指定します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LlQCTsKKXkr5"
},
"source": [
"model.compile(loss='binary_crossentropy')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "6ZJDVflWv6bm"
},
"source": [
"モデルの学習処理を実行します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "R6aG8FEZSLdr",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "91a82051-72b4-4bd1-a649-a74f4fc14b45"
},
"source": [
"model.fit(train_x, train_t,\n",
" batch_size=len(train_x), epochs=5000, verbose=0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fcbeca53250>"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DrFqiUwcwSS4"
},
"source": [
"学習後に得られたパラメーターの値 $w_1,w_2,b$ を確認します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ffVp0em2Sn4U",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "54778ae0-aa2d-46b4-ed6a-5a0aece7f2a1"
},
"source": [
"model.get_weights()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[array([[0.10842396],\n",
" [0.20109653]], dtype=float32), array([-4.6493626], dtype=float32)]"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8vbl6mtdwi_z"
},
"source": [
"得られたパラメーターの値を用いて、境界線と確率 $P(x,y)$ を図示します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "EQCm_ZqJzV7T",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 450
},
"outputId": "7749ed45-3db5-49a3-e77c-d4e87ec14118"
},
"source": [
"[[w1], [w2]], [b] = model.get_weights()\n",
"\n",
"train_set0 = train_set[train_set['t']==0]\n",
"train_set1 = train_set[train_set['t']==1]\n",
"\n",
"fig = plt.figure(figsize=(7, 7))\n",
"subplot = fig.add_subplot(1, 1, 1)\n",
"subplot.set_ylim([0, 30])\n",
"subplot.set_xlim([0, 30])\n",
"subplot.scatter(train_set1.x1, train_set1.x2, marker='o')\n",
"subplot.scatter(train_set0.x1, train_set0.x2, marker='x')\n",
"\n",
"xs = np.linspace(0, 30, 10)\n",
"ys = - (w1*xs/w2 + b/w2)\n",
"subplot.plot(xs, ys)\n",
"\n",
"field = [[(1 / (1 + np.exp(-(w1*x1 + w2*x2 + b))))\n",
" for x1 in np.linspace(0, 30, 100)]\n",
" for x2 in np.linspace(0, 30, 100)]\n",
"subplot.imshow(field, origin='lower', extent=(0, 30, 0, 30),\n",
" vmin=0, vmax=1, cmap=plt.cm.gray_r, alpha=0.5)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fcbec6d5850>"
]
},
"metadata": {},
"execution_count": 19
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 504x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment