Skip to content

Instantly share code, notes, and snippets.

@anirudhshenoy
Created December 2, 2019 06:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anirudhshenoy/ee429f48be7a2fbec71312c970b741b4 to your computer and use it in GitHub Desktop.
Save anirudhshenoy/ee429f48be7a2fbec71312c970b741b4 to your computer and use it in GitHub Desktop.
CNN for MNIST
import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
self.conv2 = nn.Conv2d(20, 40, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(640, 150)
self.fc2 = nn.Linear(150, 10)
self.log_softmax = nn.LogSoftmax(dim = 1)
def forward(self, x):
x = x.view(-1,1,28,28)
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 640)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = F.relu(self.fc2(x))
x = self.log_softmax(x)
return x
net = Net().cuda()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment