Skip to content

Instantly share code, notes, and snippets.

@sprintr
Last active April 24, 2022 23:52
Show Gist options
  • Save sprintr/6f325cf41dd3584cd9b3280870e2ee07 to your computer and use it in GitHub Desktop.
Save sprintr/6f325cf41dd3584cd9b3280870e2ee07 to your computer and use it in GitHub Desktop.
Multiclass Classifier Accuracy
test_correct = 0
train_correct = 0
output = (net(mnist_test_X) > 0.5).float()
for i, v in enumerate(output):
if v.argmax() == mnist_test_y[i].argmax():
test_correct += 1
output = (net(mnist_train_X) > 0.5).float()
for i, v in enumerate(output):
if v.argmax() == mnist_train_y[i].argmax():
train_correct += 1
print("Train Accuracy: {:.3f}, Test Accuracy: {:.3f}".format(train_correct / mnist_train_y.shape[0], test_correct / mnist_test_y.shape[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment