Skip to content

Instantly share code, notes, and snippets.

@dvdhfnr
dvdhfnr / midas_loss.py
Last active April 14, 2024 15:15
Loss function of MiDaS
import torch
import torch.nn as nn
def compute_scale_and_shift(prediction, target, mask):
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = torch.sum(mask * prediction * prediction, (1, 2))
a_01 = torch.sum(mask * prediction, (1, 2))
a_11 = torch.sum(mask, (1, 2))