Skip to content

Instantly share code, notes, and snippets.

@maximsch2
Created March 16, 2021 19:23
Show Gist options
  • Save maximsch2/2b55bab6deba629a5686258cb8152e53 to your computer and use it in GitHub Desktop.
Save maximsch2/2b55bab6deba629a5686258cb8152e53 to your computer and use it in GitHub Desktop.
from typing import Tuple, Union, List
import torch
from pytorch_lightning.metrics import Metric
from pytorch_lightning.metrics.utils import METRIC_EPS, to_onehot
# From Lightning's AveragePrecision metric
def _average_precision_compute(
precision: torch.Tensor,
recall: torch.Tensor,
num_classes: int,
) -> Union[List[torch.Tensor], torch.Tensor]:
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
if num_classes == 1:
recall = recall[0, :]
precision = precision[0, :]
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])
res = []
for p, r in zip(precision, recall):
res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1]))
return res
# pyre-fixme[13]: Attribute `FNs` is never initialized....
# pyre-fixme[13]: Attribute `FPs` is never initialized....
# pyre-fixme[13]: Attribute `TPs` is never initialized....
# pyre-fixme[13]: Attribute `thresholds` is never initialized....
class BinnedPrecisionRecallCurve(Metric):
"""Returns a tensor of recalls for a fixed precision threshold.
It is a tensor instead of a single number, because it applies to multi-label inputs.
"""
TPs: torch.Tensor
FPs: torch.Tensor
FNs: torch.Tensor
thresholds: torch.Tensor
def __init__(
self,
num_classes: int,
num_thresholds: int = 100,
compute_on_step: bool = False, # will ignore this
**kwargs
):
# TODO: enable assert after changing testing code in Lightning
# assert not compute_on_step, "computation on each step is not supported"
super().__init__(compute_on_step=False, **kwargs)
self.num_classes = num_classes
self.num_thresholds = num_thresholds
thresholds = torch.arange(num_thresholds) / num_thresholds
self.register_buffer("thresholds", thresholds)
for name in ("TPs", "FPs", "FNs"):
self.add_state(
name=name,
default=torch.zeros(num_classes, num_thresholds, dtype=torch.long),
dist_reduce_fx="sum",
)
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
"""
Args
preds: (n_samples, n_classes) tensor
targets: (n_samples, n_classes) tensor
"""
# binary case
if len(preds.shape) == len(targets.shape) == 1:
preds = preds.reshape(-1, 1)
targets = targets.reshape(-1, 1)
if len(preds.shape) == len(targets.shape) + 1:
targets = to_onehot(targets, num_classes=self.num_classes)
targets = targets == 1
for i in range(self.num_thresholds):
predictions = preds >= self.thresholds[i]
self.TPs[:, i] += (targets & predictions).sum(dim=0)
self.FPs[:, i] += ((~targets) & (predictions)).sum(dim=0)
self.FNs[:, i] += ((targets) & (~predictions)).sum(dim=0)
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns float tensor of size n_classes"""
precisions = self.TPs / (self.TPs + self.FPs + METRIC_EPS)
recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS)
return (precisions, recalls, self.thresholds)
class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
# pyre-fixme[15]: `compute` overrides method defined in `BinnedPrecisionRecallCur...
def compute(self) -> Union[List[torch.Tensor], torch.Tensor]:
precisions, recalls, thresholds = super().compute()
return _average_precision_compute(precisions, recalls, self.num_classes)
class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
def __init__(
self,
num_classes: int,
min_precision: float,
num_thresholds: int = 100,
compute_on_step: bool = False, # will ignore this
**kwargs
):
# TODO: enable assert after changing testing code in Lightning
# assert not compute_on_step, "computation on each step is not supported"
super().__init__(
num_classes=num_classes,
num_thresholds=num_thresholds,
compute_on_step=compute_on_step,
**kwargs
)
self.min_precision = min_precision
# pyre-fixme[15]: `compute` overrides method defined in `BinnedPrecisionRecallCur...
def compute(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns float tensor of size n_classes"""
precisions, recalls, thresholds = super().compute()
thresholds = thresholds.repeat(self.num_classes, 1)
condition = precisions >= self.min_precision
recalls_at_p = (
torch.where(
condition, recalls, torch.scalar_tensor(0.0, device=condition.device)
)
.max(dim=1)
.values
)
thresholds_at_p = (
torch.where(
condition, thresholds, torch.scalar_tensor(1e6, device=condition.device)
)
.min(dim=1)
.values
)
return (recalls_at_p, thresholds_at_p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment