Skip to content

Instantly share code, notes, and snippets.

@brookisme
Last active November 19, 2020 01:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save brookisme/8f9f06286251af02bb9372fc35bb7fd8 to your computer and use it in GitHub Desktop.
Save brookisme/8f9f06286251af02bb9372fc35bb7fd8 to your computer and use it in GitHub Desktop.
PyTorch and Numpy Confusion Matrix, Precision, Recall
#
# CONFIG
#
BETA=2
RETURN_CMATRIX=True
INVALID_ZERO_DIVISON=False
VALID_ZERO_DIVISON=1.0
#
# METHODS
#
def confusion_matrix(target,prediction,value,ignore_value=None):
true=(target==prediction)
false=(~true)
pos=(target==value)
neg=(~pos)
keep=(target!=ignore_value)
tp=(true*pos).sum()
fp=(false*pos*keep).sum()
fn=(false*neg*keep).sum()
tn=(true*neg).sum()
return _get_items(tp, fp, fn, tn)
def precision(tp,fp,fn):
return _precision_recall(tp,fp,fn)
def recall(tp,fn,fp):
return _precision_recall(tp,fn,fp)
def fbeta(p,r,beta=BETA):
if p is None: p=precision(tp,fp)
if r is None: r=recall(tp,fn)
beta_sq=beta**2
numerator=(beta_sq*p + r)
if numerator:
return (1+beta_sq)*(p*r)/numerator
else:
return 0
def stats(
target,
prediction,
value,
ignore_value=None,
beta=BETA,
return_cmatrix=RETURN_CMATRIX):
tp, fp, fn, tn=confusion_matrix(
target,
prediction,
value,
ignore_value=ignore_value)
p=precision(tp,fp,fn)
r=recall(tp,fn,fp)
stat_values=[p,r]
if not _is_false(beta):
stat_values.append(fbeta(p,r,beta=beta))
if return_cmatrix:
stat_values+=[tp, fp, fn, tn]
return stat_values
#
# INTERNAL
#
def _precision_recall(a,b,c):
if (a+b):
return a/(a+b)
else:
if c:
return INVALID_ZERO_DIVISON
else:
return VALID_ZERO_DIVISON
def _is_false(value):
return value in [False,None]
def _get_items(*args):
try:
return list(map(lambda s: s.item(),args))
except:
return args
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment