Skip to content

Instantly share code, notes, and snippets.

@djgagne
Created October 18, 2015 03:47
Show Gist options
  • Save djgagne/64516e3ea268ec31fb34 to your computer and use it in GitHub Desktop.
Save djgagne/64516e3ea268ec31fb34 to your computer and use it in GitHub Desktop.
Performance Diagram plotting code
import numpy as np
import pandas as pd
class DistributedROC(object):
"""
Store statistics for calculating receiver operating characteristic (ROC) curves and performance diagrams and permit
easy aggregation of ROC curves from many small datasets.
Parameters
----------
thresholds : numpy.ndarray of floats
List of probability thresholds in increasing order.
obs_threshold : float
Observation value used as the split point for determining positives.
input_str : str
String in the format output by the __str__ method so that initialization of the object can be done
from items in a text file.
"""
def __init__(self, thresholds=None, obs_threshold=None, input_str=None):
self.thresholds = thresholds
self.obs_threshold = obs_threshold
if self.thresholds is not None:
self.contingency_tables = pd.DataFrame(np.zeros((thresholds.size, 4), dtype=int),
columns=["TP", "FP", "FN", "TN"])
else:
self.contingency_tables = pd.DataFrame(columns=["TP", "FP", "FN", "TN"])
if input_str is not None:
self.from_str(input_str)
def update(self, forecasts, observations):
"""
Update the ROC curve with a set of forecasts and observations
:param forecasts: 1D array of forecast values
:param observations: 1D array of observation values.
:return:
"""
for t, threshold in enumerate(self.thresholds):
tp = np.count_nonzero((forecasts >= threshold)
& (observations >= self.obs_threshold))
fp = np.count_nonzero((forecasts >= threshold)
& (observations < self.obs_threshold))
fn = np.count_nonzero((forecasts < threshold)
& (observations >= self.obs_threshold))
tn = np.count_nonzero((forecasts < threshold)
& (observations < self.obs_threshold))
self.contingency_tables.ix[t] += [tp, fp, fn, tn]
def __add__(self, other):
"""
Add two DistributedROC objects together and combine their contingency table values.
:param other: Another DistributedROC object.
:return:
"""
sum_roc = DistributedROC(self.thresholds, self.obs_threshold)
sum_roc.contingency_tables = self.contingency_tables + other.contingency_tables
return sum_roc
def merge(self, other_roc):
"""
Ingest the values of another DistributedROC object into this one and update the statistics inplace.
:param other_roc: another DistributedROC object.
:return:
"""
if other_roc.thresholds.size == self.thresholds.size and np.all(other_roc.thresholds == self.thresholds):
self.contingency_tables += other_roc.contingency_tables
else:
print("Input table thresholds do not match.")
def roc_curve(self):
"""
Generate a ROC curve from the contingency table by calculating the probability of detection (TP/(TP+FN)) and the
probability of false detection (FP/(FP+TN)).
:return: A pandas.DataFrame containing the POD, POFD, and the corresponding probability thresholds.
"""
pod = self.contingency_tables["TP"].astype(float) / (self.contingency_tables["TP"] +
self.contingency_tables["FN"])
pofd = self.contingency_tables["FP"].astype(float) / (self.contingency_tables["FP"] +
self.contingency_tables["TN"])
return pd.DataFrame({"POD": pod, "POFD": pofd, "Thresholds": self.thresholds},
columns=["POD", "POFD", "Thresholds"])
def performance_curve(self):
"""
Calculate the Probability of Detection and False Alarm Ratio in order to output a performance diagram.
:return: pandas.DataFrame containing POD, FAR, and probability thresholds.
"""
pod = self.contingency_tables["TP"] / (self.contingency_tables["TP"] + self.contingency_tables["FN"])
far = self.contingency_tables["FP"] / (self.contingency_tables["FP"] + self.contingency_tables["TP"])
return pd.DataFrame({"POD": pod, "FAR": far, "Thresholds": self.thresholds},
columns=["POD", "FAR", "Thresholds"])
def auc(self):
"""
Calculate the Area Under the ROC Curve (AUC).
:return:
"""
roc_curve = self.roc_curve()
return np.abs(np.trapz(roc_curve['POD'], x=roc_curve['POFD']))
def __str__(self):
"""
Output the information within the DistributedROC object to a string.
:return:
"""
out_str = "Obs_Threshold:{0:0.2f}".format(self.obs_threshold) + ";"
out_str += "Thresholds:" + " ".join(["{0:0.2f}".format(t) for t in self.thresholds]) + ";"
for col in self.contingency_tables.columns:
out_str += col + ":" + " ".join(["{0:d}".format(t) for t in self.contingency_tables[col]]) + ";"
out_str = out_str.rstrip(";")
return out_str
def __repr__(self):
return self.__str__()
def from_str(self, in_str):
"""
Read the object string and parse the contingency table values from it.
:param in_str:
:return:
"""
parts = in_str.split(";")
for part in parts:
var_name, value = part.split(":")
if var_name == "Obs_Threshold":
self.obs_threshold = float(value)
elif var_name == "Thresholds":
self.thresholds = np.array(value.split(), dtype=float)
elif var_name in ["TP", "FP", "FN", "TN"]:
self.contingency_tables[var_name] = np.array(value.split(), dtype=int)
import matplotlib.pyplot as plt
import numpy as np
from DistributedROC import DistributedROC
def performance_diagram(roc_objs, obj_labels, colors, markers, filename, figsize=(9, 8), xlabel="Success Ratio (1-FAR)",
ylabel="Probability of Detection", ticks=np.arange(0, 1.1, 0.1), dpi=300, csi_cmap="Blues",
csi_label="Critical Success Index", title="Performance Diagram",
legend_params=dict(loc=4, fontsize=12, framealpha=1, frameon=True)):
"""
Draws a performance diagram from a set of DistributedROC objects.
:param roc_objs: list or array of DistributedROC Objects.
:param obj_labels: list or array of labels describing each DistributedROC object.
:param colors: list of color strings
:param markers: list of markers.
:param filename: output filename.
:param figsize: tuple with size of the figure in inches.
:param xlabel: Label for the x-axis
:param ylabel: Label for the y-axis
:param ticks: Array of ticks used for x and y axes
:param dpi: DPI of the output image
:param csi_cmap: Colormap used for the CSI contours
:param csi_label: Label for the CSI colorbar
:return:
"""
plt.figure(figsize=figsize)
grid_ticks = np.arange(0, 1.01, 0.01)
sr_g, pod_g = np.meshgrid(grid_ticks, grid_ticks)
bias = pod_g / sr_g
csi = 1.0 / (1.0 / sr_g + 1.0 / pod_g - 1.0)
csi_contour = plt.contourf(sr_g, pod_g, csi, np.arange(0.1, 1.1, 0.1), extend="max", cmap=csi_cmap)
b_contour = plt.contour(sr_g, pod_g, bias, [0.5, 1, 1.5, 2, 4], colors="k", linestyles="dashed")
plt.clabel(b_contour, fmt="%1.1f", manual=[(0.2, 0.9), (0.4, 0.9), (0.6, 0.9), (0.7, 0.7)])
for r, roc_obj in enumerate(roc_objs):
perf_data = roc_obj.performance_curve()
plt.plot(1 - perf_data["FAR"], perf_data["POD"], marker=markers[r], color=colors[r],
label=obj_labels[r].replace("_dist", "").replace("-", " ").replace("_", " "))
cbar = plt.colorbar(csi_contour)
cbar.set_label(csi_label, fontsize=14)
plt.xlabel(xlabel, fontsize=14)
plt.ylabel(ylabel, fontsize=14)
plt.xticks(ticks)
plt.yticks(ticks)
plt.title(title, fontsize=14, fontweight="bold")
plt.text(0.48,0.6,"Frequency Bias",fontdict=dict(fontsize=14, rotation=45))
plt.legend(**legend_params)
plt.savefig(filename, dpi=dpi, bbox_inches="tight")
@Jawad1061032
Copy link

Can you please tell me what is the type of roc_objs data? If you please show any example I shall be very thankful

@kresguerra02
Copy link

Same query. what kind of data format is needed to run this code. I would gladly appreciate if there's a lot of example or instruction on how to use this. Thank you.

@djgagne
Copy link
Author

djgagne commented Jun 26, 2023

@kresguerra02 An updated version of this code can be found at https://github.com/djgagne/hagelslag/blob/master/hagelslag/evaluation/ProbabilityMetrics.py. It includes a code example. The data just needs to be in numpy arrays for the forecasts and observations.

@kresguerra02
Copy link

Hi, I have reviewed the code, my question is what is the difference of threshold and obs_threshold?

image

@djgagne
Copy link
Author

djgagne commented Jun 30, 2023

thresholds are the forecast values at which the contingency table is calculated assuming that values >= threshold are positive cases and values < threshold are negative cases. obs_threshold is used to split a continuous dataset into a binary label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment