Skip to content

Instantly share code, notes, and snippets.

@weberhen
Last active February 15, 2022 23:35
Show Gist options
  • Save weberhen/5784ca521fd644dff8c37ba24460be39 to your computer and use it in GitHub Desktop.
Save weberhen/5784ca521fd644dff8c37ba24460be39 to your computer and use it in GitHub Desktop.
def get_rgb_angular_error_torch(gt_render, pred_render):
# The error need to be computed with the normalized rgb image.
# Normalized RGB is r = R / (R+G+B), g = G / (R+G+B), b = B / (R+G+B)
# The angular distance is the distance between pixel 1 and pixel 2.
# It's computed with cos^-1(p1·p2 / ||p1||*||p2||)
num = torch.sum((gt_render / torch.sum(gt_render, dim=2, keepdim=True)) * (pred_render / torch.sum(pred_render, dim=2, keepdim=True)), dim=2, keepdim=True)
den = (torch.sqrt(torch.sum((gt_render / torch.sum(gt_render, dim=2, keepdim=True))**2, dim=2, keepdim=True)) * torch.sqrt(torch.sum((pred_render / torch.sum(pred_render, dim=2, keepdim=True))**2, dim=2, keepdim=True)))
angular_distance = torch.arccos(num / den)
angular_distance = angular_distance[~torch.isnan(angular_distance)]
return angular_distance.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment