Skip to content

Instantly share code, notes, and snippets.

@YodaEmbedding
Last active April 20, 2021 09:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save YodaEmbedding/a6f8d4cf4094405edb8c65bac623a4fa to your computer and use it in GitHub Desktop.
Save YodaEmbedding/a6f8d4cf4094405edb8c65bac623a4fa to your computer and use it in GitHub Desktop.
Please see appendix (attached in the Canvas comment) for usage instructions
# example.py
from common import setup
import torch.nn.functional as F
args, model, data_module = setup()
test_data_loader = data_module.test_dataloader()
for batch in iter(test_data_loader):
inputs, targets = batch
logits = model(inputs).view(-1, 8 * 8)
y = F.softmax(logits, dim=1)
preds = y.argmax(dim=1)
top3 = logits.topk(k=3, dim=1).indices.t()
top1_acc = (preds == targets).sum().item() / len(targets)
top3_acc = (top3 == targets).sum().item() / len(targets)
print("labels: {}".format(targets))
print("predictions: {}".format(preds))
print("top-1 acc: {:.3f}".format(top1_acc))
print("top-3 acc: {:.3f}".format(top3_acc))
break
@YodaEmbedding
Copy link
Author

YodaEmbedding commented Apr 20, 2021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment