Skip to content

Instantly share code, notes, and snippets.

@jnothman
Last active December 30, 2015 07:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnothman/7798757 to your computer and use it in GitHub Desktop.
Save jnothman/7798757 to your computer and use it in GitHub Desktop.
Polymorphic handling of metrics over multilabel formats in scikit-learn.
class _SparseMultiLabelHelper(object):
def __init__(self, y_true, y_pred):
self.y_true = y_true.tocsr()
self.y_pred = y_pred.tocsr()
self.shape = y_true.shape
def count_union(self, axis=None)
return self._count_nnz(self.y_true + self.y_pred, axis)
def count_intersection(self, axis=None)
return self._count_nnz(self.y_true.multiply(self.y_pred), axis)
def count_difference(self, axis=None):
# pending availability of != or xor
return self.count_union(axis) - self.count_intersection(axis)
def count_true(self, axis=None):
return self._count_nnz(self.y_true, axis)
def count_pred(self, axis=None):
return self._count_nnz(self.y_pred, axis)
@staticmethod
def _count_nnz(X, axis=None):
if axis is None:
return X.nnz
elif axis == 1:
return np.diff(X.indptr)
elif axis == 0:
return np.bincount(X.indices, minlength=X.shape[1])
raise ValueError('Unsupported axis: {0}'.format(axis))
class _DenseMultiLabelHelper(object):
def __init__(self, y_true, y_pred):
self.y_true = y_true == 1
self.y_pred = y_pred == 1
self.shape = y_true.shape
def count_union(self, axis=None)
return np.logical_or(self.y_true, self.y_pred).sum(axis)
def count_intersection(self, axis=None)
return np.logical_and(self.y_true, self.y_pred).sum(axis)
def count_difference(self, axis=None):
return (self.y_true != self.y_pred).sum(axis)
def count_true(self, axis=None):
return self.y_true.sum(axis)
def count_pred(self, axis=None)
return self.y_pred.sum(axis)
class _SequencesMultiLabelHelper(object):
def __init__(self, y_true, y_pred, labels):
self.y_true = y_true
self.y_pred = y_pred
self.labels = labels
self.shape = (len(y_true), len(labels))
def _zipped(self):
return zip(self.y_true, self.y_pred)
def count_union(self, axis=None)
return self._count_nnz((set(true) | pred
for true, pred in self._zipped()), axis)
def count_intersection(self, axis=None)
return self._count_nnz((set(true) & pred
for true, pred in self._zipped()), axis)
def count_difference(self, axis=None):
return self._count_nnz((set(true) ^ pred
for true, pred in self._zipped()), axis)
def count_true(self, axis=None):
return self._count_nnz((set(true) for true in self.y_true), axis)
def count_pred(self, axis=None)
return self._count_nnz((set(pred) for pred in self.y_pred), axis)
@staticmethod
def _count_nnz(rows, axis=None):
if axis is None:
return sum(len(row) for row in rows)
elif axis == 1:
return np.array([len(row) for row in rows])
elif axis == 0:
return np.bincount(list(itertools.chain.from_iterable(rows)),
minlength=self.shape[1])
def _multilabel_helper(y_true, y_pred, y_type, binarize=False):
if y_type == 'multilabel-sequences':
labels = unique_labels(y_true, y_pred)
if binarize:
y_true = label_binarize(y_true, labels, multilabel=True)
y_pred = label_binarize(y_true, labels, multilabel=True)
else:
return _SequencesMultiLabelHelper(y_true, y_pred, labels)
if sp.issparse(y_true) and sp.issparse(y_pred):
return _SparseMultiLabelHelper(y_true, y_pred)
# Densify one if necessary
if hasattr(y_true, 'toarray'):
y_true = y_true.toarray()
elif hasattr(y_pred, 'toarray'):
y_pred = y_pred.toarray()
return _DenseMultiLabelHelper(y_true, y_pred)
###jaccard_similarity_score involves mlh.count_intersection(axis=1) / mlh.count_union(axis=1)
###accuracy_score involves mlh.count_difference(axis=1) == 0
###precision_recall_fscore_support involves tp_sum = mlh.count_intersection(axis=sum_axis), pred_sum = mlh.count_pred(axis=sum_axis), true_sum = mlh.count_true(axis=sum_axis)
###hamming_loss involves mlh.count_difference() / mlh.shape[0] / mlh.shape[1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment