Last active
February 15, 2022 23:35
-
-
Save weberhen/5784ca521fd644dff8c37ba24460be39 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_rgb_angular_error_torch(gt_render, pred_render): | |
# The error need to be computed with the normalized rgb image. | |
# Normalized RGB is r = R / (R+G+B), g = G / (R+G+B), b = B / (R+G+B) | |
# The angular distance is the distance between pixel 1 and pixel 2. | |
# It's computed with cos^-1(p1·p2 / ||p1||*||p2||) | |
num = torch.sum((gt_render / torch.sum(gt_render, dim=2, keepdim=True)) * (pred_render / torch.sum(pred_render, dim=2, keepdim=True)), dim=2, keepdim=True) | |
den = (torch.sqrt(torch.sum((gt_render / torch.sum(gt_render, dim=2, keepdim=True))**2, dim=2, keepdim=True)) * torch.sqrt(torch.sum((pred_render / torch.sum(pred_render, dim=2, keepdim=True))**2, dim=2, keepdim=True))) | |
angular_distance = torch.arccos(num / den) | |
angular_distance = angular_distance[~torch.isnan(angular_distance)] | |
return angular_distance.mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment