Skip to content

Instantly share code, notes, and snippets.

@mjamroz
Last active March 28, 2020 11:14
Show Gist options
  • Save mjamroz/451dddbc3ccf02a4d71be0aec9e822cf to your computer and use it in GitHub Desktop.
Save mjamroz/451dddbc3ccf02a4d71be0aec9e822cf to your computer and use it in GitHub Desktop.
custom stats for "binary" classification where we have more than 2 classes and want to binary classify between one and rest
from mxnet import metric, nd
class BinarySelectedStatistics(metric._BinaryClassificationMetrics):
def __init__(self):
super().__init__()
self.positive = 1 # default
self.num_inst = 0
self.sum_metric = 0.0
def set_positive_index(self, index):
self.positive = index
@property
def score(self):
return self.sum_metric/self.num_inst
def reset(self):
self.reset_stats()
self.num_inst = 0
self.sum_metric = 0.0
def update(self, label, outputs):
label, outputs = metric.check_label_shapes(label, outputs, True)
for label, pred_label in zip(label, outputs):
pred_label = nd.argmax(pred_label, axis=1)
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
label = label.flat
pred_label = pred_label.flat
# import numpy as np
# junk_label = np.array([1,1,0,2,3,3,3,1,1,1])
# pred_label = np.array([1,0,2,2,2,1,1,2,3,0])
# # TP: 1, TN: 3, FP: 2, FN: 4
# label=junk_label
metric.check_label_shapes(label, pred_label)
label_true = (label == self.positive)
true_positives = (pred_label == self.positive).sum(where=label_true)
false_positives = (pred_label == self.positive).sum(where=~label_true)
false_negatives = (pred_label != self.positive).sum(where=label_true)
true_negatives = (pred_label != self.positive).sum(where=~label_true)
self.true_positives += true_positives
self.false_positives += false_positives
self.false_negatives += false_negatives
self.true_negatives += true_negatives
self.global_true_positives += true_positives
self.global_false_positives += false_positives
self.global_false_negatives += false_negatives
self.global_true_negatives += true_negatives
self.num_inst += len(pred_label)
self.sum_metric += (pred_label == label).sum()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment