Skip to content

Instantly share code, notes, and snippets.

@monajalal
Last active November 12, 2020 00:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save monajalal/8f0226aba24c98235667a5f0a1216aca to your computer and use it in GitHub Desktop.
Save monajalal/8f0226aba24c98235667a5f0a1216aca to your computer and use it in GitHub Desktop.
trying a VAE training on my custom images
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Bad key savefig.frameon in file /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle, line 421 ('savefig.frameon : True')\n",
"You probably need to get an updated matplotlibrc file from\n",
"https://github.com/matplotlib/matplotlib/blob/v3.3.2/matplotlibrc.template\n",
"or from the matplotlib source distribution\n",
"\n",
"Bad key verbose.level in file /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle, line 472 ('verbose.level : silent # one of silent, helpful, debug, debug-annoying')\n",
"You probably need to get an updated matplotlibrc file from\n",
"https://github.com/matplotlib/matplotlib/blob/v3.3.2/matplotlibrc.template\n",
"or from the matplotlib source distribution\n",
"\n",
"Bad key verbose.fileo in file /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle, line 473 ('verbose.fileo : sys.stdout # a log filename, sys.stdout or sys.stderr')\n",
"You probably need to get an updated matplotlibrc file from\n",
"https://github.com/matplotlib/matplotlib/blob/v3.3.2/matplotlibrc.template\n",
"or from the matplotlib source distribution\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The savefig.jpeg_quality rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The keymap.all_axes rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The animation.avconv_path rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n",
"In /home/mona/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: \n",
"The animation.avconv_args rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.\n"
]
}
],
"source": [
"# https://github.com/TDehaene/blogposts/blob/master/vae_new_food/notebooks/vae_pytorch.ipynb\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"from skimage import io, transform\n",
"from torch import nn, optim\n",
"from torch.nn import functional as F\n",
"from torchvision import datasets, transforms\n",
"from torch.autograd import Variable\n",
"from torchvision.utils import save_image"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f9afba7c290>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_size = 8\n",
"epochs = 50\n",
"no_cuda = False\n",
"seed = 1\n",
"log_interval = 50\n",
"\n",
"cuda = not no_cuda and torch.cuda.is_available()\n",
"\n",
"torch.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"device is cuda and kwargs is {'num_workers': 1, 'pin_memory': True}\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if cuda else \"cpu\")\n",
"kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}\n",
"print('device is {} and kwargs is {}'.format(device, kwargs))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train_root = 'labeled-data/train_moth'\n",
"val_root = 'labeled-data/val_moth'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"my_transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Resize((100,100))\n",
" ])\n",
"\n",
"train_loader_food = torch.utils.data.DataLoader(\n",
" datasets.ImageFolder(train_root, transform = my_transform),\n",
" batch_size = batch_size, shuffle=True, **kwargs)\n",
"\n",
"val_loader_food = torch.utils.data.DataLoader(\n",
" datasets.ImageFolder(val_root, transform = my_transform),\n",
" batch_size = batch_size, shuffle=True, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class VAE_CNN(nn.Module):\n",
" def __init__(self):\n",
" super(VAE_CNN, self).__init__()\n",
"\n",
" # Encoder\n",
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(16)\n",
" self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(32)\n",
" self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn3 = nn.BatchNorm2d(64)\n",
" self.conv4 = nn.Conv2d(64, 16, kernel_size=3, stride=2, padding=1, bias=False)\n",
" self.bn4 = nn.BatchNorm2d(16)\n",
"\n",
" # Latent vectors mu and sigma\n",
" self.fc1 = nn.Linear(25 * 25 * 16, 2048)\n",
" self.fc_bn1 = nn.BatchNorm1d(2048)\n",
" self.fc21 = nn.Linear(2048, 2048)\n",
" self.fc22 = nn.Linear(2048, 2048)\n",
"\n",
" # Sampling vector\n",
" self.fc3 = nn.Linear(2048, 2048)\n",
" self.fc_bn3 = nn.BatchNorm1d(2048)\n",
" self.fc4 = nn.Linear(2048, 25 * 25 * 16)\n",
" self.fc_bn4 = nn.BatchNorm1d(25 * 25 * 16)\n",
"\n",
" # Decoder\n",
" self.conv5 = nn.ConvTranspose2d(16, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)\n",
" self.bn5 = nn.BatchNorm2d(64)\n",
" self.conv6 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn6 = nn.BatchNorm2d(32)\n",
" self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)\n",
" self.bn7 = nn.BatchNorm2d(16)\n",
" self.conv8 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1, bias=False)\n",
"\n",
" self.relu = nn.ReLU()\n",
"\n",
" def encode(self, x):\n",
" conv1 = self.relu(self.bn1(self.conv1(x)))\n",
" conv2 = self.relu(self.bn2(self.conv2(conv1)))\n",
" conv3 = self.relu(self.bn3(self.conv3(conv2)))\n",
" conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, 25 * 25 * 16)\n",
"\n",
" fc1 = self.relu(self.fc_bn1(self.fc1(conv4)))\n",
"\n",
" r1 = self.fc21(fc1)\n",
" r2 = self.fc22(fc1)\n",
" \n",
" return r1, r2\n",
"\n",
" def reparameterize(self, mu, logvar):\n",
" if self.training:\n",
" std = logvar.mul(0.5).exp_()\n",
" eps = Variable(std.data.new(std.size()).normal_())\n",
" return eps.mul(std).add_(mu)\n",
" else:\n",
" return mu\n",
"\n",
" def decode(self, z):\n",
" fc3 = self.relu(self.fc_bn3(self.fc3(z)))\n",
" fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 25, 25)\n",
"\n",
" conv5 = self.relu(self.bn5(self.conv5(fc4)))\n",
" conv6 = self.relu(self.bn6(self.conv6(conv5)))\n",
" conv7 = self.relu(self.bn7(self.conv7(conv6)))\n",
" return self.conv8(conv7).view(-1, 3, 100, 100)\n",
"\n",
" def forward(self, x):\n",
" mu, logvar = self.encode(x)\n",
" z = self.reparameterize(mu, logvar)\n",
" return self.decode(z), mu, logvar"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class customLoss(nn.Module):\n",
" def __init__(self):\n",
" super(customLoss, self).__init__()\n",
" self.mse_loss = nn.MSELoss(reduction=\"sum\")\n",
"\n",
" def forward(self, x_recon, x, mu, logvar):\n",
" loss_MSE = self.mse_loss(x_recon, x)\n",
" loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
"\n",
" return loss_MSE + loss_KLD"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model = VAE_CNN().to(device)\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"loss_mse = customLoss()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"val_losses = []\n",
"train_losses = []"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def train(epoch):\n",
" model.train()\n",
" train_loss = 0\n",
" for batch_idx, (data, _) in enumerate(train_loader_food):\n",
" data = data.to(device) \n",
" ##data = transforms.ToPILImage()(data) # with or without still get the pil / tensor error\n",
" optimizer.zero_grad()\n",
" recon_batch, mu, logvar = model(data)\n",
" loss = loss_mse(recon_batch, data, mu, logvar)\n",
" loss.backward()\n",
" train_loss += loss.item()\n",
" optimizer.step()\n",
" if batch_idx % log_interval == 0:\n",
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" epoch, batch_idx * len(data), len(train_loader_food.dataset),\n",
" 100. * batch_idx / len(train_loader_food),\n",
" loss.item() / len(data)))\n",
"\n",
" print('====> Epoch: {} Average loss: {:.4f}'.format(\n",
" epoch, train_loss / len(train_loader_food.dataset)))\n",
" train_losses.append(train_loss / len(train_loader_food.dataset))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def test(epoch):\n",
" model.eval()\n",
" test_loss = 0\n",
" with torch.no_grad():\n",
" for i, (data, _) in enumerate(val_loader_food):\n",
" data = data.to(device)\n",
" recon_batch, mu, logvar = model(data)\n",
" test_loss += loss_mse(recon_batch, data, mu, logvar).item()\n",
" if i == 0:\n",
" n = min(data.size(0), 8)\n",
" comparison = torch.cat([data[:n],\n",
" recon_batch.view(batch_size, 3, 100, 100)[:n]])\n",
" save_image(comparison.cpu(),\n",
" 'VAE_results/reconstruction_' + str(epoch) + '.png', nrow=n)\n",
"\n",
" test_loss /= len(val_loader_food.dataset)\n",
" print('====> Test set loss: {:.4f}'.format(test_loss))\n",
" val_losses.append(test_loss)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"ename": "TypeError",
"evalue": "Caught TypeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py\", line 185, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\", line 44, in fetch\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\", line 44, in <listcomp>\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/datasets/folder.py\", line 139, in __getitem__\n sample = self.transform(sample)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py\", line 61, in __call__\n img = t(img)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py\", line 244, in __call__\n return F.resize(img, self.size, self.interpolation)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py\", line 319, in resize\n raise TypeError('img should be PIL Image. Got {}'.format(type(img)))\nTypeError: img should be PIL Image. Got <class 'torch.Tensor'>\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-18-37f467c4f834>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0msample\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2048\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-16-2c647e00a259>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mtrain_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader_food\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransforms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mToPILImage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 987\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_task_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 989\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_process_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 990\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 991\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_process_data\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 1012\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_try_put_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1013\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1014\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1015\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1016\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/_utils.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;31m# (https://bugs.python.org/issue2651), so we work around it.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 394\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mKeyErrorMessage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 395\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: Caught TypeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py\", line 185, in _worker_loop\n data = fetcher.fetch(index)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\", line 44, in fetch\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py\", line 44, in <listcomp>\n data = [self.dataset[idx] for idx in possibly_batched_index]\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/datasets/folder.py\", line 139, in __getitem__\n sample = self.transform(sample)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py\", line 61, in __call__\n img = t(img)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py\", line 244, in __call__\n return F.resize(img, self.size, self.interpolation)\n File \"/home/mona/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py\", line 319, in resize\n raise TypeError('img should be PIL Image. Got {}'.format(type(img)))\nTypeError: img should be PIL Image. Got <class 'torch.Tensor'>\n"
]
}
],
"source": [
"for epoch in range(1, epochs + 1):\n",
" train(epoch)\n",
" test(epoch)\n",
" with torch.no_grad():\n",
" sample = torch.randn(2, 2048).to(device)\n",
" sample = model.decode(sample).cpu()\n",
" save_image(sample.view(2, 3, 100, 100),\n",
" 'VAE_results/sample_' + str(epoch) + '.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@monajalal
Copy link
Author

Changing the my_transform to the following didn't solve the problem:

my_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Resize((100,100))
                      ])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment