Last active
April 29, 2024 01:18
-
-
Save CookieBox26/58d97abbe8657a1e217bf3e020e4f93d to your computer and use it in GitHub Desktop.
カスタムバッチサンプラー
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "16f44a39-6611-43f0-a568-78bc0a6496f5", | |
"metadata": {}, | |
"source": [ | |
"# カスタムバッチサンプラー" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "ad99dd78-eb82-4d4d-b81a-097b9dce33b4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"===== epoch 0 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[10, 10, 10],\n", | |
" [ 7, 7, 7],\n", | |
" [ 9, 9, 9],\n", | |
" [ 0, 0, 0]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[2, 2, 2],\n", | |
" [1, 1, 1],\n", | |
" [8, 8, 8],\n", | |
" [4, 4, 4]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[3, 3, 3],\n", | |
" [6, 6, 6],\n", | |
" [5, 5, 5]])\n", | |
"===== epoch 1 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[10, 10, 10],\n", | |
" [ 7, 7, 7],\n", | |
" [ 6, 6, 6],\n", | |
" [ 1, 1, 1]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[0, 0, 0],\n", | |
" [4, 4, 4],\n", | |
" [9, 9, 9],\n", | |
" [2, 2, 2]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[8, 8, 8],\n", | |
" [3, 3, 3],\n", | |
" [5, 5, 5]])\n" | |
] | |
} | |
], | |
"source": [ | |
"from torch.utils.data import DataLoader\n", | |
"import pandas as pd\n", | |
"\n", | |
"class MyDataset:\n", | |
" \"\"\"\n", | |
" DataLoader に渡すためのデータセット型を定義します.\n", | |
" https://pytorch.org/docs/stable/data.html#map-style-datasets\n", | |
" \"\"\"\n", | |
" def __init__(self, df):\n", | |
" self.df = df\n", | |
" self.n_sample = len(df)\n", | |
" def __getitem__(self, batch_idx):\n", | |
" return self.df.loc[batch_idx, :].values\n", | |
" def __len__(self):\n", | |
" return self.n_sample\n", | |
"\n", | |
"# 以下のダミーデータをバッチに切り出してみます.\n", | |
"df = pd.DataFrame({\n", | |
" 'a': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n", | |
" 'b': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n", | |
" 'c': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n", | |
"})\n", | |
"dataset = MyDataset(df)\n", | |
"batch_size = 4\n", | |
"dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)\n", | |
"for i_epoch in range(2):\n", | |
" print(f'===== epoch {i_epoch} =====')\n", | |
" for i_batch, data in enumerate(dataloader):\n", | |
" print(f'----- batch {i_batch} -----')\n", | |
" print(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "da9c56b9-5d4c-4502-a713-dca0b907a8a0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"===== epoch 0 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[0, 0, 0],\n", | |
" [1, 1, 1],\n", | |
" [4, 4, 4],\n", | |
" [8, 8, 8]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[2, 2, 2],\n", | |
" [3, 3, 3],\n", | |
" [6, 6, 6],\n", | |
" [7, 7, 7]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[ 5, 5, 5],\n", | |
" [ 9, 9, 9],\n", | |
" [10, 10, 10]])\n", | |
"===== epoch 1 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[0, 0, 0],\n", | |
" [1, 1, 1],\n", | |
" [4, 4, 4],\n", | |
" [8, 8, 8]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[2, 2, 2],\n", | |
" [3, 3, 3],\n", | |
" [6, 6, 6],\n", | |
" [7, 7, 7]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[ 5, 5, 5],\n", | |
" [ 9, 9, 9],\n", | |
" [10, 10, 10]])\n" | |
] | |
} | |
], | |
"source": [ | |
"dataloader = DataLoader(dataset=dataset, batch_sampler=[[0, 1, 4, 8], [2, 3, 6, 7], [5, 9, 10]])\n", | |
"for i_epoch in range(2):\n", | |
" print(f'===== epoch {i_epoch} =====')\n", | |
" for i_batch, data in enumerate(dataloader):\n", | |
" print(f'----- batch {i_batch} -----')\n", | |
" print(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "0cc77d9d-639d-4700-99d8-db5fcc35d08a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"===== epoch 0 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[0, 0, 0],\n", | |
" [1, 1, 1],\n", | |
" [2, 2, 2],\n", | |
" [3, 3, 3]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[4, 4, 4],\n", | |
" [5, 5, 5],\n", | |
" [6, 6, 6],\n", | |
" [7, 7, 7]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[ 8, 8, 8],\n", | |
" [ 9, 9, 9],\n", | |
" [10, 10, 10]])\n", | |
"===== epoch 1 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[0, 0, 0],\n", | |
" [1, 1, 1],\n", | |
" [2, 2, 2],\n", | |
" [3, 3, 3]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[4, 4, 4],\n", | |
" [5, 5, 5],\n", | |
" [6, 6, 6],\n", | |
" [7, 7, 7]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[ 8, 8, 8],\n", | |
" [ 9, 9, 9],\n", | |
" [10, 10, 10]])\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"class MyBatchSampler:\n", | |
" \"\"\"\n", | |
" データを全くシャッフルせず順にバッチに切り出していく基本のバッチサンプラーです.\n", | |
" \"\"\"\n", | |
" def __init__(self, n_sample, batch_size):\n", | |
" self.n_sample = n_sample\n", | |
" self.batch_size = batch_size\n", | |
" self.n_batch = int(np.ceil(self.n_sample / batch_size))\n", | |
" def __iter__(self):\n", | |
" # このバッチサンプラーが使用されるとき,現在何バッチ目かをリセットします.\n", | |
" self.i_batch = -1\n", | |
" return self\n", | |
" def _get_i_batch(self, i_batch):\n", | |
" # i_batch 番目のバッチに所属するサンプルインデックスのリストを返します.\n", | |
" indices = [i_batch * self.batch_size + i for i in range(self.batch_size)]\n", | |
" if i_batch == self.n_batch - 1:\n", | |
" indices = [i for i in indices if i <= self.n_sample - 1]\n", | |
" return indices\n", | |
" def __next__(self):\n", | |
" self.i_batch += 1\n", | |
" if self.i_batch >= self.n_batch:\n", | |
" raise StopIteration()\n", | |
" return self._get_i_batch(self.i_batch)\n", | |
"\n", | |
"dataloader = DataLoader(dataset=dataset, batch_sampler=MyBatchSampler(dataset.n_sample, batch_size))\n", | |
"for i_epoch in range(2):\n", | |
" print(f'===== epoch {i_epoch} =====')\n", | |
" for i_batch, data in enumerate(dataloader):\n", | |
" print(f'----- batch {i_batch} -----')\n", | |
" print(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "0662caad-acc6-4af1-93b7-dfc92c1a9c2a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"===== epoch 0 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[ 0, 0, 0],\n", | |
" [10, 10, 10],\n", | |
" [ 8, 8, 8],\n", | |
" [ 7, 7, 7]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[9, 9, 9],\n", | |
" [6, 6, 6],\n", | |
" [4, 4, 4],\n", | |
" [1, 1, 1]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[3, 3, 3],\n", | |
" [2, 2, 2],\n", | |
" [5, 5, 5]])\n", | |
"===== epoch 1 =====\n", | |
"----- batch 0 -----\n", | |
"tensor([[1, 1, 1],\n", | |
" [2, 2, 2],\n", | |
" [0, 0, 0],\n", | |
" [8, 8, 8]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[3, 3, 3],\n", | |
" [5, 5, 5],\n", | |
" [6, 6, 6],\n", | |
" [9, 9, 9]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[10, 10, 10],\n", | |
" [ 4, 4, 4],\n", | |
" [ 7, 7, 7]])\n" | |
] | |
} | |
], | |
"source": [ | |
"class MyBatchSamplerShuffle(MyBatchSampler):\n", | |
" \"\"\"\n", | |
" データを全てシャッフルするバッチサンプラーです.\n", | |
" 各バッチ内もばらばらなサンプルインデックスのデータの集まりになります.\n", | |
" \"\"\"\n", | |
" def __init__(self, n_sample, batch_size):\n", | |
" super().__init__(n_sample, batch_size)\n", | |
" self.sample_ids_shuffled = [i for i in range(self.n_sample)]\n", | |
" def __iter__(self):\n", | |
" # このバッチサンプラーが使用されるとき,サンプルインデックスの列をかきまぜます.\n", | |
" # つまり,エポックごとにかきまぜます.\n", | |
" np.random.shuffle(self.sample_ids_shuffled)\n", | |
" return super().__iter__()\n", | |
" def __next__(self):\n", | |
" # 基底クラスの出力を得てから, かきまぜたサンプルインデックスにマッピングして返します.\n", | |
" indices = super().__next__()\n", | |
" return [self.sample_ids_shuffled[i] for i in indices]\n", | |
"\n", | |
"dataloader = DataLoader(dataset=dataset, batch_sampler=MyBatchSamplerShuffle(dataset.n_sample, batch_size))\n", | |
"for i_epoch in range(2):\n", | |
" print(f'===== epoch {i_epoch} =====')\n", | |
" for i_batch, data in enumerate(dataloader):\n", | |
" print(f'----- batch {i_batch} -----')\n", | |
" print(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "e804b009-e13d-4dff-ad0e-aafe6eb46b41", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"===== epoch 0 =====\n", | |
"只今のバッチサイズ: 4\n", | |
"----- batch 0 -----\n", | |
"tensor([[ 5, 5, 5],\n", | |
" [ 2, 2, 2],\n", | |
" [10, 10, 10],\n", | |
" [ 4, 4, 4]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[6, 6, 6],\n", | |
" [1, 1, 1],\n", | |
" [7, 7, 7],\n", | |
" [3, 3, 3]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[0, 0, 0],\n", | |
" [9, 9, 9],\n", | |
" [8, 8, 8]])\n", | |
"===== epoch 1 =====\n", | |
"只今のバッチサイズ: 2\n", | |
"----- batch 0 -----\n", | |
"tensor([[ 8, 8, 8],\n", | |
" [10, 10, 10]])\n", | |
"----- batch 1 -----\n", | |
"tensor([[7, 7, 7],\n", | |
" [2, 2, 2]])\n", | |
"----- batch 2 -----\n", | |
"tensor([[1, 1, 1],\n", | |
" [6, 6, 6]])\n", | |
"----- batch 3 -----\n", | |
"tensor([[3, 3, 3],\n", | |
" [5, 5, 5]])\n", | |
"----- batch 4 -----\n", | |
"tensor([[4, 4, 4],\n", | |
" [9, 9, 9]])\n", | |
"----- batch 5 -----\n", | |
"tensor([[0, 0, 0]])\n" | |
] | |
} | |
], | |
"source": [ | |
"class MyBatchSamplerDecaying(MyBatchSamplerShuffle):\n", | |
" \"\"\"\n", | |
" エポックごとにバッチサイズが小さくなっていくバッチサンプラーです.\n", | |
" \"\"\"\n", | |
" def __init__(self, n_sample, batch_size):\n", | |
" super().__init__(n_sample, batch_size)\n", | |
" self.batch_size_org = batch_size\n", | |
" self.i_epoch = -1\n", | |
" def __iter__(self):\n", | |
" self.i_epoch += 1\n", | |
" self.batch_size = int(np.ceil(self.batch_size_org * np.exp(-self.i_epoch)))\n", | |
" print('只今のバッチサイズ:', self.batch_size)\n", | |
" self.n_batch = int(np.ceil(self.n_sample / self.batch_size))\n", | |
" return super().__iter__()\n", | |
"\n", | |
"dataloader = DataLoader(dataset=dataset, batch_sampler=MyBatchSamplerDecaying(dataset.n_sample, batch_size))\n", | |
"for i_epoch in range(2):\n", | |
" print(f'===== epoch {i_epoch} =====')\n", | |
" for i_batch, data in enumerate(dataloader):\n", | |
" print(f'----- batch {i_batch} -----')\n", | |
" print(data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3fa2b33d-d2e6-4a21-9078-3f7ca048cd12", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment