Skip to content

Instantly share code, notes, and snippets.

@nigma
Created April 21, 2011 11:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nigma/934166 to your computer and use it in GitHub Desktop.
Save nigma/934166 to your computer and use it in GitHub Desktop.
Multi-level n-dimensional wavelet transform with PyWavelets.
#!/usr/bin/env python
#-*- coding: utf-8 -*-
__author__ = 'Filip Wasilewski <en@ig.ma>'
from pywt import Wavelet, dwtn, dwt_max_level
from pywt.numerix import as_float_array
def wavedecn(data, wavelet, mode='sym', level=None):
data = as_float_array(data)
dim = len(data.shape)
if not isinstance(wavelet, Wavelet):
wavelet = Wavelet(wavelet)
if level is None:
size = min(data.shape)
level = dwt_max_level(size, wavelet.dec_len)
elif level < 0:
raise ValueError("Level value of %d is too low . Minimum level is 0." % level)
coeffs_list = []
a = data
for i in xrange(level):
coeffs = dwtn(a, wavelet, mode)
coeffs_list.append(coeffs)
a = coeffs['a' * dim] # get the approximation coefficients array
coeffs_list.reverse()
return coeffs_list
def test():
import pprint
import numpy
data = numpy.random.randn(8,8,8)
#data = [[1,2,3,4], [4,5,6,7], [6,7,8,9], [8,9,10,11]]
rounded = []
for coeffs in wavedecn(data, 'db1'):
rounded.append(
dict((key, value.round(3)) for key, value in coeffs.items())
)
pprint.pprint(rounded)
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment