Skip to content

Instantly share code, notes, and snippets.

@rsilveira79
Forked from gpleiss/reliability_diagram.py
Created August 6, 2020 16:50
Show Gist options
  • Save rsilveira79/08df50d65c4034932d3fe8ace3753038 to your computer and use it in GitHub Desktop.
Save rsilveira79/08df50d65c4034932d3fe8ace3753038 to your computer and use it in GitHub Desktop.
Reliability diagram code
import torch
import numpy as np
from matplotlib import pyplot as plt
def make_model_diagrams(outputs, labels, n_bins=10):
"""
outputs - a torch tensor (size n x num_classes) with the outputs from the final linear layer
- NOT the softmaxes
labels - a torch tensor (size n) with the labels
"""
softmaxes = torch.nn.functional.softmax(outputs, 1)
confidences, predictions = softmaxes.max(1)
accuracies = torch.eq(predictions, labels)
f, rel_ax = plt.subplots(1, 2, figsize=(4, 2.5))
# Reliability diagram
bins = torch.linspace(0, 1, n_bins + 1)
bins[-1] = 1.0001
width = bins[1] - bins[0]
bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in zip(bins[:-1], bins[1:])]
bin_corrects = [torch.mean(accuracies[bin_index]) for bin_index in bin_indices]
bin_scores = [torch.mean(confidences[bin_index]) for bin_index in bin_scores]
confs = rel_ax.bar(bins[:-1], bin_corrects.numpy(), width=width)
gaps = rel_ax.bar(bins[:-1], (bin_scores - bin_corrects).numpy(), bottom=bin_corrects.numpy(), color=[1, 0.7, 0.7], alpha=0.5, width=width, hatch='//', edgecolor='r')
rel_ax.plot([0, 1], [0, 1], '--', color='gray')
rel_ax.legend([confs, gaps], ['Outputs', 'Gap'], loc='best', fontsize='small')
# Clean up
rel_ax.set_ylabel('Accuracy')
rel_ax.set_xlabel('Confidence')
f.tight_layout()
return f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment