Created
May 19, 2024 09:28
-
-
Save OverLordGoldDragon/b6709fd266929f90c7979fb1d5635c4b to your computer and use it in GitHub Desktop.
PhysioNet, edf -> model(x, y)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Minimal-code rewriting of BrainDecode's | |
https://braindecode.org/stable/auto_examples/applied_examples/plot_sleep_staging_chambon2018.html | |
https://github.com/braindecode/braindecode/blob/master/examples/applied_examples/plot_sleep_staging_chambon2018.py | |
It | |
- Eliminates all (explicit) dependence on `braindecode` | |
- Minimizes dependence on `mne` | |
My motivation was, understandable code that decodes PhysioNet into `(x, y)` | |
that can be fed to a model. I found the original example too high-level, and | |
source code too dense to manipulate for my purposes. | |
Rewritten code is | |
- heavily cut down in total number of lines | |
- tested to completely reproduce the original BrainDecode script's results | |
(note, requires `torch.use_deterministic_algorithms(True)`). | |
Caveat: that's with a preprocessing step (lowpass filtering) omitted, | |
just comment that out in original source (or reimplement it here). | |
- includes some fixes and extensions of functionality | |
- commented, explaining/justifying omitting of code with respect to | |
`braindecode` source code (meant to be read alongside) | |
- split into "helpers" (function definitions) and "execution"; former | |
could be moved into its own file and imported | |
- a rewriting of v0.8.0, | |
https://github.com/braindecode/braindecode/tree/v0.8 | |
Tip: download PhysioNet from Google Cloud bucket, then set `data_loaddir`, it | |
was x10 faster for me than downloading directly from physionet.org (what the | |
original script does). | |
Disclaimer, while it was free for me, it says it's paid, so I can't guarantee | |
that. Instructions at https://physionet.org/content/sleep-edfx/1.0.0/ | |
Note: example uses different `subject_ids` and `recording_ids`. | |
""" | |
# ############################################################################ | |
# User config | |
# ----------- | |
# where PhysioNet source data is already stored, if exists; if None, will install | |
# from scratch | |
data_loaddir = None | |
# where to save processed data to; defaults to current working directory | |
data_savedir = None | |
# subject and recording IDs | |
subject_ids = [0, 1, 2, 3] | |
recording_ids = [1, 2] | |
# use GPU if available | |
use_gpu = True | |
#%% ########################################################################## | |
# Imports | |
# ------- | |
# should set env var before running other imports | |
import os | |
if data_loaddir is not None: | |
os.environ['PHYSIONET_SLEEP_PATH'] = data_loaddir | |
import random | |
import bisect | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
# torch.use_deterministic_algorithms(True) | |
import mne | |
from mne.datasets.sleep_physionet.age import fetch_data | |
from sklearn.preprocessing import scale as standard_scale | |
############################################################################## | |
# HELPER FUNCTIONS | |
# **************** | |
#%% ########################################################################## | |
# Loading & saving data | |
# --------------------- | |
# Files are named e.g. | |
# | |
# SC4001E0-PSG.edf | |
# 01234567 | |
# | |
# where | |
# | |
# `34` = `subject_id` (i.e. 0) | |
# `5` = `recording_id` (i.e. 1) | |
# | |
# and, | |
# | |
# SC4001EC-Hypnogram.edf | |
# | |
# where | |
# | |
# `PSG` = data | |
# `Hypnogram` = metadata (including labels) | |
# | |
# Notes: | |
# | |
# - "Data" includes stuff other than EEG, exclude via `exclude_chs`. | |
# - `p = paths[0]` is a pair of paths, where `p[0]` = PSG, `p[1]` = Hypnogram. | |
# | |
# define a function to reuse later | |
def process_path(p, exclude_chs, data_savedir): | |
raw, annots, subj_num, sess_num = _load_data(p, exclude_chs) | |
raw = _drop_unlabeled_data(raw, annots) | |
raw = _preprocess(raw) | |
# by reading further code (window processing), we determine that we need | |
# to keep only the following: | |
# | |
# for `raw`: | |
# `raw._data` | |
# for `annotations` (== `raw.annotations`): | |
# `annotations.onset` | |
# `annotations.description` | |
# `annotations.duration` | |
# | |
data = { | |
'data': raw._data, | |
'onset': raw.annotations.onset, | |
'description': raw.annotations.description, | |
'duration': raw.annotations.duration, | |
'_first_time': raw._first_time, # for `braindecode_ver = True` | |
} | |
savename = os.path.basename(p[0]).replace('-PSG', '').replace('.edf', '.npz') | |
savepath = os.path.join(data_savedir, savename) | |
np.savez(savepath, **data) | |
return data, savepath | |
def _load_data(p, exclude_chs): | |
# load data & labels ----------------------------------------------------- | |
p_data, p_meta = p | |
# (`preload` to load the data into RAM) | |
raw = mne.io.read_raw_edf(p_data, preload=True, exclude=exclude_chs) | |
annots = mne.read_annotations(p_meta) | |
# Get subject and recording number --------------------------------------- | |
basename = os.path.basename(p_data) | |
subj_num = int(basename[3:5]) | |
sess_num = int(basename[5]) | |
return raw, annots, subj_num, sess_num | |
def _drop_unlabeled_data(raw, annots): | |
# set `raw.annotations` from `annots` ------------------------------------ | |
# - each `a = annots[0]` has `a['onset']` and `a['duration']` | |
# - `a` are cropped such that they remain within (inclusive) `tmin = 0` and | |
# `tmax = raw.times[-1] + 1 / raw.info['sfreq']`, where sfreq = sampling | |
# freq = 100 Hz. "Cropped" means their `'onset'` and `'duration'` are | |
# adjusted. | |
raw.set_annotations(annots, emit_warning=False) | |
# crop data to exclude unlabeled segments -------------------------------- | |
braindecode_ver = True | |
if not braindecode_ver: | |
# "labels" are over data's time segments. E.g. `x[:20000]` can be | |
# "Sleep stage W" (wake), and `x[20000:25000]` be "Sleep stage 1", etc. | |
# "Sleep stage ?" is unlabeled. | |
mask = [x[-1] != '?' for x in annots.description] | |
sleep_event_inds = np.where(mask)[0] | |
# above assumes there's only one such sleep stage, and that it's the last | |
# one; check both assumptions | |
assert mask.count(False) == 1, mask | |
assert not mask[-1], mask | |
else: | |
# see `not braindecode_ver` comments | |
mask = [ | |
x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description] | |
sleep_event_inds = np.where(mask)[0] | |
# Crop raw (also crops labels) | |
# determine `tmax` as first timestamp of last stage before | |
# "Sleep stage ?", plus the duration of that stage, minus one sample (`dT`) | |
# (otherwise we end up at first timestamp of "Sleep stage ?", and | |
# `crop` is inclusive on `tmax`). | |
dT = 1 / raw.info["sfreq"] | |
a_tmin = annots[sleep_event_inds[0]] | |
a_tmax = annots[sleep_event_inds[-1]] | |
if not braindecode_ver: | |
tmin = a_tmin['onset'] | |
tmax = a_tmax['onset'] + a_tmax['duration'] - dT | |
else: | |
crop_wake_mins = 30 | |
tmin = a_tmin['onset'] - crop_wake_mins * 60 | |
tmax = a_tmax['onset'] + a_tmax['duration'] - dT + crop_wake_mins * 60 | |
# internally, this converts `tmin`, `tmax` to indices, and does something | |
# like `x = x[tmin:tmax]`. | |
# it also correspondingly crops labels, via `set_annotations`. | |
# Internals (for labels): | |
# - `tmin` for `set_annotations` is set from `_first_time` (and another | |
# var we can treat as zero). | |
# This is updated in `crop` from the `tmin` we specify (see `_first_time` | |
# definition as `@property`). | |
# - `tmax` for `set_annotations` is set from `times[-1]` (+ dT). | |
# This is updated in `crop` from the `tmax` we specify (see `times` | |
# definition as `@property`). | |
# - This drops annotations that are completely out of range of `tmin`, | |
# `tmax`. Here, it amounts to dropping the annotation for "Sleep stage ?". | |
raw.crop(tmin=max(tmin, raw.times[0]), | |
tmax=min(tmax, raw.times[-1])) | |
# Rename EEG channels ---------------------------------------------------- | |
raw.rename_channels({nm: nm.replace('EEG ', '') for nm in raw.ch_names}) | |
return raw | |
def _preprocess(raw): | |
# internally, this justs updates `raw._data`, with | |
# - handling for multiprocessing | |
# - checking that out.shape == in.shape | |
# - loading `_data` if it isn't (for us, it is, via `preload=True`) | |
# V -> uV | |
raw._data = raw._data * 1e6 | |
return raw | |
#%% ########################################################################## | |
# Creating windows | |
# ---------------- | |
def _create_windows_from_events(p, mapping, sfreq): | |
# load data | |
d = np.load(p) | |
# `a_` for `annotations_` | |
data, a_description, a_onset, a_duration, _first_time = [ | |
d[nm] for nm in | |
('data', 'description', 'onset', 'duration', '_first_time') | |
] | |
events = _events_from_annotations( | |
a_description, a_onset, _first_time, mapping, sfreq) | |
# the lib method also returns `event_id_`, which is redundant for us | |
# - (it is a copy of `mapping`, unless `mapping` has some keys that | |
# `annotations.description` doesn't, but we then only use `event_id_` | |
# to check whether what's in `annotations.description` is in `event_id_`, | |
# which is circular and same as checking directly against `mapping`) | |
onsets = events[:, 0] | |
# Onsets are relative to the beginning of the recording | |
filtered_durations = np.array([ | |
dur for dur, desc in zip(a_duration, a_description) | |
if desc in mapping | |
]) | |
stops = onsets + (filtered_durations * sfreq).astype(int) | |
# sanity check; note, `stops` is used exclusively, i.e. `start:stop` | |
# don't need the `raw.first_samp` in `raw.first_samp + raw.n_times` since we | |
# commented out `+= raw.first_samp` (upon `onsets`, in | |
# `_events_from_annotations`) earlier; and, | |
# `raw.n_times == raw._data.shape[-1]` | |
# (i.e. the full assert is `stops[-1] <= raw.first_samp + raw.n_times`) | |
assert stops[-1] <= data.shape[-1] | |
# this no longer executes since we commented out `+= raw.first_samp` earlier | |
# onsets = onsets - raw.first_samp | |
# stops = stops - raw.first_samp | |
# generate windows | |
window_size_samples = 3000 | |
window_stride_samples = 3000 | |
drop_last_window = False | |
i_trials, starts, stops = _compute_window_inds( | |
onsets, stops, window_size_samples, window_stride_samples, | |
drop_last_window) | |
# generate window events | |
description = events[:, -1] | |
# events = [[start, window_size_samples, description[i_trial]] | |
# for start, i_trial in zip(starts, i_trials)] | |
# events = np.array(events) | |
# description_windows = events[:, -1] | |
description_windows = np.array([description[i_trial] for i_trial in i_trials]) | |
windows_ds = WindowsDataset( | |
data, | |
target=description_windows, | |
i_start_in_trial=starts, | |
i_stop_in_trial=stops | |
) | |
return windows_ds | |
def _events_from_annotations(description, onset, _first_time, mapping, fs): | |
# minimally implements `_select_annotations_based_on_description` -------- | |
event_sel = [i for i, d in enumerate(description) if d in mapping] | |
# Convert onsets to sample indices. Internals: =========================== | |
# - `annotations.onset` = timestamps, in seconds, of start times of labels | |
# (where each "label", again, is over a time interval over data, e.g. | |
# `x[20000:25000]`). | |
# - `len(annotations.onset) == len(labels)`. | |
# - `annotations.orig_time` = `raw.info["meas_date"]`, as long as | |
# `annotations = mne.read_annotations(path)` is used. | |
# - `raw.info["meas_date"]` = measurement date, a `datetime` object created | |
# from metadata in the (-PSG) EDF file. | |
# minimally implements --------------------------------------------------- | |
# | |
# inds = raw.time_as_index(times=annotations.onset, use_rounding=True, | |
# origin=annotations.orig_time) | |
# | |
# we won't end up needing `origin` (see below), so don't fetch it. | |
# origin = annotations.orig_time | |
times = onset | |
# Internals: | |
# - `self._first_time = self.first_samp / self.info['sfreq']` | |
# - `self.first_samp = self._cropped_samp` | |
# - `self._cropped_samp = first_samps[0]` if `raw.crop()` wasn't used | |
# with `tmin != 0`, else it's modified (in this case to | |
# `tmin * self.info['sfreq']`) | |
# - `first_samps = (0,)` | |
# - (`braindecode_ver = False` only) Hence, `raw._first_time == 0`, and | |
# since `origin == raw.info["meas_date"]` (see above), `delta == 0`, | |
# so we can skip all below code (and `times` is already 1d) | |
# Since `braindecode_ver = True` is supported, execute the relevant portion | |
# in `raw.time_as_index`, but rewritten: | |
# | |
# `origin - first_samp_in_abs_time` | |
# <=> | |
# `raw.info["meas_date"] - (raw.info["meas_date"] + raw._first_time)` | |
# <=> | |
# `- raw._first_time` | |
delta = - _first_time | |
times += delta | |
# `raw.times[0]` is always zero in our case (RawEDF loaded the way we have) | |
# index = (np.atleast_1d(times) - raw.times[0]) * fs | |
index = times * fs | |
inds = np.round(index).astype(int) | |
# ======================================================================== | |
# Executes if `if annotations.orig_time is not None:`, which is the case here. | |
# Do not execute this, so we don't have to `-= raw.first_samp` later, so | |
# we don't have to store `raw.first_samp` anywhere | |
# inds += raw.first_samp | |
# `annotations.description` -> numeric values based on `mapping`. E.g. if | |
# | |
# mapping == {'Sleep stage W': 0, 'Sleep stage 1': 1}` | |
# annotations.description == ['Sleep stage 1', 'Sleep stage W', | |
# 'Sleep stage W'] | |
# | |
# then | |
# | |
# values == [1, 0, 0] | |
# | |
# but ignoring `event_sel` (indices of selected labels based on whether they | |
# were in `mapping`). | |
# | |
values = [mapping[kk] for kk in description[event_sel]] | |
# Apply `event_sel` to `inds` | |
inds = inds[event_sel] | |
# This simply concatenates the arrays into `(n_events, 3)`, and casts to int | |
events = np.c_[inds, np.zeros(len(inds)), values].astype(int) | |
return events | |
def _compute_window_inds(starts, stops, size, stride, drop_last_window): | |
assert not any(size > (stops-starts)) | |
i_trials, window_starts = [], [] | |
for start_i, (start, stop) in enumerate(zip(starts, stops)): | |
# Generate possible window starts, with given stride, between | |
# starts and stops (i.e. original trial onsets and stops, shifted by | |
# start_offset and stop_offset, respectively) | |
possible_starts = np.arange(start, stop, stride) | |
# Possible window start is actually a start, if window size fits in | |
# trial start and stop | |
for i_window, s in enumerate(possible_starts): | |
if (s + size) <= stop: | |
window_starts.append(s) | |
i_trials.append(start_i) | |
# If the last window start + window size is not the same as | |
# stop + stop_offset, create another window that overlaps and stops | |
# at onset + stop_offset | |
if not drop_last_window: | |
if window_starts[-1] + size != stop: | |
window_starts.append(stop - size) | |
i_trials.append(start_i) | |
# Set window stops to be event stops (rather than trial stops) | |
window_stops = np.array(window_starts) + size | |
assert len(i_trials) == len(window_starts) == len(window_stops) | |
return i_trials, window_starts, window_stops | |
def _preprocess_windows(windows_ds): | |
windows_ds.data = standard_scale(windows_ds.data, axis=1) | |
windows_ds.data = windows_ds.data.copy().astype('float32') | |
class WindowsDataset(): | |
def __init__(self, data, target, i_start_in_trial, i_stop_in_trial): | |
self.data = data | |
self.y = np.asarray(target, dtype='int64') # skorch expects int64 | |
self.i_start_in_trial = i_start_in_trial | |
self.inds = np.c_[i_start_in_trial, i_stop_in_trial] | |
def __getitem__(self, index): | |
i_start, i_end = self.inds[index] | |
X = self.data[:, i_start:i_end] | |
y = self.y[index] | |
return X, y | |
def __len__(self): | |
return len(self.y) | |
#%% ########################################################################## | |
# Creating dataset objects | |
# ------------------------ | |
class Sampler(): | |
def __init__(self, i_start_in_trial_all, n_windows, n_windows_stride, | |
randomize=False): | |
self.n_windows = n_windows | |
self.n_windows_stride = n_windows_stride | |
self.randomize = randomize | |
# braindecode applies `groupby` by `'subject'`, `'recording'` upon | |
# dataframe, then resets index, and operates on said indices below, | |
# meaning we first generate the indices within each `i_start_in_trial_all` | |
# (where "each" refers to `'subject'` and `'recording'` operate on those, | |
# then concatenate | |
idxs_all = [list(range(len(i_start_in_trial_all[0])))] | |
for length in map(len, i_start_in_trial_all[1:]): | |
idx_last = idxs_all[-1][-1] | |
idxs_all.append(list(range(idx_last + 1, idx_last + 1 + length))) | |
end_offset = 1 - n_windows | |
self.start_inds = np.concatenate( | |
[idxs[:end_offset:self.n_windows_stride] for idxs in idxs_all] | |
) | |
def __len__(self): | |
return len(self.start_inds) | |
def __iter__(self): | |
if self.randomize: | |
start_inds = np.random.permutation(self.start_inds) | |
else: | |
start_inds = self.start_inds | |
for start_ind in start_inds: | |
yield tuple(range(start_ind, start_ind + self.n_windows)) | |
class ConcatDataset(torch.utils.data.Dataset): | |
# Merges `BaseConcatDataset` and `ConcatDataset` classes. | |
# `skorch` requires this to be an instance of `Dataset` | |
def __init__(self, datasets, target_transform=None): | |
self.datasets = datasets | |
# for script readability, we assign this later | |
self.target_transform = (target_transform | |
if target_transform is not None else | |
lambda x: x) | |
self.cumulative_sizes = np.cumsum(list(map(len, self.datasets))) | |
def __getitem__(self, idxs): | |
""" | |
idxs : tuple / list | |
Indices of windows and targets to return (concatenated). | |
The target output can be modified on the fly by the | |
``target_transform`` parameter. | |
""" | |
X, y = [], [] | |
for idx in idxs: | |
out_i = self._getitem(idx) | |
X.append(out_i[0]) | |
y.append(out_i[1]) | |
X = np.stack(X, axis=0) | |
y = self.target_transform(np.array(y)) | |
return X, y | |
def _getitem(self, idx): | |
assert idx >= 0 and idx < len(self) | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
sample_idx = (idx if dataset_idx == 0 else | |
idx - self.cumulative_sizes[dataset_idx - 1]) | |
return self.datasets[dataset_idx][sample_idx] | |
def __len__(self): | |
return self.cumulative_sizes[-1] | |
#%% ########################################################################## | |
# Creating model | |
# -------------- | |
class SleepStagerChambon2018(nn.Module): | |
"""Feature extractor only.""" | |
def __init__(self, n_chans, n_outputs, n_times, sfreq, | |
n_conv_chs=8, apply_batch_norm=False): | |
super().__init__() | |
self.n_chans = n_chans | |
self.n_outputs = n_outputs | |
self.n_times = n_times | |
self.n_conv_chs = n_conv_chs | |
self.apply_batch_norm = apply_batch_norm | |
assert self.n_chans > 1 | |
# handle params | |
time_conv_size_s = 0.5 | |
max_pool_size_s = time_conv_size_s / 4 | |
pad_size_s = time_conv_size_s / 2 | |
time_conv_size, max_pool_size, pad_size = ( | |
int(np.ceil(x * sfreq)) for x in | |
(time_conv_size_s, max_pool_size_s, pad_size_s) | |
) | |
# handle certain layers | |
self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1)) | |
batch_norm = (nn.BatchNorm2d if apply_batch_norm else | |
nn.Identity) | |
# make feature extractor | |
self.feature_extractor = nn.Sequential( | |
nn.Conv2d( | |
1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)), | |
batch_norm(n_conv_chs), | |
nn.ReLU(), | |
nn.MaxPool2d((1, max_pool_size)), | |
nn.Conv2d( | |
n_conv_chs, n_conv_chs, (1, time_conv_size), | |
padding=(0, pad_size)), | |
batch_norm(n_conv_chs), | |
nn.ReLU(), | |
nn.MaxPool2d((1, max_pool_size)), | |
nn.Flatten(), | |
) | |
# length of last layer (for later) | |
with torch.no_grad(): | |
self.len_last_layer = len(self.feature_extractor( | |
torch.Tensor(1, 1, self.n_chans, self.n_times) | |
).flatten()) | |
def forward(self, x): | |
"""x: batch of EEG windows of shape (batch_size, n_channels, n_times)""" | |
x = x.unsqueeze(1) | |
x = self.spatial_conv(x) | |
x = x.transpose(1, 2) | |
return self.feature_extractor(x) | |
class TimeDistributed(nn.Module): | |
"""Apply module on a sequence of windows and return their concatenation | |
(see `forward`): | |
`(batch_size, seq_len, n_channels, n_times)` -> | |
`(batch_size, seq_len, output_size)` | |
Useful with sequence-to-prediction models (e.g. sleep stager which must map | |
a sequence of consecutive windows to the label of the middle window in the | |
sequence). | |
""" | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, x): | |
""" | |
x: sequence of windows of shape | |
(batch_size, seq_len, n_channels, n_times) | |
Returns output of shape | |
(batch_size, seq_len, output_size) | |
""" | |
b, s, c, t = x.shape | |
out = self.module(x.view(b * s, c, t)) | |
return out.view(b, s, -1) | |
#%% ########################################################################## | |
# Creating train loop | |
# ------------------- | |
from skorch.helper import predefined_split | |
from skorch.classifier import NeuralNetClassifier | |
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog | |
from skorch.utils import noop, train_loss_score, valid_loss_score | |
class EEGClassifier(NeuralNetClassifier): | |
""" | |
Is `NeuralNetClassifier`, with `_default_callbacks` overridden. | |
All arguments are passed straight into `NeuralNetClassifier`. | |
Note, `NeuralNetClassifier` doesn't assume softmax activation and calls | |
the loss function directly (without applying e.g. log). | |
Parameter note: `iterator_train__shuffle` (default True) defines whether | |
train dataset will be shuffled. As `skorch` does not shuffle the train | |
dataset by default, this one overwrites this option. | |
""" | |
def __init__( | |
self, | |
module, | |
criterion=None, | |
callbacks=None, | |
iterator_train__shuffle=True, | |
iterator_train__drop_last=True, | |
**kwargs | |
): | |
super().__init__( | |
module, | |
criterion=criterion, | |
callbacks=callbacks, | |
iterator_train__shuffle=iterator_train__shuffle, | |
iterator_train__drop_last=iterator_train__drop_last, | |
**kwargs, | |
) | |
@property | |
def _default_callbacks(self): | |
return [ | |
("epoch_timer", EpochTimer()), | |
( | |
"train_loss", | |
BatchScoring( | |
train_loss_score, | |
name="train_loss", | |
on_train=True, | |
target_extractor=noop, | |
), | |
), | |
( | |
"valid_loss", | |
BatchScoring( | |
valid_loss_score, name="valid_loss", target_extractor=noop, | |
), | |
), | |
("print_log", PrintLog()), | |
( | |
"valid_acc", | |
EpochScoring( | |
"accuracy", | |
name="valid_acc", | |
lower_is_better=False, | |
) | |
) | |
] | |
# `skorch` default is | |
# return [ | |
# ('epoch_timer', EpochTimer()), | |
# ('train_loss', PassthroughScoring( | |
# name='train_loss', | |
# on_train=True, | |
# )), | |
# ('valid_loss', PassthroughScoring( | |
# name='valid_loss', | |
# )), | |
# ('print_log', PrintLog()), | |
# ] | |
# this unites `_EEGNeuralNet._default_callbacks` and | |
# `EEGClassifier._default_callbacks` (latter excluding the fact that it | |
# appends to former) | |
# Excluded inherited classes explanation --------------------------------- | |
# _EEGNeuralNet: | |
# Running original script with and without this class changed nothing. | |
# Inspecting the code, it appears to handle certain configurations that | |
# aren't used here (e.g. "cropping"). | |
# Excluded arguments explanation ----------------------------------------- | |
# aggregate_predictions: | |
# Was only used in `predict_proba`, which was dropped | |
# Excluded methods explanation ------------------------------------------- | |
# get_iterator: | |
# only does something via `ThrowAwayIndexLoader`, which only does | |
# something if iterator returns x.ndim==3, which doesn't happen here | |
# predict_proba: | |
# only does something if `cropped=True`, which isn't the case here | |
# get_loss: | |
# only does something if `isinstance(self.criterion_, torch.nn.NLLLoss)`, | |
# which isn't the case here | |
# predict: | |
# completely identical to inherited class's definition, is likely | |
# redefined for docs clarity or future changes | |
# predict_trials: | |
# meant to be used with `cropped=True`, which isn't the case here | |
# _get_n_outputs: | |
# unused in `clf.fit()`, checked by inserting `1/0` here | |
# Excluded attributes explanation ---------------------------------------- | |
# _last_window_inds_: | |
# TL;DR unused | |
#%% ########################################################################## | |
# EXECUTION | |
# ********* | |
#%% ########################################################################## | |
# Convert and save data as numpy arrays | |
# ------------------------------------- | |
# handle configs | |
if data_savedir is None: | |
data_savedir = os.getcwd() | |
# set excluded channels | |
exclude_chs = ('EOG horizontal', 'Resp oro-nasal', 'EMG submental', | |
'Temp rectal', 'Event marker') | |
# merge stages 3 and 4 following AASM standards | |
mapping = { | |
'Sleep stage W': 0, | |
'Sleep stage 1': 1, | |
'Sleep stage 2': 2, | |
'Sleep stage 3': 3, | |
'Sleep stage 4': 3, | |
'Sleep stage R': 4 | |
} | |
# sampling freq, obtained via `raw.info['sfreq']` | |
# (hard-coded here for cleaner code) | |
sfreq = 100 | |
# Fetch paths, generate ids | |
paths = fetch_data(subject_ids, recording=recording_ids, on_missing='warn') | |
# For rest of this script, generalize original script's example to any | |
# `subject_ids` and `recording_ids` by not assuming there's two of former and | |
# one of latter. | |
# Below for-loop ordering follows that of `fetch_data`. | |
ids = [(sid, rid) for sid in subject_ids for rid in recording_ids] | |
# Map ids to paths | |
ipaths = {id_: p for id_, p in zip(ids, paths)} | |
# Store processed data paths | |
isavepaths = {} | |
for _id, p in ipaths.items(): | |
raw, psave = process_path(p, exclude_chs, data_savedir) | |
isavepaths[_id] = psave | |
#%% ########################################################################## | |
# Create windows | |
# -------------- | |
iwindows_ds = {} | |
for id_, p in isavepaths.items(): | |
windows_ds = _create_windows_from_events(p, mapping, sfreq) | |
_preprocess_windows(windows_ds) | |
iwindows_ds[id_] = windows_ds | |
#%% ########################################################################## | |
# Create dataset objects, make train-test split | |
# --------------------------------------------- | |
# split by subject, so train and validation have different subjects | |
split_ids = dict( | |
train=[_id for _id in ids if _id[0] in subject_ids[::2]], | |
valid=[_id for _id in ids if _id[0] in subject_ids[1::2]], | |
) | |
train_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['train']]) | |
valid_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['valid']]) | |
# make samplers | |
n_windows = 3 | |
n_windows_stride = 3 | |
train_sampler = Sampler( | |
[windows_ds.i_start_in_trial for windows_ds in train_set.datasets], | |
n_windows, n_windows_stride, randomize=True | |
) | |
valid_sampler = Sampler( | |
[windows_ds.i_start_in_trial for windows_ds in valid_set.datasets], | |
n_windows, n_windows_stride | |
) | |
#%% Make label transformer | |
# Use label of center window in the sequence | |
def get_center_label(x): | |
if isinstance(x, int): | |
return x | |
return x[np.ceil(len(x) / 2).astype(int)] if len(x) > 1 else x | |
train_set.target_transform = get_center_label | |
valid_set.target_transform = get_center_label | |
#%% Compute class weights | |
y_train = [train_set[idx][1] for idx in train_sampler] | |
# replicate `sklearn.utils.compute_class_weight` for `class_weight='balanced'` | |
classes = np.unique(y_train) | |
class_weights = len(y_train) / ( | |
len(classes) * np.array([y_train.count(c) for c in classes])) | |
#%% ########################################################################## | |
# Create model | |
# ------------ | |
# check if GPU is available, assuming the user wants it | |
cuda = use_gpu and torch.cuda.is_available() | |
device = 'cuda' if cuda else 'cpu' | |
if cuda: | |
# faster but lowers reproducibility | |
torch.backends.cudnn.benchmark = True | |
# set seeds for reproducibility | |
seed = 31 | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if cuda: | |
torch.cuda.manual_seed_all(seed) | |
# Reproducibility caveats: | |
# - More info on reproducibility in torch: | |
# https://pytorch.org/docs/stable/notes/randomness.html | |
# - In some cases, may need to set `PYTHONHASHSEED` env var before running script: | |
# https://forums.fast.ai/t/solved-reproducibility-where-is-the-randomness-coming-in/31628/14 | |
# - `torch.use_deterministic_algorithms(True)` isn't used, also plays a role | |
n_classes = 5 | |
# Extract number of channels and time steps from dataset | |
n_channels, input_size_samples = train_set[(0,)][0][0].shape | |
feat_extractor = SleepStagerChambon2018( | |
n_channels, | |
n_outputs=n_classes, | |
n_times=input_size_samples, | |
sfreq=sfreq, | |
) | |
model = nn.Sequential( | |
TimeDistributed(feat_extractor), # apply model on each 30-s window | |
nn.Sequential( # apply linear layer on concatenated feature vectors | |
nn.Flatten(start_dim=1), | |
nn.Dropout(0.5), | |
nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes) | |
) | |
) | |
if cuda: | |
model = model.cuda() | |
#%% ########################################################################## | |
# Create train loop | |
# ----------------- | |
lr = 1e-3 | |
batch_size = 32 | |
n_epochs = 10 | |
train_bal_acc = EpochScoring( | |
scoring='balanced_accuracy', on_train=True, name='train_bal_acc', | |
lower_is_better=False) | |
valid_bal_acc = EpochScoring( | |
scoring='balanced_accuracy', on_train=False, name='valid_bal_acc', | |
lower_is_better=False) | |
callbacks = [ | |
('train_bal_acc', train_bal_acc), | |
('valid_bal_acc', valid_bal_acc) | |
] | |
clf = EEGClassifier( | |
model, | |
criterion=torch.nn.CrossEntropyLoss, | |
criterion__weight=torch.Tensor(class_weights).to(device), | |
optimizer=torch.optim.Adam, | |
iterator_train__shuffle=False, | |
iterator_train__sampler=train_sampler, | |
iterator_valid__sampler=valid_sampler, | |
train_split=predefined_split(valid_set), # using valid_set for validation | |
optimizer__lr=lr, | |
batch_size=batch_size, | |
callbacks=callbacks, | |
device=device, | |
classes=np.unique(y_train), | |
) | |
#%% ########################################################################## | |
# Run training | |
# ------------ | |
# Model training for a specified number of epochs. `y` is None as it is already | |
# supplied in the dataset. | |
clf.fit(train_set, y=None, epochs=n_epochs) | |
#%% ########################################################################## | |
# The rest of the code (e.g. plotting) is same as in original script. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment