Skip to content

Instantly share code, notes, and snippets.

@gravitino
Last active Aug 25, 2021
Embed
What would you like to do?
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training a Variational Autoencoder (VAE)\n",
"\n",
"In the last step, we train a Variational Autencoder (VAE) using Pytorch. Hence, let use first install the package torch via pip. That might take a while."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Installing PyTorch...\n",
"\u001b[33mWARNING: Running pip as root will break packages and permissions. You should install packages reliably by using venv: https://pip.pypa.io/warnings/venv\u001b[0m\n"
]
}
],
"source": [
"print(\"Installing PyTorch...\")\n",
"!pip -q install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Subsequently, we define the network topology. Here, we use a convolutional version but you could also experiment with a classical [MLP VAE](https://github.com/pytorch/examples/blob/master/vae/main.py)."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class Swish(torch.nn.Module): \n",
"\n",
" def __init__(self):\n",
" super().__init__() \n",
" self.alpha = torch.nn.Parameter(torch.tensor([1.0], requires_grad=True))\n",
" \n",
" def forward(self, x):\n",
" return x*torch.sigmoid(self.alpha.to(x.device)*x)\n",
"\n",
"class Downsample1d(torch.nn.Module): \n",
" \n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" self.filter = torch.tensor([1.0, 2.0, 1.0]).view(1, 1, 3)\n",
" \n",
" def forward(self, x):\n",
" w = torch.cat([self.filter]*x.shape[1], dim=0).to(x.device)\n",
" return torch.nn.functional.conv1d(x, w, stride=2, padding=1, groups=x.shape[1])\n",
"\n",
"class LightVAE(torch.nn.Module):\n",
" def __init__(self, num_dims):\n",
" super(LightVAE, self).__init__()\n",
" \n",
" self.num_dims = num_dims\n",
" assert num_dims & num_dims-1 == 0, \"num_dims must be power of 2\"\n",
" \n",
" self.down = Downsample1d()\n",
" self.up = torch.nn.Upsample(scale_factor=2)\n",
" self.sigma = Swish()\n",
" \n",
" self.conv0 = torch.nn.Conv1d(1, 2, kernel_size=3, stride=1, padding=1)\n",
" self.conv1 = torch.nn.Conv1d(2, 4, kernel_size=3, stride=1, padding=1)\n",
" self.conv2 = torch.nn.Conv1d(4, 8, kernel_size=3, stride=1, padding=1)\n",
" self.convA = torch.nn.Conv1d(8, 2, kernel_size=3, stride=1, padding=1)\n",
" self.convB = torch.nn.Conv1d(8, 2, kernel_size=3, stride=1, padding=1)\n",
"\n",
" self.restore = torch.nn.Linear(2, 8*num_dims//8)\n",
" \n",
" self.conv3 = torch.nn.Conv1d( 8, 4, kernel_size=3, stride=1, padding=1)\n",
" self.conv4 = torch.nn.Conv1d( 4, 2, kernel_size=3, stride=1, padding=1)\n",
" self.conv5 = torch.nn.Conv1d( 2, 1, kernel_size=3, stride=1, padding=1)\n",
" \n",
" def encode(self, x):\n",
" \n",
" x = x.view(-1, 1, self.num_dims)\n",
" x = self.down(self.sigma(self.conv0(x)))\n",
" x = self.down(self.sigma(self.conv1(x)))\n",
" x = self.down(self.sigma(self.conv2(x)))\n",
" \n",
" return torch.mean(self.convA(x), dim=(2,)), \\\n",
" torch.mean(self.convB(x), dim=(2,))\n",
"\n",
" def reparameterize(self, mu, logvar):\n",
" \n",
" std = torch.exp(0.5*logvar)\n",
" eps = torch.randn_like(std)\n",
" \n",
" return mu + eps*std\n",
"\n",
" def decode(self, z):\n",
" \n",
" x = self.restore(z).view(-1, 8, self.num_dims//8)\n",
" x = self.sigma(self.conv3(self.up(x)))\n",
" x = self.sigma(self.conv4(self.up(x))) \n",
" \n",
" return self.conv5(self.up(x)).view(-1, self.num_dims)\n",
"\n",
" def forward(self, x):\n",
" mu, logvar = self.encode(x)\n",
" z = self.reparameterize(mu, logvar)\n",
" return self.decode(z), mu, logvar\n",
" \n",
"# Reconstruction + KL divergence losses summed over all elements and batch\n",
"def loss_function(recon_x, x, mu, logvar):\n",
" MSE = torch.sum(torch.mean(torch.square(recon_x-x), dim=1))\n",
"\n",
" # see Appendix B from VAE paper:\n",
" # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014\n",
" # https://arxiv.org/abs/1312.6114\n",
" # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
" KLD = -0.1 * torch.sum(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))\n",
"\n",
" return MSE + KLD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pytorch expects its dedicated tensor type and thus we need to map the CuPy array data_cupy to a FloatTensor. We perform that again using zero-copy functionality via DLPack. The remaining code is plain Pytorch program that trains the VAE on the training set for 10 epochs using the Adam optimizer."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# zero-copy to pytorch tensors using dlpack\n",
"from torch.utils import dlpack\n",
"\n",
"cp.random.seed(42)\n",
"cp.random.shuffle(data_cupy)\n",
"\n",
"split = int(0.75*len(data_cupy))\n",
"trn_torch = dlpack.from_dlpack(data_cupy[:split].toDlpack())\n",
"tst_torch = dlpack.from_dlpack(data_cupy[split:].toDlpack())\n",
"\n",
"dim = trn_torch.shape[1]\n",
"model = LightVAE(dim).to('cuda')\n",
"optimizer = torch.optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"====> Epoch: 0 Average loss: 0.0186\n",
"====> Epoch: 1 Average loss: 0.0077\n",
"====> Epoch: 2 Average loss: 0.0066\n",
"====> Epoch: 3 Average loss: 0.0063\n",
"====> Epoch: 4 Average loss: 0.0061\n",
"====> Epoch: 5 Average loss: 0.0060\n",
"====> Epoch: 6 Average loss: 0.0059\n",
"====> Epoch: 7 Average loss: 0.0059\n",
"====> Epoch: 8 Average loss: 0.0058\n",
"====> Epoch: 9 Average loss: 0.0058\n"
]
}
],
"source": [
"# let's train a VAE\n",
"NUM_EPOCHS = 10\n",
"BATCH_SIZE = 1024\n",
"\n",
"trn_loader = torch.utils.data.DataLoader(trn_torch, batch_size=BATCH_SIZE, shuffle=True)\n",
"tst_loader = torch.utils.data.DataLoader(tst_torch, batch_size=BATCH_SIZE, shuffle=False)\n",
"\n",
"model.train()\n",
"for epoch in range(NUM_EPOCHS):\n",
" trn_loss = 0.0\n",
" for data in trn_loader:\n",
" optimizer.zero_grad()\n",
" recon_batch, mu, logvar = model(data)\n",
" loss = loss_function(recon_batch, data, mu, logvar)\n",
" loss.backward()\n",
" trn_loss += loss.item()\n",
" optimizer.step()\n",
" \n",
" print('====> Epoch: {} Average loss: {:.4f}'.format(\n",
" epoch, trn_loss / len(trn_loader.dataset)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment