Created
March 16, 2021 19:23
-
-
Save maximsch2/2b55bab6deba629a5686258cb8152e53 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Union, List | |
import torch | |
from pytorch_lightning.metrics import Metric | |
from pytorch_lightning.metrics.utils import METRIC_EPS, to_onehot | |
# From Lightning's AveragePrecision metric | |
def _average_precision_compute( | |
precision: torch.Tensor, | |
recall: torch.Tensor, | |
num_classes: int, | |
) -> Union[List[torch.Tensor], torch.Tensor]: | |
# Return the step function integral | |
# The following works because the last entry of precision is | |
# guaranteed to be 1, as returned by precision_recall_curve | |
if num_classes == 1: | |
recall = recall[0, :] | |
precision = precision[0, :] | |
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) | |
res = [] | |
for p, r in zip(precision, recall): | |
res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) | |
return res | |
# pyre-fixme[13]: Attribute `FNs` is never initialized.... | |
# pyre-fixme[13]: Attribute `FPs` is never initialized.... | |
# pyre-fixme[13]: Attribute `TPs` is never initialized.... | |
# pyre-fixme[13]: Attribute `thresholds` is never initialized.... | |
class BinnedPrecisionRecallCurve(Metric): | |
"""Returns a tensor of recalls for a fixed precision threshold. | |
It is a tensor instead of a single number, because it applies to multi-label inputs. | |
""" | |
TPs: torch.Tensor | |
FPs: torch.Tensor | |
FNs: torch.Tensor | |
thresholds: torch.Tensor | |
def __init__( | |
self, | |
num_classes: int, | |
num_thresholds: int = 100, | |
compute_on_step: bool = False, # will ignore this | |
**kwargs | |
): | |
# TODO: enable assert after changing testing code in Lightning | |
# assert not compute_on_step, "computation on each step is not supported" | |
super().__init__(compute_on_step=False, **kwargs) | |
self.num_classes = num_classes | |
self.num_thresholds = num_thresholds | |
thresholds = torch.arange(num_thresholds) / num_thresholds | |
self.register_buffer("thresholds", thresholds) | |
for name in ("TPs", "FPs", "FNs"): | |
self.add_state( | |
name=name, | |
default=torch.zeros(num_classes, num_thresholds, dtype=torch.long), | |
dist_reduce_fx="sum", | |
) | |
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: | |
""" | |
Args | |
preds: (n_samples, n_classes) tensor | |
targets: (n_samples, n_classes) tensor | |
""" | |
# binary case | |
if len(preds.shape) == len(targets.shape) == 1: | |
preds = preds.reshape(-1, 1) | |
targets = targets.reshape(-1, 1) | |
if len(preds.shape) == len(targets.shape) + 1: | |
targets = to_onehot(targets, num_classes=self.num_classes) | |
targets = targets == 1 | |
for i in range(self.num_thresholds): | |
predictions = preds >= self.thresholds[i] | |
self.TPs[:, i] += (targets & predictions).sum(dim=0) | |
self.FPs[:, i] += ((~targets) & (predictions)).sum(dim=0) | |
self.FNs[:, i] += ((targets) & (~predictions)).sum(dim=0) | |
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Returns float tensor of size n_classes""" | |
precisions = self.TPs / (self.TPs + self.FPs + METRIC_EPS) | |
recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) | |
return (precisions, recalls, self.thresholds) | |
class BinnedAveragePrecision(BinnedPrecisionRecallCurve): | |
# pyre-fixme[15]: `compute` overrides method defined in `BinnedPrecisionRecallCur... | |
def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: | |
precisions, recalls, thresholds = super().compute() | |
return _average_precision_compute(precisions, recalls, self.num_classes) | |
class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): | |
def __init__( | |
self, | |
num_classes: int, | |
min_precision: float, | |
num_thresholds: int = 100, | |
compute_on_step: bool = False, # will ignore this | |
**kwargs | |
): | |
# TODO: enable assert after changing testing code in Lightning | |
# assert not compute_on_step, "computation on each step is not supported" | |
super().__init__( | |
num_classes=num_classes, | |
num_thresholds=num_thresholds, | |
compute_on_step=compute_on_step, | |
**kwargs | |
) | |
self.min_precision = min_precision | |
# pyre-fixme[15]: `compute` overrides method defined in `BinnedPrecisionRecallCur... | |
def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Returns float tensor of size n_classes""" | |
precisions, recalls, thresholds = super().compute() | |
thresholds = thresholds.repeat(self.num_classes, 1) | |
condition = precisions >= self.min_precision | |
recalls_at_p = ( | |
torch.where( | |
condition, recalls, torch.scalar_tensor(0.0, device=condition.device) | |
) | |
.max(dim=1) | |
.values | |
) | |
thresholds_at_p = ( | |
torch.where( | |
condition, thresholds, torch.scalar_tensor(1e6, device=condition.device) | |
) | |
.min(dim=1) | |
.values | |
) | |
return (recalls_at_p, thresholds_at_p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment