Skip to content

Instantly share code, notes, and snippets.

@dengemann
Last active February 9, 2023 22:56
Show Gist options
  • Save dengemann/c2a411f50b7888d34ccd298cdfdf05c3 to your computer and use it in GitHub Desktop.
Save dengemann/c2a411f50b7888d34ccd298cdfdf05c3 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import re
import os
import glob
import datetime
from joblib import Parallel, delayed
from pathlib import Path
import mne
import numpy as np
import pandas as pd
from mne_bids import write_raw_bids, print_dir_tree, make_report, BIDSPath
from mne.io.edf.edf import _get_info
from braindecode.datasets.base import BaseDataset, BaseConcatDataset
SEX_TO_MNE = {'n/a': 0, 'm': 1, 'f': 2}
# In[2]:
from jupyterthemes.stylefx import set_nb_theme
set_nb_theme('oceans16')
# In[3]:
mne.set_log_level('warning')
# In[4]:
def _read_edf_header(file_path):
f = open(file_path, "rb")
header = f.read(88)
f.close()
return header
def _parse_age_and_gender_from_edf_header(file_path):
header = _read_edf_header(file_path)
# bytes 8 to 88 contain ascii local patient identification
# see https://www.teuniz.net/edfbrowser/edf%20format%20description.html
patient_id = header[8:].decode("ascii")
age = -1
found_age = re.findall(r"Age:(\d+)", patient_id)
if len(found_age) == 1:
age = int(found_age[0])
gender = "X"
found_gender = re.findall(r"\s([F|M])\s", patient_id)
if len(found_gender) == 1:
gender = found_gender[0]
return age, gender
# In[31]:
def _parse_description_from_file_path(file_path):
# stackoverflow.com/questions/3167154/how-to-split-a-dos-path-into-its-components-in-python # noqa
file_path = os.path.normpath(file_path)
tokens = file_path.split(os.sep)
# version 3.0
# expect file paths as file_type/split/status/reference/aaaaaaav_s004_t000.edf
# edf/train/normal/01_tcp_ar/aaaaaaav_s004_t000.edf
version = 'V3.0'
info, *_ = _get_info(
file_path, stim_channel='auto', eog=None,
misc=None, exclude=(), infer_types=False, preload=False)
date = info['meas_date']
fname = tokens[-1].replace('.edf', '')
subject_id, session, segment = fname.split('_')
return {
'path': file_path,
'version': version,
'year': date.year,
'month': date.month,
'day': date.day,
'subject': subject_id, # V3.0 has no longer subject numbers
'session': int(session[1:]),
'segment': int(segment[1:]),
}
def _create_chronological_description(file_paths):
# this is the first loop (fast)
descriptions = []
for file_path in file_paths:
description = _parse_description_from_file_path(file_path)
descriptions.append(pd.Series(description))
descriptions = pd.concat(descriptions, axis=1)
# order descriptions chronologically
descriptions.sort_values(
["subject", "session", "segment", "year", "month", "day"],
axis=1, inplace=True)
# https://stackoverflow.com/questions/42284617/reset-column-index-pandas
descriptions = descriptions.T.reset_index(drop=True).T
return descriptions
# In[32]:
class TUH(BaseConcatDataset):
"""Temple University Hospital (TUH) EEG Corpus
(www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tueg).
Parameters
----------
path: str
Parent directory of the dataset.
recording_ids: list(int) | int
A (list of) int of recording id(s) to be read (order matters and will
overwrite default chronological order, e.g. if recording_ids=[1,0],
then the first recording returned by this class will be chronologically
later then the second recording. Provide recording_ids in ascending
order to preserve chronological order.).
target_name: str
Can be 'gender', or 'age'.
preload: bool
If True, preload the data of the Raw objects.
add_physician_reports: bool
If True, the physician reports will be read from disk and added to the
description.
n_jobs: int
Number of jobs to be used to read files in parallel.
"""
def __init__(self, path, recording_ids=None, target_name=None,
preload=False, add_physician_reports=False, n_jobs=1):
# create an index of all files and gather easily accessible info
# without actually touching the files
file_paths = glob.glob(os.path.join(path, '**/*.edf'), recursive=True)
descriptions = _create_chronological_description(file_paths)
# limit to specified recording ids before doing slow stuff
if recording_ids is not None:
descriptions = descriptions[recording_ids]
# this is the second loop (slow)
# create datasets gathering more info about the files touching them
# reading the raws and potentially preloading the data
# disable joblib for tests. mocking seems to fail otherwise
if n_jobs == 1:
base_datasets = [self._create_dataset(
descriptions[i], target_name, preload, add_physician_reports)
for i in descriptions.columns]
else:
base_datasets = Parallel(n_jobs)(delayed(
self._create_dataset)(
descriptions[i], target_name, preload, add_physician_reports
) for i in descriptions.columns)
super().__init__(base_datasets)
@staticmethod
def _create_dataset(description, target_name, preload,
add_physician_reports):
file_path = description.loc['path']
# parse age and gender information from EDF header
age, gender = _parse_age_and_gender_from_edf_header(file_path)
raw = mne.io.read_raw_edf(file_path, preload=preload)
# Use recording date from path as EDF header is sometimes wrong
meas_date = datetime(1, 1, 1, tzinfo=timezone.utc) \
if raw.info['meas_date'] is None else raw.info['meas_date']
raw.set_meas_date(meas_date.replace(
*description[['year', 'month', 'day']]))
# read info relevant for preprocessing from raw without loading it
d = {
'age': int(age),
'gender': gender,
}
if add_physician_reports:
physician_report = _read_physician_report(file_path)
d['report'] = physician_report
additional_description = pd.Series(d)
description = pd.concat([description, additional_description])
base_dataset = BaseDataset(raw, description,
target_name=target_name)
return base_dataset
# In[33]:
class TUHAbnormal(TUH):
"""Temple University Hospital (TUH) Abnormal EEG Corpus.
see www.isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml#c_tuab
Parameters
----------
path: str
Parent directory of the dataset.
recording_ids: list(int) | int
A (list of) int of recording id(s) to be read (order matters and will
overwrite default chronological order, e.g. if recording_ids=[1,0],
then the first recording returned by this class will be chronologically
later then the second recording. Provide recording_ids in ascending
order to preserve chronological order.).
target_name: str
Can be 'pathological', 'gender', or 'age'.
preload: bool
If True, preload the data of the Raw objects.
add_physician_reports: bool
If True, the physician reports will be read from disk and added to the
description.
"""
def __init__(self, path, recording_ids=None, target_name='pathological',
preload=False, add_physician_reports=False, n_jobs=1):
super().__init__(path=path, recording_ids=recording_ids,
preload=preload, target_name=target_name,
add_physician_reports=add_physician_reports,
n_jobs=n_jobs)
additional_descriptions = []
for file_path in self.description.path:
additional_description = (
self._parse_additional_description_from_file_path(file_path))
additional_descriptions.append(additional_description)
additional_descriptions = pd.DataFrame(additional_descriptions)
self.set_description(additional_descriptions, overwrite=True)
@staticmethod
def _parse_additional_description_from_file_path(file_path):
file_path = os.path.normpath(file_path)
tokens = file_path.split(os.sep)
# expect paths as version/file type/data_split/pathology status/
# reference/subset/subject/recording session/file
# e.g. v2.0.0/edf/train/normal/01_tcp_ar/000/00000021/
# s004_2013_08_15/00000021_s004_t000.edf
assert ('abnormal' in tokens or 'normal' in tokens), (
'No pathology labels found.')
assert ('train' in tokens or 'eval' in tokens), (
'No train or eval set information found.')
return {
'version': 'V3.0',
'train': 'train' in tokens,
'pathological': 'abnormal' in tokens,
}
# In[34]:
def rename_tuh_channels(ch_name):
"""Rename TUH channels and ignore non-EEG and custom channels.
Rules:
- 'Z' should always be lowercase.
- 'P' following a 'F' should be lowercase.
"""
exclude = [ # Defined by hand - do we really want to remove them?
'LOC',
'ROC',
'EKG1',
]
if 'EEG' in ch_name:
out = ch_name.replace('EEG ', '').replace('-REF', '')
out = out.replace('FP', 'Fp').replace('Z', 'z') # Apply rules
else:
out = ch_name
if out in exclude:
out = ch_name
return out
def _convert_tuh_recording_to_bids(ds, bids_save_dir, desc=None):
"""Convert single TUH recording to BIDS.
Parameters
----------
ds : braindecode.datasets.BaseDataset
TUH recording to convert to BIDS.
bids_save_dir : st
Directory where to save the BIDS version of the dataset.
desc : None | pd.Series
Description of the recording, containing subject and recording
information. If None, use `ds.description`.
"""
raw = ds.raw
raw.pick_types(eeg=True) # Only keep EEG channels
if desc is None:
desc = ds.description
# Extract reference
# XXX Not supported yet in mne-bids: see mne-bids/mne_bids/write.py::766
ref = re.findall(r'\_tcp\_(\w\w)', desc['path'])
if len(ref) != 1:
raise ValueError('Expecting one directory level with tcp in it.')
elif ref[0] == 'ar': # average reference
reference = ''
elif ref[0] == 'le': # linked ears
reference = ''
else:
raise ValueError(f'Unknown reference found in file name: {ref[0]}.')
# Rename channels to a format readable by MNE
raw.rename_channels(rename_tuh_channels)
# Ignore channels that are not in the 10-5 system
montage = mne.channels.make_standard_montage('standard_1005')
ch_names = np.intersect1d(raw.ch_names, montage.ch_names)
raw.pick_channels(ch_names)
raw.set_montage(montage)
# Make up birthday based on recording date and age to allow mne-bids to
# compute age
birthday = datetime.datetime(desc['year'] - desc['age'], desc['month'], 1)
birthday -= datetime.timedelta(weeks=4)
sex = desc['gender'].lower() # This assumes gender=sex
# Add additional data required by BIDS
mrn = str(desc['subject']).zfill(4) # MRN: Medical Record Number
session_nb = str(desc['session']).zfill(3)
subject_info = {
'participant_id': mrn,
'subject': desc['subject_orig'],
'birthday': (birthday.year, birthday.month, birthday.day),
'sex': SEX_TO_MNE[sex],
'train': desc['train'],
'pathological': desc['pathological'],
'handedness': None # Not available
}
raw.info['line_freq'] = 60. # Data was collected in North America
raw.info['subject_info'] = subject_info
task = 'rest'
bids_path = BIDSPath(
subject=mrn, session=session_nb, task=task, run=desc['segment'],
root=bids_save_dir, datatype='eeg', check=True)
write_raw_bids(raw, bids_path, overwrite=True, allow_preload=True,
format='BrainVision')
def convert_tuab_to_bids(concat_ds, bids_save_dir, healthy_only=False,
reset_session_indices=True, concat_split_files=True,
n_jobs=1):
"""Convert TUAB dataset to BIDS format.
Parameters
----------
tuh_data_dir : str
Directory where the original TUAB dataset is saved, e.g.
`/tuh_eeg/www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_abnormal/v2.0.0/edf`.
bids_save_dir : str
Directory where to save the BIDS version of the dataset.
healthy_only : bool
If True, only convert recordings with "normal" EEG.
reset_session_indices : bool
If True, reset session indices so that each subject has a session 001,
and that there is no gap between session numbers for a subject.
concat_split_files : bool
If True, concatenate recordings that were split into a single file.
This is based on the "token" field of the original TUH file paths.
n_jobs : None | int
Number of jobs for parallelization.
"""
if healthy_only:
concat_ds = concat_ds.split(by='pathological')['False']
description = concat_ds.description # Make a copy because `description` is
# made on-the-fly
if concat_split_files:
n_segments_per_session = description.groupby(
['subject', 'session'])['segment'].apply(list).apply(len)
if n_segments_per_session.unique() != np.array([1]):
raise NotImplementedError(
'Concatenation of split files is not implemented yet.')
else:
description['segment'] = '001'
if reset_session_indices:
description['session'] = description.groupby(
'subject')['session'].transform(lambda x: np.arange(len(x)) + 1)
for ds, (_, desc) in zip(concat_ds.datasets, description.iterrows()):
assert ds.description['path'] == desc['path']
_convert_tuh_recording_to_bids(
ds, bids_save_dir, desc=desc)
# In[35]:
tuh_data_dir = Path(...)
bids_save_dir = Path(...)
concat_ds = TUHAbnormal(tuh_data_dir, recording_ids=None, n_jobs=16)
subjects = concat_ds.description.subject.astype('category').cat.codes
concat_ds.set_description({'subject_orig': concat_ds.description.subject})
concat_ds.set_description({'subject': subjects}, overwrite=True)
# In[37]:
# concat_d
convert_tuab_to_bids(concat_ds, bids_save_dir=bids_save_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment