Skip to content

Instantly share code, notes, and snippets.

@danielcwq
Created November 28, 2021 07:15
Show Gist options
  • Save danielcwq/6123742bd076fd91b5c758cd71aabcc7 to your computer and use it in GitHub Desktop.
Save danielcwq/6123742bd076fd91b5c758cd71aabcc7 to your computer and use it in GitHub Desktop.
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds>0.5) == yb
return correct.float().mean()
def validate_epoch(model):
accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
return round(torch.stack(accs).mean().item(), 4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment