Last active
January 12, 2022 20:02
-
-
Save andres-fr/6046b08462d175f2690d159ea4f7203c to your computer and use it in GitHub Desktop.
Wiener deconvolution study
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
""" | |
Wiener deconvolution study. | |
Copyright (C) 2021 aferro (ORCID: 0000-0003-3830-3595) | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <https://www.gnu.org/licenses/>. | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
# ############################################################################## | |
# # HELPER FUNCTIONS | |
# ############################################################################## | |
def rbf_kernel2d(size, covmat): | |
""" | |
:param size: ``(height, width)`` | |
:param covmat: A 2x2 matrix. Note that the first dimension is vertical | |
downwards, and the second dimension horizontal from left to right. | |
""" | |
h, w = size | |
yyxx = np.mgrid[0:h:1, 0:w:1] # shape (2, h, w) | |
yyxx[0] -= h // 2 | |
yyxx[1] -= w // 2 | |
# | |
yyxx = yyxx.reshape(2, -1) | |
maha = (yyxx * (np.linalg.inv(covmat) @ yyxx)).sum(axis=0).reshape(h, w) | |
rbf = np.exp(-0.5 * maha) | |
return rbf | |
def roll_half_hw(arr): | |
""" | |
:param arr: Array of shape ``(h, w, channels)``. | |
Rolls array so that middle pixel gets shifted to the top left corner. | |
""" | |
h, w = arr.shape[0], arr.shape[1] | |
delta_h, delta_w = (h // 2), (w // 2) | |
arr = np.roll(arr, -delta_h, axis=0) | |
arr = np.roll(arr, -delta_w, axis=1) | |
return arr | |
def fft_conv2d(signal, kernel): | |
""" | |
Kernel is expected to have same ``(h, w, channels)`` shape and dtype | |
as signal. It should also be centered and have at least | |
25% zero-padding on each margin to prevent circular overlapping artifacts | |
(note that circular conv will happen anyway). | |
""" | |
# the "backward" norm is important to preserve energy | |
s_fft = np.fft.fft2(signal, axes=(0, 1), norm="backward") | |
k_fft = np.fft.fft2(kernel, axes=(0, 1), norm="backward") | |
conv = np.fft.ifft2(s_fft * k_fft, axes=(0, 1), norm="backward").real | |
return conv | |
def energy(sig): | |
""" | |
""" | |
return (sig**2).sum() | |
def white_noise(shape, variance): | |
""" | |
Returns white noise in range [-x, x], x is adjusted to return | |
approximate desired variance. | |
""" | |
noise = np.random.random(shape) * 2 - 1 | |
noise -= noise.mean() | |
# white noise in [-1, 1] has variance=1/3 | |
factor = (3 * variance) ** 0.5 | |
noise *= factor # therefore this has approx desired var | |
return noise | |
class WienerEpsilon: | |
""" | |
This class computes the epsilon such that wiener_perp+epsilon equals | |
wiener_optimal, but the computations are redistributed so that most of them | |
only need to be done once at construction. Check website for algebraic | |
derivations. | |
""" | |
def __init__(self, R, Smag, Nmag, Omag): | |
""" | |
All inputs given in Fourier space. The ``mag`` parameters correspond to | |
square magnitudes, i.e. complex spectrum times its conjugate. | |
""" | |
ratio = - 2.0 / (abs(R)**2 + Nmag / Smag) # -2/Delta | |
# | |
rere2 = R.real ** 2 | |
imim2 = R.imag ** 2 | |
reim2 = R.real * R.imag | |
# | |
self.re_coeff = np.zeros_like(R) | |
self.re_coeff.real = rere2 | |
self.re_coeff.imag = -reim2 | |
self.re_coeff *= ratio | |
self.re_coeff += 1 | |
self.re_coeff /= Omag | |
# | |
self.im_coeff = np.zeros_like(R) | |
self.im_coeff.real = imim2 | |
self.im_coeff.imag = reim2 | |
self.im_coeff *= ratio | |
self.im_coeff += 1 | |
self.im_coeff /= Omag | |
def __call__(self, SNconj): | |
""" | |
:param SNconj: spectral cross-correlation ``S * N.conj()`` between | |
signal and noise, both of same shape as the spectra given at | |
construction. | |
:returns: A spectral filter epsilon, so that wiener_perp+epsilon | |
equals wiener_optimal. | |
""" | |
result = SNconj.real * self.re_coeff | |
buffr = SNconj.imag * self.im_coeff | |
# u+iv = a+ib + ic - d = (a-d) + i(b+c) | |
result.real -= buffr.imag | |
result.imag += buffr.real | |
# | |
return result | |
def plot_img(hwc): | |
""" | |
""" | |
p = hwc - hwc.min() | |
p /= p.max() | |
plt.clf() | |
plt.imshow(p) | |
plt.show() | |
def plot_residual(residual, cmap="pink"): | |
""" | |
Interesting colormaps: pink, gist_heat_r, hot, copper_r | |
https://matplotlib.org/stable/tutorials/colors/colormaps.html | |
""" | |
norm = np.linalg.norm(residual, axis=-1) | |
plt.clf() | |
plt.imshow(norm, cmap=cmap) | |
plt.show() | |
def save_arr(arr, outpath, minval=0, maxval=255): | |
""" | |
""" | |
# normalize between 0 and 1 | |
arr = arr - arr.min() | |
arr /= arr.max() | |
# normalize between minval and maxval | |
valrange = maxval - minval | |
arr *= valrange | |
arr += minval | |
# | |
im = Image.fromarray(arr.astype(np.uint8)) | |
im.save(outpath) | |
# ############################################################################## | |
# # MAIN ROUTINE | |
# ############################################################################## | |
# globals | |
SNR = 10 | |
DTYPE = np.float64 | |
IMG_PATH = "christkindlmarkt_crop.jpg" | |
KERNEL_PATH = "conv_kernel.png" | |
# load image from disk and normalize | |
img = Image.open(IMG_PATH) | |
img_arr = np.array(img).astype(DTYPE) | |
img_min, img_max = img_arr.min(), img_arr.max() | |
img_arr -= img_arr.mean() # plot_img(img_arr) | |
img_arr /= img_arr.std() | |
h, w, c = img_arr.shape | |
# create a conv kernel | |
k = Image.open(KERNEL_PATH) | |
k_arr = np.array(k)[:, :, 0].astype(DTYPE) | |
k_h, k_w = k_arr.shape | |
# | |
gaussian = rbf_kernel2d(k_arr.shape, np.array([[5.0, 0], [0, 5.0]])) | |
# k_gaussian = k_arr | |
k_gaussian = roll_half_hw(fft_conv2d(k_arr, gaussian)) | |
k_gaussian -= k_gaussian.min() | |
k_gaussian /= (k_gaussian**2).sum()**0.5 # normalize kernel to energy=1 | |
kernel_arr = np.zeros_like(img_arr[:, :, 0]) # plot_img(kernel_arr) | |
kernel_arr[(h//2)-(k_h//2): (h//2)+k_h-(k_h//2), | |
(w//2)-(k_w//2): (w//2)+k_w-(k_w//2)] = k_gaussian | |
# Expand kernel to 3 channels for easier coding | |
kernel_arr = kernel_arr[:, :, np.newaxis].repeat(3, axis=-1) | |
# distort image by convolving with kernel and adding noise with desired SNR | |
convolved = fft_conv2d(img_arr, kernel_arr) | |
noise = white_noise(img_arr.shape, convolved.var() / SNR) | |
observed = convolved + noise # plot_img(observed) | |
# auxiliary computations for deconvolution. Orth conserves energy per channel | |
S = np.fft.fft2(img_arr, axes=(0, 1), norm="backward") # "unknown" | |
R = np.fft.fft2(kernel_arr, axes=(0, 1), norm="backward") # assumed | |
N = np.fft.fft2(noise, axes=(0, 1), norm="backward") # "unknown" | |
# observed signal is known | |
O = np.fft.fft2(observed, axes=(0, 1), norm="backward") | |
Omag = abs(O)**2 # (O * O.conj()).real | |
# perform "oracle" Wiener deconvolutions with optimal parameters. In real life | |
# we would have to estimate them from the data | |
Smag_opt = abs(S)**2 # normally must be assumed | |
Nmag_opt = abs(N)**2 # normally must be assumed | |
Rmag_opt = abs(R)**2 # normally must be assumed | |
# Wiener filtering: perfect recovery W_o and orthogonal assumption W_perp | |
W_o = S * O.conj() / Omag | |
W_perp = R.conj() / (Rmag_opt + (Nmag_opt / Smag_opt)) | |
# | |
recons_o = np.fft.ifft2(W_o * O, axes=(0, 1), norm="backward").real | |
recons_perp = np.fft.ifft2(W_perp * O, axes=(0, 1), norm="backward").real | |
# Compute epsilon using the two alternative formulations | |
# and compute residual reconstruction | |
eps = (S * N.conj() - 2 * W_perp * (S * R * N.conj()).real) / Omag | |
eps2 = WienerEpsilon(R, Smag_opt, Nmag_opt, Omag)(S*N.conj()) | |
recons_eps = np.fft.ifft2(eps * O, axes=(0, 1), norm="backward").real | |
recons_eps2 = np.fft.ifft2(eps2 * O, axes=(0, 1), norm="backward").real | |
# ############################################################################## | |
# # TESTS | |
# ############################################################################## | |
print("Image energy:", energy(img_arr)) | |
print("Distorted energy:", energy(observed)) | |
print("Residual energies:") | |
# observation artifacts introduce a lot of unwanted energy | |
print(" image->distorted:", energy(img_arr - observed)) | |
# W_o provides a perfect recovery | |
print(" image - W_o:", energy(img_arr - recons_o)) | |
# W_perp provides a decent recovery | |
print(" image - W_perp:", energy(img_arr - recons_perp)) | |
# And eps provides correction from perp to perfect | |
print(" image - (W_perp + eps1) :", | |
energy(img_arr - (recons_perp + recons_eps))) | |
""" | |
Algebraic identities: | |
1. Parseval: | |
(energy(img_arr), (abs(S)**2).sum() / S[:, :, 0].size) | |
(energy(kernel_arr), (abs(R)**2).sum() / R[:, :, 0].size) | |
(energy(noise), (abs(N)**2).sum() / N[:, :, 0].size) | |
2. Decomposition of power spectrum: | |
np.allclose( ((S*R) * (S*R).conj()).real, Smag_opt * Rmag_opt ) | |
3. Conv theorem for SR | |
np.allclose(np.fft.fft2(convolved, axes=(0, 1), norm="backward"), S*R) | |
4. Conv theorem for O = SR+N | |
np.allclose(O, S*R + N) | |
5. magnitude(SR) = magnitude(S) * magnitude(R) | |
np.allclose(abs(S*R), abs(S) * abs(R)) | |
6. magnitude(O) = magnitude(SR) + magnitude(N) + magnitude(S)*gamma | |
gamma = 2 * (S * R * N.conj()).real | |
np.allclose(abs(O)**2, abs(S*R)**2 + abs(N)**2 + gamma) | |
np.linalg.norm(abs(O)**2 - (abs(S*R)**2 + abs(N)**2 + gamma)) | |
7. Check formula for f_r(SÑ): | |
vareps1 = (S * N.conj() - 2 * W_perp * (S * R * N.conj()).real) / Omag | |
we = WienerEpsilon(R, Smag_opt, Nmag_opt, Omag) | |
vareps2 = we(S*N.conj()) | |
np.allclose(W_o, W_perp + vareps1) | |
np.allclose(W_o, W_perp + vareps2) | |
8. Finally check that (recons_perp + recons_eps) yields perfect reconstruction | |
np.allclose(recons_o, recons_perp + recons_eps) | |
np.allclose(recons_o, recons_perp + recons_eps) | |
""" | |
# gallery: | |
# plot_img(img_arr) | |
# plot_img(kernel_arr) | |
# plot_img(roll_half_hw(observed)) | |
# plot_img(recons_o) | |
# plot_img(recons_perp) | |
# plot_residual(recons_eps, "pink") | |
# save to disk: | |
# save_arr(roll_half_hw(observed), "observed.jpg", img_min, img_max) | |
# save_arr(recons_perp, "recons_perp.jpg", img_min, img_max) | |
breakpoint() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment