Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Created December 9, 2021 07:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chris-hughes10/3cb7b46b24da2dfb5874a39abcfce9d2 to your computer and use it in GitHub Desktop.
Save Chris-hughes10/3cb7b46b24da2dfb5874a39abcfce9d2 to your computer and use it in GitHub Desktop.
Recommender Blog: Torchmetrics recommender callback
from pytorch_accelerated.callbacks import TrainerCallback
import torchmetrics
class RecommenderMetricsCallback(TrainerCallback):
def __init__(self):
self.metrics = torchmetrics.MetricCollection(
{
"mse": torchmetrics.MeanSquaredError(),
"mae": torchmetrics.MeanAbsoluteError(),
}
)
def _move_to_device(self, trainer):
self.metrics.to(trainer.device)
def on_training_run_start(self, trainer, **kwargs):
self._move_to_device(trainer)
def on_evaluation_run_start(self, trainer, **kwargs):
self._move_to_device(trainer)
def on_eval_step_end(self, trainer, batch, batch_output, **kwargs):
preds = batch_output["model_outputs"]
self.metrics.update(preds, batch[1])
def on_eval_epoch_end(self, trainer, **kwargs):
metrics = self.metrics.compute()
mse = metrics["mse"].cpu()
trainer.run_history.update_metric("mae", metrics["mae"].cpu())
trainer.run_history.update_metric("mse", mse)
trainer.run_history.update_metric("rmse", math.sqrt(mse))
self.metrics.reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment