Created
March 21, 2023 06:15
-
-
Save e96031413/177d0a3cb425189b734a46c3ad6404ce to your computer and use it in GitHub Desktop.
PyTorch Implementation of IQA metric Including PSNR, SSIM, LPIPS, NIQE, LOE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implementation of IQA metrics in PyTorch, including PSNR, SSIM, LPIPS, NIQE, and LOE. | |
""" | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torchvision.transforms.functional as F | |
from torch.nn.functional import conv2d | |
from IQA_pytorch import SSIM, LPIPSvgg | |
import torchvision.models as models | |
from torchvision.transforms.functional import rgb_to_grayscale | |
import numpy as np | |
import cv2 | |
import warnings | |
warnings.filterwarnings('ignore') | |
# PSNR | |
class PSNR(nn.Module): | |
def __init__(self, max_val=0): | |
super().__init__() | |
base10 = torch.log(torch.tensor(10.0)) | |
max_val = torch.tensor(max_val).float() | |
self.register_buffer('base10', base10) | |
self.register_buffer('max_val', 20 * torch.log(max_val) / base10) | |
def __call__(self, a, b): | |
mse = torch.mean((a.float() - b.float()) ** 2) | |
if mse == 0: | |
return 0 | |
return 10 * torch.log10((1.0 / mse)).item() | |
# LOE | |
def lightness_order_error(pred, target): | |
# Convert RGB images to grayscale | |
pred_gray = torch.mean(pred, dim=1) | |
target_gray = torch.mean(target, dim=1) | |
# Compute lightness values | |
pred_lightness = (torch.max(pred, dim=1)[0] + torch.min(pred, dim=1)[0]) / 2.0 | |
target_lightness = (torch.max(target, dim=1)[0] + torch.min(target, dim=1)[0]) / 2.0 | |
# Compute the lightness order error | |
error = torch.sum(torch.sign(pred_lightness - target_lightness) != torch.sign(pred_gray - target_gray)) | |
return error.item() | |
# NIQE | |
def niqe(img_tensor): | |
# Convert image to grayscale | |
gray_tensor = F.rgb_to_grayscale(img_tensor) | |
# Calculate gradient magnitudes using Sobel operator | |
dx = torch.abs(conv2d(gray_tensor, torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=img_tensor.device).unsqueeze(0).unsqueeze(0), padding=1)) | |
dy = torch.abs(conv2d(gray_tensor, torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=img_tensor.device).unsqueeze(0).unsqueeze(0), padding=1)) | |
mag = torch.sqrt(dx**2 + dy**2) | |
# Calculate statistics for gradient magnitudes | |
mean_mag = torch.mean(mag) | |
std_mag = torch.std(mag) | |
# Calculate block size based on image size | |
batch_size, channels, rows, cols = img_tensor.shape | |
block_size = int(np.round(np.sqrt(rows*cols)/10)) | |
# Compute block-wise local statistics of gradient magnitudes | |
kernel = torch.ones((1, 1, block_size, block_size), dtype=torch.float32, device=img_tensor.device) / (block_size**2) | |
block_means = conv2d(mag, kernel, padding=block_size//2) | |
block_means_sq = conv2d(mag**2, kernel, padding=block_size//2) | |
block_stds = torch.sqrt(block_means_sq - block_means**2) | |
# Calculate statistics for block-wise local statistics | |
mean_block_std = torch.mean(block_stds) | |
std_block_std = torch.std(block_stds) | |
# Compute features and weights | |
features = torch.tensor([mean_mag, std_mag, mean_block_std, std_block_std], dtype=torch.float32, device=img_tensor.device) | |
weights = torch.tensor([0.0278, 0.1581, 0.7834, 0.0306], dtype=torch.float32, device=img_tensor.device) | |
# Compute NIQE score | |
niqe_score = torch.sum(features * weights) | |
return niqe_score.item() | |
class LPIPS(torch.nn.Module): | |
def __init__(self, net='vgg16', use_gpu=True): | |
super(LPIPS, self).__init__() | |
self.net = models.alexnet(pretrained=True).features[:12] if net == 'alex' else models.vgg16(pretrained=True).features[:23] | |
if use_gpu and torch.cuda.is_available(): | |
self.device = torch.device('cuda') | |
else: | |
self.device = torch.device('cpu') | |
self.net.to(self.device) | |
self.net.eval() | |
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device) | |
self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device) | |
def forward(self, img1, img2): | |
img1 = F.normalize(F.resize(img1, (224, 224)), mean=self.mean, std=self.std).to(self.device) | |
img2 = F.normalize(F.resize(img2, (224, 224)), mean=self.mean, std=self.std).to(self.device) | |
with torch.no_grad(): | |
feat1 = self.net(img1) | |
feat2 = self.net(img2) | |
diff = (feat1 - feat2)**2 | |
sim = torch.sum(diff, dim=(1,2,3)) | |
lpips_score = torch.mean(torch.sqrt(sim)) | |
return lpips_score.item() | |
if __name__ == '__main__': | |
img1 = torch.randn(1,3,64,64) | |
img2 = torch.randn(1,3,64,64) | |
ssim = SSIM() | |
lpips = LPIPS(net='vgg16', use_gpu=False) | |
psnr = PSNR() | |
print("PSNR: ", psnr(img1, img2)) | |
print("SSIM: ", ssim(img1, img2).item()) | |
print("LOE: ", lightness_order_error(img1, img2)) | |
print("NIQE: ", niqe(img1)) | |
print("LPIPS: ", lpips(img1, img2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment