Skip to content

Instantly share code, notes, and snippets.

Created May 10, 2019 03:14
Show Gist options
  • Save mmsamiei/e047234bdd65636eb65228eecfe5e359 to your computer and use it in GitHub Desktop.
Save mmsamiei/e047234bdd65636eb65228eecfe5e359 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "simple_cnn_mnist.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
"cells": [
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"source": [
"<a href=\"\" target=\"_parent\"><img src=\"\" alt=\"Open In Colab\"/></a>"
"cell_type": "code",
"metadata": {
"id": "VPQzc5ayY7PT",
"colab_type": "code",
"colab": {}
"source": [
"import torch\n",
"import torchvision.datasets"
"execution_count": 0,
"outputs": []
"cell_type": "code",
"metadata": {
"id": "XtpTyrRtZePs",
"colab_type": "code",
"colab": {}
"source": [
"import torchvision.transforms as transforms\n",
"trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0,), (1.0,))])\n",
"MNIST_dataset = torchvision.datasets.MNIST('./mnist', transform=trans)"
"execution_count": 0,
"outputs": []
"cell_type": "code",
"metadata": {
"id": "6i4BcXkkZpFc",
"colab_type": "code",
"colab": {}
"source": [
"from import DataLoader\n",
"train_dl = DataLoader(MNIST_dataset, batch_size=128)"
"execution_count": 0,
"outputs": []
"cell_type": "code",
"metadata": {
"id": "dhQWM-x4bDdr",
"colab_type": "code",
"colab": {}
"source": [
"from torch import nn\n",
"class ImanNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(1, 3, 3)\n",
" self.conv2 = nn.Conv2d(3, 6, 3)\n",
" self.pool1 = nn.MaxPool2d(2)\n",
" self.conv3 = nn.Conv2d(6, 9, 5)\n",
" self.lin = nn.Linear(576, 10)\n",
" self.softmax = nn.Softmax(1)\n",
" def forward(self, x):\n",
" temp = self.conv1(x)\n",
" temp = nn.functional.relu(temp)\n",
" temp = self.conv2(temp)\n",
" temp = nn.functional.relu(temp)\n",
" temp = self.pool1(temp)\n",
" temp = self.conv3(temp)\n",
" temp = nn.functional.relu(temp)\n",
" temp = temp.view(-1, 576)\n",
" temp = self.lin(temp)\n",
" out = self.softmax(temp)\n",
" return out"
"execution_count": 0,
"outputs": []
"cell_type": "code",
"metadata": {
"id": "-WxMLSbMdBGD",
"colab_type": "code",
"colab": {}
"source": [
"test = torch.randn(1,1,28,28)"
"execution_count": 0,
"outputs": []
"cell_type": "code",
"metadata": {
"id": "VubQsjlDdavc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 53
"outputId": "047da79c-4d35-41dc-ff68-5166980ec75a"
"source": [
"model = ImanNet()\n",
"execution_count": 122,
"outputs": [
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.1055, 0.0977, 0.1070, 0.0914, 0.0856, 0.1202, 0.0983, 0.0903, 0.0961,\n",
" 0.1079]], grad_fn=<SoftmaxBackward>)"
"metadata": {
"tags": []
"execution_count": 122
"cell_type": "code",
"metadata": {
"id": "zYRkP6dxe4ot",
"colab_type": "code",
"colab": {}
"source": [
"opt = torch.optim.Adam(model.parameters())"
"execution_count": 0,
"outputs": []
"cell_type": "code",
"metadata": {
"id": "uW4am5pSeMyW",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 70
"outputId": "a870b32c-4df1-4fb6-f590-37c54eb27f66"
"source": [
"loss_arr = []\n",
"for epoch in range(3):\n",
" for xb,yb in train_dl:\n",
" out = model(xb)\n",
" loss = nn.functional.cross_entropy(out, yb)\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" loss_arr.append(loss.item())\n",
" print(\"epoch \",epoch,\", loss = \",loss)"
"execution_count": 124,
"outputs": [
"output_type": "stream",
"text": [
"epoch 0 , loss = tensor(1.6238, grad_fn=<NllLossBackward>)\n",
"epoch 1 , loss = tensor(1.6036, grad_fn=<NllLossBackward>)\n",
"epoch 2 , loss = tensor(1.5770, grad_fn=<NllLossBackward>)\n"
"name": "stdout"
"cell_type": "code",
"metadata": {
"id": "UwJkZ-81r-Z1",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
"outputId": "8131f829-7818-46f0-cef2-5f54c0ba114a"
"source": [
"correct = 0\n",
"total = 0\n",
"with torch.no_grad():\n",
" for data in train_dl:\n",
" images, labels = data\n",
" out = model(images)\n",
" _, predicted = torch.max(, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"print('Accuracy of the network on the 60000 test images: %d %%' % (\n",
" 100 * correct / total))\n"
"execution_count": 125,
"outputs": [
"output_type": "stream",
"text": [
"Accuracy of the network on the 60000 test images: 86 %\n"
"name": "stdout"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment