Skip to content

Instantly share code, notes, and snippets.

@gemeinl
Last active May 11, 2020 11:59
Show Gist options
  • Save gemeinl/1220413cb4418a94c48b71c5014bb390 to your computer and use it in GitHub Desktop.
Save gemeinl/1220413cb4418a94c48b71c5014bb390 to your computer and use it in GitHub Desktop.
braindecode datasets examples
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, \"/home/gemeinl/code/braindecode/\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import mne"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating RawArray with float64 data, n_channels=50, n_times=500\n",
" Range : 0 ... 499 = 0.000 ... 4.990 secs\n",
"Ready.\n",
"Used Annotations descriptions: ['test_trial']\n",
"10 matching events found\n",
"No baseline correction applied\n",
"Not setting metadata\n",
"0 projection items activated\n",
"Loading data for 10 events and 20 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=500\n",
" Range : 0 ... 499 = 0.000 ... 4.990 secs\n",
"Ready.\n",
"Used Annotations descriptions: ['test_trial']\n",
"10 matching events found\n",
"No baseline correction applied\n",
"Not setting metadata\n",
"0 projection items activated\n",
"Loading data for 10 events and 10 original time points ...\n",
"0 bad epochs dropped\n"
]
}
],
"source": [
"import numpy as np\n",
"n_channels = 50\n",
"n_times = 500\n",
"sfreq = 100\n",
"rng = np.random.RandomState(34834)\n",
"all_epochs = []\n",
"datas = []\n",
"for i_raw in range(2):\n",
" data = rng.rand(n_channels, n_times)\n",
" ch_names = [f'ch{i}' for i in range(n_channels)]\n",
" ch_types = ['eeg'] * n_channels\n",
" info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)\n",
" raw = mne.io.RawArray(data, info)\n",
"\n",
" n_anns = 10\n",
" inds = np.linspace(0,n_times,n_anns,endpoint=False).astype(int)\n",
" onsets = raw.times[inds]\n",
" if i_raw == 0:\n",
" trial_dur = 0.2 # in sec\n",
" else:\n",
" trial_dur = 0.1\n",
" durations = np.ones(n_anns) * trial_dur\n",
" descriptions = ['test_trial'] * len(durations)\n",
" anns = mne.Annotations(onsets, durations, descriptions)\n",
" raw = raw.set_annotations(anns)\n",
" events, event_id = mne.events_from_annotations(raw,)\n",
" epochs = mne.Epochs(raw, events, event_id=event_id, preload=True,\n",
" baseline=None,\n",
" tmin=0, tmax=trial_dur - 1e-2)\n",
" all_epochs.append(epochs)\n",
" datas.append(data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This now creates windows that pass the tests. \n",
"It has duplicate code with create_fixed_length_windows and can probably improved"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from braindecode.datasets.base import WindowsDataset, BaseConcatDataset\n",
"from braindecode.datautil.windowers import _check_windowing_arguments\n",
"\n",
"\n",
"def create_from_mne_epochs(list_of_epochs, supercrop_size_samples,\n",
" supercrop_stride_samples, drop_samples):\n",
" \"\"\"Create WindowsDatasets from mne.Epochs\n",
"\n",
" Parameters\n",
" ----------\n",
" list_of_epochs: array-like\n",
" list of mne.Epochs\n",
" supercrop_size_samples: int\n",
" supercrop size\n",
" supercrop_stride_samples: int\n",
" stride between supercrops\n",
" drop_samples: bool\n",
" whether or not have a last overlapping supercrop/window, when\n",
" supercrops/windows do not equally divide the continuous signal\n",
"\n",
" Returns\n",
" -------\n",
" windows_datasets: BaseConcatDataset\n",
" X and y transformed to a dataset format that is compativle with skorch\n",
" and braindecode\n",
" \"\"\"\n",
" _check_windowing_arguments(0, 0, supercrop_size_samples,\n",
" supercrop_stride_samples)\n",
"\n",
" list_of_windows_ds = []\n",
" for epochs in list_of_epochs:\n",
" event_descriptions = epochs.events[:, 2]\n",
" original_trial_starts = epochs.events[:, 0]\n",
" stop = len(epochs.times) - supercrop_size_samples\n",
"\n",
" # already includes last incomplete supercrop start\n",
" starts = np.arange(0, stop + 1, supercrop_stride_samples)\n",
"\n",
" if not drop_samples and starts[-1] < stop:\n",
" # if last supercrop does not end at trial stop, make it stop there\n",
" starts = np.append(starts, stop)\n",
"\n",
" fake_events = [[start, supercrop_size_samples, -1] for start in\n",
" starts]\n",
"\n",
" for trial_i, trial in enumerate(epochs):\n",
" metadata = pd.DataFrame({\n",
" 'i_supercrop_in_trial': np.arange(len(fake_events)),\n",
" 'i_start_in_trial': starts + original_trial_starts[trial_i],\n",
" 'i_stop_in_trial': starts + original_trial_starts[\n",
" trial_i] + supercrop_size_samples,\n",
" 'target': len(fake_events) * [event_descriptions[trial_i]]\n",
" })\n",
" # supercrop size - 1, since tmax is inclusive\n",
" mne_epochs = mne.Epochs(\n",
" mne.io.RawArray(trial, epochs.info), fake_events,\n",
" baseline=None,\n",
" tmin=0,\n",
" tmax=(supercrop_size_samples - 1) / epochs.info[\"sfreq\"],\n",
" metadata=metadata)\n",
"\n",
" mne_epochs.drop_bad(reject=None, flat=None)\n",
"\n",
" windows_ds = WindowsDataset(mne_epochs)\n",
" list_of_windows_ds.append(windows_ds)\n",
"\n",
" return BaseConcatDataset(list_of_windows_ds)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=20\n",
" Range : 0 ... 19 = 0.000 ... 0.190 secs\n",
"Ready.\n",
"9 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 9 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n",
"Creating RawArray with float64 data, n_channels=50, n_times=10\n",
" Range : 0 ... 9 = 0.000 ... 0.090 secs\n",
"Ready.\n",
"4 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 4 events and 5 original time points ...\n",
"0 bad epochs dropped\n"
]
}
],
"source": [
"windows = create_from_mne_epochs(all_epochs, supercrop_size_samples=5, supercrop_stride_samples=2, drop_samples=False)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n",
"Loading data for 1 events and 5 original time points ...\n"
]
}
],
"source": [
"# windows per trial: 0-5,2-7,4-9,6-11,...,14-19,15-20\n",
"# and then: 0-5,2-7,4-9,5-10\n",
"assert len(windows) == 9 * n_anns + 4 * n_anns\n",
"for i_w, (x,y,(i_w_in_t, i_start, i_stop)) in enumerate(windows):\n",
" if i_w < 9 * n_anns:\n",
" assert i_w_in_t == i_w % 9\n",
" i_t = i_w // 9\n",
" assert i_start == inds[i_t] + i_w_in_t * 2 - (i_w_in_t == 8)\n",
" assert i_stop == inds[i_t] + i_w_in_t * 2 - (i_w_in_t == 8) + 5\n",
" np.testing.assert_allclose(x, datas[0][:,i_start:i_stop],\n",
" atol=1e-5, rtol=1e-5)\n",
" else:\n",
" assert i_w_in_t == (i_w - n_anns*9) % 4\n",
" i_t = ((i_w - n_anns*9) // 4)\n",
" assert i_start == inds[i_t] + i_w_in_t * 2 - (i_w_in_t == 3)\n",
" assert i_stop == inds[i_t] + i_w_in_t * 2 - (i_w_in_t == 3) + 5"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "braindecode_v2",
"language": "python",
"name": "braindecode_v2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment