Skip to content

Instantly share code, notes, and snippets.

@MercuriXito
Last active March 20, 2020 07:52
Show Gist options
  • Save MercuriXito/c1be227dde6c56df1e9f40a11ee09bad to your computer and use it in GitHub Desktop.
Save MercuriXito/c1be227dde6c56df1e9f40a11ee09bad to your computer and use it in GitHub Desktop.
Pytorch based implemented Evaluation Methods of Image Quality: MSE, PSNR, SSIM
import torch
import torch.nn.functional as F
from torch.nn.functional import conv2d
import numpy as np
mse = lambda x,y: np.mean((x-y).flatten() ** 2)
def psnr(image1, image2, L = 255, eps = 1e-6):
m = mse(image1, image2)
return np.log10(L**2 / (m + eps)) * 10
def filter(image, w):
_, C, H, W = image.size()
return conv2d(image, w, stride=1, padding=0)
def ssim(image1, image2, window_size = 11, K = (0.01, 0.03),
range = 255, nonegtive = False):
H,W,C = image1.shape
dtype = torch.float
image1 = torch.tensor(image1).type(dtype).view(1, C, H, W)
image2 = torch.tensor(image2).type(dtype).view_as(image1)
w1 = torch.zeros(
(1, C, window_size, window_size),dtype=dtype).fill_(1)
w1 = w1 / ( C * window_size * window_size)
C1, C2 = torch.tensor((np.array(K) * range) **2)
mu1 = filter(image1, w1)
mu2 = filter(image2, w1)
sigma1 = filter(image1 * image1, w1) - (mu1) ** 2
sigma2 = filter(image2 * image2, w1) - (mu2) ** 2
cov = filter(image1 * image2, w1) - ( mu1 * mu2)
numerator1 = ( 2 * mu1 * mu2 + C1)
numerator2 = ( 2 * cov + C2)
denominator1 = (mu1 **2 + mu2 ** 2 + C1)
denominator2 = ( sigma1 + sigma2 + C2)
cmap = numerator2 / denominator2
if nonegtive:
cmap = F.relu(cmap, inplace=True)
ssim = (numerator1 / denominator1) * cmap
ssim = ssim.flatten()
return ssim.mean().item(), ssim.std().item()
def ms_ssim(image1, image2, N = 5, windows_size = 12, nonegative = True,
params = (0.0448,0.2856,0.3001,0.2363,0.1333)):
def ssim(image1, image2, window_size, K = (0.0, 0.03), range = 255,
nonegative = False):
_, C, H, W = image1.size()
w1 = torch.zeros(
(1, C, window_size, window_size),dtype=dtype).fill_(1)
w1 = w1 / ( C * window_size * window_size)
C1, C2 = torch.tensor((np.array(K) * range) **2)
C3 = C2 / 2
mu1 = filter(image1, w1)
mu2 = filter(image2, w1)
sigma1 = filter(image1 * image1, w1) - (mu1) ** 2
sigma2 = filter(image2 * image2, w1) - (mu2) ** 2
cov = filter(image1 * image2, w1) - ( mu1 * mu2)
l = ( 2 * mu1 * mu2 + C1) / ( mu1**2 + mu2 ** 2 + C1)
c = (2 * torch.sqrt(sigma1 * sigma2) + C2)/(sigma1 + sigma2 +C2)
s = (cov + C3) / (torch.sqrt(sigma1 * sigma2) + C3)
if nonegative:
s = F.relu(s, inplace=True)
l = l.flatten().mean().item()
c = c.flatten().mean().item()
s = s.flatten().mean().item()
return l,c,s
H,W,C = image1.shape
dtype = torch.float
image1 = torch.tensor(image1).type(dtype).view(1, C, H, W)
image2 = torch.tensor(image2).type(dtype).view_as(image1)
data = []
for i in range(N):
if i != 0:
padding = ( H % 2, W % 2)
H, W = H // 2, W // 2
image1 = F.avg_pool2d(image1, 2, padding=padding)
image2 = F.avg_pool2d(image2, 2, padding=padding)
wsize = windows_size if windows_size < H else H
l, c, s = ssim(image1, image2, wsize, nonegative= nonegative)
data.append([l, c, s])
params = np.repeat(np.array(params), 3).reshape(N, 3)
data = np.array(data) * params
ls, cs, ss = data.transpose(1,0)
return ls[-1] * np.product(cs) * np.product(ss)
if __name__ == "__main__":
x = np.random.randint(0, 255, size = (128,128,3))
y = np.random.randint(0, 255, size = (128,128,3))
# x = np.random.randn(128, 128, 3)
# y = np.copy(x)
print(mse(x,y))
print(psnr(x,y))
print(ssim(x,y,range=255))
print(ms_ssim(x,y))
import matplotlib.pyplot as plt
plt.subplot(1,2,1)
plt.imshow(x)
plt.subplot(1,2,2)
plt.imshow(y)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment