Skip to content

Instantly share code, notes, and snippets.

@ThomasA
Created July 3, 2015 11:09
Show Gist options
  • Save ThomasA/f0f45ce5012f5c0c8bc3 to your computer and use it in GitHub Desktop.
Save ThomasA/f0f45ce5012f5c0c8bc3 to your computer and use it in GitHub Desktop.
This script attempts to prove that the ISWT2 implementation from https://groups.google.com/d/msg/pywavelets/xk3UeFz9ZK0/mylNG7qUyKsJ does not work properly.
"""
This script attempts to prove that the ISWT2 implementation from
https://groups.google.com/d/msg/pywavelets/xk3UeFz9ZK0/mylNG7qUyKsJ
does not work properly.
"""
import numpy as np
import pywt
import pywt_addon
# `iswt2` appears to work from a round-trip test perspective
A = np.random.rand(32,32)
coeffs = pywt.swt2(A, 'db10', 2)
B = pywt_addon.iswt2(coeffs, 'db10')
assert np.allclose(A, B), 'The original matrix and its round-trip transform are not identical'
# However, it only reconstructs the matrix based on the level-1
# coefficients. Proof:
coeffs_new1 = (coeffs[0], (np.random.rand(32,32),
(np.random.rand(32,32),
np.random.rand(32,32),
np.random.rand(32,32))))
C = pywt_addon.iswt2(coeffs_new1, 'db10')
assert np.allclose(A, C), 'The original matrix and its round-trip transform are not identical'
# See? The above does not fail although I have turned the level-2
# wavelet coefficients into complete gibberish. It does fail on the
# first level, though:
coeffs_new0 = ((np.random.rand(32,32),
(np.random.rand(32,32),
np.random.rand(32,32),
np.random.rand(32,32))), coeffs[1])
D = pywt_addon.iswt2(coeffs_new0, 'db10')
assert np.allclose(A, D), 'The original matrix and its round-trip transform are not identical'
# This means that this `iswt2` implementation only works reliably for
# 1-layer transforms. For multi-layer transforms, it will not work
# correctly in applications that modify the layer > 1 transform
# coefficients.
"""
This module is supposed to supplement Pywavelets (pywt) with inverse
stationary wavelet transforms. The code originates from:
https://groups.google.com/d/msg/pywavelets/xk3UeFz9ZK0/mylNG7qUyKsJ. The
current version has a couple of minor modifications:
* Imports numpy in the common way (`as np`).
* Construction of `Range` in the `iswt2` function has been altered
for Python 3 compatibility.
"""
import pywt
import numpy as np
def iswt(coefficients, wavelet):
"""
Input parameters:
coefficients
approx and detail coefficients, arranged in level value
exactly as output from swt:
e.g. [(cA1, cD1), (cA2, cD2), ..., (cAn, cDn)]
wavelet
Either the name of a wavelet or a Wavelet object
"""
output = coefficients[0][0].copy() # Avoid modification of input data
#num_levels, equivalent to the decomposition level, n
num_levels = len(coefficients)
for j in range(num_levels,0,-1):
step_size = int(pow(2, j-1))
last_index = step_size
_, cD = coefficients[num_levels - j]
for first in range(last_index): # 0 to last_index - 1
# Getting the indices that we will transform
indices = np.arange(first, len(cD), step_size)
#print first, indices
# select the even indices
even_indices = indices[0::2]
# select the odd indices
odd_indices = indices[1::2]
# perform the inverse dwt on the selected indices,
# making sure to use periodic boundary conditions
x1 = pywt.idwt(output[even_indices], cD[even_indices], wavelet, 'per')
x2 = pywt.idwt(output[odd_indices], cD[odd_indices], wavelet, 'per')
# perform a circular shift right
# original:
#x2 = roll(x2, 1)
# average and insert into the correct indices
#output[indices] = (x1 + x2)/2.
#modified to allow exact reconstruction of original data, if swt2 is used
# with start_level = 0, and wavelet is haar or db1
output[even_indices] = x1[0::2]
output[odd_indices] = x2[0::2]
return output
def iswt2(coefficients, wav):
"""
Input parameters:
coefficients
Approx and detail coefficients, arranged in level value
exactly as output from swt2:
e.g. [(cA_n, (cH_n, cV_n, cD_n)), (cA_n+1, (cH_n+1, cV_n+1, cD_n+1)), ...]
Note: for accurate reconstruction of original data, swt2 must be used with
start_level = 0, unless wavelet is haar or db1 (see modification of iswt).
wavelet
The name of a wavelet
"""
Level = len(coefficients)
Range = np.arange(Level)[::-1]
Shape = coefficients[0][1][0].shape
Out = np.zeros(Shape,'d')
for iRange in Range:
C1 = coefficients[iRange]
approx = C1[0]
LL = np.transpose(approx)
LH = np.transpose(C1[1][0])
HL = np.transpose(C1[1][1])
HH = np.transpose(C1[1][2])
H = np.zeros(Shape,'d')
L = np.zeros(Shape,'d')
for i in range(H.shape[0]):
coef = [(HL[i], HH[i])]
out = iswt(coef, wav)
H[i] = out
H = H.T
for i in range(L.shape[0]):
coef = [(LL[i], LH[i])]
out = iswt(coef, wav)
L[i] = out
L = L.T
for i in range(Out.shape[0]):
coef = [(L[i], H[i])]
out = iswt(coef, wav)
Out[i] = out
return Out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment