Skip to content

Instantly share code, notes, and snippets.

@bepuca
Last active June 14, 2022 15:19
Show Gist options
  • Save bepuca/efb2b2ad08ad53a882f68883e736e126 to your computer and use it in GitHub Desktop.
Save bepuca/efb2b2ad08ad53a882f68883e736e126 to your computer and use it in GitHub Desktop.
object detection error article - Mean Average Precision wrapper exposing needed interface.
# Copyright © 2022 Bernat Puig Camps
import torch
from torchmetrics.detection.mean_ap import MeanAveragePrecision
class MyMeanAveragePrecision:
"""Wrapper for the torchmetrics MeanAveragePrecision exposing API we need"""
def __init__(self, foreground_threshold):
self.device = (
torch.device("cuda:0")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.map = MeanAveragePrecision(
iou_thresholds=[foreground_threshold]
).to(self.device)
def __call__(self, targets_df, preds_df):
targets, preds = self._format_inputs(targets_df, preds_df)
self.map.update(preds=preds, target=targets)
result = self.map.compute()["map"].item()
self.map.reset()
return result
def _format_inputs(self, targets_df, preds_df):
image_ids = set(targets_df["image_id"]) | set(preds_df["image_id"])
targets, preds = [], []
for image_id in image_ids:
im_targets_df = targets_df.query("image_id == @image_id")
im_preds_df = preds_df.query("image_id == @image_id")
targets.append(
{
"boxes": torch.as_tensor(
im_targets_df[["xmin", "ymin", "xmax", "ymax"]].values,
dtype=torch.float32,
).to(self.device),
"labels": torch.as_tensor(
im_targets_df["label_id"].values, dtype=torch.int64
).to(self.device),
}
)
preds.append(
{
"boxes": torch.as_tensor(
im_preds_df[["xmin", "ymin", "xmax", "ymax"]].values,
dtype=torch.float32,
).to(self.device),
"labels": torch.as_tensor(
im_preds_df["label_id"].values, dtype=torch.int64
).to(self.device),
"scores": torch.as_tensor(
im_preds_df["score"].values, dtype=torch.float32
).to(self.device),
}
)
return targets, preds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment