Skip to content

Instantly share code, notes, and snippets.

@nkthiebaut
Created July 14, 2019 01:34
Show Gist options
  • Save nkthiebaut/bd0baa8e83443220e8640b4fa70239e5 to your computer and use it in GitHub Desktop.
Save nkthiebaut/bd0baa8e83443220e8640b4fa70239e5 to your computer and use it in GitHub Desktop.
Plot top k accuracies
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
plt.xkcd()
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clf = LogisticRegression()
clf.fit(X_train, y_train)
y_prob = clf.predict_proba(X_test)
def top_k_accuracy(y_prob, y_true, k=5):
top_k_classes = y_prob.argsort(axis=1)[:, -k:]
return sum(y in top for y, top in zip(y_true, top_k_classes))/len(y_true)
def plot_top_k_accuracy(y_prob, y_true, k=5):
top_accuracies = [100*top_k_accuracy(y_prob, y_test, i) for i in range(1, k+1)]
plt.plot(list(range(1, k+1)), top_accuracies)
plt.title("Top k accuracy")
plt.xlabel("k")
plt.ylabel("Accuracy (%)");
plot_top_k_accuracy(y_prob, y_test, 5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment