Created
February 26, 2021 19:58
-
-
Save maximsch2/ed6527ae27922e01915d3e4be7e108c6 to your computer and use it in GitHub Desktop.
Binned Recall@Precision metric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import torch | |
from pytorch_lightning.metrics import Metric | |
class BinnedRecallAtFixedPrecision(Metric): | |
"""Returns a tensor of recalls for a fixed precision threshold. | |
It is a tensor instead of a single number, because it applies to multi-label inputs. | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
min_precision: float, | |
num_thresholds: int = 100, | |
compute_on_step: bool = False, # will ignore this | |
**kwargs | |
): | |
# TODO: enable assert after changing testing code in Lightning | |
# assert not compute_on_step, "computation on each step is not supported" | |
super().__init__(compute_on_step=False, **kwargs) | |
self.num_classes = num_classes | |
self.min_precision = min_precision | |
self.num_thresholds = num_thresholds | |
thresholds = torch.arange(num_thresholds) / num_thresholds | |
self.register_buffer("thresholds", thresholds) | |
for name in ["TPs", "FPs", "TNs", "FNs"]: | |
self.add_state( | |
name=name, | |
default=torch.zeros(num_classes, num_thresholds, dtype=torch.long), | |
dist_reduce_fx="sum", | |
) | |
def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None: | |
""" | |
Args | |
predictions: (n_samples, n_classes) tensor | |
targets: (n_samples, n_classes) tensor | |
""" | |
preds_reshaped = predictions.reshape( | |
predictions.shape[0], predictions.shape[1], 1 | |
).repeat(1, 1, self.num_thresholds) | |
targets_reshaped = targets.reshape( | |
targets.shape[0], targets.shape[1], 1 | |
).repeat(1, 1, self.num_thresholds) | |
predictions_reshaped = preds_reshaped >= self.thresholds | |
self.TPs += (targets_reshaped & predictions_reshaped).sum(axis=0) | |
self.FPs += ((~targets_reshaped) & (predictions_reshaped)).sum(axis=0) | |
self.TNs += ((~targets_reshaped) & (~predictions_reshaped)).sum(axis=0) | |
self.FNs += ((targets_reshaped) & (~predictions_reshaped)).sum(axis=0) | |
def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Returns float tensor of size n_classes""" | |
precisions = self.TPs / (self.TPs + self.FPs + 1e-6) | |
recalls = self.TPs / (self.TPs + self.FNs + 1e-6) | |
thresholds = self.thresholds.repeat(self.num_classes, 1) | |
condition = precisions >= self.min_precision | |
print("Metric P/R: ", precisions, recalls) | |
recalls_at_p = ( | |
torch.where(condition, recalls, torch.scalar_tensor(0.0)).max(axis=1).values | |
) | |
thresholds_at_p = ( | |
torch.where(condition, thresholds, torch.scalar_tensor(1e6)) | |
.min(axis=1) | |
.values | |
) | |
print("Metric result: ", (recalls_at_p, thresholds_at_p)) | |
return (recalls_at_p, thresholds_at_p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment