Skip to content

Instantly share code, notes, and snippets.

@jdkoen
Last active August 7, 2022 18:50
Show Gist options
  • Save jdkoen/f27a0dc9b6c27bd2ea47913e2d15f62e to your computer and use it in GitHub Desktop.
Save jdkoen/f27a0dc9b6c27bd2ea47913e2d15f62e to your computer and use it in GitHub Desktop.
# These are functions to estimate a variety of ERP amptlitude and latency features that interact with mne-python
import numpy as np
import pandas as pd
from scipy.integrate import trapezoid
from scipy.signal import find_peaks
from mne import Evoked, pick_channels
# Define EEG/ERP Functions
def _handle_picks(ch_names, picks):
if picks is None:
picks = np.arange(len(ch_names))
else:
if all(isinstance(pick, str) for pick in picks):
picks = pick_channels(ch_names, picks)
elif all(isinstance(pick, int) for pick in picks):
pass
else:
ValueError('picks must be a list of strings or list of integers')
return picks
def _find_nearest(a, a0, axis=-1, return_index=False):
idx = np.abs(a-a0).argmin(axis=axis)
if return_index:
return idx
else:
return a.flat[idx]
def _get_tmin_tmax(times, tmin=None, tmax=None):
tmin = times.min() if tmin is None else _find_nearest(times, tmin, axis=0)
tmax = times.max() if tmax is None else _find_nearest(times, tmax, axis=0)
return tmin, tmax
def _get_time_win(times, tmin, tmax, return_sample=False):
tmin, tmax = _get_tmin_tmax(times, tmin=tmin, tmax=tmax)
time_mask = np.logical_and(times >= tmin, times < tmax)
if return_sample:
smin = np.where(time_mask)[0][0]
smax = np.where(time_mask)[0][-1]
return time_mask, smin, smax
else:
return time_mask
def frac_area_latency(inst, mode='abs', frac=None, tmin=None, tmax=None):
# Get Data vector and sample indicies for tmin and tmax
ch_names = inst.info['ch_names']
data = np.squeeze(inst.data) * 1e6
times = inst.times
speriod = 1 / inst.info['sfreq']
time_win, smin, smax = _get_time_win(times, tmin=tmin, tmax=tmax,
return_sample=True)
# Process data based on mode
if mode == 'pos':
data[data < 0] = 0
elif mode == 'neg':
data[data > 0] = 0
data = np.abs(data) # Always rectify
# Compute area between tmin and tmax (in time_win)
area = trapezoid(data[:, time_win], dx=speriod, axis=1)
if frac is None or frac == 1.0:
return ch_names, area
# Compute cumulative area by finding nearest 'cumulative' area
frac_area = area * frac
running_area = np.ones_like(data) * 10
for i, sx in enumerate(np.arange(smin + 1, smax + 1)):
a = trapezoid(data[:, smin:sx], dx=speriod, axis=1)
running_area[:, smin+i] = a
search_samples = np.arange(smin, smax+1)
frac_lat_samples = _find_nearest(running_area[:, time_win],
frac_area[:, None],
axis=1, return_index=True)
frac_true_samples = search_samples[frac_lat_samples]
frac_lat_times = times[frac_true_samples]
# Return computed values
return ch_names, area, frac_lat_times
def peak_amp_lat(inst, mode='pos', tmin=None, tmax=None, picks=None,
return_microvolts=True, width=0):
"""Measure peak amplitude and latency. This
can be run on ERPs for peak latency and amplitude. Fractional peak
onset is better conducted on difference waves. Note fractional peak
onset is only returned if frac_peak is a float between 0 and 1.
Uses :func:~scipy.signal.find_peaks to locate peaks.
Parameters
----------
inst : :class:~mne.Evoked object
A single instance of an :class:~mne.Evoked object.
mode : {‘pos’, ‘neg’, ‘abs’} (defaults 'pos')
Controls whether positive ('pos'), negative ('neg') or absolute
('abs') peaks are detected. 'pos' searches for a postiive going peak
(but the peak can take on a negative voltage). 'neg' searches for
negative going peaks by scaling the voltages by -1 (but the peaks
can take on a positive voltage. 'abs' finds the largest peak
regardless of sign.
tmin : float | None (defaults None)
The minimum point in time to be considered for peak getting. If None
(default), the beginning of the data is used.
tmax : float | None (defaults None)
The maximum point in time to be considered for peak getting. If None
(default), the end of the data is used.
picks : str|list|int|None (defaults None)
Channels to include. integers and lists of integers will be interpreted
as channel indices. str and lists of strings will be interpreted as
channel names.
return_microvolts : bool (defaults True)
If True, returns the peak amplitude in μV.
width : int|ndarray|list (default 0)
Required width of peaks in samples. An integer is treated as the
minimal required width (with no maximum). A ndarray or list of
integers specifies the minimal and maximal widths, respectively.
Returns
-------
data : instace of :class:~pandas.DataFrame
A :class:~pandas.DataFrame with the peak amplitude, latency,
fractional peak latency, tmin, and tmax for each channel
specified by picks.
"""
# Check inst input
if isinstance(inst, Evoked):
TypeError('inst must be of Evoked type')
# Check mode
if mode not in ['neg', 'pos', 'abs']:
ValueError("mode must be 'pos', 'neg', or 'abs'")
# Handle picks
if isinstance(picks, int) or isinstance(picks, str):
picks = [picks]
picks = _handle_picks(inst.ch_names, picks)
# Extract data (use copy to avoid write-in place
data = inst.copy().data
if return_microvolts:
data *= 1e6
# Extract times and handle tmin and tmax
times = inst.times
if tmin not in times and tmax not in times:
ValueError('tmin and tmax must have values in inst.times')
# Initialize output dataframe
out_df = pd.DataFrame(columns=['ch_name', 'tmin', 'tmax',
'peak_amplitude', 'peak_latency'])
# Loop through channels
for i, pick in enumerate(picks):
# Get time window for this iteration
time_mask = np.logical_and(times >= tmin, times <= tmax)
time_window = times[time_mask]
# Extract windowed data and manipulate as needed
data_window = data[pick, time_mask]
sign_window = np.sign(data_window)
if mode == 'neg':
data_window *= -1
elif mode == 'abs':
data_window = np.abs(data_window)
# Find the peak indices and amplitudes
try:
peaks, _ = find_peaks(data_window, width=width)
amplitudes = data_window[peaks]
except ValueError:
peaks = None
# Extract peak information
if peaks is not None:
peak_index = peaks[np.argmax(amplitudes)]
peak_latency = time_window[peak_index]
peak_amplitude = np.abs(data_window[peak_index])
peak_amplitude *= sign_window[peak_index]
else:
peak_amplitude = None
peak_latency = None
# Add to output
out_df.at[i, :] = [inst.ch_names[pick], tmin, tmax,
peak_amplitude, peak_latency]
# Return
return out_df
def mean_amplitude(inst, tmin=None, tmax=None, picks=None,
return_microvolts=True):
# Check inst input
if isinstance(inst, Evoked):
TypeError('inst must be of Evoked type')
# Handle picks
ch_names = inst.ch_names
if picks is None:
picks = np.arange(len(ch_names))
else:
if all(isinstance(pick, str) for pick in picks):
picks = pick_channels(ch_names, picks)
elif all(isinstance(pick, int) for pick in picks):
pass
else:
ValueError('picks must be a list of strings or list of integers')
# Extract times and handle tmin and tmax
if tmin not in inst.times and tmax not in inst.times:
ValueError('tmin and tmax must have values in inst.times')
time_mask = np.logical_and(inst.times >= tmin, inst.times <= tmax)
# Initialize output dataframe
out_df = pd.DataFrame(columns=['ch_name', 'tmin', 'tmax',
'mean_amplitude'])
# Loop through channels
for i, pick in enumerate(picks):
# Get mean amplitude
mean_amp = inst.copy().data[pick, time_mask].mean(axis=-1)
if return_microvolts:
mean_amp *= 1e6
# Add to output
out_df.at[i, :] = [ch_names[pick], tmin, tmax, mean_amp]
# Return
return out_df
# A Code snippet of fractional peak onset/offset
frac_peak = .5
# Peaks is from the peak_amp_lat function
for ch_i, ch_row in peaks.iterrows():
this_data = evoked.get_data()[ch_i, :]
onset_lat = np.where(evoked.times == ch_row['peak_latency'])[0][0]
thresh = np.abs(ch_row['peak_amplitude'] * frac_peak)
while True:
if np.abs(this_data[onset_lat]) >= thresh:
onset_lat -= 1
else:
break
peaks.at[ch_i, '50%_onset_latency'] = evoked.times[onset_lat]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment