Skip to content

Instantly share code, notes, and snippets.

@pannous
Created May 20, 2022
Embed
What would you like to do?
#!/usr/local/bin/python3
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
device = 'mps'
# parameters
learning_rate = 0.01
training_epochs = 15
batch_size = 60000
# MNIST dataset
mnist_train = dsets.MNIST(root='MNIST_data/', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dsets.MNIST(root='MNIST_data/', train=False, transform=transforms.ToTensor(), download=True)
# dataset loader
data_loader = torch.utils.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
model = torch.nn.Linear( 28 * 28 , 10, bias=True).to(device) # most trivial
# define cost/loss & optimizer
criterion = torch.nn.CrossEntropyLoss().to(device) # Softmax is internally computed.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
def accuracy():
# Test the model using test sets
with torch.no_grad():
X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)
Y_test = mnist_test.test_labels.to(device)
prediction = model(X_test)
guesses = torch.argmax(prediction, 1)
correct_prediction = torch.argmax(prediction, 1) == Y_test
print(guesses,Y_test)
print("correct_predictions",correct_prediction)
accuracy = correct_prediction.float().mean()
print('Accuracy:', accuracy.item())
for epoch in range(training_epochs):
for X, Y in data_loader:
# reshape input image into [batch_size by 784]
# label is not one-hot encoded
X = X.view(-1, 28 * 28).to(device)
Y = Y.to(device)
optimizer.zero_grad()
hypothesis = model(X)
cost = criterion(hypothesis, Y)
cost.backward()
optimizer.step()
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(cost), flush=True)
accuracy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment