Last active
July 10, 2023 11:03
-
-
Save simon-donike/3bd7ae8e55ab7c0085e674dbf6a56853 to your computer and use it in GitHub Desktop.
Mean Gradient Error - PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MeanGradientError: | |
""" | |
A class used to calculate the Mean Gradient Error (MGE) between two images. | |
Methods | |
------- | |
sobel_x_y_gradients(img): | |
Calculates the gradients in x and y directions using Sobel filters for the input image tensor. | |
calculate_pixel_gradients(grad_x, grad_y): | |
Combines the x and y gradients to calculate the pixel gradient magnitude. | |
calculate(y_true, y_pred): | |
Calculates the Mean Gradient Error (MGE) between two images (y_true and y_pred). | |
""" | |
def __init__(self): | |
import torch | |
def sobel_x_y_gradients(img): | |
""" | |
Calculates the gradients in x and y directions using Sobel filters for the input image tensor. | |
Parameters | |
---------- | |
img : torch.Tensor | |
A 4D tensor representing a batch of images. Shape should be (batch_size, num_channels, height, width). | |
Returns | |
------- | |
gradients_x : torch.Tensor | |
A tensor of the same shape as `img` representing the x gradients of the images. | |
gradients_y : torch.Tensor | |
A tensor of the same shape as `img` representing the y gradients of the images. | |
""" | |
# Ensure img is a PyTorch tensor | |
if not isinstance(img, torch.Tensor): | |
raise TypeError('img must be a PyTorch tensor') | |
# Ensure img has the correct dimensions | |
if len(img.shape) != 4: | |
raise ValueError('img must be a 4D tensor') | |
# Define the Sobel filters | |
filter_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=img.dtype) | |
filter_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=img.dtype) | |
filter_x = filter_x.view((1, 1, 3, 3)) | |
filter_y = filter_y.view((1, 1, 3, 3)) | |
# Initialize output tensors | |
gradients_x = torch.zeros_like(img) | |
gradients_y = torch.zeros_like(img) | |
# Apply the Sobel filters to each image in the batch and each channel | |
for b in range(img.shape[0]): | |
for c in range(img.shape[1]): | |
gradients_x[b, c] = F.conv2d(img[b, c].unsqueeze(0).unsqueeze(0), filter_x, padding=1) | |
gradients_y[b, c] = F.conv2d(img[b, c].unsqueeze(0).unsqueeze(0), filter_y, padding=1) | |
return gradients_x, gradients_y | |
def calculate_pixel_gradients(grad_x,grad_y): | |
""" | |
Combines the x and y gradients to calculate the pixel gradient magnitude. | |
Parameters | |
---------- | |
grad_x : torch.Tensor | |
A tensor representing the x gradients of an image. | |
grad_y : torch.Tensor | |
A tensor representing the y gradients of an image. | |
Returns | |
------- | |
result : torch.Tensor | |
A tensor representing the magnitude of the pixel gradients. | |
""" | |
# Square each element of the input tensors | |
grad_x_squared = torch.pow(grad_x, 2) | |
grad_y_squared = torch.pow(grad_y, 2) | |
# Add the squared tensors element-wise | |
sum_squared = grad_x_squared + grad_y_squared | |
# Calculate the square root of the sum_squared tensor | |
result = torch.sqrt(sum_squared) | |
return result | |
def calculate(self,y_true, y_pred): | |
""" | |
Calculates the Mean Gradient Error (MGE) between two images. | |
Parameters | |
---------- | |
y_true : torch.Tensor | |
A tensor representing the ground truth image. | |
y_pred : torch.Tensor | |
A tensor representing the predicted image. | |
Returns | |
------- | |
mge : torch.Tensor | |
A tensor representing the Mean Gradient Error (MGE) between the ground truth and predicted images. | |
""" | |
# 1. get x and y gradients of image 1 | |
input_grads = sobel_x_y_gradients(inputs) | |
output_grads = sobel_x_y_gradients(outputs) | |
# 2. get pixel gradients | |
input_grads = calculate_pixel_gradients(input_grads[0],input_grads[1]) | |
output_grads = calculate_pixel_gradients(output_grads[0],output_grads[1]) | |
# calculate MSE of image gradients | |
shape = output_grads.shape[2:4] | |
mge = torch.mean((output_grads - input_grads)**2) / (shape[0] * shape[1]) # changed from sum to mean | |
# return MGE | |
return mge |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment