Created
November 14, 2020 00:05
-
-
Save jshirius/90ac6c9b4881bf4c9b96ba4d0007496a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "PCA(主成分分析)によるデータ水増しテスト.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "HG4e2sPbf_F9" | |
}, | |
"source": [ | |
"# 主成分分析(PCA)を使ってデータの水増し(行を増やす)ができるかの検証\n", | |
"- 結論は、可能と判断\n", | |
"- ポイントは以下の通り\n", | |
" - PCAの前に標準化すること\n", | |
" - 圧縮のパラメータ(n_components)は、「元のカラム数 - 1」が良い\n", | |
"- 実際にkaggleで、少ないラベルのところに適用したらスコアの上昇を確認できた" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ilh3dyyfWQw5" | |
}, | |
"source": [ | |
"from sklearn.datasets import load_iris\n", | |
"import pandas as pd\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"from sklearn.decomposition import PCA" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "pMXjOh1dffh7" | |
}, | |
"source": [ | |
"# データの準備\n", | |
"- 今回は定番のirisを使う\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "G2fX6RBxXmQa" | |
}, | |
"source": [ | |
"iris = load_iris()\n", | |
"df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", | |
"df['target'] = iris.target\n", | |
"\n", | |
"#標準化したいカラム取り出す\n", | |
"#今回は4カラム分\n", | |
"X = df.iloc[:, 0:4]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "o950e99aZKFD", | |
"outputId": "3968cb0c-a662-4c48-90c6-bb2b33486f53", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 419 | |
} | |
}, | |
"source": [ | |
"X" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>sepal length (cm)</th>\n", | |
" <th>sepal width (cm)</th>\n", | |
" <th>petal length (cm)</th>\n", | |
" <th>petal width (cm)</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>5.1</td>\n", | |
" <td>3.5</td>\n", | |
" <td>1.4</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>4.9</td>\n", | |
" <td>3.0</td>\n", | |
" <td>1.4</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>4.7</td>\n", | |
" <td>3.2</td>\n", | |
" <td>1.3</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4.6</td>\n", | |
" <td>3.1</td>\n", | |
" <td>1.5</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5.0</td>\n", | |
" <td>3.6</td>\n", | |
" <td>1.4</td>\n", | |
" <td>0.2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>145</th>\n", | |
" <td>6.7</td>\n", | |
" <td>3.0</td>\n", | |
" <td>5.2</td>\n", | |
" <td>2.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>146</th>\n", | |
" <td>6.3</td>\n", | |
" <td>2.5</td>\n", | |
" <td>5.0</td>\n", | |
" <td>1.9</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>147</th>\n", | |
" <td>6.5</td>\n", | |
" <td>3.0</td>\n", | |
" <td>5.2</td>\n", | |
" <td>2.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>148</th>\n", | |
" <td>6.2</td>\n", | |
" <td>3.4</td>\n", | |
" <td>5.4</td>\n", | |
" <td>2.3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>149</th>\n", | |
" <td>5.9</td>\n", | |
" <td>3.0</td>\n", | |
" <td>5.1</td>\n", | |
" <td>1.8</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>150 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n", | |
"0 5.1 3.5 1.4 0.2\n", | |
"1 4.9 3.0 1.4 0.2\n", | |
"2 4.7 3.2 1.3 0.2\n", | |
"3 4.6 3.1 1.5 0.2\n", | |
"4 5.0 3.6 1.4 0.2\n", | |
".. ... ... ... ...\n", | |
"145 6.7 3.0 5.2 2.3\n", | |
"146 6.3 2.5 5.0 1.9\n", | |
"147 6.5 3.0 5.2 2.0\n", | |
"148 6.2 3.4 5.4 2.3\n", | |
"149 5.9 3.0 5.1 1.8\n", | |
"\n", | |
"[150 rows x 4 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "6Xdm-0Tqfp7z" | |
}, | |
"source": [ | |
"# 標準化のテスト\n", | |
"- データを標準化して、さらにもとの値に戻るかの確認をしておく" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "34UjOW3fXxOE", | |
"outputId": "40415f4a-ddb9-4e8b-d3f8-5364ad7381b6", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"sc = StandardScaler()\n", | |
"result = sc.fit(X).transform(X) #標準化\n", | |
"#sc.inverse_transform(result) #標準化を元に戻す\n", | |
"result*sc.scale_ + sc.mean_ #標準化を元に戻す()\n", | |
"\n", | |
"#標準化前の値と、標準化から復元したあたいは一致したことがわかった" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[5.1, 3.5, 1.4, 0.2],\n", | |
" [4.9, 3. , 1.4, 0.2],\n", | |
" [4.7, 3.2, 1.3, 0.2],\n", | |
" [4.6, 3.1, 1.5, 0.2],\n", | |
" [5. , 3.6, 1.4, 0.2],\n", | |
" [5.4, 3.9, 1.7, 0.4],\n", | |
" [4.6, 3.4, 1.4, 0.3],\n", | |
" [5. , 3.4, 1.5, 0.2],\n", | |
" [4.4, 2.9, 1.4, 0.2],\n", | |
" [4.9, 3.1, 1.5, 0.1],\n", | |
" [5.4, 3.7, 1.5, 0.2],\n", | |
" [4.8, 3.4, 1.6, 0.2],\n", | |
" [4.8, 3. , 1.4, 0.1],\n", | |
" [4.3, 3. , 1.1, 0.1],\n", | |
" [5.8, 4. , 1.2, 0.2],\n", | |
" [5.7, 4.4, 1.5, 0.4],\n", | |
" [5.4, 3.9, 1.3, 0.4],\n", | |
" [5.1, 3.5, 1.4, 0.3],\n", | |
" [5.7, 3.8, 1.7, 0.3],\n", | |
" [5.1, 3.8, 1.5, 0.3],\n", | |
" [5.4, 3.4, 1.7, 0.2],\n", | |
" [5.1, 3.7, 1.5, 0.4],\n", | |
" [4.6, 3.6, 1. , 0.2],\n", | |
" [5.1, 3.3, 1.7, 0.5],\n", | |
" [4.8, 3.4, 1.9, 0.2],\n", | |
" [5. , 3. , 1.6, 0.2],\n", | |
" [5. , 3.4, 1.6, 0.4],\n", | |
" [5.2, 3.5, 1.5, 0.2],\n", | |
" [5.2, 3.4, 1.4, 0.2],\n", | |
" [4.7, 3.2, 1.6, 0.2],\n", | |
" [4.8, 3.1, 1.6, 0.2],\n", | |
" [5.4, 3.4, 1.5, 0.4],\n", | |
" [5.2, 4.1, 1.5, 0.1],\n", | |
" [5.5, 4.2, 1.4, 0.2],\n", | |
" [4.9, 3.1, 1.5, 0.2],\n", | |
" [5. , 3.2, 1.2, 0.2],\n", | |
" [5.5, 3.5, 1.3, 0.2],\n", | |
" [4.9, 3.6, 1.4, 0.1],\n", | |
" [4.4, 3. , 1.3, 0.2],\n", | |
" [5.1, 3.4, 1.5, 0.2],\n", | |
" [5. , 3.5, 1.3, 0.3],\n", | |
" [4.5, 2.3, 1.3, 0.3],\n", | |
" [4.4, 3.2, 1.3, 0.2],\n", | |
" [5. , 3.5, 1.6, 0.6],\n", | |
" [5.1, 3.8, 1.9, 0.4],\n", | |
" [4.8, 3. , 1.4, 0.3],\n", | |
" [5.1, 3.8, 1.6, 0.2],\n", | |
" [4.6, 3.2, 1.4, 0.2],\n", | |
" [5.3, 3.7, 1.5, 0.2],\n", | |
" [5. , 3.3, 1.4, 0.2],\n", | |
" [7. , 3.2, 4.7, 1.4],\n", | |
" [6.4, 3.2, 4.5, 1.5],\n", | |
" [6.9, 3.1, 4.9, 1.5],\n", | |
" [5.5, 2.3, 4. , 1.3],\n", | |
" [6.5, 2.8, 4.6, 1.5],\n", | |
" [5.7, 2.8, 4.5, 1.3],\n", | |
" [6.3, 3.3, 4.7, 1.6],\n", | |
" [4.9, 2.4, 3.3, 1. ],\n", | |
" [6.6, 2.9, 4.6, 1.3],\n", | |
" [5.2, 2.7, 3.9, 1.4],\n", | |
" [5. , 2. , 3.5, 1. ],\n", | |
" [5.9, 3. , 4.2, 1.5],\n", | |
" [6. , 2.2, 4. , 1. ],\n", | |
" [6.1, 2.9, 4.7, 1.4],\n", | |
" [5.6, 2.9, 3.6, 1.3],\n", | |
" [6.7, 3.1, 4.4, 1.4],\n", | |
" [5.6, 3. , 4.5, 1.5],\n", | |
" [5.8, 2.7, 4.1, 1. ],\n", | |
" [6.2, 2.2, 4.5, 1.5],\n", | |
" [5.6, 2.5, 3.9, 1.1],\n", | |
" [5.9, 3.2, 4.8, 1.8],\n", | |
" [6.1, 2.8, 4. , 1.3],\n", | |
" [6.3, 2.5, 4.9, 1.5],\n", | |
" [6.1, 2.8, 4.7, 1.2],\n", | |
" [6.4, 2.9, 4.3, 1.3],\n", | |
" [6.6, 3. , 4.4, 1.4],\n", | |
" [6.8, 2.8, 4.8, 1.4],\n", | |
" [6.7, 3. , 5. , 1.7],\n", | |
" [6. , 2.9, 4.5, 1.5],\n", | |
" [5.7, 2.6, 3.5, 1. ],\n", | |
" [5.5, 2.4, 3.8, 1.1],\n", | |
" [5.5, 2.4, 3.7, 1. ],\n", | |
" [5.8, 2.7, 3.9, 1.2],\n", | |
" [6. , 2.7, 5.1, 1.6],\n", | |
" [5.4, 3. , 4.5, 1.5],\n", | |
" [6. , 3.4, 4.5, 1.6],\n", | |
" [6.7, 3.1, 4.7, 1.5],\n", | |
" [6.3, 2.3, 4.4, 1.3],\n", | |
" [5.6, 3. , 4.1, 1.3],\n", | |
" [5.5, 2.5, 4. , 1.3],\n", | |
" [5.5, 2.6, 4.4, 1.2],\n", | |
" [6.1, 3. , 4.6, 1.4],\n", | |
" [5.8, 2.6, 4. , 1.2],\n", | |
" [5. , 2.3, 3.3, 1. ],\n", | |
" [5.6, 2.7, 4.2, 1.3],\n", | |
" [5.7, 3. , 4.2, 1.2],\n", | |
" [5.7, 2.9, 4.2, 1.3],\n", | |
" [6.2, 2.9, 4.3, 1.3],\n", | |
" [5.1, 2.5, 3. , 1.1],\n", | |
" [5.7, 2.8, 4.1, 1.3],\n", | |
" [6.3, 3.3, 6. , 2.5],\n", | |
" [5.8, 2.7, 5.1, 1.9],\n", | |
" [7.1, 3. , 5.9, 2.1],\n", | |
" [6.3, 2.9, 5.6, 1.8],\n", | |
" [6.5, 3. , 5.8, 2.2],\n", | |
" [7.6, 3. , 6.6, 2.1],\n", | |
" [4.9, 2.5, 4.5, 1.7],\n", | |
" [7.3, 2.9, 6.3, 1.8],\n", | |
" [6.7, 2.5, 5.8, 1.8],\n", | |
" [7.2, 3.6, 6.1, 2.5],\n", | |
" [6.5, 3.2, 5.1, 2. ],\n", | |
" [6.4, 2.7, 5.3, 1.9],\n", | |
" [6.8, 3. , 5.5, 2.1],\n", | |
" [5.7, 2.5, 5. , 2. ],\n", | |
" [5.8, 2.8, 5.1, 2.4],\n", | |
" [6.4, 3.2, 5.3, 2.3],\n", | |
" [6.5, 3. , 5.5, 1.8],\n", | |
" [7.7, 3.8, 6.7, 2.2],\n", | |
" [7.7, 2.6, 6.9, 2.3],\n", | |
" [6. , 2.2, 5. , 1.5],\n", | |
" [6.9, 3.2, 5.7, 2.3],\n", | |
" [5.6, 2.8, 4.9, 2. ],\n", | |
" [7.7, 2.8, 6.7, 2. ],\n", | |
" [6.3, 2.7, 4.9, 1.8],\n", | |
" [6.7, 3.3, 5.7, 2.1],\n", | |
" [7.2, 3.2, 6. , 1.8],\n", | |
" [6.2, 2.8, 4.8, 1.8],\n", | |
" [6.1, 3. , 4.9, 1.8],\n", | |
" [6.4, 2.8, 5.6, 2.1],\n", | |
" [7.2, 3. , 5.8, 1.6],\n", | |
" [7.4, 2.8, 6.1, 1.9],\n", | |
" [7.9, 3.8, 6.4, 2. ],\n", | |
" [6.4, 2.8, 5.6, 2.2],\n", | |
" [6.3, 2.8, 5.1, 1.5],\n", | |
" [6.1, 2.6, 5.6, 1.4],\n", | |
" [7.7, 3. , 6.1, 2.3],\n", | |
" [6.3, 3.4, 5.6, 2.4],\n", | |
" [6.4, 3.1, 5.5, 1.8],\n", | |
" [6. , 3. , 4.8, 1.8],\n", | |
" [6.9, 3.1, 5.4, 2.1],\n", | |
" [6.7, 3.1, 5.6, 2.4],\n", | |
" [6.9, 3.1, 5.1, 2.3],\n", | |
" [5.8, 2.7, 5.1, 1.9],\n", | |
" [6.8, 3.2, 5.9, 2.3],\n", | |
" [6.7, 3.3, 5.7, 2.5],\n", | |
" [6.7, 3. , 5.2, 2.3],\n", | |
" [6.3, 2.5, 5. , 1.9],\n", | |
" [6.5, 3. , 5.2, 2. ],\n", | |
" [6.2, 3.4, 5.4, 2.3],\n", | |
" [5.9, 3. , 5.1, 1.8]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nRViW7rxXx1a", | |
"outputId": "d2627b47-e689-400e-ad77-dd17c36c88fa", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
" sc.mean_, sc.scale_" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([5.84333333, 3.05733333, 3.758 , 1.19933333]),\n", | |
" array([0.82530129, 0.43441097, 1.75940407, 0.75969263]))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zxRI6WKXhPth" | |
}, | |
"source": [ | |
"# 標準化なしのPCAによるデータの水増しが出来るか確認\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "IiKrw7NwZ32H", | |
"outputId": "b42e66e0-b436-447f-f9f8-6f0c581c7104", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"#PCAによるデータ水増し実験\n", | |
"n_comp = 3#(元のカラム数 - 1)\n", | |
"\n", | |
"#圧縮\n", | |
"pca = PCA(n_components=n_comp, random_state=42)\n", | |
"pca_res = pca.fit_transform(X)\n", | |
"\n", | |
"#圧縮をもとに戻す\n", | |
"restore_list = pca.inverse_transform(pca_res)\n", | |
"\n", | |
"\n", | |
"#1列目の差分を取る\n", | |
"diff = 0\n", | |
"for index, i in enumerate(X.values.tolist()):\n", | |
" diff += abs(i[0] - restore_list[index][0])\n", | |
" #print(i[0] )\n", | |
"\n", | |
"#1列目の元データと圧縮後の復元の差分を確認する\n", | |
"print(diff)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"5.482752051480244\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "WgMQiyothwG0" | |
}, | |
"source": [ | |
"# 標準化ありのPCAによるデータの水増しが出来るか確認\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "d9R96PqGbyHs", | |
"outputId": "313ab784-c825-4fb4-8fe1-013d7daf2eed", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"#PCAによるデータ水増し(標準化あり)\n", | |
"n_comp = 3#(元のカラム数 - 1)\n", | |
"\n", | |
"sc = StandardScaler()\n", | |
"X_sc = sc.fit(X).transform(X) #標準化\n", | |
"\n", | |
"#PCAによる圧縮\n", | |
"pca = PCA(n_components=n_comp, random_state=42)\n", | |
"pca_res = pca.fit_transform(X_sc)\n", | |
"#PCAの復元化\n", | |
"restore_list = pca.inverse_transform(pca_res) \n", | |
"\n", | |
"#標準化を元に戻す\n", | |
"restore_list = restore_list*sc.scale_ + sc.mean_ \n", | |
"\n", | |
"#1列目の差分を取る\n", | |
"diff = 0\n", | |
"for index, i in enumerate(X.values.tolist()):\n", | |
" diff += abs(i[0] - restore_list[index][0])\n", | |
" #print(i[0] )\n", | |
"print(diff)\n", | |
"\n", | |
"#実際にデータを行単位で水増しするときは、「restore_list」を、元データに追加することになる。\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"3.5309526623433074\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "mG2orkskiWCD" | |
}, | |
"source": [ | |
"これまでの結果から、標準化してからPCA圧縮したほうが、\n", | |
"より元のデータに近い形でデータの水増しが出来ることがわかった。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "i4K_0DApefVK", | |
"outputId": "92c8e6ca-e1e6-4b34-b314-0cf4ed5f964a", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"#これまでの結果を一般化するため関数にまとめる\n", | |
"def data_augmentation(X, n_comp):\n", | |
" #Xは元データ\n", | |
" #n_compはn_componentsの圧縮パラメータ\n", | |
"\n", | |
" sc = StandardScaler()\n", | |
" X_sc = sc.fit(X).transform(X) #標準化\n", | |
"\n", | |
" #PCAによる圧縮\n", | |
" pca = PCA(n_components=n_comp, random_state=42)\n", | |
" pca_res = pca.fit_transform(X_sc)\n", | |
"\n", | |
" #PCAの復元化\n", | |
" restore_list = pca.inverse_transform(pca_res) \n", | |
"\n", | |
" #標準化を元に戻す\n", | |
" restore_list = restore_list*sc.scale_ + sc.mean_ \n", | |
"\n", | |
" #1列目の差分も評価用に出力\n", | |
" diff = 0\n", | |
" for index, i in enumerate(X):\n", | |
" diff += abs(i[0] - restore_list[index][0])\n", | |
" #print(i[0] )\n", | |
"\n", | |
" return restore_list, diff\n", | |
"\n", | |
"aug_data, diff = data_augmentation(X.values.tolist(), 3)\n", | |
"diff" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"3.5309526623433074" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 9 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LkftBNwelK6F", | |
"outputId": "cd6ca611-d352-401c-95c2-d842584709c1", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 419 | |
} | |
}, | |
"source": [ | |
"#pandasでデータを渡して、水増し対象の列のみ水増し、それ以外は、データをコピーして返す\n", | |
"def pd_augmentation(df, aug_col, n_comp):\n", | |
" #dfのコピー\n", | |
" df_cp = df.copy()\n", | |
" df_cp = df_cp.reset_index(drop=True)\n", | |
"\n", | |
" #水増し該当列のみ水増し対応\n", | |
" aug_data, diff = data_augmentation(df_cp[aug_col].values.tolist(), n_comp)\n", | |
"\n", | |
" #再びdfに変換する\n", | |
" aug_data_df = pd.DataFrame(aug_data, columns=aug_col)\n", | |
"\n", | |
" #水増し側のデータに上書きする\n", | |
" for col in aug_col:\n", | |
" df_cp[col] = aug_data_df[col]\n", | |
"\n", | |
" return df_cp\n", | |
"\n", | |
"\n", | |
"aug_df = pd_augmentation(df, iris.feature_names, 3)\n", | |
"aug_df" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>sepal length (cm)</th>\n", | |
" <th>sepal width (cm)</th>\n", | |
" <th>petal length (cm)</th>\n", | |
" <th>petal width (cm)</th>\n", | |
" <th>target</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>5.094788</td>\n", | |
" <td>3.501297</td>\n", | |
" <td>1.434079</td>\n", | |
" <td>0.190387</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>4.877788</td>\n", | |
" <td>3.005527</td>\n", | |
" <td>1.545247</td>\n", | |
" <td>0.159027</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>4.693881</td>\n", | |
" <td>3.201523</td>\n", | |
" <td>1.340014</td>\n", | |
" <td>0.188712</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4.614223</td>\n", | |
" <td>3.096461</td>\n", | |
" <td>1.406998</td>\n", | |
" <td>0.226235</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5.007746</td>\n", | |
" <td>3.598073</td>\n", | |
" <td>1.349346</td>\n", | |
" <td>0.214289</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>145</th>\n", | |
" <td>6.616061</td>\n", | |
" <td>3.020885</td>\n", | |
" <td>5.748881</td>\n", | |
" <td>2.145164</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>146</th>\n", | |
" <td>6.252518</td>\n", | |
" <td>2.511814</td>\n", | |
" <td>5.310487</td>\n", | |
" <td>1.812414</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>147</th>\n", | |
" <td>6.474302</td>\n", | |
" <td>3.006394</td>\n", | |
" <td>5.368040</td>\n", | |
" <td>1.952597</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>148</th>\n", | |
" <td>6.194366</td>\n", | |
" <td>3.401402</td>\n", | |
" <td>5.436843</td>\n", | |
" <td>2.289607</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>149</th>\n", | |
" <td>5.935166</td>\n", | |
" <td>2.991250</td>\n", | |
" <td>4.870048</td>\n", | |
" <td>1.864868</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>150 rows × 5 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" sepal length (cm) sepal width (cm) ... petal width (cm) target\n", | |
"0 5.094788 3.501297 ... 0.190387 0\n", | |
"1 4.877788 3.005527 ... 0.159027 0\n", | |
"2 4.693881 3.201523 ... 0.188712 0\n", | |
"3 4.614223 3.096461 ... 0.226235 0\n", | |
"4 5.007746 3.598073 ... 0.214289 0\n", | |
".. ... ... ... ... ...\n", | |
"145 6.616061 3.020885 ... 2.145164 2\n", | |
"146 6.252518 2.511814 ... 1.812414 2\n", | |
"147 6.474302 3.006394 ... 1.952597 2\n", | |
"148 6.194366 3.401402 ... 2.289607 2\n", | |
"149 5.935166 2.991250 ... 1.864868 2\n", | |
"\n", | |
"[150 rows x 5 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 20 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "U_DCipcACmHK" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment