Skip to content

Instantly share code, notes, and snippets.

@gemeinl
Last active May 11, 2020 11:59
Show Gist options
  • Save gemeinl/1220413cb4418a94c48b71c5014bb390 to your computer and use it in GitHub Desktop.
Save gemeinl/1220413cb4418a94c48b71c5014bb390 to your computer and use it in GitHub Desktop.
braindecode datasets examples
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import mne\n",
"\n",
"# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n",
"subject_id = 22\n",
"event_codes = [5,6,9,10,13,14]\n",
"#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]\n",
"\n",
"# This will download the files if you don't have them yet,\n",
"# and then return the paths to the files.\n",
"physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n",
"\n",
"# Load each of the files\n",
"parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n",
" for path in physionet_paths]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use Braindecode if you have preprocessed data as X and y, where X holds your data cut to trials and y targets corresponding to the trials. In addition, you need to know the channel names of your data as well as the sampling frequency of the signals."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"sfreq = parts[0].info[\"sfreq\"]\n",
"ch_names = parts[0].info[\"ch_names\"]\n",
"X = [raw.get_data() for raw in parts]\n",
"y = event_codes"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/gemeinl/anaconda3/envs/braindecode_v2/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.metrics.scorer module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.metrics. Anything that cannot be imported from sklearn.metrics is now part of the private API.\n",
" warnings.warn(message, FutureWarning)\n"
]
}
],
"source": [
"import mne\n",
"import pandas as pd\n",
"\n",
"from braindecode.datasets.base import BaseDataset, BaseConcatDataset\n",
"from braindecode.datautil.windowers import create_fixed_length_windows"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- First, create a mne.RawArray for all x in X also using the sampling frequency and channel names for mne.Info. \n",
"- Second, Transform the mne.RawArrays to BaseDatasets. Enter the targets per trial into the pandas.Series object and set target_name corresponding to the column where the targets are located in the pandas.Series. \n",
"- Third, concatenate the BaseDatasets."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n"
]
}
],
"source": [
"base_datasets = []\n",
"for x, target in zip(X, y):\n",
" info = mne.create_info(ch_names=ch_names, sfreq=sfreq)\n",
" raw = mne.io.RawArray(x, info)\n",
" base_dataset = BaseDataset(raw, pd.Series({\"target\": target}), target_name=\"target\")\n",
" base_datasets.append(base_dataset)\n",
"base_datasets = BaseConcatDataset(base_datasets)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use create_fixed_lenth_windows to do the supercrop computation for you which is needed for decoding with skorch and braindecode. Set the parameters, such that there fits a single supercrop in your pre-cut trials."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n"
]
}
],
"source": [
"windows_datasets = create_fixed_length_windows(\n",
" base_datasets,\n",
" start_offset_samples=0,\n",
" stop_offset_samples=0,\n",
" supercrop_size_samples=x.shape[1],\n",
" supercrop_stride_samples=x.shape[1],\n",
" drop_samples=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data for 1 events and 20000 original time points ...\n",
"(64, 20000) 5 [0, 0, 20000]\n"
]
}
],
"source": [
"for x, y, ind in windows_datasets:\n",
" break\n",
"print(x.shape, y, ind)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/gemeinl/anaconda3/envs/braindecode_v2/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.metrics.scorer module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.metrics. Anything that cannot be imported from sklearn.metrics is now part of the private API.\n",
" warnings.warn(message, FutureWarning)\n"
]
}
],
"source": [
"# TODO: add this function to braindecode\n",
"import mne\n",
"import pandas as pd\n",
"\n",
"from braindecode.datasets.base import BaseDataset, BaseConcatDataset\n",
"from braindecode.datautil.windowers import create_fixed_length_windows\n",
"\n",
"def create_from_X_y(X, y, sfreq, ch_names, supercrop_size_samples=None, supercrop_stride_samples=None, drop_samples=False):\n",
" base_datasets = []\n",
" for x, target in zip(X, y):\n",
" info = mne.create_info(ch_names=ch_names, sfreq=sfreq)\n",
" raw = mne.io.RawArray(x, info)\n",
" base_dataset = BaseDataset(raw, pd.Series({\"target\": target}), target_name=\"target\")\n",
" base_datasets.append(base_dataset)\n",
" base_datasets = BaseConcatDataset(base_datasets)\n",
" \n",
" if supercrop_size_samples is None:\n",
" supercrop_size_samples = x.shape[1]\n",
" if supercrop_stride_samples is None:\n",
" supercrop_stride_samples = x.shape[1]\n",
" windows_datasets = create_fixed_length_windows(\n",
" base_datasets,\n",
" start_offset_samples=0,\n",
" stop_offset_samples=0,\n",
" supercrop_size_samples=supercrop_size_samples,\n",
" supercrop_stride_samples=supercrop_stride_samples,\n",
" drop_samples=drop_samples\n",
" )\n",
" return windows_datasets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"Creating RawArray with float64 data, n_channels=64, n_times=20000\n",
" Range : 0 ... 19999 = 0.000 ... 124.994 secs\n",
"Ready.\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n",
"1 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 1 events and 20000 original time points ...\n",
"0 bad epochs dropped\n"
]
}
],
"source": [
"windows_dataset = create_from_X_y(X, y, sfreq, ch_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "braindecode_v2",
"language": "python",
"name": "braindecode_v2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment