Skip to content

Instantly share code, notes, and snippets.

@alex-vasilchenko-md
Last active June 20, 2022 17:15
Show Gist options
  • Save alex-vasilchenko-md/79ac41a33e50dffb7792edd7ff7779d5 to your computer and use it in GitHub Desktop.
Save alex-vasilchenko-md/79ac41a33e50dffb7792edd7ff7779d5 to your computer and use it in GitHub Desktop.
A little composition wrapper for Pytorch loss functions. It calculates weighted sum of multiple loss functions.
from collections.abc import Collection
from torch import nn
class WeightedSumCompositionLoss(nn.Module):
def __init__(self, loss_funcs: Collection, weights: Collection):
super().__init__()
self.loss_funcs = loss_funcs
self.weights = weights
def forward(self, output, target):
loss = 0.
for idx in range(len(self.loss_funcs)):
loss += self.loss_funcs[idx].forward(output, target) * self.weights[idx]
return loss
@alex-vasilchenko-md
Copy link
Author

A little composition wrapper for Pytorch loss functions. It calculates weighted sum of multiple loss functions.

Example how to use:

loss_fn = WeightedSumCompositionLoss(
            loss_funcs=[
                VGGPerceptualLoss(resize=True),
                torch.nn.BCEWithLogitsLoss(reduction='mean'),
            ],
            weights=[1, 1],
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment