Skip to content

Instantly share code, notes, and snippets.

@SrivastavaKshitij
Created September 3, 2018 23:38
Show Gist options
  • Save SrivastavaKshitij/fd2c0c25d257ffc9eed4941102ec0025 to your computer and use it in GitHub Desktop.
Save SrivastavaKshitij/fd2c0c25d257ffc9eed4941102ec0025 to your computer and use it in GitHub Desktop.
performance metric in pytorch
import torch
def pf1(output,target,metric=None):
d = output.data
t = target.data
TP = torch.nonzero(d*t).size(0)
TN = torch.nonzero((d - 1) * (t - 1)).size(0)
FP = torch.nonzero(d * (t - 1)).size(0)
FN = torch.nonzero((d - 1) * t).size(0)
precision = TP / (TP + FP)
recall = TP / (TP + FN)
F1 = 2 * precision * recall / (precision + recall)
accuracy = (TP+TN)/(TP+TN+FP+FN)
PPV = TP/(TP+FP)
if metric=='precision':
return precision
elif metric=='recall':
return recall
elif metric=='PPV':
return PPV
elif metric=='accuracy':
return accuracy
else: return F1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment