Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Classification model metrics and plots
import matplotlib.pyplot as plt
import numpy as np
import torch
import pytorch_lightning.metrics.functional as FM
import seaborn as sns
from sklearn import metrics
class Metrics:
def __init__(self):
self.trues = []
self.preds = []
def __call__(self, y_pred, y_true):
self.trues.append(y_true.cpu())
self.preds.append(y_pred.detach().cpu())
def reset(self):
self.trues = []
self.preds = []
def get(self, class_id):
trues = torch.cat(self.trues, dim=0)
preds = torch.cat(self.preds, dim=0)
perf = Performance(preds, trues, class_id=class_id)
return perf
class Performance:
def __init__(self, y_pred, y_true, logits=True, class_id=1):
self.probs, self.trues = y_pred, y_true
self.class_id = class_id
if logits:
self.probs = y_pred.softmax(1)
fpr, tpr, thr = self.get_roc(class_id)
best_thr = thr[np.argmax(np.abs(tpr + (1 - fpr)))] # maximize((sens: tpr) + (spec: 1-fpr))
y_pred = self.probs[:, self.class_id] > best_thr
self.TP, self.FP, self.TN, self.FN = 0, 0, 0, 0
for i in range(len(y_pred)):
if y_true[i] == y_pred[i] == 1:
self.TP += 1
if y_pred[i] == 1 and y_true[i] != y_pred[i]:
self.FP += 1
if y_true[i] == y_pred[i] == 0:
self.TN += 1
if y_pred[i] == 0 and y_true[i] != y_pred[i]:
self.FN += 1
def __repr__(self):
return f"TP:{self.TP}, FP:{self.FP}, TN:{self.TN}, FN:{self.FN}"
@property
def accuracy(self):
try:
acc = (self.TP + self.TN) / (self.TP + self.TN + self.FP + self.FN)
except ZeroDivisionError:
acc = 0
return acc
@property
def auroc(self):
fpr, tpr, _ = self.get_roc(self.class_id)
return metrics.auc(fpr, tpr)
@property
def aupr(self):
pr, rec, _ = self.get_pr(self.class_id)
return metrics.auc(rec, pr)
@property
def average_precision_score(self):
return metrics.average_precision_score(self.trues.numpy(), self.probs[:, self.class_id].numpy())
@property
def confusion_matrix(self):
classes = np.unique(self.trues)
num_classes = len(classes)
fig, ax = plt.subplots(1, 1, figsize=(num_classes * 2, num_classes))
if num_classes > 2:
cm = metrics.confusion_matrix(self.trues, self.probs.argmax(1), labels=classes)
else:
cm = np.array([[self.TP, self.FP], [self.FN, self.TN]])
sns.heatmap(cm, annot=True, ax=ax, square=True)
return fig
@property
def fpr(self):
try:
fpr = self.FP / (self.FP + self.TN)
except ZeroDivisionError:
fpr = 0
return fpr
@property
def negative_predictive_value(self):
try:
npv = self.TN / (self.TN + self.FN)
except ZeroDivisionError:
npv = 0
return npv
@property
def mcc(self):
num = self.TP * self.TN - self.FP * self.FN
den = (
(self.TP + self.FP) * (self.TP + self.FN) * (self.TN + self.FP) * (self.TN + self.FN)
) ** 1 / 2
return num / den
@property
def precision(self):
try:
prec = self.TP / (self.TP + self.FP)
except ZeroDivisionError:
prec = 0
return prec
@property
def sensitivity(self):
try:
sens = self.TP / (self.TP + self.FN)
except ZeroDivisionError:
sens = 0
return sens
@property
def specificity(self):
try:
spec = self.TN / (self.TN + self.FP)
except ZeroDivisionError:
spec = 0
return spec
def get_pr(self, class_id=1):
pr, rec, thr = FM.precision_recall_curve(self.probs[:, class_id], self.trues)
pr, rec, thr = pr.numpy(), rec.numpy(), thr.numpy()
return pr, rec, thr
def get_roc(self, class_id=1):
fpr, tpr, thr = FM.roc(self.probs[:, class_id], self.trues)
fpr, tpr, thr = fpr.numpy(), tpr.numpy(), thr.numpy()
return fpr, tpr, thr
def plot_roc_pr_curves(self, figsize=(16, 4), title=""):
fpr, tpr, roc_thr = self.get_roc(self.class_id)
pr, rec, pr_thr = self.get_pr(self.class_id)
fig, ax = plt.subplots(1, 2, figsize=figsize)
ax[0].plot(fpr, tpr)
ax[0].set_title(f"ROC. AUC: {self.auroc:.4f}")
ax[0].set_xlabel("fpr")
ax[0].set_ylabel("recall (tpr)")
ax_thr = ax[0].twinx()
ax_thr.plot(fpr[1:], roc_thr[1:], c="orange")
ax_thr.tick_params(axis="y", labelcolor="orange")
ax_thr.set_ylabel("threshold", color="orange")
ax[1].step(rec, pr)
ax[1].set_title(f"PR. AUC: {self.aupr:.4f}")
ax[1].set_xlabel("recall (tpr)")
ax[1].set_ylabel("precision")
ax_thr = ax[1].twinx()
ax_thr.plot(rec[:-1], pr_thr, c="orange")
ax_thr.tick_params(axis="y", labelcolor="orange")
ax_thr.set_ylabel("threshold", color="orange")
fig.suptitle(f"Class ID: {self.class_id} {title}", y=1.01)
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.