Skip to content

Instantly share code, notes, and snippets.

@jstypka
Last active April 26, 2016 08:26
Show Gist options
  • Save jstypka/abbf5df1fe5048ce8b92bb2fecbd1977 to your computer and use it in GitHub Desktop.
Save jstypka/abbf5df1fe5048ce8b92bb2fecbd1977 to your computer and use it in GitHub Desktop.
from itertools import compress
import numpy as np
from sklearn.metrics import auc
def compute_precision_recall_auc(y_true, y_pred, precision_recall_fun):
thresholds = np.append(np.unique(y_pred), [-1])
thresholds.sort()
precision = np.zeros(len(y_true)) # reuse them for every binarisation
recall = np.zeros(len(y_true))
precision_means = np.zeros(len(thresholds))
recall_means = np.zeros(len(thresholds))
for i, t in enumerate(reversed(thresholds)):
y_pred_bin = y_pred > t
for row in xrange(len(y_true)):
precision[row], recall[row] = precision_recall_fun(y_true[row],
y_pred_bin[row])
precision_means[i], recall_means[i] = np.mean(precision), np.mean(recall)
return auc(recall_means, precision_means)
def descendant_precision_recall(y_t, y_p):
masked_labels = set(compress(labels, y_p))
pred_labels = ontology.get_descendants_of_labels(masked_labels).keys()
masked_labels = set(compress(labels, y_t))
true_labels = ontology.get_descendants_of_labels(masked_labels).keys()
intersection = set(true_labels) & set(pred_labels)
precision = len(intersection) / len(pred_labels) if pred_labels else 1
recall = len(intersection) / len(true_labels) if true_labels else 1
return precision, recall
class Ontology(object):
# ...
# some other stuff here
# ...
def get_descendants_of_labels(self, starting_labels):
"""
Walks a graph downwards (with a BFS) from given nodes towards the leaves.
Returns all the found descendants and their distances from starting nodes.
:param node_canonical_labels: canonical labels of different nodes
:return: list of tuples e.g. [('node1', 1), ('node2', 2) ...]
"""
relations = {SKOS.narrower, SKOS.composite}
parsed = [self.parse_label(lab) for lab in starting_labels] # constant time
uris = [self.get_uri_from_label(lab) for lab in parsed] # constant time
for lab, uri in zip(starting_labels, uris):
if not uri:
raise ValueError('Label ' + lab + ' not in the ontology graph')
distances = {}
queue = deque([(uri, 0) for uri in uris])
while queue:
node, distance = queue.popleft()
node_label = self.get_canonical_label_from_uri(node) # constant time
if node_label in distances:
distances[node_label] = min(distances[node_label], distance)
continue
else:
distances[node_label] = distance
for edge_tuple in self.graph.out_edges_iter(nbunch=[node], data=True):
if edge_tuple[2].get('relation') in relations:
queue.append((edge_tuple[1], distance + 1))
return distances
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment