Skip to content

Instantly share code, notes, and snippets.

@EtienneCmb
Last active August 24, 2022 06:59
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EtienneCmb/f76d8b4aba0088aba6a8c07e397a33c2 to your computer and use it in GitHub Desktop.
Save EtienneCmb/f76d8b4aba0088aba6a8c07e397a33c2 to your computer and use it in GitHub Desktop.
Python implementation of the Superlets
import numpy as np
try:
import cupy as cp
from cusignal.convolution.convolve import fftconvolve
except:
cp = np
from scipy.signal import fftconvolve
np.asnumpy = np.asarray
def bw_cf(t, bw, cf):
"""Computes the complex wavelet coefficients for the desired time point t,
bandwidth bw and center frequency cf"""
cnorm = 1 / (bw * np.sqrt(2 * np.pi))
exp1 = cnorm * np.exp(-(t ** 2) / (2 * (bw ** 2)));
res = np.exp(2j * np.pi * cf * t) * exp1
return res
def gauss(t, sd):
"""Compute the gaussian coefficient for the desired time point t and
standard deviation sd"""
cnorm = 1 / (sd * np.sqrt(2 * np.pi))
res = cnorm * np.exp(-(t ** 2) / (2 * (sd ** 2)))
return res
def cxmorlet(fc, n_cycles, sfreq):
"""Computes the complex Morlet wavelet for the desired center frequency.
Parameters
----------
fc : center frequency
n_cycles : number of cycles
sfreq : sampling frequency
"""
# we want to have the last peak at 2.5 SD
sd = (n_cycles / 2) * (1 / fc) / 2.5
wl = int(2 * np.floor(np.fix(6 * sd * sfreq) / 2) + 1)
w = np.zeros((wl), dtype=np.complex128)
gi = 0
off = np.fix(wl / 2)
for i in range(wl):
t = (i - off) / sfreq
w[i] = bw_cf(t, sd, fc)
gi += gauss(t, sd)
w /= gi
return w
def aslt(data, sfreq, foi, n_cycles, order=None, mult=False):
"""Adaptive superresolution wavelet (superlet) transform.
- data (array_like) : (n_epochs, n_times)
- sfreq (float) : sampling frequency
- foi (array_like) : central frequency of interest
- n_cycles (integer) : number of initial wavelet cycles
- order (array_like) : interval of super-resolution orders of shape (2,).
For example, use order=[1, 30]
- mult (bool) : specifies the use of multiplicative super-resolution (True)
or additive (False)
"""
# inputs checking
assert isinstance(data, np.ndarray)
data = np.atleast_2d(data).astype(np.float32)
n_epochs, n_times = data.shape
foi = np.asarray(foi)
n_freqs = len(foi)
# check order parameter and initialize the order used at each frequency. If
# empty, go with an order of 1 for each frequency (single wavelet per set)
if order is not None:
order_ls = np.fix(np.linspace(order[0], order[1], n_freqs)).astype(int)
else:
order_ls = np.ones((n_freqs,), dtype=np.int)
# the padding will be size of the lateral zero-pads, which serve to avoid
# border effects during convolution
padding = 0
# the wavelet sets
wavelets = dict()
# initialize wavelet sets for either additive or multiplicative
# superresolution
for i_freq in range(n_freqs):
for i_ord in range(order_ls[i_freq]):
# get the number of cycles
if mult: # multiplicative superresolution
n_cyc = n_cycles * (i_ord + 1)
else: # additive superresolution
n_cyc = n_cycles + i_ord
# each new wavelet has n_cyc extra cycles
_w = cxmorlet(foi[i_freq], n_cyc, sfreq)
# the margin will be the half-size of the largest wavelet
padding = max(padding, np.fix(len(_w) / 2))
wavelets[(i_freq, i_ord)] = _w
# the zero-padded buffer
buffer = cp.zeros((n_epochs, int(n_times + 2 * padding)),
dtype=cp.float32)
# convenience indexers for the zero-padded buffer
bufbegin = int(padding)
bufend = int(padding + n_times)
# fill the central part of the buffer with input data
buffer[:, bufbegin:bufend] = cp.asarray(data)
# the output scalogram
wtresult = cp.zeros((n_epochs, n_freqs, n_times), dtype=cp.float32)
for i_freq in range(n_freqs):
# pooling buffer, starts with 1 because we're doing geometric mean
temp = cp.ones((n_epochs, n_times), dtype=cp.float32)
# compute the convolution of the buffer with each wavelet in the
# current set
for i_ord in range(order_ls[i_freq]):
# get the single wavelets
sw = cp.asarray(wavelets[(i_freq, i_ord)]).reshape(1, -1)
# restricted convolution (input size == output size)
_temp = fftconvolve(buffer, sw, mode='same', axes=1)
# accumulate the magnitude (times 2 to get the full spectral
# energy
temp *= (2 * cp.abs(_temp[:, bufbegin:bufend]))
# compute the power of the geometric mean
root = 1. / float(order_ls[i_freq])
temp = (temp ** root) ** 2
# accumulate the current FOI to the result spectrum
wtresult[:, i_freq, :] += temp
return cp.asnumpy(wtresult)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment