Skip to content

Instantly share code, notes, and snippets.

@YHaruoka
Created May 2, 2022 07:43
Show Gist options
  • Save YHaruoka/b6dd59da9c524c9ba9aaa0d41f7f84a1 to your computer and use it in GitHub Desktop.
Save YHaruoka/b6dd59da9c524c9ba9aaa0d41f7f84a1 to your computer and use it in GitHub Desktop.
import cv2
import sys
import math
import numpy as np
from matplotlib import pyplot as plt
def main():
print("OpenCV Version: " + str(cv2.__version__) + "\n")
# Loading noisy image data
filename = "noisy_image.png"
image = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
# Loading groundtruth image data
filename_gt = "gt_image.png"
image_gt = cv2.imread(filename_gt, cv2.IMREAD_GRAYSCALE)
if image is None:
print("Cannot find image : " + filename)
sys.exit()
if image_gt is None:
print("Cannot find image : " + filename_gt)
sys.exit()
image_tm = np.asarray(image)
image_tm_gt = np.asarray(image_gt)
PSNREval(image_tm_gt, image_tm, 255)
[length_y, length_x] = image_tm.shape
large_length = max(length_x, length_y)
maxPSNR = 0
PSNR_array = []
for i in range(int(large_length/2)):
print(i)
# FFT Lowpass filter
result = FFTLowpassFilter(image_tm, i)
image_tm_inverse = result[0]
# Evaluation with PSNR
PSNR = PSNREval(image_tm_gt, image_tm_inverse, 255)
PSNR_array.append(PSNR)
if(maxPSNR < PSNR):
max_i = i
maxPSNR = PSNR
# FFT Lowpass filter (Best PSNR pattern)
result = FFTLowpassFilter(image_tm, max_i)
# Calculating magnitude spectrum
magnitude_spectrum = 20 * np.log(np.abs(result[1]));
# Output result
fig_image = plt.figure()
plt.subplot(221),plt.imshow(image_tm, cmap = 'gray')
plt.title('Image (input)'), plt.xticks([]), plt.yticks([])
plt.subplot(222),plt.imshow(result[2], cmap = 'gray')
plt.title('Frequency filter'), plt.xticks([]), plt.yticks([])
plt.subplot(223),plt.imshow(result[0], cmap = 'gray')
plt.title('Image (inverse)'), plt.xticks([]), plt.yticks([])
plt.subplot(224),plt.imshow(magnitude_spectrum, cmap = 'gray')
plt.title('Magnitude spectrum'), plt.xticks([]), plt.yticks([])
fig_image.savefig("FFTresult.png")
fig_graph = plt.figure()
x = list(range(int(large_length/2)))
y = PSNR_array
plt.xlabel('CUTOFF FREQUENCY')
plt.ylabel('PSNR [dB]')
plt.plot(x, y)
fig_graph.savefig("PSNRresult.png")
def FFTLowpassFilter(image, CUTOFF_FREQ):
[length_y, length_x] = image.shape
center_y = length_y / 2
center_x = length_x / 2
mask = np.zeros(image.shape)
# FFT (time domain -> frequency domain)
image_fq = np.fft.fft2(image);
image_fq_shifted = np.fft.fftshift(image_fq)
# Frequency mask creation
for x in range(0,length_x):
for y in range(0,length_y):
if abs(x - center_x) < CUTOFF_FREQ and abs(y - center_y) < CUTOFF_FREQ:
mask[x,y]=1
else:
mask[x,y]=1e-10
# Frequency mask
image_fq_shifted = image_fq_shifted * mask
image_fq = np.fft.fftshift(image_fq_shifted)
# IFFT (frequency domain -> time domain)
image_tm_inverse = np.fft.ifft2(image_fq).real
return image_tm_inverse, image_fq_shifted, mask
def PSNREval(image1, image2, R):
error = np.sum((image1.astype(float) - image2.astype(float)) ** 2)
MSE = error / (float(image1.shape[0] * image1.shape[1]))
PSNR = 10 * math.log10(255 * 255 / MSE)
# print("MSE: " + str(MSE))
print("PSNR: " + str(PSNR))
return PSNR
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment