Skip to content

Instantly share code, notes, and snippets.

@primaryobjects
Created July 15, 2024 02:22
Show Gist options
  • Save primaryobjects/d6d7f0dacea993df319c648ec4e6c3a5 to your computer and use it in GitHub Desktop.
Save primaryobjects/d6d7f0dacea993df319c648ec4e6c3a5 to your computer and use it in GitHub Desktop.
Machine learning tutorial with confusion matrix calculations in Python https://replit.com/@primaryobjects/MachineLearning
# make predictions
from pandas import read_csv
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
# Load dataset
url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/iris.csv"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = read_csv(url, names=names)
# Split-out validation dataset
array = dataset.values
X = array[:, 0:4]
y = array[:, 4]
X_train, X_validation, Y_train, Y_validation = train_test_split(X,
y,
test_size=0.20,
random_state=1)
# Make predictions on validation dataset
model = SVC(gamma='auto')
model.fit(X_train, Y_train)
predictions = model.predict(X_validation)
# Evaluate predictions
print(accuracy_score(Y_validation, predictions))
cm = confusion_matrix(Y_validation, predictions)
print(cm)
print(classification_report(Y_validation, predictions))
def calculate_metrics(confusion_matrix):
FP = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
FN = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
TP = np.diag(confusion_matrix)
TN = confusion_matrix.sum() - (FP + FN + TP)
return TP, TN, FP, FN
# https://stackoverflow.com/a/43331484
# True Positives = the diagonal of the confusion matrix, top-left to bottom-right.
tps = [cm[0, 0], cm[1, 1], cm[2, 2]]
# True negatives = everything except for the row and column for the class.
tns = [
cm[1, 1] + cm[1, 2] + cm[2, 1] + cm[2, 2],
cm[0, 0] + cm[2, 0] + cm[0, 2] + cm[2, 2],
cm[0, 0] + cm[0, 1] + cm[1, 0] + cm[1, 1]
]
# False positives = everything in the column for the class, except for the true positive value.
fps = [cm[1, 0] + cm[2, 0], cm[0, 1] + cm[2, 1], cm[0, 2] + cm[1, 2]]
# False negatives = everything in the row for the class, except for the true positive value.
fns = [cm[0, 1] + cm[0, 2], cm[1, 0] + cm[1, 2], cm[2, 0] + cm[2, 1]]
for i in tps:
print(i, end=' ')
print()
for i in tns:
print(i, end=' ')
print()
for i in fps:
print(i, end=' ')
print()
for i in fns:
print(i, end=' ')
print()
tps1, tns1, fps1, fns1 = calculate_metrics(cm)
print(f'TP: {tps1}, TN: {tns1}, FP: {fps1}, FN: {fns1}')
precision1 = tps[0] / (tps[0] + fps[0])
recall1 = tps[0] / (tps[0] + fns[0])
f11 = (2 * precision1 * recall1) / (precision1 + recall1)
print(f'precision1: {precision1}, recall1: {recall1}, f11: {f11}')
precision2 = tps[1] / (tps[1] + fps[1])
recall2 = tps[1] / (tps[1] + fns[1])
f12 = (2 * precision2 * recall2) / (precision2 + recall2)
print(f'precision2: {precision2}, recall2: {recall2}, f12: {f12}')
# Plot the confusion matrix.
sns.heatmap(cm, annot=True)
plt.ylabel('Prediction', fontsize=13)
plt.xlabel('Actual', fontsize=13)
plt.title('Confusion Matrix', fontsize=17)
plt.show()

See also https://machinelearningmastery.com/machine-learning-in-python-step-by-step/

Output

0.9666666666666667
[[11  0  0]
 [ 0 12  1]
 [ 0  0  6]]
                 precision    recall  f1-score   support

    Iris-setosa       1.00      1.00      1.00        11
Iris-versicolor       1.00      0.92      0.96        13
 Iris-virginica       0.86      1.00      0.92         6

       accuracy                           0.97        30
      macro avg       0.95      0.97      0.96        30
   weighted avg       0.97      0.97      0.97        30

11 12 6 
19 17 23 
0 0 1 
0 1 0 
TP: [11 12  6], TN: [19 17 23], FP: [0 0 1], FN: [0 1 0]
precision1: 1.0, recall1: 1.0, f11: 1.0
precision2: 1.0, recall2: 0.9230769230769231, f12: 0.9600000000000001
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment