Skip to content

Instantly share code, notes, and snippets.

@vinupriyesh
Created February 4, 2018 18:44
Show Gist options
  • Save vinupriyesh/c764d26100e127d0d0f434b1c5b2cd51 to your computer and use it in GitHub Desktop.
Save vinupriyesh/c764d26100e127d0d0f434b1c5b2cd51 to your computer and use it in GitHub Desktop.
Hello world pytorch network
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Problem statement \n",
"Design a blunt system which counts how many ones were there in the input of size 2\n",
"\n",
"#### Input: \n",
"Size : 2 nodes \n",
"Values : 0 or 1 \n",
"\n",
"#### Output:\n",
"Size : 3 nodes \n",
"Values : 0 or 1 \n",
"- Node 1 is selected for no Ones \n",
"- Node 2 for 1 Ones \n",
"- Node 3 for 2 Ones "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solution \n",
"First thing is to import the required modules"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step is to define the model class"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"class SimpleModel(nn.Module):\n",
" def __init__(self):\n",
" super(SimpleModel,self).__init__()\n",
" self.fc1 = nn.Linear(2,3)\n",
" self.fc2 = nn.Linear(3,3)\n",
" def forward(self,x):\n",
" z1 = self.fc1(x)\n",
" a1 = torch.tanh(z1)\n",
" z2 = self.fc2(a1)\n",
" a2 = torch.sigmoid(z2)\n",
" return a2\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, the input and output can be processed"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def generate_input():\n",
" x = torch.Tensor(1, 2).uniform_(0, 1)\n",
" return torch.bernoulli(x) \n",
"\n",
"\n",
"def generate_output(x):\n",
" y = torch.zeros(1,3)\n",
" sum_x = x.sum()\n",
" if sum_x == 0:\n",
" y[0,0] = 1\n",
" elif sum_x == 1:\n",
" y[0,1] = 1\n",
" else:\n",
" y[0,2] = 1\n",
" return y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just a check whether the input and output is fine"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" 0 1\n",
"[torch.FloatTensor of size 1x2]\n",
"\n",
"\n",
" 0 1 0\n",
"[torch.FloatTensor of size 1x3]\n",
"\n"
]
}
],
"source": [
"a = generate_input()\n",
"print(a)\n",
"b = generate_output(a)\n",
"print(b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Everything is ready, can create the model instance, an optimizer and loop to learn the model"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss on iteration 0 is 0.4865916669368744\n",
"Loss on iteration 100 is 0.2068905383348465\n",
"Loss on iteration 200 is 0.10168895125389099\n",
"Loss on iteration 300 is 0.12561121582984924\n",
"Loss on iteration 400 is 0.048275481909513474\n",
"Loss on iteration 500 is 0.0448077954351902\n",
"Loss on iteration 600 is 0.004153769463300705\n",
"Loss on iteration 700 is 0.002344121690839529\n",
"Loss on iteration 800 is 0.0010068188421428204\n",
"Loss on iteration 900 is 0.0005072578205727041\n"
]
}
],
"source": [
"nn_model = SimpleModel()\n",
"criterion = nn.BCELoss() #Binary Cross Entropy Loss\n",
"optimizer = torch.optim.Adam(nn_model.parameters(),lr = 0.01)\n",
"\n",
"for i in range(1000):\n",
" x = generate_input()\n",
" y = generate_output(x)\n",
" x = Variable(x)\n",
" y = Variable(y)\n",
" y_hat = nn_model(x)\n",
" loss = criterion(y_hat,y)\n",
" loss.backward() \n",
" optimizer.step()\n",
" if i % 100 == 0:\n",
" print(\"Loss on iteration {} is {}\".format(i,loss.data[0])) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check what is the output value"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All done\n"
]
}
],
"source": [
"def check_model():\n",
" x = Variable(torch.FloatTensor([0,0]))\n",
" y = nn_model(x).round().int()\n",
" assert y.data[0] == 1 and y.data.sum() == 1\n",
"\n",
" x = Variable(torch.FloatTensor([0,1]))\n",
" y = nn_model(x).round().int()\n",
" assert y.data[1] == 1 and y.data.sum() == 1\n",
"\n",
" x = Variable(torch.FloatTensor([1,0]))\n",
" y = nn_model(x).round().int()\n",
" assert y.data[1] == 1 and y.data.sum() == 1\n",
"\n",
" x = Variable(torch.FloatTensor([1,1]))\n",
" y = nn_model(x).round().int()\n",
" assert y.data[2] == 1 and y.data.sum() == 1\n",
"\n",
"check_model()\n",
"print(\"All done\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment