Skip to content

Instantly share code, notes, and snippets.

@malteos
Last active May 23, 2016 17:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malteos/a5196181a41e53b5faa4bdb4c9ab2a5c to your computer and use it in GitHub Desktop.
Save malteos/a5196181a41e53b5faa4bdb4c9ab2a5c to your computer and use it in GitHub Desktop.
AUC + GAMMAIDX
def auc(y_true, y_val, plot=False):
"""
Computes the AUC (area under the receiver operator curve).
For example, y val could be the output of a learning algorithm (binary logistic regression, ...).
:param y_true: true labels in {-1,+1}
:param y_val: predicted value where lower values tend to correspond to label -1 and higher values to label +1
:param plot: whether to plot the ROC curve or not
:return: returns the AUC
"""
sum_cond_pos = np.count_nonzero(y_true == 1)
sum_cond_neg = np.count_nonzero(y_true == -1)
# find the best threshold value, save iterations by only iterating through ones that change the result
fpr_vals = []
tpr_vals = []
for thresh in np.unique(np.concatenate(([0.0], y_val), axis=0)):
sum_true_pos = np.sum(np.logical_and(y_val > thresh, y_true == 1).astype(int))
sum_false_pos = np.sum(np.logical_and(y_val > thresh, y_true == -1).astype(int))
fpr = 1.0 * sum_false_pos / sum_cond_neg
tpr = 1.0 * sum_true_pos / sum_cond_pos
fpr_vals += [fpr]
tpr_vals += [tpr]
if plot:
plt.scatter(fpr, tpr)
if plot:
plt.show()
fpr_vals = np.array(fpr_vals)
tpr_vals = np.array(tpr_vals)
fpr_sorted = fpr_vals[np.argsort(fpr_vals)]
tpr_sorted = tpr_vals[np.argsort(tpr_vals)]
auc = np.trapz(tpr_sorted, x=fpr_sorted)
return auc
def gammaidx(X, k):
"""
Computes the Gamma Index: a points average distance to its k nearest neighbors.
:param X: (n, d) matrix with n data points of dimension d
:param k: number of neighbors
:return: the vector with gamma indices for each datapoint, length n
"""
n,d = X.shape
X_tile = np.copy(X).reshape(n,1,d).repeat(n,axis=1)
diff = X_tile - X # (n,n,d) - (n,d)
D = np.hypot(diff[:,:,0],diff[:,:,1]) # (n,d), hypot computes sqrt(a^2+b^2)
return np.mean(np.sort(D, axis=1)[:, 1:(k + 1)], axis=1) # mean( top-k ( sort ) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment