-
-
Save Templarrr/e40059c00b7d65f1f2c04f85ebb44c17 to your computer and use it in GitHub Desktop.
Confusion chart metrics from matrix and class values
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| def calc_class_metrics(class_labels, full_confusion_matrix): | |
| """ | |
| Calculate per-class metrics. | |
| Parameters | |
| ---------- | |
| class_labels : list | |
| List of all class labels, ordered by their indices. | |
| full_confusion_matrix : CxC matrix | |
| Confusion matrix for all classes. | |
| Returns | |
| ------- | |
| metrics: list[dict] | |
| List of class metrics in one vs all form. | |
| """ | |
| total_count = np.sum(full_confusion_matrix) | |
| actual_counts = np.sum(full_confusion_matrix, axis=1) | |
| predicted_counts = np.sum(full_confusion_matrix, axis=0) | |
| class_count = len(class_labels) | |
| # We do a lot of iterations on this and for ??? reason iteration directly on range is slower | |
| class_idx_list = list(range(class_count)) | |
| result = [] | |
| for class_idx in class_idx_list: | |
| # Fill basic values | |
| # Convert all numpy types to native Python types via .item() | |
| class_metrics = { | |
| 'actual_count': actual_counts[class_idx].item(), | |
| 'predicted_count': predicted_counts[class_idx].item(), | |
| 'class_name': class_labels[class_idx], | |
| } | |
| # Base values for 2x2 confusion matrix and metrics | |
| true_negatives = ( | |
| total_count | |
| - actual_counts[class_idx] | |
| - predicted_counts[class_idx] | |
| + full_confusion_matrix[class_idx, class_idx] | |
| ).item() | |
| false_positives = ( | |
| predicted_counts[class_idx] - full_confusion_matrix[class_idx, class_idx] | |
| ).item() | |
| false_negatives = ( | |
| actual_counts[class_idx] - full_confusion_matrix[class_idx, class_idx] | |
| ).item() | |
| true_positives = (full_confusion_matrix[class_idx, class_idx]).item() | |
| class_metrics['confusion_matrix_one_vs_all'] = [ | |
| [true_negatives, false_positives], | |
| [false_negatives, true_positives], | |
| ] | |
| if true_positives + false_positives > 0: | |
| class_metrics['precision'] = true_positives / (true_positives + false_positives) | |
| else: | |
| class_metrics['precision'] = 0.0 | |
| if true_positives + false_negatives > 0: | |
| class_metrics['recall'] = true_positives / (true_positives + false_negatives) | |
| else: | |
| class_metrics['recall'] = 0.0 | |
| if class_metrics['precision'] + class_metrics['recall'] > 0: | |
| class_metrics['f1'] = ( | |
| 2 | |
| * class_metrics['precision'] | |
| * class_metrics['recall'] | |
| / (class_metrics['precision'] + class_metrics['recall']) | |
| ) | |
| else: | |
| class_metrics['f1'] = 0.0 | |
| if actual_counts[class_idx] > 0: | |
| actual_percentages = full_confusion_matrix[class_idx, :] / actual_counts[class_idx] | |
| else: | |
| actual_percentages = np.zeros(shape=(class_count,), dtype=int) | |
| class_metrics['was_actual_percentages'] = [ | |
| { | |
| 'other_class_name': class_labels[other_cls_label], | |
| 'percentage': actual_percentages[other_cls_label].item(), | |
| } | |
| for other_cls_label in class_idx_list | |
| ] | |
| if predicted_counts[class_idx] > 0: | |
| predicted_percentages = ( | |
| full_confusion_matrix[:, class_idx] / predicted_counts[class_idx] | |
| ) | |
| else: | |
| predicted_percentages = np.zeros(shape=(class_count,), dtype=int) | |
| class_metrics['was_predicted_percentages'] = [ | |
| { | |
| 'other_class_name': class_labels[other_cls_label], | |
| 'percentage': predicted_percentages[other_cls_label].item(), | |
| } | |
| for other_cls_label in class_idx_list | |
| ] | |
| result.append(class_metrics) | |
| return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment