Skip to content

Instantly share code, notes, and snippets.

@thuliumsystems
Created May 11, 2023 03:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thuliumsystems/1e74da5788b5fe736a31f489a10ba372 to your computer and use it in GitHub Desktop.
Save thuliumsystems/1e74da5788b5fe736a31f489a10ba372 to your computer and use it in GitHub Desktop.
Roc
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (RandomForestClassifier)
training = [
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[21, 16, 12, 26, 6, 0, 0, 0, 0, 0, 1],
[2, 5, 2, 3, 3, 3, 2, 3, 2, 4, 1],
[13, 6, 3, 4, 4, 1, 1, 0, 0, 1, 1],
[2, 3, 4, 2, 5, 2, 1, 3, 2, 1, 1],
[2, 3, 2, 3, 4, 2, 4, 3, 2, 2, 1],
[1, 0, 1, 1, 1, 4, 33, 2, 9, 4, 0],
[6, 4, 3, 4, 2, 6, 2, 4, 1, 4, 1],
[29, 12, 3, 6, 6, 2, 1, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 10, 27, 8, 8, 13, 0],
[21, 5, 11, 12, 9, 0, 0, 0, 0, 0, 1],
[5, 2, 2, 2, 2, 4, 5, 3, 4, 2, 1],
[14, 20, 6, 8, 8, 0, 0, 0, 0, 0, 1],
[2, 3, 2, 1, 1, 3, 7, 9, 3, 3, 0],
[5, 2, 5, 2, 4, 2, 3, 3, 2, 1, 1],
[8, 6, 4, 3, 2, 3, 4, 2, 1, 3, 1],
[5, 2, 0, 0, 1, 7, 5, 9, 5, 2, 0],
[6, 19, 2, 3, 4, 2, 0, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 6, 16, 6, 19, 22, 0],
[17, 5, 28, 23, 4, 0, 0, 0, 0, 0, 1],
[4, 15, 2, 7, 11, 0, 0, 0, 0, 0, 1],
[3, 5, 4, 3, 8, 1, 2, 1, 2, 3, 1],
[1, 1, 0, 1, 0, 7, 7, 5, 28, 6, 0],
[5, 5, 3, 3, 8, 1, 1, 2, 1, 1, 1],
[2, 1, 1, 1, 1, 9, 7, 8, 2, 6, 0],
[7, 5, 5, 2, 4, 2, 2, 1, 1, 1, 1],
[3, 23, 3, 5, 6, 2, 1, 2, 2, 1, 1],
[1, 1, 1, 0, 1, 8, 5, 3, 2, 6, 0],
[11, 9, 6, 1, 3, 0, 2, 1, 1, 2, 1],
[11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1],
[8, 11, 8, 5, 10, 1, 0, 0, 1, 0, 1],
[16, 19, 6, 8, 2, 0, 0, 0, 0, 1, 1],
[7, 3, 5, 3, 5, 5, 4, 1, 1, 2, 1],
[1, 1, 1, 0, 1, 24, 5, 5, 5, 4, 0],
[1, 0, 1, 0, 1, 12, 8, 10, 7, 4, 0],
[6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1],
[4, 4, 2, 2, 3, 5, 7, 1, 1, 2, 0],
[4, 8, 3, 3, 5, 2, 1, 1, 1, 3, 1],
[9, 10, 5, 7, 5, 1, 1, 1, 0, 0, 1],
[4, 1, 1, 1, 2, 22, 4, 19, 4, 2, 0],
[0, 0, 0, 0, 1, 13, 9, 10, 4, 11, 0],
[14, 14, 9, 4, 4, 1, 1, 0, 0, 0, 1],
[21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 13, 6, 6, 34, 4, 0],
[7, 5, 3, 3, 1, 9, 2, 1, 3, 4, 1],
[0, 0, 0, 0, 0, 6, 10, 4, 13, 30, 0],
[5, 8, 10, 9, 3, 1, 0, 1, 0, 1, 1],
[7, 6, 5, 4, 2, 1, 3, 2, 1, 1, 1],
[1, 1, 1, 1, 1, 5, 5, 7, 5, 5, 0],
[1, 1, 2, 1, 2, 7, 4, 4, 3, 2, 0],
[1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0],
[1, 1, 3, 1, 1, 2, 4, 10, 4, 3, 0],
[1, 2, 1, 1, 3, 5, 8, 9, 3, 7, 0],
[7, 4, 4, 7, 3, 1, 2, 3, 2, 1, 1],
[0, 1, 0, 1, 0, 12, 7, 22, 4, 16, 0],
[1, 4, 3, 1, 1, 4, 5, 11, 2, 7, 0],
[3, 2, 2, 3, 9, 7, 8, 1, 1, 2, 0],
[3, 3, 1, 2, 2, 4, 7, 4, 3, 5, 0],
[6, 7, 29, 3, 6, 0, 1, 0, 1, 1, 1],
[0, 0, 1, 0, 0, 3, 10, 12, 10, 10, 0],
[1, 1, 1, 2, 1, 8, 4, 5, 5, 4, 0],
[0, 1, 0, 1, 0, 4, 22, 7, 4, 3, 0],
[5, 3, 4, 5, 3, 3, 1, 2, 2, 1, 1],
[1, 0, 1, 1, 0, 6, 8, 13, 6, 4, 0],
[2, 1, 1, 0, 0, 18, 6, 38, 5, 8, 0],
[1, 0, 3, 2, 2, 5, 8, 4, 3, 4, 0],
[5, 18, 6, 5, 5, 0, 1, 0, 0, 1, 1],
[0, 1, 2, 1, 3, 4, 6, 12, 2, 13, 0],
[0, 1, 1, 1, 0, 7, 15, 6, 4, 7, 0],
[0, 0, 1, 1, 1, 5, 18, 7, 3, 6, 0],
[6, 7, 5, 9, 4, 1, 1, 0, 1, 1, 1],
[0, 3, 0, 0, 0, 8, 6, 7, 19, 2, 0],
[0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0],
[0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0],
[0, 1, 2, 0, 0, 19, 28, 10, 7, 5, 0],
[17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1],
[20, 9, 7, 8, 14, 0, 0, 0, 0, 0, 1],
[0, 2, 0, 1, 0, 14, 4, 40, 4, 2, 0],
[0, 0, 0, 0, 0, 5, 10, 6, 11, 11, 0],
[2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0],
[17, 3, 3, 4, 12, 2, 1, 0, 1, 1, 1],
[7, 3, 6, 5, 4, 5, 2, 1, 1, 0, 1],
[3, 3, 1, 1, 1, 4, 1, 6, 2, 3, 0],
[4, 5, 3, 2, 4, 1, 4, 1, 1, 0, 1],
[5, 7, 5, 2, 10, 1, 1, 1, 1, 2, 1],
[1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0],
[1, 0, 0, 1, 0, 21, 8, 6, 3, 9, 0],
[8, 3, 1, 3, 5, 2, 3, 0, 2, 2, 1],
[4, 1, 0, 1, 2, 20, 6, 2, 5, 3, 0],
[1, 1, 1, 0, 0, 4, 11, 6, 4, 9, 0],
[9, 3, 8, 2, 2, 2, 3, 0, 1, 1, 1],
[7, 5, 6, 4, 2, 4, 3, 1, 1, 1, 1],
[10, 6, 4, 6, 7, 0, 2, 1, 0, 1, 1],
[4, 6, 5, 8, 8, 1, 0, 1, 1, 0, 1],
[3, 4, 4, 3, 7, 2, 3, 1, 1, 2, 1],
[6, 9, 10, 4, 8, 1, 1, 1, 2, 1, 1],
[5, 9, 4, 3, 3, 1, 1, 0, 0, 1, 1],
[9, 1, 2, 2, 2, 5, 5, 3, 3, 4, 0],
[0, 0, 2, 1, 0, 9, 6, 6, 9, 7, 0],
[14, 11, 7, 4, 2, 2, 1, 1, 1, 1, 1],
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[7, 6, 7, 4, 22, 0, 1, 0, 0, 1, 1],
[10, 2, 8, 2, 2, 2, 3, 1, 4, 2, 1],
[11, 32, 15, 8, 5, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 3, 14, 7, 6, 5, 0],
[6, 19, 8, 19, 16, 0, 0, 0, 0, 0, 1],
[2, 3, 1, 0, 1, 13, 6, 5, 3, 2, 0],
[17, 17, 36, 3, 3, 0, 0, 0, 1, 1, 1],
[1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0],
[9, 3, 21, 7, 3, 1, 0, 1, 1, 0, 1],
[17, 18, 3, 19, 5, 1, 0, 0, 0, 0, 1],
[1, 3, 1, 1, 2, 11, 3, 4, 3, 4, 0],
[6, 3, 6, 3, 2, 1, 2, 3, 2, 3, 1],
[2, 2, 2, 2, 2, 4, 3, 5, 3, 6, 0],
[6, 14, 11, 4, 17, 0, 0, 1, 0, 0, 1],
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[25, 5, 3, 3, 3, 1, 1, 0, 1, 2, 1],
[6, 31, 6, 5, 8, 0, 1, 0, 0, 0, 1],
[4, 3, 3, 2, 1, 3, 4, 2, 2, 4, 1],
[1, 1, 1, 1, 0, 5, 5, 13, 13, 10, 0],
[18, 9, 18, 5, 19, 0, 0, 0, 0, 0, 1],
[0, 1, 2, 0, 0, 4, 14, 5, 3, 6, 0],
[5, 6, 2, 3, 2, 2, 2, 2, 2, 1, 1],
[23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1],
[7, 5, 8, 8, 7, 2, 1, 1, 0, 0, 1],
[1, 3, 1, 2, 2, 10, 7, 4, 3, 4, 0],
[3, 3, 2, 2, 2, 2, 6, 4, 2, 3, 1],
[10, 10, 21, 3, 4, 1, 1, 0, 1, 0, 1],
[15, 6, 11, 7, 5, 0, 1, 0, 1, 0, 1],
[11, 3, 8, 2, 5, 1, 1, 2, 1, 1, 1],
[11, 5, 11, 3, 3, 1, 2, 1, 1, 1, 1],
[8, 29, 6, 9, 7, 1, 0, 0, 0, 1, 1],
[6, 16, 5, 16, 7, 1, 0, 1, 0, 0, 1],
[3, 2, 8, 2, 1, 3, 3, 3, 1, 1, 1],
[21, 24, 4, 8, 5, 0, 0, 0, 0, 0, 1],
[0, 1, 1, 1, 1, 6, 8, 5, 17, 4, 0],
[19, 16, 12, 4, 4, 0, 0, 1, 1, 0, 1],
[3, 1, 2, 1, 1, 10, 7, 8, 3, 10, 0],
[1, 0, 0, 0, 0, 31, 11, 6, 5, 7, 0],
[16, 7, 18, 8, 2, 0, 1, 1, 0, 0, 1],
[0, 1, 0, 0, 0, 7, 13, 11, 13, 6, 0],
[0, 0, 1, 0, 0, 7, 5, 19, 8, 10, 0],
[17, 10, 10, 9, 5, 1, 0, 0, 0, 0, 1],
[2, 1, 0, 1, 1, 14, 10, 6, 5, 5, 0],
[12, 7, 4, 5, 6, 0, 1, 1, 0, 0, 1],
[1, 1, 1, 0, 0, 8, 13, 9, 9, 6, 0],
[1, 2, 1, 1, 1, 18, 7, 9, 3, 1, 0],
[7, 2, 3, 4, 2, 3, 5, 3, 3, 2, 1],
[2, 3, 3, 5, 3, 4, 5, 3, 2, 1, 0],
[10, 5, 7, 3, 3, 1, 2, 1, 1, 1, 1],
[8, 9, 5, 5, 2, 2, 2, 3, 0, 0, 1],
[4, 3, 10, 3, 1, 1, 3, 3, 4, 1, 1],
[7, 3, 8, 3, 8, 1, 1, 1, 2, 1, 1],
[1, 3, 4, 1, 3, 2, 2, 5, 2, 1, 1],
[1, 1, 0, 0, 1, 7, 5, 13, 10, 6, 0],
[6, 5, 4, 2, 2, 1, 4, 2, 1, 1, 1],
[1, 3, 1, 2, 1, 16, 4, 4, 1, 5, 0],
[4, 6, 4, 7, 3, 2, 1, 2, 2, 1, 1],
[4, 4, 16, 5, 16, 3, 0, 1, 0, 1, 1],
[5, 10, 14, 16, 3, 0, 0, 0, 0, 0, 1],
[23, 5, 2, 3, 4, 4, 2, 1, 2, 1, 1],
[2, 1, 2, 1, 3, 2, 5, 5, 3, 3, 0],
[5, 5, 8, 4, 2, 3, 3, 1, 1, 2, 1],
[3, 5, 3, 2, 4, 2, 3, 3, 0, 2, 1],
[6, 19, 4, 4, 8, 0, 1, 0, 0, 0, 1],
[1, 2, 0, 1, 1, 19, 3, 6, 9, 3, 0],
[2, 2, 2, 0, 1, 4, 10, 9, 3, 5, 0],
[2, 0, 0, 0, 0, 13, 6, 23, 9, 11, 0],
[5, 34, 14, 14, 4, 0, 0, 1, 1, 0, 1],
[19, 2, 5, 2, 21, 0, 6, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 21, 18, 20, 6, 9, 0],
[13, 16, 22, 5, 13, 0, 0, 0, 0, 0, 1],
[6, 1, 1, 1, 0, 5, 12, 6, 5, 2, 0],
[5, 12, 4, 10, 4, 2, 1, 1, 1, 1, 1],
[17, 18, 6, 8, 10, 1, 0, 0, 0, 0, 1],
[9, 15, 10, 2, 4, 2, 1, 1, 2, 3, 1],
[1, 0, 0, 0, 0, 16, 19, 10, 4, 6, 0],
[6, 5, 3, 1, 3, 3, 2, 2, 1, 1, 1],
[6, 4, 5, 8, 8, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 4, 5, 22, 3, 4, 0],
[2, 2, 1, 0, 1, 7, 11, 6, 2, 3, 0],
[6, 10, 3, 2, 5, 1, 4, 0, 1, 0, 1],
[2, 2, 5, 2, 1, 4, 4, 3, 4, 2, 1],
[2, 1, 0, 1, 1, 5, 12, 7, 7, 4, 0],
[8, 22, 4, 4, 3, 1, 1, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 26, 15, 15, 7, 9, 0],
[10, 5, 3, 3, 3, 2, 2, 3, 1, 1, 1],
[0, 0, 0, 0, 0, 21, 9, 21, 13, 2, 0],
[13, 4, 7, 7, 5, 1, 1, 1, 1, 0, 1],
[12, 25, 6, 2, 6, 5, 1, 0, 0, 1, 1],
[17, 4, 1, 1, 1, 4, 4, 8, 1, 4, 0],
[0, 0, 0, 0, 0, 10, 22, 19, 5, 9, 0],
[2, 5, 4, 2, 2, 9, 2, 2, 1, 1, 0],
]
scaler = StandardScaler()
scaled_data = scaler.fit_transform(training)
X = []
for t in training:
X.append(t[:-1])
y = [last for *_, last in training]
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)
rfc = RandomForestClassifier(max_depth=10, min_samples_split=2, n_estimators=50)
rfc.fit(X_train, y_train)
dtc = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1, min_samples_split=5)
dtc.fit(X_train, y_train)
lr = LogisticRegression()
lr.fit(X_train, y_train)
# Predict the probabilities of the test data for both classifiers
y_prob_rfc = rfc.predict_proba(X_test)[:, 1]
y_prob_dtc = dtc.predict_proba(X_test)[:, 1]
y_prob_lr = lr.predict_proba(X_test)[:, 1]
# Calculate the false positive rate, true positive rate, and thresholds for both classifiers
fpr_rfc, tpr_rfc, thresholds_rfc = roc_curve(y_test, y_prob_rfc)
fpr_dtc, tpr_dtc, thresholds_dtc = roc_curve(y_test, y_prob_dtc)
fpr_lr, tpr_lr, thresholds_lr = roc_curve(y_test, y_prob_lr)
# Calculate the area under the ROC curve for both classifiers
roc_auc_rfc = roc_auc_score(y_test, y_prob_rfc)
roc_auc_dtc = roc_auc_score(y_test, y_prob_dtc)
roc_auc_lr = roc_auc_score(y_test, y_prob_lr)
# Plot the ROC curves for both classifiers
plt.plot(
fpr_rfc, tpr_rfc, color="green", label="Random Forest (area = %0.2f)" % roc_auc_rfc
)
plt.plot(
fpr_dtc, tpr_dtc, color="blue", label="Decision Tree (area = %0.2f)" % roc_auc_dtc
)
# plt.plot(
# fpr_lr, tpr_lr, color="yellow", label="Logistic Regression (area = %0.2f)" % roc_auc_lr
# )
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curves")
plt.legend(loc="lower right")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment