Skip to content

Instantly share code, notes, and snippets.

@raphaelvallat
Last active September 23, 2018 00:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raphaelvallat/b01ffd67a34fd60ca78738555ec5e0e0 to your computer and use it in GitHub Desktop.
Save raphaelvallat/b01ffd67a34fd60ca78738555ec5e0e0 to your computer and use it in GitHub Desktop.
A simple and efficient wavelet-based spindles detector
import numpy as np
def spindles_detect(x, sf, perc_threshold=90, wlt_params={'n_cycles': 7, 'central_freq': 'auto'}):
"""Simple spindles detector based on Morlet wavelet.
Parameters
----------
x : 1D-array
EEG signal
sf : float
Sampling frequency
perc_threshold : float
Percentile threshold
wlt_params : dict
Wavelet parameters ::
'n_cycles' : number of oscillations (lower = better time res, higher = better freq res)
'central_freq' : central spindles frequency ('auto' uses the peak sigma frequency of x)
Returns
-------
spindes : 1D-array (boolean)
Boolean array indicating for each point if it is a spindles or not.
spindles_param : dict
Spindles parameters dictionnary (duration, frequency and amplitude of each detected spindles)
"""
from scipy.signal import detrend
from mne.time_frequency import morlet, psd_array_welch
if wlt_params['central_freq'] == 'auto':
psd, freqs = psd_array_welch(x, sf, fmin=11, fmax=16, n_fft=int(2 * sf), verbose=0)
wlt_params['central_freq'] = freqs[np.argmax(psd)]
# Compute the wavelet and convolve with data
wlt = morlet(sf, [wlt_params['central_freq']], n_cycles=wlt_params['n_cycles'])[0]
analytic = np.convolve(x, wlt, mode='same')
magnitude = np.abs(analytic)
phase = np.angle(analytic)
# Find supra-threshold values and indices
supra_thresh_bool = magnitude >= np.percentile(magnitude, q=perc_threshold)
supra_thresh_idx = np.where(supra_thresh_bool)[0]
# Extract duration, frequency and amplitude of spindles
sp = np.split(supra_thresh_idx, np.where(np.diff(supra_thresh_idx) != 1)[0] + 1)
idx_start_end = np.array([[k[0], k[-1]] for k in sp])
sp_dur = (np.diff(idx_start_end, axis=1) / sf).flatten() * 1000
sp_amp, sp_freq = np.zeros(len(sp)), np.zeros(len(sp))
for i in range(len(sp)):
sp_amp[i] = np.ptp(detrend(x[sp[i]]))
sp_freq[i] = np.median((sf / (2 * np.pi) * np.diff(phase[sp[i]])))
sp_params = {'Duration (ms)' : sp_dur, 'Frequency (Hz)': sp_freq, 'Amplitude (uV)': sp_amp}
return supra_thresh_bool, sp_params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment