import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import argparse | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(3, 6, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(6, 16, 5) | |
self.fc1 = nn.Linear(16 * 5 * 5, 120) | |
self.fc2 = nn.Linear(120, 84) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, 16 * 5 * 5) | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return x | |
def get_data(args): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
download=True, transform=transform) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, | |
shuffle=True, num_workers=2) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
download=True, transform=transform) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, | |
shuffle=False, num_workers=2) | |
return trainloader, testloader | |
def train(net, trainloader, criterion, optimizer, pbar): | |
running_loss = 0.0 | |
for i, data in enumerate(trainloader, start=1): | |
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) | |
optimizer.zero_grad() | |
outputs = net(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
pbar.set_description(f'loss: {running_loss/i:.3f}') | |
def test(net, testloader): | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in testloader: | |
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) | |
outputs = net(images) | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
print(f'Test Accuracy: {100 * correct / total:.2f}%') | |
def main(): | |
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |
parser.add_argument('--batch-size', type=int, default=4, metavar='N', | |
help='input batch size for training (default: 64)') | |
parser.add_argument('--epochs', type=int, default=10, metavar='N', | |
help='number of epochs to train (default: 10)') | |
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', | |
help='learning rate (default: 0.01)') | |
args = parser.parse_args() | |
trainloader, testloader = get_data(args) | |
net = Net() | |
net.to(DEVICE) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) | |
print(""" | |
Hyperparameters | |
--------------- | |
Batch size: {batch_size} | |
learning rate: {lr} | |
""".format(batch_size=args.batch_size, lr=args.lr)) | |
pbar = tqdm(range(0, 10), ascii=True) | |
for epoch in pbar: | |
train(net, trainloader, criterion, optimizer, pbar) | |
test(net, testloader) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment