Last active
November 25, 2022 13:19
-
-
Save GrovesD2/8c7c548cc50f0072b8e4ede3b40248c8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy import sparse | |
from typing import Tuple | |
from copy import deepcopy | |
from scipy.ndimage import gaussian_filter | |
IMAGE = 'PATH_TO_JPG_IMAGE' # Include JPG image here! | |
P = 0.2 # Finite diference parameter, lower p = smaller step sizes | |
STEPS = 20 # The amount of finite difference steps to take | |
CONVOLVE = True # Whether to regularise the Perona-Malik PDE | |
SIGMA = 0.1 # The regularisation parameter, higher sigma = higher smoothening | |
ALPHA = 100 # The Perona-Malik edge preserving parameter | |
PLOT = True # Whether to plot the results | |
SAVE = True # Whether to save the images | |
APPLY_NOISE = False # Whether to add salt and pepper noise to the image | |
NOISE_CHANCE = 0.995 # The probability that a pixel wont be affected by noise | |
def crop_image(I: np.array) -> np.array: | |
''' | |
Crop image to a square, this makes our lives easier when performing the | |
finite difference calculations. | |
Parameters | |
---------- | |
I : np.array | |
The image we are going to crop | |
Returns | |
------- | |
np.array | |
The cropped image. | |
''' | |
if I.shape[0] != I.shape[1]: | |
# Determine the difference between the axes to see where to crop | |
diff = abs(I.shape[0] - I.shape[1]) | |
# Crop the larger axes down to the same size as the smaller one, this | |
# is done by cropping the difference/2 on each side | |
if np.argmin(I.shape[:2]) == 0: | |
return I[:, diff//2:-diff//2, :] | |
else: | |
return I[diff//2:-diff//2, :, :] | |
else: | |
return I | |
def get_diff_mat(n: int) -> Tuple[np.array, np.array, np.array, np.array]: | |
''' | |
Obtain the finite-difference differentiation matrices in x and y | |
Parameters | |
---------- | |
n : int | |
The number of discrete points in one spatial direction | |
Returns | |
------- | |
[Dx, Dxx, Dy, Dyy] : np.array | |
The first and second differentiation matrices in x and y | |
''' | |
h = 1/n # The spacing between discrete points | |
Dx = sparse.csr_matrix( | |
np.diag(np.ones(n-1), 1)/(2*h) | |
- np.diag(np.ones(n-1), -1)/(2*h) | |
) | |
Dx = sparse.kron(sparse.eye(n), Dx, format = 'coo') | |
Dxx = sparse.csr_matrix( | |
np.diag(np.ones(n-1), 1)/h**2 | |
-2*np.diag(np.ones(n))/h**2 | |
+ np.diag(np.ones(n-1), -1)/h**2 | |
) | |
Dxx = sparse.kron(sparse.eye(n), Dxx, format = 'coo') | |
Dy = ( | |
sparse.diags(np.ones(n**2 - n)/(2*h), n, format = 'coo') | |
- sparse.diags(np.ones(n**2 - n)/(2*h), -n, format = 'coo') | |
) | |
Dyy = ( | |
sparse.diags(np.ones(n**2 - n)/h**2, n, format = 'coo') | |
+ sparse.diags(-2*np.ones(n**2)/h**2, format = 'coo') | |
+ sparse.diags(np.ones(n**2 - n)/h**2, -n, format = 'coo') | |
) | |
return Dx, Dxx, Dy, Dyy | |
def get_bc_positions(n: int) -> np.array: | |
''' | |
Find the positions where we have the boundary of the image, this is | |
required to fix the image pixels on the border. | |
Parameters | |
---------- | |
n : int | |
The number of discrete points in one spatial direction | |
Returns | |
------- | |
np.array | |
An array where 1 indicates the positions of the image boundary | |
''' | |
bc_pos = np.ones((n, n)) | |
bc_pos[1:-1, 1:-1] = 0 | |
return np.where(np.reshape(bc_pos, (n**2, )) == 1)[0] | |
def heat_smooth(I: np.array, k: float) -> np.array: | |
''' | |
Smooth out the image using the heat equation. | |
Parameters | |
---------- | |
I : np.array | |
The image to smoothen | |
k : float | |
The step size in time | |
Returns | |
------- | |
np.array | |
The smoothened image | |
''' | |
n = I.shape[0] # The amount of discrete points in one spatial direction | |
# The image requires reshaping for the matrix-vector calculations | |
U = np.reshape(I, (n**2, 3)) | |
# Obtain the differentiation matrices required | |
_, Dxx, _, Dyy = get_diff_mat(n) | |
L = Dxx + Dyy | |
# Obtain the ingredients necessary to fix the image pixels on the boundary | |
U0 = deepcopy(U) | |
bc_pos = get_bc_positions(n) | |
# Step forward in time to solve the heat equation and smoothen the image | |
for step in range(STEPS): | |
U = U + k*L@U | |
U[bc_pos] = U0[bc_pos] | |
return reshape_to_image(U, n) | |
def get_g(Dx: np.array, | |
Dy: np.array, | |
U: np.array, | |
n: int) -> np.array: | |
''' | |
Obtain the g function in the Perona-Malik PDE | |
Parameters | |
---------- | |
Dx : np.array | |
The differentiation matrix in x | |
Dy : np.array | |
The differentiation matrix in y | |
U : np.array | |
The image we are smoothening | |
n : int | |
The amount of discrete points in one spatial direction | |
Returns | |
------- | |
np.array | |
The g function in the Perona-Malik PDE. | |
''' | |
if CONVOLVE: | |
C = gaussian_filter( | |
np.reshape(U, (n, n, 3)), | |
sigma = SIGMA, | |
order = 0, | |
) | |
C = np.reshape(C, (n**2, 3)) | |
else: | |
C = U | |
Cx = Dx@C | |
Cy = Dx@C | |
return 1/(1+(Cx**2 + Cy**2)/ALPHA) | |
def perona_malik(I: np.array, k: float) -> np.array: | |
''' | |
Solve the Perona-Malik PDE to smoothen the image | |
Parameters | |
---------- | |
I : np.array | |
The image to smoothen. | |
k : TYPE | |
The step size in time. | |
Returns | |
------- | |
np.array | |
The Perona-Malik smoothened image. | |
''' | |
n = I.shape[0] # The amount of discrete points in one spatial direction | |
# The image requires reshaping for the matrix-vector calculations | |
U = np.reshape(I, (n**2, 3)) | |
# Obtain the differentiation matrices required | |
Dx, Dxx, Dy, Dyy = get_diff_mat(n) | |
L = Dxx + Dyy | |
# Obtain the necessary ingredients to enfore the BCs | |
U0 = deepcopy(U) | |
bc_pos = get_bc_positions(n) | |
# Solve the Perona-Malik PDE | |
for step in range(STEPS): | |
Ux = Dx@U | |
Uy = Dy@U | |
Lu = L@U | |
G = get_g(Dx, Dy, U, n) | |
Gx = Dx@G | |
Gy = Dy@G | |
U = U + k*(Gx*Ux + Gy*Uy + G*Lu) | |
U[bc_pos] = U0[bc_pos] | |
return reshape_to_image(U, n) | |
def reshape_to_image(U: np.array, n: int) -> np.array: | |
return np.reshape(U, (n, n, 3)) | |
def salt_pepper_noise(I: np.array) -> np.array: | |
''' | |
Apply the salt and pepper noise to the image to test how well the PDEs can | |
remove this noise | |
Parameters | |
---------- | |
I : np.array | |
The original image. | |
Returns | |
------- | |
I : np.array | |
The image with the salt and pepper noise applied. | |
''' | |
I[get_noise_mask()] = 0 | |
I[get_noise_mask()] = 1 | |
return I | |
def get_noise_mask() -> np.array: | |
''' | |
Get an array to find which of the pixes are going to change due to noise | |
''' | |
mask = np.random.uniform(0, 1, size=I.shape) | |
mask[mask > NOISE_CHANCE] = 1 | |
mask[mask < NOISE_CHANCE] = 0 | |
return mask.astype(bool) | |
if __name__ == '__main__': | |
# Crop the image to a square so that the finite difference approximations | |
# are easier to apply | |
I = crop_image(plt.imread(IMAGE)/255) | |
if APPLY_NOISE: | |
I = salt_pepper_noise(I) | |
# Determine the step size in time from the finite difference parameters | |
h = 1/I.shape[0] | |
k = P*h**2 | |
# Solve the PDEs to obtain the smoothened images | |
I_heat = heat_smooth(deepcopy(I), k) | |
I_pm = perona_malik(deepcopy(I), k) | |
# Post-process the arrays to be back in the 0-1 range by clipping | |
I_heat = np.clip(I_heat, a_min = 0, a_max = 1) | |
I_pm = np.clip(I_pm, a_min = 0, a_max = 1) | |
# Plot the figures | |
if PLOT: | |
plt.figure(0) | |
plt.imshow(I_heat, vmin = I_heat.min(), vmax = I_heat.max()) | |
plt.title('Heat Equation') | |
plt.axis('off') | |
plt.figure(1) | |
plt.imshow(I_pm, vmin = I_pm.min(), vmax = I_pm.max()) | |
plt.title('Perona Malik PDE') | |
plt.axis('off') | |
plt.figure(2) | |
plt.imshow(I, vmin = I.min(), vmax = I.max()) | |
plt.title('Orignal Image') | |
plt.axis('off') | |
if SAVE: | |
plt.imsave('original.jpg', I, vmin = I.min(), vmax = I.max()) | |
plt.imsave('heat_eqn.jpg', I_heat, vmin = I_heat.min(), vmax = I_heat.max()) | |
plt.imsave('perona_malik.jpg', I_pm, vmin = I_pm.min(), vmax = I_pm.max()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment