import cv2
import sys

def main():
        
    print("OpenCV Version: " + str(cv2.__version__) + "\n")     

    # Loading image data (COLOR)
    filename1 = "data/gt_image.png"
    filename2 = "data/noisy_image.png"
    
    gt_image = cv2.imread(filename1, cv2.IMREAD_COLOR)
    noisy_image = cv2.imread(filename2, cv2.IMREAD_COLOR)
    
    if gt_image is None:
        print("Cannot find image1 : " + filename1)
        sys.exit()
    if noisy_image is None:
        print("Cannot find image2 : " + filename2)
        sys.exit()
        
    # Average filter (image, window_size)
    denoised_image1 = cv2.blur(noisy_image, ksize=(9,9))

    # Gaussian filter (image, window_size)
    denoised_image2 = cv2.GaussianBlur(noisy_image,(9,9),cv2.BORDER_DEFAULT)
   
    # Median Filter (image, w)
    denoised_image3= cv2.medianBlur(noisy_image, 9)

    # Bilateral Filter (image, d, sigmaColor, sigmaSpace)
    denoised_image4 = bilateralFilter(noisy_image, 11, 100, 10)
    
    # Evaluation with PSNR and SSIM
    PSNREval(gt_image, noisy_image, 255)
    SSIMEval(gt_image, noisy_image)
    PSNREval(gt_image, denoised_image1, 255)
    SSIMEval(gt_image, denoised_image1)
    PSNREval(gt_image, denoised_image2, 255)
    SSIMEval(gt_image, denoised_image2)
    PSNREval(gt_image, denoised_image3, 255)
    SSIMEval(gt_image, denoised_image3)
    PSNREval(gt_image, denoised_image4, 255)
    SSIMEval(gt_image, denoised_image4)
    
    cv2.imwrite('noisy_image.png',noisy_image)
    cv2.imwrite('denoised_image1(Average).png',denoised_image1)
    cv2.imwrite('denoised_image2(Gaussian).png',denoised_image2)
    cv2.imwrite('denoised_image3(Median).png',denoised_image3)
    cv2.imwrite('denoised_image4(Bilateral).png',denoised_image4)
    
def bilateralFilter(noisy_image, d, sigma_color, sigma_space):
    denoised_image = cv2.bilateralFilter(noisy_image, d, sigma_color, sigma_space)
    return denoised_image

def PSNREval(image1, image2, R):
    PSNR_opencv, _ = cv2.quality.QualityPSNR_compute(image1, image2)
    print("PSNR Evaluation Results")
    print("   PSNR (Blue): " + str(round(PSNR_opencv[0],3)))
    print("   PSNR (Green): " + str(round(PSNR_opencv[1],3)))
    print("   PSNR (Red): " + str(round(PSNR_opencv[2],3)))
    print("   PSNR (RGB Average): " + str(round(((PSNR_opencv[0] + PSNR_opencv[1] + PSNR_opencv[2]) / 3),3)) + "\n")

def SSIMEval(image1, image2):
    SSIM_opencv, _ = cv2.quality.QualitySSIM_compute(image1, image2)
    print("SSIM Evaluation Results")
    print("   SSIM (Blue): " + str(round(SSIM_opencv[0],3)))
    print("   SSIM (Green): " + str(round(SSIM_opencv[1],3)))
    print("   SSIM (Red): " + str(round(SSIM_opencv[2],3)))
    print("   SSIM (RGB Average): " + str(round(((SSIM_opencv[0] + SSIM_opencv[1] + SSIM_opencv[2]) / 3),3)) + "\n")

if __name__ == "__main__":
    main()