Skip to content

Instantly share code, notes, and snippets.

@MHenderson
Last active November 21, 2019 15:15
Show Gist options
  • Save MHenderson/9fe8692a661b78115f696572161eabe3 to your computer and use it in GitHub Desktop.
Save MHenderson/9fe8692a661b78115f696572161eabe3 to your computer and use it in GitHub Desktop.
PyTorch tutorial
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Net(\n",
" (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))\n",
" (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))\n",
" (fc1): Linear(in_features=576, out_features=120, bias=True)\n",
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
" (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
")\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Net(nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 6, 3)\n",
" self.conv2 = nn.Conv2d(6, 16, 3)\n",
" self.fc1 = nn.Linear(16 * 6 * 6, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
" \n",
" def forward(self, x):\n",
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
" x = x.view(-1, self.num_flat_features(x))\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
" \n",
" def num_flat_features(self, x):\n",
" size = x.size()[1:]\n",
" num_features = 1\n",
" for s in size:\n",
" num_features *= s\n",
" return num_features\n",
" \n",
"net = Net()\n",
"print(net)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10\n",
"torch.Size([6, 1, 3, 3])\n"
]
}
],
"source": [
"params = list(net.parameters())\n",
"print(len(params))\n",
"print(params[0].size())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0080, 0.0062, 0.1065, -0.0166, -0.0649, -0.0224, -0.0546, 0.0132,\n",
" 0.1019, -0.1757]], grad_fn=<ThAddmmBackward>)\n"
]
}
],
"source": [
"input = torch.randn(1, 1, 32, 32)\n",
"out = net(input)\n",
"print(out)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"net.zero_grad()\n",
"out.backward(torch.randn(1, 10))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"output = net(input)\n",
"target = torch.randn(10)\n",
"target = target.view(1, -1)\n",
"criterion = nn.MSELoss()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.9649, grad_fn=<MseLossBackward>)\n"
]
}
],
"source": [
"loss = criterion(output, target)\n",
"print(loss)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<MseLossBackward object at 0x7f8a34878eb8>\n",
"<ThAddmmBackward object at 0x7f8a34878518>\n",
"<ExpandBackward object at 0x7f8a34878eb8>\n"
]
}
],
"source": [
"print(loss.grad_fn)\n",
"print(loss.grad_fn.next_functions[0][0])\n",
"print(loss.grad_fn.next_functions[0][0].next_functions[0][0])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"net.zero_grad()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv1.bias.grad before backward\n",
"tensor([0., 0., 0., 0., 0., 0.])\n"
]
}
],
"source": [
"print('conv1.bias.grad before backward')\n",
"print(net.conv1.bias.grad)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"conv1.bias.grad after backward\n",
"tensor([ 0.0214, 0.0103, -0.0089, -0.0115, -0.0027, 0.0058])\n"
]
}
],
"source": [
"print('conv1.bias.grad after backward')\n",
"print(net.conv1.bias.grad)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"learning_rate = 0.01\n",
"for f in net.parameters():\n",
" f.data.sub_(f.grad.data * learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.SGD(net.parameters(), lr = 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"optimizer.zero_grad()\n",
"output = net(input)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"loss = criterion(output, target)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment