Skip to content

Instantly share code, notes, and snippets.

@dsaint31x
Last active June 6, 2024 12:56
Show Gist options
  • Save dsaint31x/736ea85c52b704c40e0214785b456351 to your computer and use it in GitHub Desktop.
Save dsaint31x/736ea85c52b704c40e0214785b456351 to your computer and use it in GitHub Desktop.
dl_torch_multiclass_classification.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true,
"authorship_tag": "ABX9TyNj1kdaUZgptbBRBwIow1U/",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/dsaint31x/736ea85c52b704c40e0214785b456351/dl_torch_multiclass_classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Dataset"
],
"metadata": {
"id": "XwRyZEthq7FW"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oLUs_h-Ysn6i",
"outputId": "22348492-0fe4-45d9-d587-07ecbfc2b550"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'sklearn.utils._bunch.Bunch'>\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['DESCR',\n",
" 'data',\n",
" 'data_module',\n",
" 'feature_names',\n",
" 'filename',\n",
" 'frame',\n",
" 'target',\n",
" 'target_names']"
]
},
"metadata": {},
"execution_count": 1
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"\n",
"iris = load_iris()\n",
"print(type(iris))\n",
"dir(iris)"
]
},
{
"cell_type": "markdown",
"source": [
"# Exploratory Data Analysis (EDA)\n",
"\n",
"탐색적 데이터 분석.\n",
"\n",
"여기선 df를 이용해 간단히 살펴본다."
],
"metadata": {
"id": "BigXS9S_E9n7"
}
},
{
"cell_type": "code",
"source": [
"from IPython import display\n",
"display.Markdown(iris.DESCR)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 872
},
"id": "TWQMre-us0bD",
"outputId": "b7d17519-2666-47bd-fe05-c53aa0e24f04"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": ".. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n \n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher's paper. Note that it's the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher's paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n - Fisher, R.A. \"The use of multiple measurements in taxonomic problems\"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\n Mathematical Statistics\" (John Wiley, NY, 1950).\n - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments\". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al\"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ..."
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
"df['label'] = iris.target\n",
"\n",
"df.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "dhnsMaE0Dl3s",
"outputId": "ec5aba95-c39c-4020-fa6e-279527b202de"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n",
"0 5.1 3.5 1.4 0.2 \n",
"1 4.9 3.0 1.4 0.2 \n",
"2 4.7 3.2 1.3 0.2 \n",
"3 4.6 3.1 1.5 0.2 \n",
"4 5.0 3.6 1.4 0.2 \n",
"\n",
" label \n",
"0 0 \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 "
],
"text/html": [
"\n",
" <div id=\"df-f51c7776-70b5-4e58-af69-f9d5b1c7e702\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f51c7776-70b5-4e58-af69-f9d5b1c7e702')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-f51c7776-70b5-4e58-af69-f9d5b1c7e702 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-f51c7776-70b5-4e58-af69-f9d5b1c7e702');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-5fda5829-6ec5-4cf6-9f15-1d82a10c9cc5\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-5fda5829-6ec5-4cf6-9f15-1d82a10c9cc5')\"\n",
" title=\"Suggest charts\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-5fda5829-6ec5-4cf6-9f15-1d82a10c9cc5 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
"\n",
" </div>\n",
" </div>\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "df",
"summary": "{\n \"name\": \"df\",\n \"rows\": 150,\n \"fields\": [\n {\n \"column\": \"sepal length (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.828066127977863,\n \"min\": 4.3,\n \"max\": 7.9,\n \"num_unique_values\": 35,\n \"samples\": [\n 6.2,\n 4.5,\n 5.6\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"sepal width (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.4358662849366982,\n \"min\": 2.0,\n \"max\": 4.4,\n \"num_unique_values\": 23,\n \"samples\": [\n 2.3,\n 4.0,\n 3.5\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"petal length (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.7652982332594662,\n \"min\": 1.0,\n \"max\": 6.9,\n \"num_unique_values\": 43,\n \"samples\": [\n 6.7,\n 3.8,\n 3.7\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"petal width (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.7622376689603465,\n \"min\": 0.1,\n \"max\": 2.5,\n \"num_unique_values\": 22,\n \"samples\": [\n 0.2,\n 1.2,\n 1.3\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"label\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
}
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"df.describe()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "HtdJ8b29ETZb",
"outputId": "8506418c-e2c2-4fbf-97f1-f564e5857e07"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) \\\n",
"count 150.000000 150.000000 150.000000 \n",
"mean 5.843333 3.057333 3.758000 \n",
"std 0.828066 0.435866 1.765298 \n",
"min 4.300000 2.000000 1.000000 \n",
"25% 5.100000 2.800000 1.600000 \n",
"50% 5.800000 3.000000 4.350000 \n",
"75% 6.400000 3.300000 5.100000 \n",
"max 7.900000 4.400000 6.900000 \n",
"\n",
" petal width (cm) label \n",
"count 150.000000 150.000000 \n",
"mean 1.199333 1.000000 \n",
"std 0.762238 0.819232 \n",
"min 0.100000 0.000000 \n",
"25% 0.300000 0.000000 \n",
"50% 1.300000 1.000000 \n",
"75% 1.800000 2.000000 \n",
"max 2.500000 2.000000 "
],
"text/html": [
"\n",
" <div id=\"df-4cfd8458-1294-4138-b725-12b6242a9c73\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>150.000000</td>\n",
" <td>150.000000</td>\n",
" <td>150.000000</td>\n",
" <td>150.000000</td>\n",
" <td>150.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>5.843333</td>\n",
" <td>3.057333</td>\n",
" <td>3.758000</td>\n",
" <td>1.199333</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.828066</td>\n",
" <td>0.435866</td>\n",
" <td>1.765298</td>\n",
" <td>0.762238</td>\n",
" <td>0.819232</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>4.300000</td>\n",
" <td>2.000000</td>\n",
" <td>1.000000</td>\n",
" <td>0.100000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>5.100000</td>\n",
" <td>2.800000</td>\n",
" <td>1.600000</td>\n",
" <td>0.300000</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>5.800000</td>\n",
" <td>3.000000</td>\n",
" <td>4.350000</td>\n",
" <td>1.300000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>6.400000</td>\n",
" <td>3.300000</td>\n",
" <td>5.100000</td>\n",
" <td>1.800000</td>\n",
" <td>2.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>7.900000</td>\n",
" <td>4.400000</td>\n",
" <td>6.900000</td>\n",
" <td>2.500000</td>\n",
" <td>2.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-4cfd8458-1294-4138-b725-12b6242a9c73')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-4cfd8458-1294-4138-b725-12b6242a9c73 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-4cfd8458-1294-4138-b725-12b6242a9c73');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-632ebb4c-f65e-4da1-8fa8-27b4cb38c406\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-632ebb4c-f65e-4da1-8fa8-27b4cb38c406')\"\n",
" title=\"Suggest charts\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-632ebb4c-f65e-4da1-8fa8-27b4cb38c406 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
"\n",
" </div>\n",
" </div>\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"summary": "{\n \"name\": \"df\",\n \"rows\": 8,\n \"fields\": [\n {\n \"column\": \"sepal length (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 51.24711349471842,\n \"min\": 0.828066127977863,\n \"max\": 150.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 5.843333333333334,\n 5.8,\n 150.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"sepal width (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 52.08617800869865,\n \"min\": 0.4358662849366982,\n \"max\": 150.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 3.0573333333333337,\n 3.0,\n 150.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"petal length (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 51.83521261418364,\n \"min\": 1.0,\n \"max\": 150.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 3.7580000000000005,\n 4.35,\n 150.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"petal width (cm)\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 52.63664824261751,\n \"min\": 0.1,\n \"max\": 150.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 1.1993333333333336,\n 1.3,\n 150.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"label\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 52.69404575122032,\n \"min\": 0.0,\n \"max\": 150.0,\n \"num_unique_values\": 5,\n \"samples\": [\n 1.0,\n 2.0,\n 0.8192319205190405\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
}
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"df.info()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VQ3w6RP0EbIC",
"outputId": "48174f67-62af-4478-e545-11e7b7b76cf4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 150 entries, 0 to 149\n",
"Data columns (total 5 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 sepal length (cm) 150 non-null float64\n",
" 1 sepal width (cm) 150 non-null float64\n",
" 2 petal length (cm) 150 non-null float64\n",
" 3 petal width (cm) 150 non-null float64\n",
" 4 label 150 non-null int64 \n",
"dtypes: float64(4), int64(1)\n",
"memory usage: 6.0 KB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"df['label'].value_counts()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RJ6EnxywFQ4M",
"outputId": "ef77b276-971a-4ff3-ea22-e7dc3f89b359"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"label\n",
"0 50\n",
"1 50\n",
"2 50\n",
"Name: count, dtype: int64"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"np.unique(iris.target_names) # 0: setosa, 1: versicolor, 2: virginica"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nR5npuudGSV6",
"outputId": "53c123c6-394d-46f2-bb18-cbe28eac92dc"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(['setosa', 'versicolor', 'virginica'], dtype='<U10')"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "markdown",
"source": [
"# Data Processing"
],
"metadata": {
"id": "lb5YoeHDq4hl"
}
},
{
"cell_type": "code",
"source": [
"x_raw = iris.data\n",
"y_raw = iris.target\n",
"\n",
"print(f'{type(x_raw)=}:{x_raw.shape=}')\n",
"print(f'{type(y_raw)=}:{y_raw.shape=}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HFn8uURBtCP5",
"outputId": "5149406e-32a9-4a6b-9b81-4c63f4af7409"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"type(x_raw)=<class 'numpy.ndarray'>:x_raw.shape=(150, 4)\n",
"type(y_raw)=<class 'numpy.ndarray'>:y_raw.shape=(150,)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"iris.feature_names"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1frbC4zpuCkL",
"outputId": "08deb9dd-5c57-450e-957f-e5d25eeded3c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['sepal length (cm)',\n",
" 'sepal width (cm)',\n",
" 'petal length (cm)',\n",
" 'petal width (cm)']"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import Dataset\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(\n",
" x_raw, y_raw, test_size = .2,\n",
" stratify=y_raw,\n",
")\n",
"\n",
"x_train, x_val, y_train, y_val = train_test_split(\n",
" x_train, y_train, test_size = .2,\n",
" stratify=y_train,\n",
")\n",
"\n",
"print(len(x_train), len(x_val), len(x_test))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DcTbgjUDuW1Q",
"outputId": "a1360d7e-45d7-4aaa-f9dd-a1e0a68a6ba2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"96 24 30\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"\n",
"test_r = np.unique(y_test , return_counts = True)\n",
"train_r = np.unique(y_train, return_counts = True)\n",
"val_r = np.unique(y_val, return_counts = True)\n",
"\n",
"print(\"test's versicolor ratio : \", np.round( test_r[1][1]/np.sum( test_r[1]),2))\n",
"print(\"train's versicolor ratio: \", np.round(train_r[1][1]/np.sum(train_r[1]),2))\n",
"print(\"val's versicolor ratio. : \", np.round( val_r[1][1]/np.sum( val_r[1]),2))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Sm5GtCYGu8Up",
"outputId": "8cfef970-9d7c-4147-dff1-aa597088b1c0"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"test's versicolor ratio : 0.33\n",
"train's versicolor ratio: 0.33\n",
"val's versicolor ratio. : 0.33\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f' {x_test.shape=}\\n{x_train.shape=}\\n{ x_val.shape=}')\n",
"print(f' {y_test.shape=}\\n{y_train.shape=}\\n{ y_val.shape=}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rYQVGnODvNsV",
"outputId": "641326ba-bbcc-46ed-b490-f027f49ea5f7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" x_test.shape=(30, 4)\n",
"x_train.shape=(96, 4)\n",
" x_val.shape=(24, 4)\n",
" y_test.shape=(30,)\n",
"y_train.shape=(96,)\n",
" y_val.shape=(24,)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"f_scaler = StandardScaler()\n",
"x_norm = f_scaler.fit_transform(x_train)"
],
"metadata": {
"id": "YySDUs0HS2u5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Dataset and DataLoader"
],
"metadata": {
"id": "Jnwq_Fg8lELm"
}
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import Dataset\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class IrisDataset (Dataset):\n",
"\n",
" def __init__(self, r_vec, r_label, transform):\n",
" self.data = torch.tensor(transform.transform(r_vec)).float()\n",
" self.label = torch.tensor(r_label).long()\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
"\n",
" return self.data[idx],self.label[idx]\n",
"\n",
"train_ds = IrisDataset(x_train,y_train,f_scaler)\n",
"val_ds = IrisDataset(x_val, y_val,f_scaler)\n",
"test_ds = IrisDataset(x_test, y_test,f_scaler)"
],
"metadata": {
"id": "xCk7YXHRyQ3f"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"next(iter(train_ds))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pWDO_6s_K-ch",
"outputId": "28dbb47a-6d26-4fcd-924b-41618cdd259a"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(tensor([0.6833, 1.0169, 1.0851, 1.6177]), tensor(2))"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"train_loader = DataLoader(\n",
" dataset = train_ds,\n",
" batch_size = 32,\n",
" shuffle = True,\n",
")\n",
"valid_loader = DataLoader(\n",
" dataset = val_ds,\n",
" batch_size = 32,\n",
" shuffle = True,\n",
")\n",
"test_loader = DataLoader(\n",
" dataset = test_ds,\n",
" batch_size = 32,\n",
" shuffle = False,\n",
")"
],
"metadata": {
"id": "1JnWNey3Q-tO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Train Loop"
],
"metadata": {
"id": "F2HehIUClKx0"
}
},
{
"cell_type": "code",
"source": [
"def train_loop (\n",
" n_epoch,\n",
" train_loader, val_loader,\n",
" model, optimizer, loss_fnc,\n",
" device = 'cpu',):\n",
"\n",
" log_hist = torch.zeros([0,3]).to(device)\n",
" for epoch in range(n_epoch):\n",
" # print(f'\\r{epoch:6>}', end='')\n",
"\n",
" model = model.to(device)\n",
" model.train()\n",
"\n",
" for x_tensor, y_tensor in train_loader:\n",
"\n",
" x_tensor = x_tensor.to(device)\n",
" y_tensor = y_tensor.to(device)\n",
"\n",
" y_pred = model(x_tensor)\n",
" # print(y_tensor.shape)\n",
" loss_train = loss_fnc(y_pred,y_tensor)\n",
"\n",
" if torch.isinf(loss_train):\n",
" print(f'Error: loss is infinity!')\n",
" break\n",
"\n",
" optimizer.zero_grad()\n",
" loss_train.backward()\n",
" optimizer.step()\n",
"\n",
" with torch.no_grad():\n",
" for x_tensor, y_tensor in val_loader:\n",
" x_tensor = x_tensor.to(device)\n",
" y_tensor = y_tensor.to(device)\n",
" model.eval()\n",
" pred = model(x_tensor)\n",
" loss_val = loss_fnc(pred, y_tensor)\n",
"\n",
"\n",
" if epoch % 200 == 0:\n",
" tmp = torch.tensor([epoch, loss_train.item(), loss_val.item()]).to(log_hist.device)\n",
" # print(tmp)\n",
" log_hist = torch.concat( (log_hist, tmp.reshape(1,-1)), dim=0 )\n",
" print(f'{epoch} Epoch / loss {loss_train.item():.4f} / val_loss {loss_val.item():.4f}')\n",
"\n",
" if epoch == n_epoch:\n",
" print(f'{epoch} Epoch / loss {loss_train.item():.4f} / val_loss {loss_val.item():.4f}')\n",
"\n",
" return model,log_hist\n"
],
"metadata": {
"id": "ANWVidl1LB8C"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Model"
],
"metadata": {
"id": "-xE9AeCXlNTm"
}
},
{
"cell_type": "code",
"source": [
"class SimpleModel(nn.Module):\n",
"\n",
" def __init__(self, n_feature, n_classes):\n",
" super().__init__()\n",
"\n",
" self.l0 = nn.Linear(n_feature, 16)\n",
" self.a0 = nn.ReLU()\n",
" self.l1 = nn.Linear(16,16)\n",
" self.a1 = nn.ReLU()\n",
" self.lf = nn.Linear(16,n_classes)\n",
" # self.out = nn.Softmax(dim=-1)\n",
"\n",
" def forward(self, x):\n",
" x = self.l0(x)\n",
" x = self.a0(x)\n",
" x = self.l1(x)\n",
" x = self.a1(x)\n",
" x = self.lf(x)\n",
" # x = self.out(x)\n",
"\n",
" return x"
],
"metadata": {
"id": "3jQjVqvBO7x1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# input vector의 number of components확인.\n",
"n_features = x_raw.shape[-1]\n",
"print(f'{n_features =}')\n",
"# output vector의 number of components확인.\n",
"n_classes = len(np.unique(y_raw))\n",
"print(f'{n_classes =}')\n",
"\n",
"\n",
"# random한 input vector 5개에 대한\n",
"# 모델의 예측 결과를 확인하여\n",
"# 모델의 i/o 의 shape 확인.\n",
"model = SimpleModel(n_features, n_classes)\n",
"x = torch.randn( (5, n_features) )\n",
"print(f'{x.shape =}')\n",
"print(f'{model(x).shape=}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2PyGzAoIP4XY",
"outputId": "c3847a1d-2904-4e5d-9e48-0164a0604622"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"n_features =4\n",
"n_classes =3\n",
"x.shape =torch.Size([5, 4])\n",
"model(x).shape=torch.Size([5, 3])\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Training"
],
"metadata": {
"id": "PoVUr8UrleVU"
}
},
{
"cell_type": "code",
"source": [
"# device 설정.\n",
"device = (\n",
" \"cuda\" if torch.cuda.is_available()\n",
" else \"mps\"\n",
" if torch.backends.mps.is_available()\n",
" else \"cpu\"\n",
" )\n",
"print(f\"{device=}\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4_8JdhNbERd-",
"outputId": "6a727d9d-a8b3-4187-d055-bb20fb625e78"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"device='cpu'\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# number of epochs\n",
"n_epoch = 7000\n",
"# learning ratio\n",
"lr = 4e-4\n",
"\n",
"# loss function 설정\n",
"loss_fnc = nn.CrossEntropyLoss()\n",
"# 모델 생성 및 초기화\n",
"model = SimpleModel(n_features, n_classes)\n",
"# optimizer 생성 및 초기화\n",
"# optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=lr) # recommanded\n",
"\n",
"m, h = train_loop (\n",
" n_epoch,\n",
" train_loader, valid_loader,\n",
" model, optimizer, loss_fnc,\n",
" device = device)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3ffpgj5jMbsZ",
"outputId": "7635d05a-a7f9-44ce-c2a3-ed84ae785501"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 Epoch / loss 1.1006 / val_loss 1.1116\n",
"200 Epoch / loss 0.3397 / val_loss 0.3407\n",
"400 Epoch / loss 0.1005 / val_loss 0.0738\n",
"600 Epoch / loss 0.0817 / val_loss 0.0198\n",
"800 Epoch / loss 0.0367 / val_loss 0.0076\n",
"1000 Epoch / loss 0.0088 / val_loss 0.0031\n",
"1200 Epoch / loss 0.0354 / val_loss 0.0016\n",
"1400 Epoch / loss 0.0181 / val_loss 0.0010\n",
"1600 Epoch / loss 0.0008 / val_loss 0.0006\n",
"1800 Epoch / loss 0.0062 / val_loss 0.0004\n",
"2000 Epoch / loss 0.0005 / val_loss 0.0003\n",
"2200 Epoch / loss 0.0026 / val_loss 0.0002\n",
"2400 Epoch / loss 0.0008 / val_loss 0.0001\n",
"2600 Epoch / loss 0.0047 / val_loss 0.0001\n",
"2800 Epoch / loss 0.0023 / val_loss 0.0001\n",
"3000 Epoch / loss 0.0017 / val_loss 0.0000\n",
"3200 Epoch / loss 0.0006 / val_loss 0.0000\n",
"3400 Epoch / loss 0.0001 / val_loss 0.0000\n",
"3600 Epoch / loss 0.0007 / val_loss 0.0000\n",
"3800 Epoch / loss 0.0004 / val_loss 0.0000\n",
"4000 Epoch / loss 0.0000 / val_loss 0.0000\n",
"4200 Epoch / loss 0.0004 / val_loss 0.0000\n",
"4400 Epoch / loss 0.0002 / val_loss 0.0000\n",
"4600 Epoch / loss 0.0000 / val_loss 0.0000\n",
"4800 Epoch / loss 0.0000 / val_loss 0.0000\n",
"5000 Epoch / loss 0.0000 / val_loss 0.0000\n",
"5200 Epoch / loss 0.0000 / val_loss 0.0000\n",
"5400 Epoch / loss 0.0000 / val_loss 0.0000\n",
"5600 Epoch / loss 0.0000 / val_loss 0.0000\n",
"5800 Epoch / loss 0.0000 / val_loss 0.0000\n",
"6000 Epoch / loss 0.0000 / val_loss 0.0000\n",
"6200 Epoch / loss 0.0000 / val_loss 0.0000\n",
"6400 Epoch / loss 0.0000 / val_loss 0.0000\n",
"6600 Epoch / loss 0.0000 / val_loss 0.0000\n",
"6800 Epoch / loss 0.0000 / val_loss 0.0000\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Learning Curve"
],
"metadata": {
"id": "QdTtvBTyllHo"
}
},
{
"cell_type": "code",
"source": [
"h"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZM1WKxeSItRM",
"outputId": "9f2ed918-668a-47d9-c939-5c2eaabd366c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.0000e+00, 1.1006e+00, 1.1116e+00],\n",
" [2.0000e+02, 3.3965e-01, 3.4072e-01],\n",
" [4.0000e+02, 1.0054e-01, 7.3833e-02],\n",
" [6.0000e+02, 8.1709e-02, 1.9839e-02],\n",
" [8.0000e+02, 3.6681e-02, 7.6313e-03],\n",
" [1.0000e+03, 8.8194e-03, 3.1158e-03],\n",
" [1.2000e+03, 3.5368e-02, 1.6438e-03],\n",
" [1.4000e+03, 1.8125e-02, 9.6418e-04],\n",
" [1.6000e+03, 7.6428e-04, 5.9564e-04],\n",
" [1.8000e+03, 6.1815e-03, 4.1257e-04],\n",
" [2.0000e+03, 4.8867e-04, 2.7939e-04],\n",
" [2.2000e+03, 2.5835e-03, 1.9565e-04],\n",
" [2.4000e+03, 8.2186e-04, 1.3476e-04],\n",
" [2.6000e+03, 4.7214e-03, 9.3336e-05],\n",
" [2.8000e+03, 2.2601e-03, 6.5094e-05],\n",
" [3.0000e+03, 1.7150e-03, 4.8414e-05],\n",
" [3.2000e+03, 6.3290e-04, 3.5402e-05],\n",
" [3.4000e+03, 1.4205e-04, 2.6710e-05],\n",
" [3.6000e+03, 7.0753e-04, 1.9487e-05],\n",
" [3.8000e+03, 3.9150e-04, 1.4536e-05],\n",
" [4.0000e+03, 1.4901e-06, 1.0757e-05],\n",
" [4.2000e+03, 4.2672e-04, 7.6883e-06],\n",
" [4.4000e+03, 1.9806e-04, 5.4783e-06],\n",
" [4.6000e+03, 2.7865e-06, 4.0033e-06],\n",
" [4.8000e+03, 1.6243e-05, 3.2583e-06],\n",
" [5.0000e+03, 4.2322e-05, 2.5133e-06],\n",
" [5.2000e+03, 2.8833e-06, 1.9272e-06],\n",
" [5.4000e+03, 4.0924e-05, 1.4355e-06],\n",
" [5.6000e+03, 3.2992e-05, 9.3380e-07],\n",
" [5.8000e+03, 6.8054e-06, 6.0598e-07],\n",
" [6.0000e+03, 1.1435e-05, 3.7253e-07],\n",
" [6.2000e+03, 1.0538e-05, 2.4338e-07],\n",
" [6.4000e+03, 4.8165e-06, 1.5398e-07],\n",
" [6.6000e+03, 1.1884e-06, 9.9341e-08],\n",
" [6.8000e+03, 2.2352e-08, 6.4572e-08]])"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"h = h.cpu()\n",
"h0 = h.detach().numpy()\n",
"print(h0.shape)\n",
"plt.plot(h0[:,0], h0[:,1], label='train')\n",
"plt.plot(h0[:,0], h0[:,2], label='valid')\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 448
},
"id": "2PbbyHveHWp0",
"outputId": "17d782de-00db-41b5-bd96-985f016aa317"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(35, 3)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Test"
],
"metadata": {
"id": "XzE1E_fDlssv"
}
},
{
"cell_type": "code",
"source": [
"f_m = model.cpu()\n",
"f_m.eval()\n",
"with torch.no_grad():\n",
" is_first = True\n",
" for x_tensor, y_tensor in test_loader:\n",
" x_tensor = x_tensor.cpu()\n",
" y_tensor = y_tensor.cpu()\n",
" pred_prob = f_m(x_tensor)\n",
" pred = torch.argmax(pred_prob, dim=-1) # class 수 만큼 나오므로 가장 큰 확률의 index를 얻어야함.\n",
"\n",
" # print(y_tensor.numpy().shape)\n",
" if is_first:\n",
" test_pred = pred.numpy().copy()\n",
" test_label = y_tensor.numpy().copy()\n",
" is_first = False\n",
" else:\n",
" test_label = np.concatenate((test_label, y_tensor.numpy()), axis=0)\n",
" test_pred = np.concatenate((test_pred, pred.numpy()), axis=0)\n",
"\n",
"print(test_label.shape, test_label.dtype)\n",
"print(test_pred.shape, test_pred.dtype)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bJEygiJETGIt",
"outputId": "398e37a4-962b-43ea-fa53-9ef358abdd0f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(30,) int64\n",
"(30,) int64\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.metrics import (\n",
" confusion_matrix,\n",
" precision_score,\n",
" recall_score,\n",
" f1_score,\n",
" fbeta_score,\n",
")\n",
"pred = test_pred\n",
"label = test_label\n",
"\n",
"print('Confusion Matrix\\n',\n",
" str(confusion_matrix(label,pred))\n",
" )\n",
"print(f'Precision :{np.round(precision_score(label,pred,average=None),2)}')\n",
"print(f'Precision (macro) :{np.round(precision_score(label,pred,average=\"macro\"),2)}')\n",
"print(f'Precision (micro) :{np.round(precision_score(label,pred,average=\"micro\"),2)}')\n",
"print(f'Precision (weighted):{np.round(precision_score(label,pred,average=\"weighted\"),2)}')\n",
"print(f'Recall :{np.round(recall_score(label,pred,average=None),2)}')\n",
"print(f'Recall (macro) :{np.round(recall_score(label,pred,average=\"macro\"),2)}')\n",
"print(f'Recall (micro) :{np.round(recall_score(label,pred,average=\"micro\"),2)}')\n",
"print(f'Recall (weighted):{np.round(recall_score(label,pred,average=\"weighted\"),2)}')\n",
"print(f'F1-score :{np.round(f1_score(label,pred,average=None),2)}')\n",
"print(f'F1-score (macro) :{np.round(f1_score(label,pred,average=\"macro\"),2)}')\n",
"print(f'F1-score (micro) :{np.round(f1_score(label,pred,average=\"micro\"),2)}')\n",
"print(f'F1-score (weighted):{np.round(f1_score(label,pred,average=\"weighted\"),2)}')\n",
"print(f'F2-score :{np.round(fbeta_score(label,pred,beta=2,average=None),2)}')\n",
"print(f'F2-score (macro) :{np.round(fbeta_score(label,pred,beta=2,average=\"macro\"),2)}')\n",
"print(f'F2-score (micro) :{np.round(fbeta_score(label,pred,beta=2,average=\"micro\"),2)}')\n",
"print(f'F2-score (weighted):{np.round(fbeta_score(label,pred,beta=2,average=\"weighted\"),2)}')\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BhUC7lATWE4g",
"outputId": "db725383-a1c8-4b45-d680-d3eb2d2e7d08"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Confusion Matrix\n",
" [[10 0 0]\n",
" [ 0 8 2]\n",
" [ 0 0 10]]\n",
"Precision :[1. 1. 0.83]\n",
"Precision (macro) :0.94\n",
"Precision (micro) :0.93\n",
"Precision (weighted):0.94\n",
"Recall :[1. 0.8 1. ]\n",
"Recall (macro) :0.93\n",
"Recall (micro) :0.93\n",
"Recall (weighted):0.93\n",
"F1-score :[1. 0.89 0.91]\n",
"F1-score (macro) :0.93\n",
"F1-score (micro) :0.93\n",
"F1-score (weighted):0.93\n",
"F2-score :[1. 0.83 0.96]\n",
"F2-score (macro) :0.93\n",
"F2-score (micro) :0.93\n",
"F2-score (weighted):0.93\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Cx3NqQukU-7t"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment