Skip to content

Instantly share code, notes, and snippets.

@antoinebrl
Created November 30, 2023 22:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antoinebrl/d744697dbf996ef42be7ad91c7330db7 to your computer and use it in GitHub Desktop.
Save antoinebrl/d744697dbf996ef42be7ad91c7330db7 to your computer and use it in GitHub Desktop.
MNIST MLP - 97% accuracy
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
transform = Compose([ToTensor(), Lambda(lambda x: torch.flatten(x))])
train_dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=250, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=250, shuffle=False)
model = nn.Sequential(
nn.Linear(28*28, 256, bias=True),
nn.ReLU(),
nn.Linear(256, 10, bias=True),
)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criteria = nn.CrossEntropyLoss()
for epoch in range(10):
for step, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
logits = model(images)
loss = criteria(logits, labels)
loss.backward()
optimizer.step()
if step % 100 == 0:
correct = 0
total = 0
for images, labels in test_loader:
logits = model(images)
predicted = torch.max(logits, 1)[1]
correct += (predicted == labels).sum()
total += len(labels)
print(f'epoch: {epoch:2}, step: {step:4}, loss: {loss.item():.4f}, acc: {correct/total*100:.2f}%')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment