Skip to content

Instantly share code, notes, and snippets.

@BachiLi
Last active April 18, 2022 03:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BachiLi/89c702f246d9ed5c2299f8a13753c61c to your computer and use it in GitHub Desktop.
Save BachiLi/89c702f246d9ed5c2299f8a13753c61c to your computer and use it in GitHub Desktop.
import jax.numpy as np
import skimage
import skimage.io
import matplotlib.pyplot as plt
import math
lambda_d = 1
img = skimage.img_as_float(skimage.io.imread('cameraman.png'))
grad_x = np.roll(img, 1, axis=[1]) - img
grad_y = np.roll(img, 1, axis=[0]) - img
img_freq = np.fft.fft2(img)
grad_x_freq = np.fft.fft2(grad_x)
grad_y_freq = np.fft.fft2(grad_y)
sx = np.fft.fftfreq(img.shape[1])
sx = np.repeat(sx, img.shape[0])
sx = np.reshape(sx, [img.shape[1], img.shape[0]])
sx = np.transpose(sx)
sy = np.fft.fftfreq(img.shape[0])
sy = np.repeat(sy, img.shape[1])
sy = np.reshape(sy, img.shape)
# Fourier transform of shift operators
Dx_freq = 2 * math.pi * (np.exp(-1j * sx) - 1)
Dy_freq = 2 * math.pi * (np.exp(-1j * sy) - 1)
my_grad_x_freq = Dx_freq * img_freq
my_grad_y_freq = Dy_freq * img_freq
my_grad_x = np.real(np.fft.ifft2(my_grad_x_freq))
my_grad_y = np.real(np.fft.ifft2(my_grad_y_freq))
# my_grad_x_freq & my_grad_y_freq should be the same as grad_x_freq & grad_y_freq
recon_freq = (lambda_d * img_freq + np.conjugate(Dx_freq) * grad_x_freq + np.conjugate(Dy_freq) * grad_y_freq) / \
(lambda_d + (np.conjugate(Dx_freq) * Dx_freq + np.conjugate(Dy_freq) * Dy_freq))
recon = np.real(np.fft.ifft2(recon_freq))
plt.figure()
plt.imshow(np.log(np.absolute(np.fft.fftshift(img_freq)) + 1), vmin = 0, vmax = 15)
plt.figure()
plt.imshow(np.log(np.absolute(np.fft.fftshift(grad_x_freq)) + 1), vmin = 0, vmax = 15)
plt.figure()
plt.imshow(np.log(np.absolute(np.fft.fftshift(grad_y_freq)) + 1), vmin = 0, vmax = 15)
plt.figure()
plt.imshow(np.log(np.absolute(np.fft.fftshift(my_grad_x_freq)) + 1), vmin = 0, vmax = 15)
plt.figure()
plt.imshow(np.log(np.absolute(np.fft.fftshift(my_grad_y_freq)) + 1), vmin = 0, vmax = 15)
plt.figure()
plt.imshow(np.log(np.absolute(np.fft.fftshift(recon_freq)) + 1), vmin = 0, vmax = 15)
plt.figure()
plt.imshow(recon)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment