Created May 19, 2024 09:28
PhysioNet, edf -> model(x, y)
# -*- coding: utf-8 -*-
Minimal-code rewriting of BrainDecode's
- Eliminates all (explicit) dependence on `braindecode`
- Minimizes dependence on `mne`
My motivation was, understandable code that decodes PhysioNet into `(x, y)`
that can be fed to a model. I found the original example too high-level, and
source code too dense to manipulate for my purposes.
Rewritten code is
- heavily cut down in total number of lines
- tested to completely reproduce the original BrainDecode script's results
(note, requires `torch.use_deterministic_algorithms(True)`).
Caveat: that's with a preprocessing step (lowpass filtering) omitted,
just comment that out in original source (or reimplement it here).
- includes some fixes and extensions of functionality
- commented, explaining/justifying omitting of code with respect to
`braindecode` source code (meant to be read alongside)
- split into "helpers" (function definitions) and "execution"; former
could be moved into its own file and imported
- a rewriting of v0.8.0,
Tip: download PhysioNet from Google Cloud bucket, then set `data_loaddir`, it
was x10 faster for me than downloading directly from (what the
original script does).
Disclaimer, while it was free for me, it says it's paid, so I can't guarantee
that. Instructions at
Note: example uses different `subject_ids` and `recording_ids`.
# ############################################################################
# User config
# -----------
# where PhysioNet source data is already stored, if exists; if None, will install
# from scratch
data_loaddir = None
# where to save processed data to; defaults to current working directory
data_savedir = None
# subject and recording IDs
subject_ids = [0, 1, 2, 3]
recording_ids = [1, 2]
# use GPU if available
use_gpu = True
#%% ##########################################################################
# Imports
# -------
# should set env var before running other imports
import os
if data_loaddir is not None:
os.environ['PHYSIONET_SLEEP_PATH'] = data_loaddir
import random
import bisect
import numpy as np
import torch
import torch.nn as nn
# torch.use_deterministic_algorithms(True)
import mne
from mne.datasets.sleep_physionet.age import fetch_data
from sklearn.preprocessing import scale as standard_scale
# ****************
#%% ##########################################################################
# Loading & saving data
# ---------------------
# Files are named e.g.
# SC4001E0-PSG.edf
# 01234567
# where
# `34` = `subject_id` (i.e. 0)
# `5` = `recording_id` (i.e. 1)
# and,
# SC4001EC-Hypnogram.edf
# where
# `PSG` = data
# `Hypnogram` = metadata (including labels)
# Notes:
# - "Data" includes stuff other than EEG, exclude via `exclude_chs`.
# - `p = paths[0]` is a pair of paths, where `p[0]` = PSG, `p[1]` = Hypnogram.
# define a function to reuse later
def process_path(p, exclude_chs, data_savedir):
raw, annots, subj_num, sess_num = _load_data(p, exclude_chs)
raw = _drop_unlabeled_data(raw, annots)
raw = _preprocess(raw)
# by reading further code (window processing), we determine that we need
# to keep only the following:
# for `raw`:
# `raw._data`
# for `annotations` (== `raw.annotations`):
# `annotations.onset`
# `annotations.description`
# `annotations.duration`
data = {
'data': raw._data,
'onset': raw.annotations.onset,
'description': raw.annotations.description,
'duration': raw.annotations.duration,
'_first_time': raw._first_time, # for `braindecode_ver = True`
savename = os.path.basename(p[0]).replace('-PSG', '').replace('.edf', '.npz')
savepath = os.path.join(data_savedir, savename)
np.savez(savepath, **data)
return data, savepath
def _load_data(p, exclude_chs):
# load data & labels -----------------------------------------------------
p_data, p_meta = p
# (`preload` to load the data into RAM)
raw =, preload=True, exclude=exclude_chs)
annots = mne.read_annotations(p_meta)
# Get subject and recording number ---------------------------------------
basename = os.path.basename(p_data)
subj_num = int(basename[3:5])
sess_num = int(basename[5])
return raw, annots, subj_num, sess_num
def _drop_unlabeled_data(raw, annots):
# set `raw.annotations` from `annots` ------------------------------------
# - each `a = annots[0]` has `a['onset']` and `a['duration']`
# - `a` are cropped such that they remain within (inclusive) `tmin = 0` and
# `tmax = raw.times[-1] + 1 /['sfreq']`, where sfreq = sampling
# freq = 100 Hz. "Cropped" means their `'onset'` and `'duration'` are
# adjusted.
raw.set_annotations(annots, emit_warning=False)
# crop data to exclude unlabeled segments --------------------------------
braindecode_ver = True
if not braindecode_ver:
# "labels" are over data's time segments. E.g. `x[:20000]` can be
# "Sleep stage W" (wake), and `x[20000:25000]` be "Sleep stage 1", etc.
# "Sleep stage ?" is unlabeled.
mask = [x[-1] != '?' for x in annots.description]
sleep_event_inds = np.where(mask)[0]
# above assumes there's only one such sleep stage, and that it's the last
# one; check both assumptions
assert mask.count(False) == 1, mask
assert not mask[-1], mask
# see `not braindecode_ver` comments
mask = [
x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description]
sleep_event_inds = np.where(mask)[0]
# Crop raw (also crops labels)
# determine `tmax` as first timestamp of last stage before
# "Sleep stage ?", plus the duration of that stage, minus one sample (`dT`)
# (otherwise we end up at first timestamp of "Sleep stage ?", and
# `crop` is inclusive on `tmax`).
dT = 1 /["sfreq"]
a_tmin = annots[sleep_event_inds[0]]
a_tmax = annots[sleep_event_inds[-1]]
if not braindecode_ver:
tmin = a_tmin['onset']
tmax = a_tmax['onset'] + a_tmax['duration'] - dT
crop_wake_mins = 30
tmin = a_tmin['onset'] - crop_wake_mins * 60
tmax = a_tmax['onset'] + a_tmax['duration'] - dT + crop_wake_mins * 60
# internally, this converts `tmin`, `tmax` to indices, and does something
# like `x = x[tmin:tmax]`.
# it also correspondingly crops labels, via `set_annotations`.
# Internals (for labels):
# - `tmin` for `set_annotations` is set from `_first_time` (and another
# var we can treat as zero).
# This is updated in `crop` from the `tmin` we specify (see `_first_time`
# definition as `@property`).
# - `tmax` for `set_annotations` is set from `times[-1]` (+ dT).
# This is updated in `crop` from the `tmax` we specify (see `times`
# definition as `@property`).
# - This drops annotations that are completely out of range of `tmin`,
# `tmax`. Here, it amounts to dropping the annotation for "Sleep stage ?".
raw.crop(tmin=max(tmin, raw.times[0]),
tmax=min(tmax, raw.times[-1]))
# Rename EEG channels ----------------------------------------------------
raw.rename_channels({nm: nm.replace('EEG ', '') for nm in raw.ch_names})
return raw
def _preprocess(raw):
# internally, this justs updates `raw._data`, with
# - handling for multiprocessing
# - checking that out.shape == in.shape
# - loading `_data` if it isn't (for us, it is, via `preload=True`)
# V -> uV
raw._data = raw._data * 1e6
return raw
#%% ##########################################################################
# Creating windows
# ----------------
def _create_windows_from_events(p, mapping, sfreq):
# load data
d = np.load(p)
# `a_` for `annotations_`
data, a_description, a_onset, a_duration, _first_time = [
d[nm] for nm in
('data', 'description', 'onset', 'duration', '_first_time')
events = _events_from_annotations(
a_description, a_onset, _first_time, mapping, sfreq)
# the lib method also returns `event_id_`, which is redundant for us
# - (it is a copy of `mapping`, unless `mapping` has some keys that
# `annotations.description` doesn't, but we then only use `event_id_`
# to check whether what's in `annotations.description` is in `event_id_`,
# which is circular and same as checking directly against `mapping`)
onsets = events[:, 0]
# Onsets are relative to the beginning of the recording
filtered_durations = np.array([
dur for dur, desc in zip(a_duration, a_description)
if desc in mapping
stops = onsets + (filtered_durations * sfreq).astype(int)
# sanity check; note, `stops` is used exclusively, i.e. `start:stop`
# don't need the `raw.first_samp` in `raw.first_samp + raw.n_times` since we
# commented out `+= raw.first_samp` (upon `onsets`, in
# `_events_from_annotations`) earlier; and,
# `raw.n_times == raw._data.shape[-1]`
# (i.e. the full assert is `stops[-1] <= raw.first_samp + raw.n_times`)
assert stops[-1] <= data.shape[-1]
# this no longer executes since we commented out `+= raw.first_samp` earlier
# onsets = onsets - raw.first_samp
# stops = stops - raw.first_samp
# generate windows
window_size_samples = 3000
window_stride_samples = 3000
drop_last_window = False
i_trials, starts, stops = _compute_window_inds(
onsets, stops, window_size_samples, window_stride_samples,
# generate window events
description = events[:, -1]
# events = [[start, window_size_samples, description[i_trial]]
# for start, i_trial in zip(starts, i_trials)]
# events = np.array(events)
# description_windows = events[:, -1]
description_windows = np.array([description[i_trial] for i_trial in i_trials])
windows_ds = WindowsDataset(
return windows_ds
def _events_from_annotations(description, onset, _first_time, mapping, fs):
# minimally implements `_select_annotations_based_on_description` --------
event_sel = [i for i, d in enumerate(description) if d in mapping]
# Convert onsets to sample indices. Internals: ===========================
# - `annotations.onset` = timestamps, in seconds, of start times of labels
# (where each "label", again, is over a time interval over data, e.g.
# `x[20000:25000]`).
# - `len(annotations.onset) == len(labels)`.
# - `annotations.orig_time` = `["meas_date"]`, as long as
# `annotations = mne.read_annotations(path)` is used.
# - `["meas_date"]` = measurement date, a `datetime` object created
# from metadata in the (-PSG) EDF file.
# minimally implements ---------------------------------------------------
# inds = raw.time_as_index(times=annotations.onset, use_rounding=True,
# origin=annotations.orig_time)
# we won't end up needing `origin` (see below), so don't fetch it.
# origin = annotations.orig_time
times = onset
# Internals:
# - `self._first_time = self.first_samp /['sfreq']`
# - `self.first_samp = self._cropped_samp`
# - `self._cropped_samp = first_samps[0]` if `raw.crop()` wasn't used
# with `tmin != 0`, else it's modified (in this case to
# `tmin *['sfreq']`)
# - `first_samps = (0,)`
# - (`braindecode_ver = False` only) Hence, `raw._first_time == 0`, and
# since `origin ==["meas_date"]` (see above), `delta == 0`,
# so we can skip all below code (and `times` is already 1d)
# Since `braindecode_ver = True` is supported, execute the relevant portion
# in `raw.time_as_index`, but rewritten:
# `origin - first_samp_in_abs_time`
# <=>
# `["meas_date"] - (["meas_date"] + raw._first_time)`
# <=>
# `- raw._first_time`
delta = - _first_time
times += delta
# `raw.times[0]` is always zero in our case (RawEDF loaded the way we have)
# index = (np.atleast_1d(times) - raw.times[0]) * fs
index = times * fs
inds = np.round(index).astype(int)
# ========================================================================
# Executes if `if annotations.orig_time is not None:`, which is the case here.
# Do not execute this, so we don't have to `-= raw.first_samp` later, so
# we don't have to store `raw.first_samp` anywhere
# inds += raw.first_samp
# `annotations.description` -> numeric values based on `mapping`. E.g. if
# mapping == {'Sleep stage W': 0, 'Sleep stage 1': 1}`
# annotations.description == ['Sleep stage 1', 'Sleep stage W',
# 'Sleep stage W']
# then
# values == [1, 0, 0]
# but ignoring `event_sel` (indices of selected labels based on whether they
# were in `mapping`).
values = [mapping[kk] for kk in description[event_sel]]
# Apply `event_sel` to `inds`
inds = inds[event_sel]
# This simply concatenates the arrays into `(n_events, 3)`, and casts to int
events = np.c_[inds, np.zeros(len(inds)), values].astype(int)
return events
def _compute_window_inds(starts, stops, size, stride, drop_last_window):
assert not any(size > (stops-starts))
i_trials, window_starts = [], []
for start_i, (start, stop) in enumerate(zip(starts, stops)):
# Generate possible window starts, with given stride, between
# starts and stops (i.e. original trial onsets and stops, shifted by
# start_offset and stop_offset, respectively)
possible_starts = np.arange(start, stop, stride)
# Possible window start is actually a start, if window size fits in
# trial start and stop
for i_window, s in enumerate(possible_starts):
if (s + size) <= stop:
# If the last window start + window size is not the same as
# stop + stop_offset, create another window that overlaps and stops
# at onset + stop_offset
if not drop_last_window:
if window_starts[-1] + size != stop:
window_starts.append(stop - size)
# Set window stops to be event stops (rather than trial stops)
window_stops = np.array(window_starts) + size
assert len(i_trials) == len(window_starts) == len(window_stops)
return i_trials, window_starts, window_stops
def _preprocess_windows(windows_ds): = standard_scale(, axis=1) ='float32')
class WindowsDataset():
def __init__(self, data, target, i_start_in_trial, i_stop_in_trial): = data
self.y = np.asarray(target, dtype='int64') # skorch expects int64
self.i_start_in_trial = i_start_in_trial
self.inds = np.c_[i_start_in_trial, i_stop_in_trial]
def __getitem__(self, index):
i_start, i_end = self.inds[index]
X =[:, i_start:i_end]
y = self.y[index]
return X, y
def __len__(self):
return len(self.y)
#%% ##########################################################################
# Creating dataset objects
# ------------------------
class Sampler():
def __init__(self, i_start_in_trial_all, n_windows, n_windows_stride,
self.n_windows = n_windows
self.n_windows_stride = n_windows_stride
self.randomize = randomize
# braindecode applies `groupby` by `'subject'`, `'recording'` upon
# dataframe, then resets index, and operates on said indices below,
# meaning we first generate the indices within each `i_start_in_trial_all`
# (where "each" refers to `'subject'` and `'recording'` operate on those,
# then concatenate
idxs_all = [list(range(len(i_start_in_trial_all[0])))]
for length in map(len, i_start_in_trial_all[1:]):
idx_last = idxs_all[-1][-1]
idxs_all.append(list(range(idx_last + 1, idx_last + 1 + length)))
end_offset = 1 - n_windows
self.start_inds = np.concatenate(
[idxs[:end_offset:self.n_windows_stride] for idxs in idxs_all]
def __len__(self):
return len(self.start_inds)
def __iter__(self):
if self.randomize:
start_inds = np.random.permutation(self.start_inds)
start_inds = self.start_inds
for start_ind in start_inds:
yield tuple(range(start_ind, start_ind + self.n_windows))
class ConcatDataset(
# Merges `BaseConcatDataset` and `ConcatDataset` classes.
# `skorch` requires this to be an instance of `Dataset`
def __init__(self, datasets, target_transform=None):
self.datasets = datasets
# for script readability, we assign this later
self.target_transform = (target_transform
if target_transform is not None else
lambda x: x)
self.cumulative_sizes = np.cumsum(list(map(len, self.datasets)))
def __getitem__(self, idxs):
idxs : tuple / list
Indices of windows and targets to return (concatenated).
The target output can be modified on the fly by the
``target_transform`` parameter.
X, y = [], []
for idx in idxs:
out_i = self._getitem(idx)
X = np.stack(X, axis=0)
y = self.target_transform(np.array(y))
return X, y
def _getitem(self, idx):
assert idx >= 0 and idx < len(self)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
sample_idx = (idx if dataset_idx == 0 else
idx - self.cumulative_sizes[dataset_idx - 1])
return self.datasets[dataset_idx][sample_idx]
def __len__(self):
return self.cumulative_sizes[-1]
#%% ##########################################################################
# Creating model
# --------------
class SleepStagerChambon2018(nn.Module):
"""Feature extractor only."""
def __init__(self, n_chans, n_outputs, n_times, sfreq,
n_conv_chs=8, apply_batch_norm=False):
self.n_chans = n_chans
self.n_outputs = n_outputs
self.n_times = n_times
self.n_conv_chs = n_conv_chs
self.apply_batch_norm = apply_batch_norm
assert self.n_chans > 1
# handle params
time_conv_size_s = 0.5
max_pool_size_s = time_conv_size_s / 4
pad_size_s = time_conv_size_s / 2
time_conv_size, max_pool_size, pad_size = (
int(np.ceil(x * sfreq)) for x in
(time_conv_size_s, max_pool_size_s, pad_size_s)
# handle certain layers
self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1))
batch_norm = (nn.BatchNorm2d if apply_batch_norm else
# make feature extractor
self.feature_extractor = nn.Sequential(
1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)),
nn.MaxPool2d((1, max_pool_size)),
n_conv_chs, n_conv_chs, (1, time_conv_size),
padding=(0, pad_size)),
nn.MaxPool2d((1, max_pool_size)),
# length of last layer (for later)
with torch.no_grad():
self.len_last_layer = len(self.feature_extractor(
torch.Tensor(1, 1, self.n_chans, self.n_times)
def forward(self, x):
"""x: batch of EEG windows of shape (batch_size, n_channels, n_times)"""
x = x.unsqueeze(1)
x = self.spatial_conv(x)
x = x.transpose(1, 2)
return self.feature_extractor(x)
class TimeDistributed(nn.Module):
"""Apply module on a sequence of windows and return their concatenation
(see `forward`):
`(batch_size, seq_len, n_channels, n_times)` ->
`(batch_size, seq_len, output_size)`
Useful with sequence-to-prediction models (e.g. sleep stager which must map
a sequence of consecutive windows to the label of the middle window in the
def __init__(self, module):
self.module = module
def forward(self, x):
x: sequence of windows of shape
(batch_size, seq_len, n_channels, n_times)
Returns output of shape
(batch_size, seq_len, output_size)
b, s, c, t = x.shape
out = self.module(x.view(b * s, c, t))
return out.view(b, s, -1)
#%% ##########################################################################
# Creating train loop
# -------------------
from skorch.helper import predefined_split
from skorch.classifier import NeuralNetClassifier
from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog
from skorch.utils import noop, train_loss_score, valid_loss_score
class EEGClassifier(NeuralNetClassifier):
Is `NeuralNetClassifier`, with `_default_callbacks` overridden.
All arguments are passed straight into `NeuralNetClassifier`.
Note, `NeuralNetClassifier` doesn't assume softmax activation and calls
the loss function directly (without applying e.g. log).
Parameter note: `iterator_train__shuffle` (default True) defines whether
train dataset will be shuffled. As `skorch` does not shuffle the train
dataset by default, this one overwrites this option.
def __init__(
def _default_callbacks(self):
return [
("epoch_timer", EpochTimer()),
valid_loss_score, name="valid_loss", target_extractor=noop,
("print_log", PrintLog()),
# `skorch` default is
# return [
# ('epoch_timer', EpochTimer()),
# ('train_loss', PassthroughScoring(
# name='train_loss',
# on_train=True,
# )),
# ('valid_loss', PassthroughScoring(
# name='valid_loss',
# )),
# ('print_log', PrintLog()),
# ]
# this unites `_EEGNeuralNet._default_callbacks` and
# `EEGClassifier._default_callbacks` (latter excluding the fact that it
# appends to former)
# Excluded inherited classes explanation ---------------------------------
# _EEGNeuralNet:
# Running original script with and without this class changed nothing.
# Inspecting the code, it appears to handle certain configurations that
# aren't used here (e.g. "cropping").
# Excluded arguments explanation -----------------------------------------
# aggregate_predictions:
# Was only used in `predict_proba`, which was dropped
# Excluded methods explanation -------------------------------------------
# get_iterator:
# only does something via `ThrowAwayIndexLoader`, which only does
# something if iterator returns x.ndim==3, which doesn't happen here
# predict_proba:
# only does something if `cropped=True`, which isn't the case here
# get_loss:
# only does something if `isinstance(self.criterion_, torch.nn.NLLLoss)`,
# which isn't the case here
# predict:
# completely identical to inherited class's definition, is likely
# redefined for docs clarity or future changes
# predict_trials:
# meant to be used with `cropped=True`, which isn't the case here
# _get_n_outputs:
# unused in ``, checked by inserting `1/0` here
# Excluded attributes explanation ----------------------------------------
# _last_window_inds_:
# TL;DR unused
#%% ##########################################################################
# *********
#%% ##########################################################################
# Convert and save data as numpy arrays
# -------------------------------------
# handle configs
if data_savedir is None:
data_savedir = os.getcwd()
# set excluded channels
exclude_chs = ('EOG horizontal', 'Resp oro-nasal', 'EMG submental',
'Temp rectal', 'Event marker')
# merge stages 3 and 4 following AASM standards
mapping = {
'Sleep stage W': 0,
'Sleep stage 1': 1,
'Sleep stage 2': 2,
'Sleep stage 3': 3,
'Sleep stage 4': 3,
'Sleep stage R': 4
# sampling freq, obtained via `['sfreq']`
# (hard-coded here for cleaner code)
sfreq = 100
# Fetch paths, generate ids
paths = fetch_data(subject_ids, recording=recording_ids, on_missing='warn')
# For rest of this script, generalize original script's example to any
# `subject_ids` and `recording_ids` by not assuming there's two of former and
# one of latter.
# Below for-loop ordering follows that of `fetch_data`.
ids = [(sid, rid) for sid in subject_ids for rid in recording_ids]
# Map ids to paths
ipaths = {id_: p for id_, p in zip(ids, paths)}
# Store processed data paths
isavepaths = {}
for _id, p in ipaths.items():
raw, psave = process_path(p, exclude_chs, data_savedir)
isavepaths[_id] = psave
#%% ##########################################################################
# Create windows
# --------------
iwindows_ds = {}
for id_, p in isavepaths.items():
windows_ds = _create_windows_from_events(p, mapping, sfreq)
iwindows_ds[id_] = windows_ds
#%% ##########################################################################
# Create dataset objects, make train-test split
# ---------------------------------------------
# split by subject, so train and validation have different subjects
split_ids = dict(
train=[_id for _id in ids if _id[0] in subject_ids[::2]],
valid=[_id for _id in ids if _id[0] in subject_ids[1::2]],
train_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['train']])
valid_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['valid']])
# make samplers
n_windows = 3
n_windows_stride = 3
train_sampler = Sampler(
[windows_ds.i_start_in_trial for windows_ds in train_set.datasets],
n_windows, n_windows_stride, randomize=True
valid_sampler = Sampler(
[windows_ds.i_start_in_trial for windows_ds in valid_set.datasets],
n_windows, n_windows_stride
#%% Make label transformer
# Use label of center window in the sequence
def get_center_label(x):
if isinstance(x, int):
return x
return x[np.ceil(len(x) / 2).astype(int)] if len(x) > 1 else x
train_set.target_transform = get_center_label
valid_set.target_transform = get_center_label
#%% Compute class weights
y_train = [train_set[idx][1] for idx in train_sampler]
# replicate `sklearn.utils.compute_class_weight` for `class_weight='balanced'`
classes = np.unique(y_train)
class_weights = len(y_train) / (
len(classes) * np.array([y_train.count(c) for c in classes]))
#%% ##########################################################################
# Create model
# ------------
# check if GPU is available, assuming the user wants it
cuda = use_gpu and torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
# faster but lowers reproducibility
torch.backends.cudnn.benchmark = True
# set seeds for reproducibility
seed = 31
if cuda:
# Reproducibility caveats:
# - More info on reproducibility in torch:
# - In some cases, may need to set `PYTHONHASHSEED` env var before running script:
# - `torch.use_deterministic_algorithms(True)` isn't used, also plays a role
n_classes = 5
# Extract number of channels and time steps from dataset
n_channels, input_size_samples = train_set[(0,)][0][0].shape
feat_extractor = SleepStagerChambon2018(
model = nn.Sequential(
TimeDistributed(feat_extractor), # apply model on each 30-s window
nn.Sequential( # apply linear layer on concatenated feature vectors
nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes)
if cuda:
model = model.cuda()
#%% ##########################################################################
# Create train loop
# -----------------
lr = 1e-3
batch_size = 32
n_epochs = 10
train_bal_acc = EpochScoring(
scoring='balanced_accuracy', on_train=True, name='train_bal_acc',
valid_bal_acc = EpochScoring(
scoring='balanced_accuracy', on_train=False, name='valid_bal_acc',
callbacks = [
('train_bal_acc', train_bal_acc),
('valid_bal_acc', valid_bal_acc)
clf = EEGClassifier(
train_split=predefined_split(valid_set), # using valid_set for validation
#%% ##########################################################################
# Run training
# ------------
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset., y=None, epochs=n_epochs)
#%% ##########################################################################
# The rest of the code (e.g. plotting) is same as in original script.
