Created
June 6, 2019 05:02
-
-
Save bdhammel/ca7c12ccb24e326a8521594a7f7ef208 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import argparse | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(3, 6, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(6, 16, 5) | |
self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, 16 * 5 * 5) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
def get_data(args): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
download=True, transform=transform) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, | |
shuffle=True, num_workers=2) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
download=True, transform=transform) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, | |
shuffle=False, num_workers=2) | |
return trainloader, testloader | |
def train(net, trainloader, criterion, optimizer, pbar): | |
running_loss = 0.0 | |
for i, data in enumerate(trainloader, start=1): | |
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) | |
optimizer.zero_grad() | |
outputs = net(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
pbar.set_description(f'loss: {running_loss/i:.3f}') | |
def test(net, testloader): | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in testloader: | |
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) | |
outputs = net(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f'Test Accuracy: {100 * correct / total:.2f}%') | |
def main(): | |
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |
parser.add_argument('--batch-size', type=int, default=4, metavar='N', | |
help='input batch size for training (default: 64)') | |
parser.add_argument('--epochs', type=int, default=10, metavar='N', | |
help='number of epochs to train (default: 10)') | |
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', | |
help='learning rate (default: 0.01)') | |
args = parser.parse_args() | |
trainloader, testloader = get_data(args) | |
net = Net() | |
net.to(DEVICE) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) | |
print(""" | |
Hyperparameters | |
--------------- | |
Batch size: {batch_size} | |
learning rate: {lr} | |
""".format(batch_size=args.batch_size, lr=args.lr)) | |
pbar = tqdm(range(0, 10), ascii=True) | |
for epoch in pbar: | |
train(net, trainloader, criterion, optimizer, pbar) | |
test(net, testloader) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment