Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created March 21, 2023 06:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e96031413/177d0a3cb425189b734a46c3ad6404ce to your computer and use it in GitHub Desktop.
Save e96031413/177d0a3cb425189b734a46c3ad6404ce to your computer and use it in GitHub Desktop.
PyTorch Implementation of IQA metric Including PSNR, SSIM, LPIPS, NIQE, LOE
"""
Implementation of IQA metrics in PyTorch, including PSNR, SSIM, LPIPS, NIQE, and LOE.
"""
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms.functional as F
from torch.nn.functional import conv2d
from IQA_pytorch import SSIM, LPIPSvgg
import torchvision.models as models
from torchvision.transforms.functional import rgb_to_grayscale
import numpy as np
import cv2
import warnings
warnings.filterwarnings('ignore')
# PSNR
class PSNR(nn.Module):
def __init__(self, max_val=0):
super().__init__()
base10 = torch.log(torch.tensor(10.0))
max_val = torch.tensor(max_val).float()
self.register_buffer('base10', base10)
self.register_buffer('max_val', 20 * torch.log(max_val) / base10)
def __call__(self, a, b):
mse = torch.mean((a.float() - b.float()) ** 2)
if mse == 0:
return 0
return 10 * torch.log10((1.0 / mse)).item()
# LOE
def lightness_order_error(pred, target):
# Convert RGB images to grayscale
pred_gray = torch.mean(pred, dim=1)
target_gray = torch.mean(target, dim=1)
# Compute lightness values
pred_lightness = (torch.max(pred, dim=1)[0] + torch.min(pred, dim=1)[0]) / 2.0
target_lightness = (torch.max(target, dim=1)[0] + torch.min(target, dim=1)[0]) / 2.0
# Compute the lightness order error
error = torch.sum(torch.sign(pred_lightness - target_lightness) != torch.sign(pred_gray - target_gray))
return error.item()
# NIQE
def niqe(img_tensor):
# Convert image to grayscale
gray_tensor = F.rgb_to_grayscale(img_tensor)
# Calculate gradient magnitudes using Sobel operator
dx = torch.abs(conv2d(gray_tensor, torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=img_tensor.device).unsqueeze(0).unsqueeze(0), padding=1))
dy = torch.abs(conv2d(gray_tensor, torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=img_tensor.device).unsqueeze(0).unsqueeze(0), padding=1))
mag = torch.sqrt(dx**2 + dy**2)
# Calculate statistics for gradient magnitudes
mean_mag = torch.mean(mag)
std_mag = torch.std(mag)
# Calculate block size based on image size
batch_size, channels, rows, cols = img_tensor.shape
block_size = int(np.round(np.sqrt(rows*cols)/10))
# Compute block-wise local statistics of gradient magnitudes
kernel = torch.ones((1, 1, block_size, block_size), dtype=torch.float32, device=img_tensor.device) / (block_size**2)
block_means = conv2d(mag, kernel, padding=block_size//2)
block_means_sq = conv2d(mag**2, kernel, padding=block_size//2)
block_stds = torch.sqrt(block_means_sq - block_means**2)
# Calculate statistics for block-wise local statistics
mean_block_std = torch.mean(block_stds)
std_block_std = torch.std(block_stds)
# Compute features and weights
features = torch.tensor([mean_mag, std_mag, mean_block_std, std_block_std], dtype=torch.float32, device=img_tensor.device)
weights = torch.tensor([0.0278, 0.1581, 0.7834, 0.0306], dtype=torch.float32, device=img_tensor.device)
# Compute NIQE score
niqe_score = torch.sum(features * weights)
return niqe_score.item()
class LPIPS(torch.nn.Module):
def __init__(self, net='vgg16', use_gpu=True):
super(LPIPS, self).__init__()
self.net = models.alexnet(pretrained=True).features[:12] if net == 'alex' else models.vgg16(pretrained=True).features[:23]
if use_gpu and torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
self.net.to(self.device)
self.net.eval()
self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(self.device)
self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(self.device)
def forward(self, img1, img2):
img1 = F.normalize(F.resize(img1, (224, 224)), mean=self.mean, std=self.std).to(self.device)
img2 = F.normalize(F.resize(img2, (224, 224)), mean=self.mean, std=self.std).to(self.device)
with torch.no_grad():
feat1 = self.net(img1)
feat2 = self.net(img2)
diff = (feat1 - feat2)**2
sim = torch.sum(diff, dim=(1,2,3))
lpips_score = torch.mean(torch.sqrt(sim))
return lpips_score.item()
if __name__ == '__main__':
img1 = torch.randn(1,3,64,64)
img2 = torch.randn(1,3,64,64)
ssim = SSIM()
lpips = LPIPS(net='vgg16', use_gpu=False)
psnr = PSNR()
print("PSNR: ", psnr(img1, img2))
print("SSIM: ", ssim(img1, img2).item())
print("LOE: ", lightness_order_error(img1, img2))
print("NIQE: ", niqe(img1))
print("LPIPS: ", lpips(img1, img2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment