-
-
Save gravitino/4ee16eaa6aadef65a6082ffe232f4f8c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training a Variational Autoencoder (VAE)\n", | |
"\n", | |
"In the last step, we train a Variational Autencoder (VAE) using Pytorch. Hence, let use first install the package torch via pip. That might take a while." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Installing PyTorch...\n", | |
"\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Installing PyTorch...\")\n", | |
"!pip -q install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Subsequently, we define the network topology. Here, we use a convolutional version but you could also experiment with a classical [MLP VAE](https://github.com/pytorch/examples/blob/master/vae/main.py)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"class Swish(torch.nn.Module): \n", | |
"\n", | |
" def __init__(self):\n", | |
" super().__init__() \n", | |
" self.alpha = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True))\n", | |
" \n", | |
" def forward(self, x):\n", | |
" return x*torch.sigmoid(self.alpha.to(x.device)*x)\n", | |
"\n", | |
"class Downsample1d(torch.nn.Module): \n", | |
" \n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.filter = torch.tensor([1.0, 2.0, 1.0]).view(1, 1, 3)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" w = torch.cat([self.filter]*x.shape[1], dim=0).to(x.device)\n", | |
" return torch.nn.functional.conv1d(x, w, stride=2, padding=1, groups=x.shape[1])\n", | |
"\n", | |
"class LightVAE(torch.nn.Module):\n", | |
" def __init__(self, num_dims):\n", | |
" super(LightVAE, self).__init__()\n", | |
" \n", | |
" self.num_dims = num_dims\n", | |
" assert num_dims & num_dims-1 == 0, \"num_dims must be power of 2\"\n", | |
" \n", | |
" self.down = Downsample1d()\n", | |
" self.up = torch.nn.Upsample(scale_factor=2)\n", | |
" self.sigma = Swish()\n", | |
" \n", | |
" self.conv0 = torch.nn.Conv1d(1, 2, kernel_size=3, stride=1, padding=1)\n", | |
" self.conv1 = torch.nn.Conv1d(2, 4, kernel_size=3, stride=1, padding=1)\n", | |
" self.conv2 = torch.nn.Conv1d(4, 8, kernel_size=3, stride=1, padding=1)\n", | |
" self.convA = torch.nn.Conv1d(8, 2, kernel_size=3, stride=1, padding=1)\n", | |
" self.convB = torch.nn.Conv1d(8, 2, kernel_size=3, stride=1, padding=1)\n", | |
"\n", | |
" self.restore = torch.nn.Linear(2, 8*num_dims//8)\n", | |
" \n", | |
" self.conv3 = torch.nn.Conv1d( 8, 4, kernel_size=3, stride=1, padding=1)\n", | |
" self.conv4 = torch.nn.Conv1d( 4, 2, kernel_size=3, stride=1, padding=1)\n", | |
" self.conv5 = torch.nn.Conv1d( 2, 1, kernel_size=3, stride=1, padding=1)\n", | |
" \n", | |
" def encode(self, x):\n", | |
" \n", | |
" x = x.view(-1, 1, self.num_dims)\n", | |
" x = self.down(self.sigma(self.conv0(x)))\n", | |
" x = self.down(self.sigma(self.conv1(x)))\n", | |
" x = self.down(self.sigma(self.conv2(x)))\n", | |
" \n", | |
" return torch.mean(self.convA(x), dim=(2,)), \\\n", | |
" torch.mean(self.convB(x), dim=(2,))\n", | |
"\n", | |
" def reparameterize(self, mu, logvar):\n", | |
" \n", | |
" std = torch.exp(0.5*logvar)\n", | |
" eps = torch.randn_like(std)\n", | |
" \n", | |
" return mu + eps*std\n", | |
"\n", | |
" def decode(self, z):\n", | |
" \n", | |
" x = self.restore(z).view(-1, 8, self.num_dims//8)\n", | |
" x = self.sigma(self.conv3(self.up(x)))\n", | |
" x = self.sigma(self.conv4(self.up(x))) \n", | |
" \n", | |
" return self.conv5(self.up(x)).view(-1, self.num_dims)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" mu, logvar = self.encode(x)\n", | |
" z = self.reparameterize(mu, logvar)\n", | |
" return self.decode(z), mu, logvar\n", | |
" \n", | |
"# Reconstruction + KL divergence losses summed over all elements and batch\n", | |
"def loss_function(recon_x, x, mu, logvar):\n", | |
" MSE = torch.sum(torch.mean(torch.square(recon_x-x), dim=1))\n", | |
"\n", | |
" # see Appendix B from VAE paper:\n", | |
" # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014\n", | |
" # https://arxiv.org/abs/1312.6114\n", | |
" # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n", | |
" KLD = -0.1 * torch.sum(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))\n", | |
"\n", | |
" return MSE + KLD" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Pytorch expects its dedicated tensor type and thus we need to map the CuPy array data_cupy to a FloatTensor. We perform that again using zero-copy functionality via DLPack. The remaining code is plain Pytorch program that trains the VAE on the training set for 10 epochs using the Adam optimizer." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# zero-copy to pytorch tensors using dlpack\n", | |
"from torch.utils import dlpack\n", | |
"\n", | |
"cp.random.seed(42)\n", | |
"cp.random.shuffle(data_cupy)\n", | |
"\n", | |
"split = int(0.75*len(data_cupy))\n", | |
"trn_torch = dlpack.from_dlpack(data_cupy[:split].toDlpack())\n", | |
"tst_torch = dlpack.from_dlpack(data_cupy[split:].toDlpack())\n", | |
"\n", | |
"dim = trn_torch.shape[1]\n", | |
"model = LightVAE(dim).to('cuda')\n", | |
"optimizer = torch.optim.Adam(model.parameters())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"====> Epoch: 0 Average loss: 0.0186\n", | |
"====> Epoch: 1 Average loss: 0.0077\n", | |
"====> Epoch: 2 Average loss: 0.0066\n", | |
"====> Epoch: 3 Average loss: 0.0063\n", | |
"====> Epoch: 4 Average loss: 0.0061\n", | |
"====> Epoch: 5 Average loss: 0.0060\n", | |
"====> Epoch: 6 Average loss: 0.0059\n", | |
"====> Epoch: 7 Average loss: 0.0059\n", | |
"====> Epoch: 8 Average loss: 0.0058\n", | |
"====> Epoch: 9 Average loss: 0.0058\n" | |
] | |
} | |
], | |
"source": [ | |
"# let's train a VAE\n", | |
"NUM_EPOCHS = 10\n", | |
"BATCH_SIZE = 1024\n", | |
"\n", | |
"trn_loader = torch.utils.data.DataLoader(trn_torch, batch_size=BATCH_SIZE, shuffle=True)\n", | |
"tst_loader = torch.utils.data.DataLoader(tst_torch, batch_size=BATCH_SIZE, shuffle=False)\n", | |
"\n", | |
"model.train()\n", | |
"for epoch in range(NUM_EPOCHS):\n", | |
" trn_loss = 0.0\n", | |
" for data in trn_loader:\n", | |
" optimizer.zero_grad()\n", | |
" recon_batch, mu, logvar = model(data)\n", | |
" loss = loss_function(recon_batch, data, mu, logvar)\n", | |
" loss.backward()\n", | |
" trn_loss += loss.item()\n", | |
" optimizer.step()\n", | |
" \n", | |
" print('====> Epoch: {} Average loss: {:.4f}'.format(\n", | |
" epoch, trn_loss / len(trn_loader.dataset)))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment