Skip to content

Instantly share code, notes, and snippets.

@maximsch2
Created February 26, 2021 19:58
Show Gist options
  • Save maximsch2/ed6527ae27922e01915d3e4be7e108c6 to your computer and use it in GitHub Desktop.
Save maximsch2/ed6527ae27922e01915d3e4be7e108c6 to your computer and use it in GitHub Desktop.
Binned Recall@Precision metric
from typing import Tuple
import torch
from pytorch_lightning.metrics import Metric
class BinnedRecallAtFixedPrecision(Metric):
"""Returns a tensor of recalls for a fixed precision threshold.
It is a tensor instead of a single number, because it applies to multi-label inputs.
"""
def __init__(
self,
num_classes: int,
min_precision: float,
num_thresholds: int = 100,
compute_on_step: bool = False, # will ignore this
**kwargs
):
# TODO: enable assert after changing testing code in Lightning
# assert not compute_on_step, "computation on each step is not supported"
super().__init__(compute_on_step=False, **kwargs)
self.num_classes = num_classes
self.min_precision = min_precision
self.num_thresholds = num_thresholds
thresholds = torch.arange(num_thresholds) / num_thresholds
self.register_buffer("thresholds", thresholds)
for name in ["TPs", "FPs", "TNs", "FNs"]:
self.add_state(
name=name,
default=torch.zeros(num_classes, num_thresholds, dtype=torch.long),
dist_reduce_fx="sum",
)
def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None:
"""
Args
predictions: (n_samples, n_classes) tensor
targets: (n_samples, n_classes) tensor
"""
preds_reshaped = predictions.reshape(
predictions.shape[0], predictions.shape[1], 1
).repeat(1, 1, self.num_thresholds)
targets_reshaped = targets.reshape(
targets.shape[0], targets.shape[1], 1
).repeat(1, 1, self.num_thresholds)
predictions_reshaped = preds_reshaped >= self.thresholds
self.TPs += (targets_reshaped & predictions_reshaped).sum(axis=0)
self.FPs += ((~targets_reshaped) & (predictions_reshaped)).sum(axis=0)
self.TNs += ((~targets_reshaped) & (~predictions_reshaped)).sum(axis=0)
self.FNs += ((targets_reshaped) & (~predictions_reshaped)).sum(axis=0)
def compute(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns float tensor of size n_classes"""
precisions = self.TPs / (self.TPs + self.FPs + 1e-6)
recalls = self.TPs / (self.TPs + self.FNs + 1e-6)
thresholds = self.thresholds.repeat(self.num_classes, 1)
condition = precisions >= self.min_precision
print("Metric P/R: ", precisions, recalls)
recalls_at_p = (
torch.where(condition, recalls, torch.scalar_tensor(0.0)).max(axis=1).values
)
thresholds_at_p = (
torch.where(condition, thresholds, torch.scalar_tensor(1e6))
.min(axis=1)
.values
)
print("Metric result: ", (recalls_at_p, thresholds_at_p))
return (recalls_at_p, thresholds_at_p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment