Last active
August 7, 2022 18:50
-
-
Save jdkoen/f27a0dc9b6c27bd2ea47913e2d15f62e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# These are functions to estimate a variety of ERP amptlitude and latency features that interact with mne-python | |
import numpy as np | |
import pandas as pd | |
from scipy.integrate import trapezoid | |
from scipy.signal import find_peaks | |
from mne import Evoked, pick_channels | |
# Define EEG/ERP Functions | |
def _handle_picks(ch_names, picks): | |
if picks is None: | |
picks = np.arange(len(ch_names)) | |
else: | |
if all(isinstance(pick, str) for pick in picks): | |
picks = pick_channels(ch_names, picks) | |
elif all(isinstance(pick, int) for pick in picks): | |
pass | |
else: | |
ValueError('picks must be a list of strings or list of integers') | |
return picks | |
def _find_nearest(a, a0, axis=-1, return_index=False): | |
idx = np.abs(a-a0).argmin(axis=axis) | |
if return_index: | |
return idx | |
else: | |
return a.flat[idx] | |
def _get_tmin_tmax(times, tmin=None, tmax=None): | |
tmin = times.min() if tmin is None else _find_nearest(times, tmin, axis=0) | |
tmax = times.max() if tmax is None else _find_nearest(times, tmax, axis=0) | |
return tmin, tmax | |
def _get_time_win(times, tmin, tmax, return_sample=False): | |
tmin, tmax = _get_tmin_tmax(times, tmin=tmin, tmax=tmax) | |
time_mask = np.logical_and(times >= tmin, times < tmax) | |
if return_sample: | |
smin = np.where(time_mask)[0][0] | |
smax = np.where(time_mask)[0][-1] | |
return time_mask, smin, smax | |
else: | |
return time_mask | |
def frac_area_latency(inst, mode='abs', frac=None, tmin=None, tmax=None): | |
# Get Data vector and sample indicies for tmin and tmax | |
ch_names = inst.info['ch_names'] | |
data = np.squeeze(inst.data) * 1e6 | |
times = inst.times | |
speriod = 1 / inst.info['sfreq'] | |
time_win, smin, smax = _get_time_win(times, tmin=tmin, tmax=tmax, | |
return_sample=True) | |
# Process data based on mode | |
if mode == 'pos': | |
data[data < 0] = 0 | |
elif mode == 'neg': | |
data[data > 0] = 0 | |
data = np.abs(data) # Always rectify | |
# Compute area between tmin and tmax (in time_win) | |
area = trapezoid(data[:, time_win], dx=speriod, axis=1) | |
if frac is None or frac == 1.0: | |
return ch_names, area | |
# Compute cumulative area by finding nearest 'cumulative' area | |
frac_area = area * frac | |
running_area = np.ones_like(data) * 10 | |
for i, sx in enumerate(np.arange(smin + 1, smax + 1)): | |
a = trapezoid(data[:, smin:sx], dx=speriod, axis=1) | |
running_area[:, smin+i] = a | |
search_samples = np.arange(smin, smax+1) | |
frac_lat_samples = _find_nearest(running_area[:, time_win], | |
frac_area[:, None], | |
axis=1, return_index=True) | |
frac_true_samples = search_samples[frac_lat_samples] | |
frac_lat_times = times[frac_true_samples] | |
# Return computed values | |
return ch_names, area, frac_lat_times | |
def peak_amp_lat(inst, mode='pos', tmin=None, tmax=None, picks=None, | |
return_microvolts=True, width=0): | |
"""Measure peak amplitude and latency. This | |
can be run on ERPs for peak latency and amplitude. Fractional peak | |
onset is better conducted on difference waves. Note fractional peak | |
onset is only returned if frac_peak is a float between 0 and 1. | |
Uses :func:~scipy.signal.find_peaks to locate peaks. | |
Parameters | |
---------- | |
inst : :class:~mne.Evoked object | |
A single instance of an :class:~mne.Evoked object. | |
mode : {‘pos’, ‘neg’, ‘abs’} (defaults 'pos') | |
Controls whether positive ('pos'), negative ('neg') or absolute | |
('abs') peaks are detected. 'pos' searches for a postiive going peak | |
(but the peak can take on a negative voltage). 'neg' searches for | |
negative going peaks by scaling the voltages by -1 (but the peaks | |
can take on a positive voltage. 'abs' finds the largest peak | |
regardless of sign. | |
tmin : float | None (defaults None) | |
The minimum point in time to be considered for peak getting. If None | |
(default), the beginning of the data is used. | |
tmax : float | None (defaults None) | |
The maximum point in time to be considered for peak getting. If None | |
(default), the end of the data is used. | |
picks : str|list|int|None (defaults None) | |
Channels to include. integers and lists of integers will be interpreted | |
as channel indices. str and lists of strings will be interpreted as | |
channel names. | |
return_microvolts : bool (defaults True) | |
If True, returns the peak amplitude in μV. | |
width : int|ndarray|list (default 0) | |
Required width of peaks in samples. An integer is treated as the | |
minimal required width (with no maximum). A ndarray or list of | |
integers specifies the minimal and maximal widths, respectively. | |
Returns | |
------- | |
data : instace of :class:~pandas.DataFrame | |
A :class:~pandas.DataFrame with the peak amplitude, latency, | |
fractional peak latency, tmin, and tmax for each channel | |
specified by picks. | |
""" | |
# Check inst input | |
if isinstance(inst, Evoked): | |
TypeError('inst must be of Evoked type') | |
# Check mode | |
if mode not in ['neg', 'pos', 'abs']: | |
ValueError("mode must be 'pos', 'neg', or 'abs'") | |
# Handle picks | |
if isinstance(picks, int) or isinstance(picks, str): | |
picks = [picks] | |
picks = _handle_picks(inst.ch_names, picks) | |
# Extract data (use copy to avoid write-in place | |
data = inst.copy().data | |
if return_microvolts: | |
data *= 1e6 | |
# Extract times and handle tmin and tmax | |
times = inst.times | |
if tmin not in times and tmax not in times: | |
ValueError('tmin and tmax must have values in inst.times') | |
# Initialize output dataframe | |
out_df = pd.DataFrame(columns=['ch_name', 'tmin', 'tmax', | |
'peak_amplitude', 'peak_latency']) | |
# Loop through channels | |
for i, pick in enumerate(picks): | |
# Get time window for this iteration | |
time_mask = np.logical_and(times >= tmin, times <= tmax) | |
time_window = times[time_mask] | |
# Extract windowed data and manipulate as needed | |
data_window = data[pick, time_mask] | |
sign_window = np.sign(data_window) | |
if mode == 'neg': | |
data_window *= -1 | |
elif mode == 'abs': | |
data_window = np.abs(data_window) | |
# Find the peak indices and amplitudes | |
try: | |
peaks, _ = find_peaks(data_window, width=width) | |
amplitudes = data_window[peaks] | |
except ValueError: | |
peaks = None | |
# Extract peak information | |
if peaks is not None: | |
peak_index = peaks[np.argmax(amplitudes)] | |
peak_latency = time_window[peak_index] | |
peak_amplitude = np.abs(data_window[peak_index]) | |
peak_amplitude *= sign_window[peak_index] | |
else: | |
peak_amplitude = None | |
peak_latency = None | |
# Add to output | |
out_df.at[i, :] = [inst.ch_names[pick], tmin, tmax, | |
peak_amplitude, peak_latency] | |
# Return | |
return out_df | |
def mean_amplitude(inst, tmin=None, tmax=None, picks=None, | |
return_microvolts=True): | |
# Check inst input | |
if isinstance(inst, Evoked): | |
TypeError('inst must be of Evoked type') | |
# Handle picks | |
ch_names = inst.ch_names | |
if picks is None: | |
picks = np.arange(len(ch_names)) | |
else: | |
if all(isinstance(pick, str) for pick in picks): | |
picks = pick_channels(ch_names, picks) | |
elif all(isinstance(pick, int) for pick in picks): | |
pass | |
else: | |
ValueError('picks must be a list of strings or list of integers') | |
# Extract times and handle tmin and tmax | |
if tmin not in inst.times and tmax not in inst.times: | |
ValueError('tmin and tmax must have values in inst.times') | |
time_mask = np.logical_and(inst.times >= tmin, inst.times <= tmax) | |
# Initialize output dataframe | |
out_df = pd.DataFrame(columns=['ch_name', 'tmin', 'tmax', | |
'mean_amplitude']) | |
# Loop through channels | |
for i, pick in enumerate(picks): | |
# Get mean amplitude | |
mean_amp = inst.copy().data[pick, time_mask].mean(axis=-1) | |
if return_microvolts: | |
mean_amp *= 1e6 | |
# Add to output | |
out_df.at[i, :] = [ch_names[pick], tmin, tmax, mean_amp] | |
# Return | |
return out_df | |
# A Code snippet of fractional peak onset/offset | |
frac_peak = .5 | |
# Peaks is from the peak_amp_lat function | |
for ch_i, ch_row in peaks.iterrows(): | |
this_data = evoked.get_data()[ch_i, :] | |
onset_lat = np.where(evoked.times == ch_row['peak_latency'])[0][0] | |
thresh = np.abs(ch_row['peak_amplitude'] * frac_peak) | |
while True: | |
if np.abs(this_data[onset_lat]) >= thresh: | |
onset_lat -= 1 | |
else: | |
break | |
peaks.at[ch_i, '50%_onset_latency'] = evoked.times[onset_lat] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment