Skip to content

Instantly share code, notes, and snippets.

@khirotaka
Last active August 10, 2019 12:20
Show Gist options
  • Save khirotaka/9cf16561df688878e2e91a2afdc6f28a to your computer and use it in GitHub Desktop.
Save khirotaka/9cf16561df688878e2e91a2afdc6f28a to your computer and use it in GitHub Desktop.
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
torch.backends.cudnn.benchmark = True
train_ds = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
test_ds = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, shuffle=True, batch_size=128, num_workers=4)
test_loader = DataLoader(test_ds, shuffle=False, batch_size=128, num_workers=4)
class Network(nn.Module):
def __init__(self) -> None:
super(Network, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3),
nn.ReLU(),
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Sequential(
nn.Linear(12 * 12 * 64, 1024),
nn.Linear(1024, 10)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = x.view(-1, 12*12*64)
x = self.fc(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Run on {}".format(device))
model = Network().to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
print("Start Training")
start = time.time()
num_epochs = 5
for epoch in range(num_epochs):
print("{}/{}".format(epoch+1, num_epochs))
for step, (data, label) in enumerate(train_loader):
data = data.to(device)
label = label.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, label)
loss.backward()
optimizer.step()
if step % 100 == 0:
print("step: {} - Loss: {:.4f}".format(step, loss.detach().cpu().numpy()))
end = time.time()
print(end - start)
total = 0
correct = 0
model.eval()
with torch.no_grad():
for data, label in test_loader:
data = data.to(device)
label = label.to(device)
output = model(data)
_, predict = torch.max(output, 1)
total += label.shape[0]
correct += (predict == label).sum().detach().cpu().numpy()
print("Test Accuracy: {:.4%}".format(float(correct / total)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment