Skip to content

Instantly share code, notes, and snippets.

@alperyeg
Last active October 9, 2019 18:14
Show Gist options
  • Save alperyeg/5c5fe6d2515963ffff03d127f95be275 to your computer and use it in GitHub Desktop.
Save alperyeg/5c5fe6d2515963ffff03d127f95be275 to your computer and use it in GitHub Desktop.
simple network in pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import optim
from torchvision import datasets, transforms
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}
# Load data and normalize images to [0, 1]
# training set
train_loader_mnist = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=True, download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True, **kwargs)
# test set
test_loader_mnist = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True, **kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(784, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return F.softmax(x, dim=1)
net = Net().to(device)
print(net)
# Adam optimizer
optimizer = optim.Adam(net.parameters(), lr=1e-3)
# Cross entropy loss to calculate the loss
criterion = nn.CrossEntropyLoss()
def train(epoch):
net.train()
train_loss = 0
for idx, (img, target) in enumerate(train_loader_mnist):
optimizer.zero_grad()
# network prediction for the image
output = net(img)
# calculate the loss
loss = criterion(output, target)
# backprop
loss.backward()
train_loss += loss.item()
optimizer.step()
if idx % 10 == 0:
print('Loss {} in epoch {}, idx {}'.format(
loss.item(), epoch, idx))
print('Average loss: {} epoch:{}'.format(
train_loss / len(train_loader_mnist.dataset), epoch))
def test(epoch):
net.eval()
test_accuracy = 0
test_loss = 0
with torch.no_grad():
for idx, (img, target) in enumerate(test_loader_mnist):
output = net(img)
loss = criterion(output, target)
test_loss += loss.item()
# network prediction
pred = output.argmax(1, keepdim=True)
# how many image are correct classified, compare with targets
test_accuracy += pred.eq(target.view_as(pred)).sum().item()
if idx % 10 == 0:
print('Test Loss {} in epoch {}, idx {}'.format(
loss.item(), epoch, idx))
print('Test accuracy: {} Average test loss: {} epoch:{}'.format(100 * test_accuracy / len(test_loader_mnist.dataset),
test_loss / len(test_loader_mnist.dataset), epoch))
if __name__ == "__main__":
for ep in range(1, 3):
train(ep)
print('training done')
test(ep)
print('test done')
print('saving network weigts')
state_dict = net.state_dict()
weights = []
for key, value in state_dict.items():
if key in ['fc1.weight', 'fc2.weight']:
weights.append(state_dict[key].numpy())
np.save('./weights.npy', np.array(weights))
torch.save(net.state_dict(), 'net.pt')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment