Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active Jan 12, 2022
Embed
What would you like to do?
Wiener deconvolution study
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
Wiener deconvolution study.
Copyright (C) 2021 aferro (ORCID: 0000-0003-3830-3595)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
# ##############################################################################
# # HELPER FUNCTIONS
# ##############################################################################
def rbf_kernel2d(size, covmat):
"""
:param size: ``(height, width)``
:param covmat: A 2x2 matrix. Note that the first dimension is vertical
downwards, and the second dimension horizontal from left to right.
"""
h, w = size
yyxx = np.mgrid[0:h:1, 0:w:1] # shape (2, h, w)
yyxx[0] -= h // 2
yyxx[1] -= w // 2
#
yyxx = yyxx.reshape(2, -1)
maha = (yyxx * (np.linalg.inv(covmat) @ yyxx)).sum(axis=0).reshape(h, w)
rbf = np.exp(-0.5 * maha)
return rbf
def roll_half_hw(arr):
"""
:param arr: Array of shape ``(h, w, channels)``.
Rolls array so that middle pixel gets shifted to the top left corner.
"""
h, w = arr.shape[0], arr.shape[1]
delta_h, delta_w = (h // 2), (w // 2)
arr = np.roll(arr, -delta_h, axis=0)
arr = np.roll(arr, -delta_w, axis=1)
return arr
def fft_conv2d(signal, kernel):
"""
Kernel is expected to have same ``(h, w, channels)`` shape and dtype
as signal. It should also be centered and have at least
25% zero-padding on each margin to prevent circular overlapping artifacts
(note that circular conv will happen anyway).
"""
# the "backward" norm is important to preserve energy
s_fft = np.fft.fft2(signal, axes=(0, 1), norm="backward")
k_fft = np.fft.fft2(kernel, axes=(0, 1), norm="backward")
conv = np.fft.ifft2(s_fft * k_fft, axes=(0, 1), norm="backward").real
return conv
def energy(sig):
"""
"""
return (sig**2).sum()
def white_noise(shape, variance):
"""
Returns white noise in range [-x, x], x is adjusted to return
approximate desired variance.
"""
noise = np.random.random(shape) * 2 - 1
noise -= noise.mean()
# white noise in [-1, 1] has variance=1/3
factor = (3 * variance) ** 0.5
noise *= factor # therefore this has approx desired var
return noise
class WienerEpsilon:
"""
This class computes the epsilon such that wiener_perp+epsilon equals
wiener_optimal, but the computations are redistributed so that most of them
only need to be done once at construction. Check website for algebraic
derivations.
"""
def __init__(self, R, Smag, Nmag, Omag):
"""
All inputs given in Fourier space. The ``mag`` parameters correspond to
square magnitudes, i.e. complex spectrum times its conjugate.
"""
ratio = - 2.0 / (abs(R)**2 + Nmag / Smag) # -2/Delta
#
rere2 = R.real ** 2
imim2 = R.imag ** 2
reim2 = R.real * R.imag
#
self.re_coeff = np.zeros_like(R)
self.re_coeff.real = rere2
self.re_coeff.imag = -reim2
self.re_coeff *= ratio
self.re_coeff += 1
self.re_coeff /= Omag
#
self.im_coeff = np.zeros_like(R)
self.im_coeff.real = imim2
self.im_coeff.imag = reim2
self.im_coeff *= ratio
self.im_coeff += 1
self.im_coeff /= Omag
def __call__(self, SNconj):
"""
:param SNconj: spectral cross-correlation ``S * N.conj()`` between
signal and noise, both of same shape as the spectra given at
construction.
:returns: A spectral filter epsilon, so that wiener_perp+epsilon
equals wiener_optimal.
"""
result = SNconj.real * self.re_coeff
buffr = SNconj.imag * self.im_coeff
# u+iv = a+ib + ic - d = (a-d) + i(b+c)
result.real -= buffr.imag
result.imag += buffr.real
#
return result
def plot_img(hwc):
"""
"""
p = hwc - hwc.min()
p /= p.max()
plt.clf()
plt.imshow(p)
plt.show()
def plot_residual(residual, cmap="pink"):
"""
Interesting colormaps: pink, gist_heat_r, hot, copper_r
https://matplotlib.org/stable/tutorials/colors/colormaps.html
"""
norm = np.linalg.norm(residual, axis=-1)
plt.clf()
plt.imshow(norm, cmap=cmap)
plt.show()
def save_arr(arr, outpath, minval=0, maxval=255):
"""
"""
# normalize between 0 and 1
arr = arr - arr.min()
arr /= arr.max()
# normalize between minval and maxval
valrange = maxval - minval
arr *= valrange
arr += minval
#
im = Image.fromarray(arr.astype(np.uint8))
im.save(outpath)
# ##############################################################################
# # MAIN ROUTINE
# ##############################################################################
# globals
SNR = 10
DTYPE = np.float64
IMG_PATH = "christkindlmarkt_crop.jpg"
KERNEL_PATH = "conv_kernel.png"
# load image from disk and normalize
img = Image.open(IMG_PATH)
img_arr = np.array(img).astype(DTYPE)
img_min, img_max = img_arr.min(), img_arr.max()
img_arr -= img_arr.mean() # plot_img(img_arr)
img_arr /= img_arr.std()
h, w, c = img_arr.shape
# create a conv kernel
k = Image.open(KERNEL_PATH)
k_arr = np.array(k)[:, :, 0].astype(DTYPE)
k_h, k_w = k_arr.shape
#
gaussian = rbf_kernel2d(k_arr.shape, np.array([[5.0, 0], [0, 5.0]]))
# k_gaussian = k_arr
k_gaussian = roll_half_hw(fft_conv2d(k_arr, gaussian))
k_gaussian -= k_gaussian.min()
k_gaussian /= (k_gaussian**2).sum()**0.5 # normalize kernel to energy=1
kernel_arr = np.zeros_like(img_arr[:, :, 0]) # plot_img(kernel_arr)
kernel_arr[(h//2)-(k_h//2): (h//2)+k_h-(k_h//2),
(w//2)-(k_w//2): (w//2)+k_w-(k_w//2)] = k_gaussian
# Expand kernel to 3 channels for easier coding
kernel_arr = kernel_arr[:, :, np.newaxis].repeat(3, axis=-1)
# distort image by convolving with kernel and adding noise with desired SNR
convolved = fft_conv2d(img_arr, kernel_arr)
noise = white_noise(img_arr.shape, convolved.var() / SNR)
observed = convolved + noise # plot_img(observed)
# auxiliary computations for deconvolution. Orth conserves energy per channel
S = np.fft.fft2(img_arr, axes=(0, 1), norm="backward") # "unknown"
R = np.fft.fft2(kernel_arr, axes=(0, 1), norm="backward") # assumed
N = np.fft.fft2(noise, axes=(0, 1), norm="backward") # "unknown"
# observed signal is known
O = np.fft.fft2(observed, axes=(0, 1), norm="backward")
Omag = abs(O)**2 # (O * O.conj()).real
# perform "oracle" Wiener deconvolutions with optimal parameters. In real life
# we would have to estimate them from the data
Smag_opt = abs(S)**2 # normally must be assumed
Nmag_opt = abs(N)**2 # normally must be assumed
Rmag_opt = abs(R)**2 # normally must be assumed
# Wiener filtering: perfect recovery W_o and orthogonal assumption W_perp
W_o = S * O.conj() / Omag
W_perp = R.conj() / (Rmag_opt + (Nmag_opt / Smag_opt))
#
recons_o = np.fft.ifft2(W_o * O, axes=(0, 1), norm="backward").real
recons_perp = np.fft.ifft2(W_perp * O, axes=(0, 1), norm="backward").real
# Compute epsilon using the two alternative formulations
# and compute residual reconstruction
eps = (S * N.conj() - 2 * W_perp * (S * R * N.conj()).real) / Omag
eps2 = WienerEpsilon(R, Smag_opt, Nmag_opt, Omag)(S*N.conj())
recons_eps = np.fft.ifft2(eps * O, axes=(0, 1), norm="backward").real
recons_eps2 = np.fft.ifft2(eps2 * O, axes=(0, 1), norm="backward").real
# ##############################################################################
# # TESTS
# ##############################################################################
print("Image energy:", energy(img_arr))
print("Distorted energy:", energy(observed))
print("Residual energies:")
# observation artifacts introduce a lot of unwanted energy
print(" image->distorted:", energy(img_arr - observed))
# W_o provides a perfect recovery
print(" image - W_o:", energy(img_arr - recons_o))
# W_perp provides a decent recovery
print(" image - W_perp:", energy(img_arr - recons_perp))
# And eps provides correction from perp to perfect
print(" image - (W_perp + eps1) :",
energy(img_arr - (recons_perp + recons_eps)))
"""
Algebraic identities:
1. Parseval:
(energy(img_arr), (abs(S)**2).sum() / S[:, :, 0].size)
(energy(kernel_arr), (abs(R)**2).sum() / R[:, :, 0].size)
(energy(noise), (abs(N)**2).sum() / N[:, :, 0].size)
2. Decomposition of power spectrum:
np.allclose( ((S*R) * (S*R).conj()).real, Smag_opt * Rmag_opt )
3. Conv theorem for SR
np.allclose(np.fft.fft2(convolved, axes=(0, 1), norm="backward"), S*R)
4. Conv theorem for O = SR+N
np.allclose(O, S*R + N)
5. magnitude(SR) = magnitude(S) * magnitude(R)
np.allclose(abs(S*R), abs(S) * abs(R))
6. magnitude(O) = magnitude(SR) + magnitude(N) + magnitude(S)*gamma
gamma = 2 * (S * R * N.conj()).real
np.allclose(abs(O)**2, abs(S*R)**2 + abs(N)**2 + gamma)
np.linalg.norm(abs(O)**2 - (abs(S*R)**2 + abs(N)**2 + gamma))
7. Check formula for f_r(SÑ):
vareps1 = (S * N.conj() - 2 * W_perp * (S * R * N.conj()).real) / Omag
we = WienerEpsilon(R, Smag_opt, Nmag_opt, Omag)
vareps2 = we(S*N.conj())
np.allclose(W_o, W_perp + vareps1)
np.allclose(W_o, W_perp + vareps2)
8. Finally check that (recons_perp + recons_eps) yields perfect reconstruction
np.allclose(recons_o, recons_perp + recons_eps)
np.allclose(recons_o, recons_perp + recons_eps)
"""
# gallery:
# plot_img(img_arr)
# plot_img(kernel_arr)
# plot_img(roll_half_hw(observed))
# plot_img(recons_o)
# plot_img(recons_perp)
# plot_residual(recons_eps, "pink")
# save to disk:
# save_arr(roll_half_hw(observed), "observed.jpg", img_min, img_max)
# save_arr(recons_perp, "recons_perp.jpg", img_min, img_max)
breakpoint()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment