Skip to content

Instantly share code, notes, and snippets.

@Ed-Optalysys
Last active June 29, 2021 12:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ed-Optalysys/fb88afc916deff700e112341c547181a to your computer and use it in GitHub Desktop.
Save Ed-Optalysys/fb88afc916deff700e112341c547181a to your computer and use it in GitHub Desktop.
def mge_loss(prediction, target):
x_filter = torch.tensor([[-1, -2, -1],
[0, 0, 0],
[1, 2, 1.0]]).unsqueeze(0).repeat(3, 1, 1).unsqueeze(0).cuda() / 8
y_filter = torch.tensor([[-1, 0, 1],
[-2, 0, 2],
[1, 0, 1.0]]).unsqueeze(0).repeat(3, 1, 1).unsqueeze(0).cuda() / 8
replication_pad = nn.ReplicationPad2d(1)
gx_prediction = F.conv2d(replication_pad(prediction), x_filter)
gy_prediction = F.conv2d(replication_pad(prediction), y_filter)
gx_target = F.conv2d(replication_pad(target), x_filter)
gy_target = F.conv2d(replication_pad(target), y_filter)
mge_x = F.mse_loss(gx_prediction, gx_target) / 2.0
mge_y = F.mse_loss(gy_prediction, gy_target) / 2.0
return (mge_x + mge_y) * mge_loss_scale
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment