Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Created July 16, 2021 09:56
Show Gist options
  • Save Chris-hughes10/0cf9030076346b2b38db2279ccd91e90 to your computer and use it in GitHub Desktop.
Save Chris-hughes10/0cf9030076346b2b38db2279ccd91e90 to your computer and use it in GitHub Desktop.
from objdetecteval.metrics.coco_metrics import get_coco_stats
@patch
def validation_epoch_end(self: EfficientDetModel, outputs):
"""Compute and log training loss and accuracy at the epoch level."""
validation_loss_mean = torch.stack(
[output["loss"] for output in outputs]
).mean()
(
predicted_class_labels,
image_ids,
predicted_bboxes,
predicted_class_confidences,
targets,
) = self.aggregate_prediction_outputs(outputs)
truth_image_ids = [target["image_id"].detach().item() for target in targets]
truth_boxes = [
target["bboxes"].detach()[:, [1, 0, 3, 2]].tolist() for target in targets
] # convert to xyxy for evaluation
truth_labels = [target["labels"].detach().tolist() for target in targets]
stats = get_coco_stats(
prediction_image_ids=image_ids,
predicted_class_confidences=predicted_class_confidences,
predicted_bboxes=predicted_bboxes,
predicted_class_labels=predicted_class_labels,
target_image_ids=truth_image_ids,
target_bboxes=truth_boxes,
target_class_labels=truth_labels,
)['All']
return {"val_loss": validation_loss_mean, "metrics": stats}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment