Skip to content

Instantly share code, notes, and snippets.

@rhiever
Created March 8, 2016 18:47
Show Gist options
  • Save rhiever/bf270295b5ffac042aa3 to your computer and use it in GitHub Desktop.
Save rhiever/bf270295b5ffac042aa3 to your computer and use it in GitHub Desktop.
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.metrics import recall_score
import numpy as np
import pandas as pd
digits = load_digits(10)
features, labels = digits['data'], digits['target']
X_train, X_test, y_train, y_test = train_test_split(features, labels, train_size=0.75, test_size=0.25)
clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
clf.fit(X_train, y_train)
def balanced_accuracy(result):
all_classes = list(set(result['class'].values))
all_class_accuracies = []
for this_class in all_classes:
this_class_accuracy = len(result[(result['guess'] == this_class) & (result['class'] == this_class)])\
/ float(len(result[result['class'] == this_class]))
all_class_accuracies.append(this_class_accuracy)
balanced_accuracy = np.mean(all_class_accuracies)
return balanced_accuracy
predictions = clf.predict(X_test)
print('Macro-averaged recall:\t', recall_score(y_test, predictions, average='macro'))
data = pd.DataFrame({'class': y_test,
'guess': predictions})
print('Balanced accuracy:\t', balanced_accuracy(data))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment