Skip to content

Instantly share code, notes, and snippets.

@vinhkhuc
Last active August 29, 2015 14:16
Show Gist options
  • Save vinhkhuc/671cc1cbb9ccc16b456b to your computer and use it in GitHub Desktop.
Save vinhkhuc/671cc1cbb9ccc16b456b to your computer and use it in GitHub Desktop.
"""
Code to replicate Ron Kohavi's cross-validation experiment on the Iris data set.
"""
from sklearn import datasets, svm
from sklearn.cross_validation import cross_val_score, KFold, LeavePOut
import matplotlib.pyplot as plt
output_file = "cross-validation-experiment-iris.png"
iris = datasets.load_iris()
X = iris.data
y = iris.target
clf = svm.LinearSVC()
n = X.shape[0]
folds = [2, 5, 10, 20, -2, -1]
n_folds = len(folds)
accuracies = []
# Run K-folds
for k in folds:
cv = KFold(n, n_folds=k) if k > 0 else LeavePOut(n, p=abs(k))
scores = cross_val_score(clf, X, y, cv=cv)
accuracies.append(100 * scores.mean())
print("K = %d, accuracy: %0.2f%%" % (k, accuracies[-1]))
# Print chart
plt.figure()
plt.errorbar(range(1, n_folds + 1), accuracies, yerr=[5] * n_folds) # Use 5% for the error bars
ax = plt.gca()
plt.xticks(range(0, n_folds + 2), [''] + [str(k) for k in folds] + [''])
plt.yticks(range(30, 110, 10))
plt.title("K-fold Cross-validation")
plt.xlabel("Folds")
plt.ylabel("% Acc")
plt.savefig(output_file)
print("Saved the chart into " + output_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment