Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
#Load packages
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
import torch.nn.functional as F
train_dataset = dsets.MNIST(root = '/root/workspace/data',
transform = transforms.ToTensor(),
download = True)
test_dataset = dsets.MNIST(root = '/root/workspace/data',
transform = transforms.ToTensor())
train_loader = = train_dataset,
batch_size = batch_size,
shuffle = True)
test_loader = = test_dataset,
batch_size = batch_size,
shuffle = False)
class FFN(nn.Module):
def __init__(self):
super(FFN, self).__init__()
#Linear functions
self.fc1 = nn.Linear(28*28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.fc1(x)) #Non-linearity, can be changed to Tanh,ReLu
out = F.relu(self.fc2(out))
#Linear function (readout)
out = self.fc3(out)
return out
model = FFN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1)
epochs = 1
for epoch in range(epochs):
for i, (images, labels) in enumerate(train_loader):
#Load images as Variables
images = Variable(images.view(-1, 28*28))
labels = Variable(labels)
def closure():
#Clear gradients, not be accumulated
#Forward pass to get output
outputs = model(images)
#Calculate Loss: softmax + cross entropy loss
loss = criterion(outputs, labels)
#Get gradients
return loss
#update parameters
loss = optimizer.step(closure)
print('Epoch: {}, Loss: {}'.format(epoch,[0]))
#Calculate accuracy on testset
correct = 0
total = 0
#Iterate through test data set
for images, labels in test_loader:
#Load images to a Torch Variable
images = Variable(images.view(-1, 28*28))
#Forward pass only to get output
outputs = model(images)
#Get prediction
_, predicted = torch.max(,1)
#total number of labels
total += labels.size(0)
#Total correct predictions
correct += (predicted ==labels).sum()
accuracy = 100*correct /total
print('Epoch: {}, Loss: {}, Accuracy on testset: {}'.format(epoch,[0], accuracy))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment