Skip to content

Instantly share code, notes, and snippets.

@simon-donike
Last active July 10, 2023 11:03
Show Gist options
  • Save simon-donike/3bd7ae8e55ab7c0085e674dbf6a56853 to your computer and use it in GitHub Desktop.
Save simon-donike/3bd7ae8e55ab7c0085e674dbf6a56853 to your computer and use it in GitHub Desktop.
Mean Gradient Error - PyTorch
class MeanGradientError:
"""
A class used to calculate the Mean Gradient Error (MGE) between two images.
Methods
-------
sobel_x_y_gradients(img):
Calculates the gradients in x and y directions using Sobel filters for the input image tensor.
calculate_pixel_gradients(grad_x, grad_y):
Combines the x and y gradients to calculate the pixel gradient magnitude.
calculate(y_true, y_pred):
Calculates the Mean Gradient Error (MGE) between two images (y_true and y_pred).
"""
def __init__(self):
import torch
def sobel_x_y_gradients(img):
"""
Calculates the gradients in x and y directions using Sobel filters for the input image tensor.
Parameters
----------
img : torch.Tensor
A 4D tensor representing a batch of images. Shape should be (batch_size, num_channels, height, width).
Returns
-------
gradients_x : torch.Tensor
A tensor of the same shape as `img` representing the x gradients of the images.
gradients_y : torch.Tensor
A tensor of the same shape as `img` representing the y gradients of the images.
"""
# Ensure img is a PyTorch tensor
if not isinstance(img, torch.Tensor):
raise TypeError('img must be a PyTorch tensor')
# Ensure img has the correct dimensions
if len(img.shape) != 4:
raise ValueError('img must be a 4D tensor')
# Define the Sobel filters
filter_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=img.dtype)
filter_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=img.dtype)
filter_x = filter_x.view((1, 1, 3, 3))
filter_y = filter_y.view((1, 1, 3, 3))
# Initialize output tensors
gradients_x = torch.zeros_like(img)
gradients_y = torch.zeros_like(img)
# Apply the Sobel filters to each image in the batch and each channel
for b in range(img.shape[0]):
for c in range(img.shape[1]):
gradients_x[b, c] = F.conv2d(img[b, c].unsqueeze(0).unsqueeze(0), filter_x, padding=1)
gradients_y[b, c] = F.conv2d(img[b, c].unsqueeze(0).unsqueeze(0), filter_y, padding=1)
return gradients_x, gradients_y
def calculate_pixel_gradients(grad_x,grad_y):
"""
Combines the x and y gradients to calculate the pixel gradient magnitude.
Parameters
----------
grad_x : torch.Tensor
A tensor representing the x gradients of an image.
grad_y : torch.Tensor
A tensor representing the y gradients of an image.
Returns
-------
result : torch.Tensor
A tensor representing the magnitude of the pixel gradients.
"""
# Square each element of the input tensors
grad_x_squared = torch.pow(grad_x, 2)
grad_y_squared = torch.pow(grad_y, 2)
# Add the squared tensors element-wise
sum_squared = grad_x_squared + grad_y_squared
# Calculate the square root of the sum_squared tensor
result = torch.sqrt(sum_squared)
return result
def calculate(self,y_true, y_pred):
"""
Calculates the Mean Gradient Error (MGE) between two images.
Parameters
----------
y_true : torch.Tensor
A tensor representing the ground truth image.
y_pred : torch.Tensor
A tensor representing the predicted image.
Returns
-------
mge : torch.Tensor
A tensor representing the Mean Gradient Error (MGE) between the ground truth and predicted images.
"""
# 1. get x and y gradients of image 1
input_grads = sobel_x_y_gradients(inputs)
output_grads = sobel_x_y_gradients(outputs)
# 2. get pixel gradients
input_grads = calculate_pixel_gradients(input_grads[0],input_grads[1])
output_grads = calculate_pixel_gradients(output_grads[0],output_grads[1])
# calculate MSE of image gradients
shape = output_grads.shape[2:4]
mge = torch.mean((output_grads - input_grads)**2) / (shape[0] * shape[1]) # changed from sum to mean
# return MGE
return mge
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment