Skip to content

Instantly share code, notes, and snippets.

@OverLordGoldDragon
Created May 19, 2024 09:28
Show Gist options
  • Save OverLordGoldDragon/b6709fd266929f90c7979fb1d5635c4b to your computer and use it in GitHub Desktop.
Save OverLordGoldDragon/b6709fd266929f90c7979fb1d5635c4b to your computer and use it in GitHub Desktop.
PhysioNet, edf -> model(x, y)
# -*- coding: utf-8 -*-
"""
Minimal-code rewriting of BrainDecode's
https://braindecode.org/stable/auto_examples/applied_examples/plot_sleep_staging_chambon2018.html
https://github.com/braindecode/braindecode/blob/master/examples/applied_examples/plot_sleep_staging_chambon2018.py
It
- Eliminates all (explicit) dependence on `braindecode`
- Minimizes dependence on `mne`
My motivation was, understandable code that decodes PhysioNet into `(x, y)`
that can be fed to a model. I found the original example too high-level, and
source code too dense to manipulate for my purposes.
Rewritten code is
- heavily cut down in total number of lines
- tested to completely reproduce the original BrainDecode script's results
(note, requires `torch.use_deterministic_algorithms(True)`).
Caveat: that's with a preprocessing step (lowpass filtering) omitted,
just comment that out in original source (or reimplement it here).
- includes some fixes and extensions of functionality
- commented, explaining/justifying omitting of code with respect to
`braindecode` source code (meant to be read alongside)
- split into "helpers" (function definitions) and "execution"; former
could be moved into its own file and imported
- a rewriting of v0.8.0,
https://github.com/braindecode/braindecode/tree/v0.8
Tip: download PhysioNet from Google Cloud bucket, then set `data_loaddir`, it
was x10 faster for me than downloading directly from physionet.org (what the
original script does).
Disclaimer, while it was free for me, it says it's paid, so I can't guarantee
that. Instructions at https://physionet.org/content/sleep-edfx/1.0.0/
Note: example uses different `subject_ids` and `recording_ids`.
"""
# ############################################################################
# User config
# -----------
# where PhysioNet source data is already stored, if exists; if None, will install
# from scratch
data_loaddir = None
# where to save processed data to; defaults to current working directory
data_savedir = None
# subject and recording IDs
subject_ids = [0, 1, 2, 3]
recording_ids = [1, 2]
# use GPU if available
use_gpu = True
#%% ##########################################################################
# Imports
# -------
# should set env var before running other imports
import os
if data_loaddir is not None:
os.environ['PHYSIONET_SLEEP_PATH'] = data_loaddir
import random
import bisect
import numpy as np
import torch
import torch.nn as nn
# torch.use_deterministic_algorithms(True)
import mne
from mne.datasets.sleep_physionet.age import fetch_data
from sklearn.preprocessing import scale as standard_scale
##############################################################################
# HELPER FUNCTIONS
# ****************
#%% ##########################################################################
# Loading & saving data
# ---------------------
# Files are named e.g.
#
# SC4001E0-PSG.edf
# 01234567
#
# where
#
# `34` = `subject_id` (i.e. 0)
# `5` = `recording_id` (i.e. 1)
#
# and,
#
# SC4001EC-Hypnogram.edf
#
# where
#
# `PSG` = data
# `Hypnogram` = metadata (including labels)
#
# Notes:
#
# - "Data" includes stuff other than EEG, exclude via `exclude_chs`.
# - `p = paths[0]` is a pair of paths, where `p[0]` = PSG, `p[1]` = Hypnogram.
#
# define a function to reuse later
def process_path(p, exclude_chs, data_savedir):
raw, annots, subj_num, sess_num = _load_data(p, exclude_chs)
raw = _drop_unlabeled_data(raw, annots)
raw = _preprocess(raw)
# by reading further code (window processing), we determine that we need
# to keep only the following:
#
# for `raw`:
# `raw._data`
# for `annotations` (== `raw.annotations`):
# `annotations.onset`
# `annotations.description`
# `annotations.duration`
#
data = {
'data': raw._data,
'onset': raw.annotations.onset,
'description': raw.annotations.description,
'duration': raw.annotations.duration,
'_first_time': raw._first_time, # for `braindecode_ver = True`
}
savename = os.path.basename(p[0]).replace('-PSG', '').replace('.edf', '.npz')
savepath = os.path.join(data_savedir, savename)
np.savez(savepath, **data)
return data, savepath
def _load_data(p, exclude_chs):
# load data & labels -----------------------------------------------------
p_data, p_meta = p
# (`preload` to load the data into RAM)
raw = mne.io.read_raw_edf(p_data, preload=True, exclude=exclude_chs)
annots = mne.read_annotations(p_meta)
# Get subject and recording number ---------------------------------------
basename = os.path.basename(p_data)
subj_num = int(basename[3:5])
sess_num = int(basename[5])
return raw, annots, subj_num, sess_num
def _drop_unlabeled_data(raw, annots):
# set `raw.annotations` from `annots` ------------------------------------
# - each `a = annots[0]` has `a['onset']` and `a['duration']`
# - `a` are cropped such that they remain within (inclusive) `tmin = 0` and
# `tmax = raw.times[-1] + 1 / raw.info['sfreq']`, where sfreq = sampling
# freq = 100 Hz. "Cropped" means their `'onset'` and `'duration'` are
# adjusted.
raw.set_annotations(annots, emit_warning=False)
# crop data to exclude unlabeled segments --------------------------------
braindecode_ver = True
if not braindecode_ver:
# "labels" are over data's time segments. E.g. `x[:20000]` can be
# "Sleep stage W" (wake), and `x[20000:25000]` be "Sleep stage 1", etc.
# "Sleep stage ?" is unlabeled.
mask = [x[-1] != '?' for x in annots.description]
sleep_event_inds = np.where(mask)[0]
# above assumes there's only one such sleep stage, and that it's the last
# one; check both assumptions
assert mask.count(False) == 1, mask
assert not mask[-1], mask
else:
# see `not braindecode_ver` comments
mask = [
x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description]
sleep_event_inds = np.where(mask)[0]
# Crop raw (also crops labels)
# determine `tmax` as first timestamp of last stage before
# "Sleep stage ?", plus the duration of that stage, minus one sample (`dT`)
# (otherwise we end up at first timestamp of "Sleep stage ?", and
# `crop` is inclusive on `tmax`).
dT = 1 / raw.info["sfreq"]
a_tmin = annots[sleep_event_inds[0]]
a_tmax = annots[sleep_event_inds[-1]]
if not braindecode_ver:
tmin = a_tmin['onset']
tmax = a_tmax['onset'] + a_tmax['duration'] - dT
else:
crop_wake_mins = 30
tmin = a_tmin['onset'] - crop_wake_mins * 60
tmax = a_tmax['onset'] + a_tmax['duration'] - dT + crop_wake_mins * 60
# internally, this converts `tmin`, `tmax` to indices, and does something
# like `x = x[tmin:tmax]`.
# it also correspondingly crops labels, via `set_annotations`.
# Internals (for labels):
# - `tmin` for `set_annotations` is set from `_first_time` (and another
# var we can treat as zero).
# This is updated in `crop` from the `tmin` we specify (see `_first_time`
# definition as `@property`).
# - `tmax` for `set_annotations` is set from `times[-1]` (+ dT).
# This is updated in `crop` from the `tmax` we specify (see `times`
# definition as `@property`).
# - This drops annotations that are completely out of range of `tmin`,
# `tmax`. Here, it amounts to dropping the annotation for "Sleep stage ?".
raw.crop(tmin=max(tmin, raw.times[0]),
tmax=min(tmax, raw.times[-1]))
# Rename EEG channels ----------------------------------------------------
raw.rename_channels({nm: nm.replace('EEG ', '') for nm in raw.ch_names})
return raw
def _preprocess(raw):
# internally, this justs updates `raw._data`, with
# - handling for multiprocessing
# - checking that out.shape == in.shape
# - loading `_data` if it isn't (for us, it is, via `preload=True`)
# V -> uV
raw._data = raw._data * 1e6
return raw
#%% ##########################################################################
# Creating windows
# ----------------
def _create_windows_from_events(p, mapping, sfreq):
# load data
d = np.load(p)
# `a_` for `annotations_`
data, a_description, a_onset, a_duration, _first_time = [
d[nm] for nm in
('data', 'description', 'onset', 'duration', '_first_time')
]
events = _events_from_annotations(
a_description, a_onset, _first_time, mapping, sfreq)
# the lib method also returns `event_id_`, which is redundant for us
# - (it is a copy of `mapping`, unless `mapping` has some keys that
# `annotations.description` doesn't, but we then only use `event_id_`
# to check whether what's in `annotations.description` is in `event_id_`,
# which is circular and same as checking directly against `mapping`)
onsets = events[:, 0]
# Onsets are relative to the beginning of the recording
filtered_durations = np.array([
dur for dur, desc in zip(a_duration, a_description)
if desc in mapping
])
stops = onsets + (filtered_durations * sfreq).astype(int)
# sanity check; note, `stops` is used exclusively, i.e. `start:stop`
# don't need the `raw.first_samp` in `raw.first_samp + raw.n_times` since we
# commented out `+= raw.first_samp` (upon `onsets`, in
# `_events_from_annotations`) earlier; and,
# `raw.n_times == raw._data.shape[-1]`
# (i.e. the full assert is `stops[-1] <= raw.first_samp + raw.n_times`)
assert stops[-1] <= data.shape[-1]
# this no longer executes since we commented out `+= raw.first_samp` earlier
# onsets = onsets - raw.first_samp
# stops = stops - raw.first_samp
# generate windows
window_size_samples = 3000
window_stride_samples = 3000
drop_last_window = False
i_trials, starts, stops = _compute_window_inds(
onsets, stops, window_size_samples, window_stride_samples,
drop_last_window)
# generate window events
description = events[:, -1]
# events = [[start, window_size_samples, description[i_trial]]
# for start, i_trial in zip(starts, i_trials)]
# events = np.array(events)
# description_windows = events[:, -1]
description_windows = np.array([description[i_trial] for i_trial in i_trials])
windows_ds = WindowsDataset(
data,
target=description_windows,
i_start_in_trial=starts,
i_stop_in_trial=stops
)
return windows_ds
def _events_from_annotations(description, onset, _first_time, mapping, fs):
# minimally implements `_select_annotations_based_on_description` --------
event_sel = [i for i, d in enumerate(description) if d in mapping]
# Convert onsets to sample indices. Internals: ===========================
# - `annotations.onset` = timestamps, in seconds, of start times of labels
# (where each "label", again, is over a time interval over data, e.g.
# `x[20000:25000]`).
# - `len(annotations.onset) == len(labels)`.
# - `annotations.orig_time` = `raw.info["meas_date"]`, as long as
# `annotations = mne.read_annotations(path)` is used.
# - `raw.info["meas_date"]` = measurement date, a `datetime` object created
# from metadata in the (-PSG) EDF file.
# minimally implements ---------------------------------------------------
#
# inds = raw.time_as_index(times=annotations.onset, use_rounding=True,
# origin=annotations.orig_time)
#
# we won't end up needing `origin` (see below), so don't fetch it.
# origin = annotations.orig_time
times = onset
# Internals:
# - `self._first_time = self.first_samp / self.info['sfreq']`
# - `self.first_samp = self._cropped_samp`
# - `self._cropped_samp = first_samps[0]` if `raw.crop()` wasn't used
# with `tmin != 0`, else it's modified (in this case to
# `tmin * self.info['sfreq']`)
# - `first_samps = (0,)`
# - (`braindecode_ver = False` only) Hence, `raw._first_time == 0`, and
# since `origin == raw.info["meas_date"]` (see above), `delta == 0`,
# so we can skip all below code (and `times` is already 1d)
# Since `braindecode_ver = True` is supported, execute the relevant portion
# in `raw.time_as_index`, but rewritten:
#
# `origin - first_samp_in_abs_time`
# <=>
# `raw.info["meas_date"] - (raw.info["meas_date"] + raw._first_time)`
# <=>
# `- raw._first_time`
delta = - _first_time
times += delta
# `raw.times[0]` is always zero in our case (RawEDF loaded the way we have)
# index = (np.atleast_1d(times) - raw.times[0]) * fs
index = times * fs
inds = np.round(index).astype(int)
# ========================================================================
# Executes if `if annotations.orig_time is not None:`, which is the case here.
# Do not execute this, so we don't have to `-= raw.first_samp` later, so
# we don't have to store `raw.first_samp` anywhere
# inds += raw.first_samp
# `annotations.description` -> numeric values based on `mapping`. E.g. if
#
# mapping == {'Sleep stage W': 0, 'Sleep stage 1': 1}`
# annotations.description == ['Sleep stage 1', 'Sleep stage W',
# 'Sleep stage W']
#
# then
#
# values == [1, 0, 0]
#
# but ignoring `event_sel` (indices of selected labels based on whether they
# were in `mapping`).
#
values = [mapping[kk] for kk in description[event_sel]]
# Apply `event_sel` to `inds`
inds = inds[event_sel]
# This simply concatenates the arrays into `(n_events, 3)`, and casts to int
events = np.c_[inds, np.zeros(len(inds)), values].astype(int)
return events
def _compute_window_inds(starts, stops, size, stride, drop_last_window):
assert not any(size > (stops-starts))
i_trials, window_starts = [], []
for start_i, (start, stop) in enumerate(zip(starts, stops)):
# Generate possible window starts, with given stride, between
# starts and stops (i.e. original trial onsets and stops, shifted by
# start_offset and stop_offset, respectively)
possible_starts = np.arange(start, stop, stride)
# Possible window start is actually a start, if window size fits in
# trial start and stop
for i_window, s in enumerate(possible_starts):
if (s + size) <= stop:
window_starts.append(s)
i_trials.append(start_i)
# If the last window start + window size is not the same as
# stop + stop_offset, create another window that overlaps and stops
# at onset + stop_offset
if not drop_last_window:
if window_starts[-1] + size != stop:
window_starts.append(stop - size)
i_trials.append(start_i)
# Set window stops to be event stops (rather than trial stops)
window_stops = np.array(window_starts) + size
assert len(i_trials) == len(window_starts) == len(window_stops)
return i_trials, window_starts, window_stops
def _preprocess_windows(windows_ds):
windows_ds.data = standard_scale(windows_ds.data, axis=1)
windows_ds.data = windows_ds.data.copy().astype('float32')
class WindowsDataset():
def __init__(self, data, target, i_start_in_trial, i_stop_in_trial):
self.data = data
self.y = np.asarray(target, dtype='int64') # skorch expects int64
self.i_start_in_trial = i_start_in_trial
self.inds = np.c_[i_start_in_trial, i_stop_in_trial]
def __getitem__(self, index):
i_start, i_end = self.inds[index]
X = self.data[:, i_start:i_end]
y = self.y[index]
return X, y
def __len__(self):
return len(self.y)
#%% ##########################################################################
# Creating dataset objects
# ------------------------
class Sampler():
def __init__(self, i_start_in_trial_all, n_windows, n_windows_stride,
randomize=False):
self.n_windows = n_windows
self.n_windows_stride = n_windows_stride
self.randomize = randomize
# braindecode applies `groupby` by `'subject'`, `'recording'` upon
# dataframe, then resets index, and operates on said indices below,
# meaning we first generate the indices within each `i_start_in_trial_all`
# (where "each" refers to `'subject'` and `'recording'` operate on those,
# then concatenate
idxs_all = [list(range(len(i_start_in_trial_all[0])))]
for length in map(len, i_start_in_trial_all[1:]):
idx_last = idxs_all[-1][-1]
idxs_all.append(list(range(idx_last + 1, idx_last + 1 + length)))
end_offset = 1 - n_windows
self.start_inds = np.concatenate(
[idxs[:end_offset:self.n_windows_stride] for idxs in idxs_all]
)
def __len__(self):
return len(self.start_inds)
def __iter__(self):
if self.randomize:
start_inds = np.random.permutation(self.start_inds)
else:
start_inds = self.start_inds
for start_ind in start_inds:
yield tuple(range(start_ind, start_ind + self.n_windows))
class ConcatDataset(torch.utils.data.Dataset):
# Merges `BaseConcatDataset` and `ConcatDataset` classes.
# `skorch` requires this to be an instance of `Dataset`
def __init__(self, datasets, target_transform=None):
self.datasets = datasets
# for script readability, we assign this later
self.target_transform = (target_transform
if target_transform is not None else
lambda x: x)
self.cumulative_sizes = np.cumsum(list(map(len, self.datasets)))
def __getitem__(self, idxs):
"""
idxs : tuple / list
Indices of windows and targets to return (concatenated).
The target output can be modified on the fly by the
``target_transform`` parameter.
"""
X, y = [], []
for idx in idxs:
out_i = self._getitem(idx)
X.append(out_i[0])
y.append(out_i[1])
X = np.stack(X, axis=0)
y = self.target_transform(np.array(y))
return X, y
def _getitem(self, idx):
assert idx >= 0 and idx < len(self)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
sample_idx = (idx if dataset_idx == 0 else
idx - self.cumulative_sizes[dataset_idx - 1])
return self.datasets[dataset_idx][sample_idx]
def __len__(self):
return self.cumulative_sizes[-1]
#%% ##########################################################################
# Creating model
# --------------
class SleepStagerChambon2018(nn.Module):
"""Feature extractor only."""
def __init__(self, n_chans, n_outputs, n_times, sfreq,
n_conv_chs=8, apply_batch_norm=False):
super().__init__()
self.n_chans = n_chans
self.n_outputs = n_outputs
self.n_times = n_times
self.n_conv_chs = n_conv_chs
self.apply_batch_norm = apply_batch_norm
assert self.n_chans > 1
# handle params
time_conv_size_s = 0.5
max_pool_size_s = time_conv_size_s / 4
pad_size_s = time_conv_size_s / 2
time_conv_size, max_pool_size, pad_size = (
int(np.ceil(x * sfreq)) for x in
(time_conv_size_s, max_pool_size_s, pad_size_s)
)
# handle certain layers
self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
batch_norm = (nn.BatchNorm2d if apply_batch_norm else
nn.Identity)
# make feature extractor
self.feature_extractor = nn.Sequential(
nn.Conv2d(
1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
batch_norm(n_conv_chs),
nn.ReLU(),
nn.MaxPool2d((1, max_pool_size)),
nn.Conv2d(
n_conv_chs, n_conv_chs, (1, time_conv_size),
padding=(0, pad_size)),
batch_norm(n_conv_chs),
nn.ReLU(),
nn.MaxPool2d((1, max_pool_size)),
nn.Flatten(),
)
# length of last layer (for later)
with torch.no_grad():
self.len_last_layer = len(self.feature_extractor(
torch.Tensor(1, 1, self.n_chans, self.n_times)
).flatten())
def forward(self, x):
"""x: batch of EEG windows of shape (batch_size, n_channels, n_times)"""
x = x.unsqueeze(1)
x = self.spatial_conv(x)
x = x.transpose(1, 2)
return self.feature_extractor(x)
class TimeDistributed(nn.Module):
"""Apply module on a sequence of windows and return their concatenation
(see `forward`):
`(batch_size, seq_len, n_channels, n_times)` ->
`(batch_size, seq_len, output_size)`
Useful with sequence-to-prediction models (e.g. sleep stager which must map
a sequence of consecutive windows to the label of the middle window in the
sequence).
"""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
"""
x: sequence of windows of shape
(batch_size, seq_len, n_channels, n_times)
Returns output of shape
(batch_size, seq_len, output_size)
"""
b, s, c, t = x.shape
out = self.module(x.view(b * s, c, t))
return out.view(b, s, -1)
#%% ##########################################################################
# Creating train loop
# -------------------
from skorch.helper import predefined_split
from skorch.classifier import NeuralNetClassifier
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
from skorch.utils import noop, train_loss_score, valid_loss_score
class EEGClassifier(NeuralNetClassifier):
"""
Is `NeuralNetClassifier`, with `_default_callbacks` overridden.
All arguments are passed straight into `NeuralNetClassifier`.
Note, `NeuralNetClassifier` doesn't assume softmax activation and calls
the loss function directly (without applying e.g. log).
Parameter note: `iterator_train__shuffle` (default True) defines whether
train dataset will be shuffled. As `skorch` does not shuffle the train
dataset by default, this one overwrites this option.
"""
def __init__(
self,
module,
criterion=None,
callbacks=None,
iterator_train__shuffle=True,
iterator_train__drop_last=True,
**kwargs
):
super().__init__(
module,
criterion=criterion,
callbacks=callbacks,
iterator_train__shuffle=iterator_train__shuffle,
iterator_train__drop_last=iterator_train__drop_last,
**kwargs,
)
@property
def _default_callbacks(self):
return [
("epoch_timer", EpochTimer()),
(
"train_loss",
BatchScoring(
train_loss_score,
name="train_loss",
on_train=True,
target_extractor=noop,
),
),
(
"valid_loss",
BatchScoring(
valid_loss_score, name="valid_loss", target_extractor=noop,
),
),
("print_log", PrintLog()),
(
"valid_acc",
EpochScoring(
"accuracy",
name="valid_acc",
lower_is_better=False,
)
)
]
# `skorch` default is
# return [
# ('epoch_timer', EpochTimer()),
# ('train_loss', PassthroughScoring(
# name='train_loss',
# on_train=True,
# )),
# ('valid_loss', PassthroughScoring(
# name='valid_loss',
# )),
# ('print_log', PrintLog()),
# ]
# this unites `_EEGNeuralNet._default_callbacks` and
# `EEGClassifier._default_callbacks` (latter excluding the fact that it
# appends to former)
# Excluded inherited classes explanation ---------------------------------
# _EEGNeuralNet:
# Running original script with and without this class changed nothing.
# Inspecting the code, it appears to handle certain configurations that
# aren't used here (e.g. "cropping").
# Excluded arguments explanation -----------------------------------------
# aggregate_predictions:
# Was only used in `predict_proba`, which was dropped
# Excluded methods explanation -------------------------------------------
# get_iterator:
# only does something via `ThrowAwayIndexLoader`, which only does
# something if iterator returns x.ndim==3, which doesn't happen here
# predict_proba:
# only does something if `cropped=True`, which isn't the case here
# get_loss:
# only does something if `isinstance(self.criterion_, torch.nn.NLLLoss)`,
# which isn't the case here
# predict:
# completely identical to inherited class's definition, is likely
# redefined for docs clarity or future changes
# predict_trials:
# meant to be used with `cropped=True`, which isn't the case here
# _get_n_outputs:
# unused in `clf.fit()`, checked by inserting `1/0` here
# Excluded attributes explanation ----------------------------------------
# _last_window_inds_:
# TL;DR unused
#%% ##########################################################################
# EXECUTION
# *********
#%% ##########################################################################
# Convert and save data as numpy arrays
# -------------------------------------
# handle configs
if data_savedir is None:
data_savedir = os.getcwd()
# set excluded channels
exclude_chs = ('EOG horizontal', 'Resp oro-nasal', 'EMG submental',
'Temp rectal', 'Event marker')
# merge stages 3 and 4 following AASM standards
mapping = {
'Sleep stage W': 0,
'Sleep stage 1': 1,
'Sleep stage 2': 2,
'Sleep stage 3': 3,
'Sleep stage 4': 3,
'Sleep stage R': 4
}
# sampling freq, obtained via `raw.info['sfreq']`
# (hard-coded here for cleaner code)
sfreq = 100
# Fetch paths, generate ids
paths = fetch_data(subject_ids, recording=recording_ids, on_missing='warn')
# For rest of this script, generalize original script's example to any
# `subject_ids` and `recording_ids` by not assuming there's two of former and
# one of latter.
# Below for-loop ordering follows that of `fetch_data`.
ids = [(sid, rid) for sid in subject_ids for rid in recording_ids]
# Map ids to paths
ipaths = {id_: p for id_, p in zip(ids, paths)}
# Store processed data paths
isavepaths = {}
for _id, p in ipaths.items():
raw, psave = process_path(p, exclude_chs, data_savedir)
isavepaths[_id] = psave
#%% ##########################################################################
# Create windows
# --------------
iwindows_ds = {}
for id_, p in isavepaths.items():
windows_ds = _create_windows_from_events(p, mapping, sfreq)
_preprocess_windows(windows_ds)
iwindows_ds[id_] = windows_ds
#%% ##########################################################################
# Create dataset objects, make train-test split
# ---------------------------------------------
# split by subject, so train and validation have different subjects
split_ids = dict(
train=[_id for _id in ids if _id[0] in subject_ids[::2]],
valid=[_id for _id in ids if _id[0] in subject_ids[1::2]],
)
train_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['train']])
valid_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['valid']])
# make samplers
n_windows = 3
n_windows_stride = 3
train_sampler = Sampler(
[windows_ds.i_start_in_trial for windows_ds in train_set.datasets],
n_windows, n_windows_stride, randomize=True
)
valid_sampler = Sampler(
[windows_ds.i_start_in_trial for windows_ds in valid_set.datasets],
n_windows, n_windows_stride
)
#%% Make label transformer
# Use label of center window in the sequence
def get_center_label(x):
if isinstance(x, int):
return x
return x[np.ceil(len(x) / 2).astype(int)] if len(x) > 1 else x
train_set.target_transform = get_center_label
valid_set.target_transform = get_center_label
#%% Compute class weights
y_train = [train_set[idx][1] for idx in train_sampler]
# replicate `sklearn.utils.compute_class_weight` for `class_weight='balanced'`
classes = np.unique(y_train)
class_weights = len(y_train) / (
len(classes) * np.array([y_train.count(c) for c in classes]))
#%% ##########################################################################
# Create model
# ------------
# check if GPU is available, assuming the user wants it
cuda = use_gpu and torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
# faster but lowers reproducibility
torch.backends.cudnn.benchmark = True
# set seeds for reproducibility
seed = 31
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if cuda:
torch.cuda.manual_seed_all(seed)
# Reproducibility caveats:
# - More info on reproducibility in torch:
# https://pytorch.org/docs/stable/notes/randomness.html
# - In some cases, may need to set `PYTHONHASHSEED` env var before running script:
# https://forums.fast.ai/t/solved-reproducibility-where-is-the-randomness-coming-in/31628/14
# - `torch.use_deterministic_algorithms(True)` isn't used, also plays a role
n_classes = 5
# Extract number of channels and time steps from dataset
n_channels, input_size_samples = train_set[(0,)][0][0].shape
feat_extractor = SleepStagerChambon2018(
n_channels,
n_outputs=n_classes,
n_times=input_size_samples,
sfreq=sfreq,
)
model = nn.Sequential(
TimeDistributed(feat_extractor), # apply model on each 30-s window
nn.Sequential( # apply linear layer on concatenated feature vectors
nn.Flatten(start_dim=1),
nn.Dropout(0.5),
nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes)
)
)
if cuda:
model = model.cuda()
#%% ##########################################################################
# Create train loop
# -----------------
lr = 1e-3
batch_size = 32
n_epochs = 10
train_bal_acc = EpochScoring(
scoring='balanced_accuracy', on_train=True, name='train_bal_acc',
lower_is_better=False)
valid_bal_acc = EpochScoring(
scoring='balanced_accuracy', on_train=False, name='valid_bal_acc',
lower_is_better=False)
callbacks = [
('train_bal_acc', train_bal_acc),
('valid_bal_acc', valid_bal_acc)
]
clf = EEGClassifier(
model,
criterion=torch.nn.CrossEntropyLoss,
criterion__weight=torch.Tensor(class_weights).to(device),
optimizer=torch.optim.Adam,
iterator_train__shuffle=False,
iterator_train__sampler=train_sampler,
iterator_valid__sampler=valid_sampler,
train_split=predefined_split(valid_set), # using valid_set for validation
optimizer__lr=lr,
batch_size=batch_size,
callbacks=callbacks,
device=device,
classes=np.unique(y_train),
)
#%% ##########################################################################
# Run training
# ------------
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
#%% ##########################################################################
# The rest of the code (e.g. plotting) is same as in original script.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment