Skip to content

Instantly share code, notes, and snippets.

@VineethKumar7
Created February 21, 2024 14:32
Show Gist options
  • Save VineethKumar7/7a72aa494507d77ca29eac798169fb1b to your computer and use it in GitHub Desktop.
Save VineethKumar7/7a72aa494507d77ca29eac798169fb1b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
def calculate_mse_loss(model: nn.Module, dataloader: DataLoader, max_samples: int = None, device: torch.device = torch.device("cpu")) -> float:
"""
Calculate the Mean Squared Error (MSE) loss between the predictions of a model and the actual target values
in a dataset.
Parameters:
- model (nn.Module): The neural network model to evaluate.
- dataloader (DataLoader): The DataLoader providing the dataset for evaluation.
- max_samples (int, optional): The maximum number of samples to evaluate. If None, evaluates the entire dataset.
- device (torch.device): The device (CPU or GPU) to perform the calculations on.
Returns:
- float: The average MSE loss over the evaluated samples.
"""
model.to(device)
total_loss = 0.0
total_samples = 0
mse_loss_function = nn.MSELoss(reduction="sum")
with torch.no_grad():
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
predictions = model(inputs)
loss = mse_loss_function(predictions, targets)
total_loss += loss.item()
total_samples += inputs.size(0)
if max_samples and total_samples >= max_samples:
total_samples = min(total_samples, max_samples) # Ensure not to exceed max_samples
break
average_loss = total_loss / total_samples
return average_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment