Skip to content

Instantly share code, notes, and snippets.

@khalido
Created December 9, 2019 05:10
Show Gist options
  • Save khalido/b4f9a19b4b3eb17e922f764e69e20d7e to your computer and use it in GitHub Desktop.
Save khalido/b4f9a19b4b3eb17e922f764e69e20d7e to your computer and use it in GitHub Desktop.
[sklearn] useful sklearn stuff #sklearn
from sklearn.model_selection import train_test_split
from sklearn import metrics # for evaluation
from sklearn.ensemble import RandomForestClassifier
# initiate a classifier and train on some data
rf = RandomForestClassifier(n_jobs=-1)
rf.fit(x_train, y_train)
# predict
y_predict = rf.predict(x_train)
# make confusion matrix using sklearn
cm = metrics.confusion_matrix(y_train, y_predict)
# note: normalize not implemented, might be useful to show % instead of raw numbers.
def plot_cm(cm, labels=labels, normalize=False):
"""takes in a confusion matrix as well as the label names"""
df = pd.DataFrame(cm, columns=labels, index=labels)
fig, ax = plt.subplots(figsize=(10,6))
ax.set_title("Confusion Matrix")
sns.heatmap(df, annot=True, fmt="d", annot_kws={"size": 8}, ax=ax, cmap="YlGnBu")
ax.set_xlabel("Predicted label"); ax.set_ylabel("True label")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment