Skip to content

Instantly share code, notes, and snippets.

@viniciusmss
Created April 22, 2020 22:43
Show Gist options
  • Save viniciusmss/52e0eaeff6722dc9254e8c8479d52e16 to your computer and use it in GitHub Desktop.
Save viniciusmss/52e0eaeff6722dc9254e8c8479d52e16 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data Loading"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"batch_size = 64\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ])\n",
"\n",
"\n",
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
" download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
" download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
" shuffle=False, num_workers=2)\n",
"\n",
"classes = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Architectures\n",
"\n",
"## Traditional CNN"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class CNN(nn.Module):\n",
" def __init__(self):\n",
" super(CNN, self).__init__()\n",
" # convolutional layer (sees 32x32x3 image tensor)\n",
" self.conv1 = nn.Conv2d(3, 8, 3, padding=1)\n",
" # convolutional layer (sees 16x16x8 tensor)\n",
" self.conv2 = nn.Conv2d(8, 16, 3, padding=1)\n",
" # convolutional layer (sees 8x8x16 tensor)\n",
" self.conv3 = nn.Conv2d(16, 16, 3, padding=1)\n",
" # max pooling layer\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" # linear layer (16 * 4 * 4 -> 100)\n",
" self.fc1 = nn.Linear(16 * 4 * 4, 100)\n",
" # linear layer (100 -> 10)\n",
" self.fc2 = nn.Linear(100, 10)\n",
" # dropout layer (p=0.25)\n",
" self.dropout = nn.Dropout(0.25)\n",
"\n",
" def forward(self, x):\n",
" # add sequence of convolutional and max pooling layers\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = self.pool(F.relu(self.conv3(x)))\n",
" # flatten image input\n",
" x = x.view(-1, 16 * 4 * 4)\n",
" # add dropout layer\n",
" x = self.dropout(x)\n",
" # add 1st hidden layer, with relu activation function\n",
" x = F.relu(self.fc1(x))\n",
" # add dropout layer\n",
" x = self.dropout(x)\n",
" # add 2nd hidden layer, with relu activation function\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"\n",
"cnn = CNN().to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bayesian CNN"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from blitz.modules import BayesianLinear, BayesianConv2d\n",
"from blitz.utils import variational_estimator\n",
"\n",
"@variational_estimator\n",
"class BNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # convolutional layer (sees 32x32x3 image tensor)\n",
" self.conv1 = BayesianConv2d(3, 8, (3,3), padding=1)\n",
" # convolutional layer (sees 16x16x8 tensor)\n",
" self.conv2 = BayesianConv2d(8, 16, (3,3), padding=1)\n",
" # convolutional layer (sees 8x8x16 tensor)\n",
" self.conv3 = BayesianConv2d(16, 16, (3,3), padding=1)\n",
" # max pooling layer\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" # linear layer (16 * 4 * 4 -> 100)\n",
" self.fc1 = BayesianLinear(16 * 4 * 4, 100)\n",
" # linear layer (100 -> 10)\n",
" self.fc2 = BayesianLinear(100, 10)\n",
"\n",
" def forward(self, x):\n",
" # add sequence of convolutional and max pooling layers\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = self.pool(F.relu(self.conv3(x)))\n",
" # flatten image input\n",
" x = x.view(-1, 16 * 4 * 4)\n",
" # add 1st hidden layer, with relu activation function\n",
" x = F.relu(self.fc1(x))\n",
" return self.fc2(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BNN + Softplus"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"@variational_estimator\n",
"class BNN_softplus(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # convolutional layer (sees 32x32x3 image tensor)\n",
" self.conv1 = BayesianConv2d(3, 8, (3,3), padding=1)\n",
" # convolutional layer (sees 16x16x8 tensor)\n",
" self.conv2 = BayesianConv2d(8, 16, (3,3), padding=1)\n",
" # convolutional layer (sees 8x8x16 tensor)\n",
" self.conv3 = BayesianConv2d(16, 16, (3,3), padding=1)\n",
" # max pooling layer\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" # linear layer (16 * 4 * 4 -> 100)\n",
" self.fc1 = BayesianLinear(16 * 4 * 4, 100)\n",
" # linear layer (100 -> 10)\n",
" self.fc2 = BayesianLinear(100, 10)\n",
"\n",
" def forward(self, x):\n",
" # add sequence of convolutional and max pooling layers\n",
" x = self.pool(F.softplus(self.conv1(x)))\n",
" x = self.pool(F.softplus(self.conv2(x)))\n",
" x = self.pool(F.softplus(self.conv3(x)))\n",
" # flatten image input\n",
" x = x.view(-1, 16 * 4 * 4)\n",
" # add 1st hidden layer, with softplus activation function\n",
" x = F.softplus(self.fc1(x))\n",
" return self.fc2(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Linear BNN"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"@variational_estimator\n",
"class BNN_Linear(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.fc1 = BayesianLinear(32 * 32 * 3, 100)\n",
" self.fc2 = BayesianLinear(100, 10)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(-1, 32 * 32 * 3)\n",
" x = F.softplus(self.fc1(x))\n",
" return self.fc2(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training\n",
"## Traditional CNN"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"cnn_optimizer = optim.SGD(cnn.parameters(), lr=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 \tTraining Loss: 2.300790\n",
"Epoch: 1 \tTraining Loss: 2.267595\n",
"Epoch: 2 \tTraining Loss: 2.110062\n",
"Epoch: 3 \tTraining Loss: 1.999698\n",
"Epoch: 4 \tTraining Loss: 1.904039\n",
"Epoch: 5 \tTraining Loss: 1.788206\n",
"Epoch: 6 \tTraining Loss: 1.677056\n",
"Epoch: 7 \tTraining Loss: 1.621456\n",
"Epoch: 8 \tTraining Loss: 1.577198\n",
"Epoch: 9 \tTraining Loss: 1.547966\n",
"Finished Training\n"
]
}
],
"source": [
"for epoch in range(10): # loop over the dataset multiple times\n",
"\n",
" running_loss = 0.0\n",
" for i, (inputs, labels) in enumerate(trainloader, 0):\n",
" # get the inputs; data is a list of [inputs, labels]\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
"\n",
" # zero the parameter gradients\n",
" cnn_optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" outputs = cnn(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" cnn_optimizer.step()\n",
"\n",
" # Save loss\n",
" running_loss += loss.item()*inputs.size(0)\n",
"\n",
" # print training/validation statistics \n",
" running_loss = running_loss/len(trainloader.sampler)\n",
" print('Epoch: {} \\tTraining Loss: {:.6f}'.format(epoch, running_loss))\n",
"print('Finished Training')\n",
"\n",
"CNN_PATH = './cifar_cnn.pth'\n",
"torch.save(cnn.state_dict(), CNN_PATH)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bayesian CNN"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def train_bnn(net, optimizer):\n",
" for epoch in range(10): # loop over the dataset multiple times\n",
" iteration = total_loss = 0\n",
" for i, (inputs, labels) in enumerate(trainloader, 0):\n",
" # get the inputs; data is a list of [inputs, labels]\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" loss = net.sample_elbo(inputs=inputs,\n",
" labels=labels,\n",
" criterion=criterion,\n",
" sample_nbr=5,\n",
" complexity_cost_weight = 1 / len(trainloader))\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Save loss\n",
" total_loss += loss\n",
" iteration += 1\n",
"\n",
" if iteration%50==0:\n",
" print(\"Epoch: {}.\\t Loss: {:.4f}\".format(epoch, total_loss/iteration), end=\"\\r\") \n",
"\n",
" # print training/validation statistics \n",
" print('Epoch: {} \\tTraining Loss: {:.6f}'.format(epoch, total_loss / len(trainloader)))\n",
" print('Finished Training')\n"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 \tTraining Loss: 4.698672\n",
"Epoch: 1 \tTraining Loss: 0.584409\n",
"Epoch: 2 \tTraining Loss: 0.518806\n",
"Epoch: 3 \tTraining Loss: 0.494763\n",
"Epoch: 4 \tTraining Loss: 0.483292\n",
"Epoch: 5 \tTraining Loss: 0.476598\n",
"Epoch: 6 \tTraining Loss: 0.477477\n",
"Epoch: 7 \tTraining Loss: 0.473085\n",
"Epoch: 8 \tTraining Loss: 0.471659\n",
"Epoch: 9 \tTraining Loss: 0.470142\n",
"Finished Training\n"
]
}
],
"source": [
"bnn_conv = BNN().to(device)\n",
"bnn_conv_optimizer = optim.SGD(bnn_conv.parameters(), lr=0.001)\n",
"train_bnn(bnn_conv, bnn_conv_optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BNN + Softplus"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 \tTraining Loss: 4.018325\n",
"Epoch: 1 \tTraining Loss: 0.534618\n",
"Epoch: 2 \tTraining Loss: 0.496891\n",
"Epoch: 3 \tTraining Loss: 0.495052\n",
"Epoch: 4 \tTraining Loss: 0.476044\n",
"Epoch: 5 \tTraining Loss: 0.474423\n",
"Epoch: 6 \tTraining Loss: 0.472371\n",
"Epoch: 7 \tTraining Loss: 0.473131\n",
"Epoch: 8 \tTraining Loss: 0.469505\n",
"Epoch: 9 \tTraining Loss: 0.467890\n",
"Finished Training\n"
]
}
],
"source": [
"bnn_softplus = BNN_softplus().to(device)\n",
"bnn_softplus_optimizer = optim.SGD(bnn_softplus.parameters(), lr=0.001)\n",
"train_bnn(bnn_softplus, bnn_softplus_optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Linear BNN\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 0 \tTraining Loss: 0.717498\n",
"Epoch: 1 \tTraining Loss: 0.567976\n",
"Epoch: 2 \tTraining Loss: 0.529524\n",
"Epoch: 3 \tTraining Loss: 0.506957\n",
"Epoch: 4 \tTraining Loss: 0.490132\n",
"Epoch: 5 \tTraining Loss: 0.477800\n",
"Epoch: 6 \tTraining Loss: 0.467948\n",
"Epoch: 7 \tTraining Loss: 0.459822\n",
"Epoch: 8 \tTraining Loss: 0.452051\n",
"Epoch: 9 \tTraining Loss: 0.446135\n",
"Finished Training\n"
]
}
],
"source": [
"bnn_linear = BNN_Linear().to(device)\n",
"bnn_linear_optimizer = optim.SGD(bnn_linear.parameters(), lr=0.001)\n",
"train_bnn(bnn_linear, bnn_linear_optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test Performance "
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"def run_tests(net):\n",
"\n",
" correct = total = 0\n",
" class_correct = list(0. for i in range(10))\n",
" class_total = list(0. for i in range(10))\n",
" with torch.no_grad():\n",
" for images, labels in testloader:\n",
" images, labels = images.to(device), labels.to(device)\n",
" outputs = net(images)\n",
" _, predicted = torch.max(outputs, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" c = (predicted == labels).squeeze()\n",
" for i in range(4):\n",
" label = labels[i]\n",
" class_correct[label] += c[i].item()\n",
" class_total[label] += 1\n",
"\n",
" print('Accuracy of the network on the 10000 test images: %d %%\\n' % (\n",
" 100 * correct / total))\n",
" for i in range(10):\n",
" print('Accuracy of %5s : %2d %%' % (\n",
" classes[i], 100 * class_correct[i] / class_total[i]))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the network on the 10000 test images: 43 %\n",
"\n",
"Accuracy of plane : 51 %\n",
"Accuracy of car : 54 %\n",
"Accuracy of bird : 18 %\n",
"Accuracy of cat : 23 %\n",
"Accuracy of deer : 23 %\n",
"Accuracy of dog : 45 %\n",
"Accuracy of frog : 62 %\n",
"Accuracy of horse : 40 %\n",
"Accuracy of ship : 58 %\n",
"Accuracy of truck : 52 %\n"
]
}
],
"source": [
"run_tests(cnn)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the network on the 10000 test images: 10 %\n",
"\n",
"Accuracy of plane : 0 %\n",
"Accuracy of car : 84 %\n",
"Accuracy of bird : 0 %\n",
"Accuracy of cat : 0 %\n",
"Accuracy of deer : 0 %\n",
"Accuracy of dog : 0 %\n",
"Accuracy of frog : 0 %\n",
"Accuracy of horse : 1 %\n",
"Accuracy of ship : 15 %\n",
"Accuracy of truck : 0 %\n"
]
}
],
"source": [
"run_tests(bnn_conv)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the network on the 10000 test images: 10 %\n",
"\n",
"Accuracy of plane : 0 %\n",
"Accuracy of car : 0 %\n",
"Accuracy of bird : 88 %\n",
"Accuracy of cat : 0 %\n",
"Accuracy of deer : 5 %\n",
"Accuracy of dog : 0 %\n",
"Accuracy of frog : 0 %\n",
"Accuracy of horse : 0 %\n",
"Accuracy of ship : 0 %\n",
"Accuracy of truck : 0 %\n"
]
}
],
"source": [
"run_tests(bnn_softplus)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the network on the 10000 test images: 25 %\n",
"\n",
"Accuracy of plane : 48 %\n",
"Accuracy of car : 34 %\n",
"Accuracy of bird : 22 %\n",
"Accuracy of cat : 20 %\n",
"Accuracy of deer : 14 %\n",
"Accuracy of dog : 22 %\n",
"Accuracy of frog : 35 %\n",
"Accuracy of horse : 20 %\n",
"Accuracy of ship : 48 %\n",
"Accuracy of truck : 26 %\n"
]
}
],
"source": [
"run_tests(bnn_linear)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment