Skip to content

Instantly share code, notes, and snippets.

@geblanco
Created March 5, 2021 12:44
Show Gist options
  • Save geblanco/5cfe4a3224e021113968a8c7ebe31419 to your computer and use it in GitHub Desktop.
Save geblanco/5cfe4a3224e021113968a8c7ebe31419 to your computer and use it in GitHub Desktop.
Pretty print classification report from a previous classification report dict (Useful when saving dict to disk and later print without recalculating everything). Not tested for multi labeled data
# extracted from: https://github.com/scikit-learn/scikit-learn/blob/0fb307bf3/sklearn/metrics/_classification.py#L1825
def classification_report(data_dict):
"""Build a text report showing the main classification metrics.
Read more in the :ref:`User Guide <classification_report>`.
Parameters
----------
report : string
Text summary of the precision, recall, F1 score for each class.
Dictionary returned if output_dict is True. Dictionary has the
following structure::
{'label 1': {'precision':0.5,
'recall':1.0,
'f1-score':0.67,
'support':1},
'label 2': { ... },
...
}
The reported averages include macro average (averaging the unweighted
mean per label), weighted average (averaging the support-weighted mean
per label), and sample average (only for multilabel classification).
Micro average (averaging the total true positives, false negatives and
false positives) is only shown for multi-label or multi-class
with a subset of classes, because it corresponds to accuracy otherwise.
See also :func:`precision_recall_fscore_support` for more details
on averages.
Note that in binary classification, recall of the positive class
is also known as "sensitivity"; recall of the negative class is
"specificity".
"""
non_label_keys = ["accuracy", "macro avg", "weighted avg"]
y_type = "binary"
digits = 2
target_names = [
"%s" % key for key in data_dict.keys() if key not in non_label_keys
]
# labelled micro average
micro_is_accuracy = (y_type == "multiclass" or y_type == "binary")
headers = ["precision", "recall", "f1-score", "support"]
p = [data_dict[l][headers[0]] for l in target_names]
r = [data_dict[l][headers[1]] for l in target_names]
f1 = [data_dict[l][headers[2]] for l in target_names]
s = [data_dict[l][headers[3]] for l in target_names]
rows = zip(target_names, p, r, f1, s)
if y_type.startswith("multilabel"):
average_options = ("micro", "macro", "weighted", "samples")
else:
average_options = ("micro", "macro", "weighted")
longest_last_line_heading = "weighted avg"
name_width = max(len(cn) for cn in target_names)
width = max(name_width, len(longest_last_line_heading), digits)
head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
report = head_fmt.format("", *headers, width=width)
report += "\n\n"
row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
for row in rows:
report += row_fmt.format(*row, width=width, digits=digits)
report += "\n"
# compute all applicable averages
for average in average_options:
if average.startswith("micro") and micro_is_accuracy:
line_heading = "accuracy"
else:
line_heading = average + " avg"
if line_heading == "accuracy":
avg = [data_dict[line_heading], sum(s)]
row_fmt_accuracy = "{:>{width}s} " + \
" {:>9.{digits}}" * 2 + " {:>9.{digits}f}" + \
" {:>9}\n"
report += row_fmt_accuracy.format(line_heading, "", "",
*avg, width=width,
digits=digits)
else:
avg = list(data_dict[line_heading].values())
report += row_fmt.format(line_heading, *avg,
width=width, digits=digits)
return report
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment