Skip to content

Instantly share code, notes, and snippets.

@ranftlr
Last active May 30, 2024 08:41
Show Gist options
  • Save ranftlr/1d6194db2e1dffa0a50c9b0a9549cbd2 to your computer and use it in GitHub Desktop.
Save ranftlr/1d6194db2e1dffa0a50c9b0a9549cbd2 to your computer and use it in GitHub Desktop.
def trimmed_mae_loss(prediction, target, mask, trim=0.2):
M = torch.sum(mask, (1, 2))
res = prediction - target
res = res[mask.bool()].abs()
trimmed, _ = torch.sort(res.view(-1), descending=False)[
: int(len(res) * (1.0 - trim))
]
return trimmed.sum() / (2 * M.sum())
def reduction_batch_based(image_loss, M):
# average of all valid pixels of the batch
# avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
divisor = torch.sum(M)
if divisor == 0:
return 0
else:
return torch.sum(image_loss) / divisor
def normalize_prediction_robust(target, mask):
ssum = torch.sum(mask, (1, 2))
valid = ssum > 0
m = torch.zeros_like(ssum)
s = torch.ones_like(ssum)
m[valid] = torch.median(
(mask[valid] * target[valid]).view(valid.sum(), -1), dim=1
).values
target = target - m.view(-1, 1, 1)
sq = torch.sum(mask * target.abs(), (1, 2))
s[valid] = torch.clamp((sq[valid] / ssum[valid]), min=1e-6)
return target / (s.view(-1, 1, 1))
class TrimmedProcrustesLoss(nn.Module):
def __init__(self, alpha=0.5, scales=4, reduction="batch-based"):
super(TrimmedProcrustesLoss, self).__init__()
self.__data_loss = TrimmedMAELoss(reduction=reduction)
self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
self.__alpha = alpha
self.__prediction_ssi = None
def forward(self, prediction, target, mask):
self.__prediction_ssi = normalize_prediction_robust(prediction, mask)
target_ = normalize_prediction_robust(target, mask)
total = self.__data_loss(self.__prediction_ssi, target_, mask)
if self.__alpha > 0:
total += self.__alpha * self.__regularization_loss(
self.__prediction_ssi, target_, mask
)
return total
def __get_prediction_ssi(self):
return self.__prediction_ssi
prediction_ssi = property(__get_prediction_ssi)
class GradientLoss(nn.Module):
def __init__(self, scales=4, reduction='batch-based'):
super().__init__()
if reduction == 'batch-based':
self.__reduction = reduction_batch_based
else:
self.__reduction = reduction_image_based
self.__scales = scales
def forward(self, prediction, target, mask):
total = 0
for scale in range(self.__scales):
step = pow(2, scale)
total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
mask[:, ::step, ::step], reduction=self.__reduction)
return total
@baldgirllulu
Copy link

I would like to know what mask means and how to calculate it. this will help me a lot and I look forward to your answer. Thank you!

@yassineAlouini
Copy link

I would like to know what mask means and how to calculate it. this will help me a lot and I look forward to your answer. Thank you!

As far as my understanding goes, a mask is anything used to specify valid depth (or disparity) values. For example, the mask could be the location where depth values aren't 0, i.e. mask = (depth != 0).

I hope this helps!

@ayush29feb
Copy link

I don't see TrimmedMAELoss declared anywhere used on L45.

Is it just trimmed_mae_loss?

@wch1996
Copy link

wch1996 commented Feb 26, 2022

Hello,I want to konw the meaning of " M = torch.sum(mask, (1, 2))" , is that "mask" is a tensor with shape [n,c,h,w], the c is equal 1, this code is to change the tensor shape to [n,h,w]?

@jiaqixuac
Copy link

Hi @wch1996 , I think the mask and target should be of shape [n, h, w] instead of [n, 1, h, w]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment