Skip to content

Instantly share code, notes, and snippets.

@shamilnabiyev
Last active November 2, 2022 10:43
Show Gist options
  • Save shamilnabiyev/b2bb681e783d8fe48054df161c9bb29a to your computer and use it in GitHub Desktop.
Save shamilnabiyev/b2bb681e783d8fe48054df161c9bb29a to your computer and use it in GitHub Desktop.
Confusion matrix for cross validation. Results are being saved as mlflow artifacts and retrieved later for evaluation purposes.
import mlflow
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import StratifiedKFold


SEED = 42
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("cross-validation-average")

# Data preparation
dataset = load_breast_cancer()
X = dataset.data
y = dataset.target

print(X.shape, y.shape)

# Model training
skf = StratifiedKFold(n_splits=5, random_state=SEED, shuffle=True)

for train_index, test_index in skf.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
    
    with mlflow.start_run():
        mlflow.sklearn.autolog(exclusive=False)
        model = RandomForestClassifier(random_state=SEED)
        model.fit(X_train, y_train)
        
        y_pred = model.predict(X_test)
        mlflow.log_dict({"y_test": [int(x) for x in y_test],
                         "y_pred": [int(x) for x in y_pred]
                        }, "ytest-ypred.json")
        
        test_acc = accuracy_score(y_test, y_pred)
        mlflow.log_metric("test_accuracy", test_acc)
        print("test_accuracy:", test_acc)

        test_precision, test_recall, test_f1, _ = precision_recall_fscore_support(
            y_test, 
            y_pred, 
            average='binary'
        )
        mlflow.log_metric("test_precision", test_precision)
        mlflow.log_metric("test_recall", test_recall)
        mlflow.log_metric("test_f1_score", test_f1)
        
        tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
        mlflow.log_metric("tn", tn)
        mlflow.log_metric("fp", fp)
        mlflow.log_metric("fn", fn)
        mlflow.log_metric("tp", tp)
        
        tn, fp, fn, tp = confusion_matrix(y_test, y_pred, normalize="true").ravel()
        mlflow.log_metric("tn_normalized", tn)
        mlflow.log_metric("fp_normalized", fp)
        mlflow.log_metric("fn_normalized", fn)
        mlflow.log_metric("tp_normalized", tp)
        
        mlflow.sklearn.autolog(disable=True)
  
# Results evaluation
runs = mlflow.search_runs(experiment_ids=["2"])
columns = ['metrics.tn_normalized', 'metrics.fp_normalized', 'metrics.fn_normalized', 'metrics.tp_normalized']
mean_confusion_matrix = runs[columns].mean()
print(mean_confusion_matrix)
@shamilnabiyev
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment