Skip to content

Instantly share code, notes, and snippets.

@Eeman1113
Created May 22, 2024 20:47
Show Gist options
  • Save Eeman1113/2b53db14eb1bca16bfbc8ae7e5673562 to your computer and use it in GitHub Desktop.
Save Eeman1113/2b53db14eb1bca16bfbc8ae7e5673562 to your computer and use it in GitHub Desktop.
from .module import Module
class MSELoss(Module):
def __init__(self):
pass
def forward(self, predictions, labels):
assert labels.shape == predictions.shape, \
"Labels and predictions shape does not match: {} and {}".format(labels.shape, predictions.shape)
return ((predictions - labels) ** 2).sum() / predictions.numel
def __call__(self, *inputs):
return self.forward(*inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment