Last active
December 30, 2015 07:49
-
-
Save jnothman/7798757 to your computer and use it in GitHub Desktop.
Polymorphic handling of metrics over multilabel formats in scikit-learn.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class _SparseMultiLabelHelper(object): | |
def __init__(self, y_true, y_pred): | |
self.y_true = y_true.tocsr() | |
self.y_pred = y_pred.tocsr() | |
self.shape = y_true.shape | |
def count_union(self, axis=None) | |
return self._count_nnz(self.y_true + self.y_pred, axis) | |
def count_intersection(self, axis=None) | |
return self._count_nnz(self.y_true.multiply(self.y_pred), axis) | |
def count_difference(self, axis=None): | |
# pending availability of != or xor | |
return self.count_union(axis) - self.count_intersection(axis) | |
def count_true(self, axis=None): | |
return self._count_nnz(self.y_true, axis) | |
def count_pred(self, axis=None): | |
return self._count_nnz(self.y_pred, axis) | |
@staticmethod | |
def _count_nnz(X, axis=None): | |
if axis is None: | |
return X.nnz | |
elif axis == 1: | |
return np.diff(X.indptr) | |
elif axis == 0: | |
return np.bincount(X.indices, minlength=X.shape[1]) | |
raise ValueError('Unsupported axis: {0}'.format(axis)) | |
class _DenseMultiLabelHelper(object): | |
def __init__(self, y_true, y_pred): | |
self.y_true = y_true == 1 | |
self.y_pred = y_pred == 1 | |
self.shape = y_true.shape | |
def count_union(self, axis=None) | |
return np.logical_or(self.y_true, self.y_pred).sum(axis) | |
def count_intersection(self, axis=None) | |
return np.logical_and(self.y_true, self.y_pred).sum(axis) | |
def count_difference(self, axis=None): | |
return (self.y_true != self.y_pred).sum(axis) | |
def count_true(self, axis=None): | |
return self.y_true.sum(axis) | |
def count_pred(self, axis=None) | |
return self.y_pred.sum(axis) | |
class _SequencesMultiLabelHelper(object): | |
def __init__(self, y_true, y_pred, labels): | |
self.y_true = y_true | |
self.y_pred = y_pred | |
self.labels = labels | |
self.shape = (len(y_true), len(labels)) | |
def _zipped(self): | |
return zip(self.y_true, self.y_pred) | |
def count_union(self, axis=None) | |
return self._count_nnz((set(true) | pred | |
for true, pred in self._zipped()), axis) | |
def count_intersection(self, axis=None) | |
return self._count_nnz((set(true) & pred | |
for true, pred in self._zipped()), axis) | |
def count_difference(self, axis=None): | |
return self._count_nnz((set(true) ^ pred | |
for true, pred in self._zipped()), axis) | |
def count_true(self, axis=None): | |
return self._count_nnz((set(true) for true in self.y_true), axis) | |
def count_pred(self, axis=None) | |
return self._count_nnz((set(pred) for pred in self.y_pred), axis) | |
@staticmethod | |
def _count_nnz(rows, axis=None): | |
if axis is None: | |
return sum(len(row) for row in rows) | |
elif axis == 1: | |
return np.array([len(row) for row in rows]) | |
elif axis == 0: | |
return np.bincount(list(itertools.chain.from_iterable(rows)), | |
minlength=self.shape[1]) | |
def _multilabel_helper(y_true, y_pred, y_type, binarize=False): | |
if y_type == 'multilabel-sequences': | |
labels = unique_labels(y_true, y_pred) | |
if binarize: | |
y_true = label_binarize(y_true, labels, multilabel=True) | |
y_pred = label_binarize(y_true, labels, multilabel=True) | |
else: | |
return _SequencesMultiLabelHelper(y_true, y_pred, labels) | |
if sp.issparse(y_true) and sp.issparse(y_pred): | |
return _SparseMultiLabelHelper(y_true, y_pred) | |
# Densify one if necessary | |
if hasattr(y_true, 'toarray'): | |
y_true = y_true.toarray() | |
elif hasattr(y_pred, 'toarray'): | |
y_pred = y_pred.toarray() | |
return _DenseMultiLabelHelper(y_true, y_pred) | |
###jaccard_similarity_score involves mlh.count_intersection(axis=1) / mlh.count_union(axis=1) | |
###accuracy_score involves mlh.count_difference(axis=1) == 0 | |
###precision_recall_fscore_support involves tp_sum = mlh.count_intersection(axis=sum_axis), pred_sum = mlh.count_pred(axis=sum_axis), true_sum = mlh.count_true(axis=sum_axis) | |
###hamming_loss involves mlh.count_difference() / mlh.shape[0] / mlh.shape[1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment