Created
April 18, 2026 20:37
-
-
Save bruAristimunha/27d74c8410fe9d0db258a03f42efa7c6 to your computer and use it in GitHub Desktop.
Braindecode tutorial checkpoint trainer/pusher (PR #985)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import argparse | |
| import copy | |
| import json | |
| import sys | |
| from dataclasses import dataclass | |
| from numbers import Integral | |
| from pathlib import Path | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| LOCAL_MOABB_ROOT = REPO_ROOT / "moabb" | |
| if (LOCAL_MOABB_ROOT / "moabb" / "__init__.py").exists(): | |
| sys.path.insert(0, str(LOCAL_MOABB_ROOT)) | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from joblib import parallel_backend | |
| from numpy import multiply | |
| from sklearn.metrics import balanced_accuracy_score, r2_score | |
| from sklearn.model_selection import GridSearchCV, KFold, train_test_split | |
| from sklearn.preprocessing import robust_scale | |
| from sklearn.preprocessing import scale as standard_scale | |
| from sklearn.utils import compute_class_weight | |
| from skorch.callbacks import ( | |
| EarlyStopping, | |
| EpochScoring, | |
| GradientNormClipping, | |
| LRScheduler, | |
| ) | |
| from skorch.dataset import ValidSplit | |
| from skorch.helper import SliceDataset, predefined_split | |
| from torch import nn | |
| from braindecode import EEGClassifier, EEGRegressor | |
| from braindecode._tutorial_hub import ( | |
| save_tutorial_checkpoint, | |
| tutorial_repo_id, | |
| upload_tutorial_artifacts, | |
| ) | |
| from braindecode.augmentation import ( | |
| AugmentedDataLoader, | |
| ChannelsDropout, | |
| FTSurrogate, | |
| IdentityTransform, | |
| SmoothTimeMask, | |
| ) | |
| from braindecode.datasets import BCICompetitionIVDataset4, MOABBDataset, SleepPhysionet | |
| from braindecode.datautil import infer_signal_properties | |
| from braindecode.models import ( | |
| AttnSleep, | |
| EEGNeX, | |
| ShallowFBCSPNet, | |
| SleepStagerChambon2018, | |
| USleep, | |
| ) | |
| from braindecode.modules import TimeDistributed | |
| from braindecode.preprocessing import ( | |
| EEGPrep, | |
| Preprocessor, | |
| create_fixed_length_windows, | |
| create_windows_from_events, | |
| create_windows_from_target_channels, | |
| exponential_moving_standardize, | |
| preprocess, | |
| ) | |
| from braindecode.samplers import SequenceSampler | |
| from braindecode.training import ( | |
| CroppedLoss, | |
| CroppedTimeSeriesEpochScoring, | |
| TimeSeriesLoss, | |
| ) | |
| from braindecode.util import set_random_seeds | |
| AVAILABLE_TUTORIALS = ( | |
| "plot_bcic_iv_2a_moabb_trial", | |
| "plot_bcic_iv_2a_moabb_cropped", | |
| "plot_bcic_iv_2a_eegprep_cleaning", | |
| "bcic_iv_4_ecog_trial", | |
| "bcic_iv_4_ecog_cropped", | |
| "plot_sleep_staging_usleep", | |
| "plot_sleep_staging_eldele2021", | |
| "plot_sleep_staging_chambon2018", | |
| "plot_data_augmentation_search", | |
| ) | |
| @dataclass | |
| class TutorialArtifacts: | |
| clf: EEGClassifier | EEGRegressor | |
| repo_id: str | |
| metadata: dict | |
| # Global wandb state, set by main() when --wandb is passed. | |
| _WANDB_PROJECT: str | None = None | |
| def _make_wandb_callback(tutorial_name: str, config: dict): | |
| """Create a WandbLogger callback if wandb is enabled, else return None.""" | |
| if _WANDB_PROJECT is None: | |
| return None | |
| import wandb # isort: skip | |
| from skorch.callbacks import WandbLogger # isort: skip | |
| run = wandb.init( | |
| project=_WANDB_PROJECT, | |
| name=tutorial_name, | |
| config=config, | |
| reinit=True, | |
| ) | |
| return ("wandb_logger", WandbLogger(run, save_model=False)), run | |
| def _save_loss_curve(clf, output_dir: Path, tutorial_name: str): | |
| """Save a loss/metric curve plot as a PNG artifact.""" | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| history = clf.history | |
| epochs = [h["epoch"] for h in history] | |
| has_bal_acc = "valid_bal_acc" in history[0] | |
| n_plots = 2 if has_bal_acc else 1 | |
| fig, axes = plt.subplots(n_plots, 1, figsize=(8, 3.5 * n_plots), sharex=True) | |
| if n_plots == 1: | |
| axes = [axes] | |
| axes[0].plot(epochs, history[:, "train_loss"], "r-", label="Train") | |
| axes[0].plot(epochs, history[:, "valid_loss"], "b-", label="Valid") | |
| axes[0].set_ylabel("Loss") | |
| axes[0].legend() | |
| axes[0].grid(alpha=0.3) | |
| axes[0].set_title(tutorial_name) | |
| if has_bal_acc: | |
| axes[1].plot(epochs, history[:, "train_bal_acc"], "r-", label="Train") | |
| axes[1].plot(epochs, history[:, "valid_bal_acc"], "b-", label="Valid") | |
| axes[1].set_ylabel("Balanced accuracy") | |
| axes[1].legend() | |
| axes[1].grid(alpha=0.3) | |
| axes[-1].set_xlabel("Epoch") | |
| fig.tight_layout() | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| plot_path = output_dir / "loss_curve.png" | |
| fig.savefig(plot_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f"Saved loss curve to {plot_path}") | |
| return plot_path | |
| SEARCH_RESULTS_FILENAME = "search_results.csv" | |
| SLEEP_MAPPING = { | |
| "Sleep stage W": 0, | |
| "Sleep stage 1": 1, | |
| "Sleep stage 2": 2, | |
| "Sleep stage 3": 3, | |
| "Sleep stage 4": 3, | |
| "Sleep stage R": 4, | |
| } | |
| def _common_preprocessing( | |
| dataset, | |
| *, | |
| use_eegprep: bool = False, | |
| n_jobs: int = 1, | |
| ): | |
| preprocessors = [ | |
| Preprocessor("pick_types", eeg=True, meg=False, stim=False), | |
| Preprocessor(lambda data: multiply(data, 1e6)), | |
| ] | |
| if use_eegprep: | |
| preprocessors.append( | |
| EEGPrep( | |
| resample_to=128, | |
| bad_window_max_bad_channels=None, | |
| ) | |
| ) | |
| preprocessors.extend( | |
| [ | |
| Preprocessor("filter", l_freq=4.0, h_freq=38.0), | |
| Preprocessor( | |
| exponential_moving_standardize, | |
| factor_new=1e-3, | |
| init_block_size=1000, | |
| ), | |
| ] | |
| ) | |
| preprocess(dataset, preprocessors, n_jobs=n_jobs) | |
| def _ecog_preprocessing(dataset, *, n_jobs: int = 1): | |
| preprocessors = [ | |
| Preprocessor("pick_types", ecog=True, misc=True), | |
| Preprocessor(lambda x: x / 1e6, picks="ecog"), | |
| Preprocessor("filter", l_freq=1.0, h_freq=200.0), | |
| Preprocessor( | |
| exponential_moving_standardize, | |
| factor_new=1e-3, | |
| init_block_size=1000, | |
| picks="ecog", | |
| ), | |
| ] | |
| preprocess(dataset, preprocessors, n_jobs=n_jobs) | |
| def _trialwise_windows(dataset): | |
| sfreq = dataset.datasets[0].raw.info["sfreq"] | |
| return create_windows_from_events( | |
| dataset, | |
| trial_start_offset_samples=int(-0.5 * sfreq), | |
| trial_stop_offset_samples=0, | |
| preload=True, | |
| ) | |
| def _training_callbacks(epochs: int, patience: int): | |
| return [ | |
| "accuracy", | |
| ("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=max(1, epochs - 1))), | |
| ("early_stopping", EarlyStopping(patience=patience, load_best=True)), | |
| ] | |
| def _device_and_seed(seed: int = 20200220): | |
| cuda = torch.cuda.is_available() | |
| mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() | |
| if cuda: | |
| torch.backends.cudnn.benchmark = True | |
| set_random_seeds(seed=seed, cuda=cuda) | |
| return cuda, "cuda" if cuda else "mps" if mps else "cpu" | |
| def _pearson_r_score(net, dataset, y): | |
| preds = net.predict(dataset) | |
| corr_coeffs = [np.corrcoef(y[:, i], preds[:, i])[0, 1] for i in range(y.shape[1])] | |
| return float(np.nanmean(corr_coeffs)) | |
| def _trialwise_shallow( | |
| tutorial_name: str, | |
| *, | |
| subject_id: int, | |
| epochs: int, | |
| patience: int, | |
| ) -> TutorialArtifacts: | |
| dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id]) | |
| _common_preprocessing(dataset, n_jobs=-1) | |
| windows_dataset = _trialwise_windows(dataset) | |
| splitted = windows_dataset.split("session") | |
| train_set = splitted["0train"] | |
| valid_set = splitted["1test"] | |
| sig_props = infer_signal_properties(train_set, mode="classification") | |
| _, device = _device_and_seed() | |
| model = ShallowFBCSPNet( | |
| n_chans=sig_props["n_chans"], | |
| n_outputs=sig_props["n_outputs"], | |
| n_times=sig_props["n_times"], | |
| final_conv_length="auto", | |
| ) | |
| clf = EEGClassifier( | |
| model, | |
| criterion=torch.nn.CrossEntropyLoss, | |
| optimizer=torch.optim.AdamW, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=0.0625 * 0.01, | |
| optimizer__weight_decay=0, | |
| batch_size=64, | |
| callbacks=_training_callbacks(epochs, patience), | |
| device=device, | |
| classes=list(range(sig_props["n_outputs"])), | |
| ) | |
| clf.fit(train_set, y=None, epochs=epochs) | |
| metadata = { | |
| "best_valid_accuracy": max(clf.history[:, "valid_accuracy"]), | |
| "chance_level": 0.25, | |
| "display_metric_key": "best_valid_accuracy", | |
| "display_metric_name": "accuracy", | |
| "display_split_name": "held-out session", | |
| "epochs_ran": len(clf.history), | |
| "epochs_requested": epochs, | |
| "final_valid_accuracy": clf.history[-1, "valid_accuracy"], | |
| "patience": patience, | |
| "short_run_epochs": 4, | |
| "subject_id": subject_id, | |
| "tutorial": tutorial_name, | |
| } | |
| return TutorialArtifacts( | |
| clf=clf, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _ecog_trialwise(*, epochs: int, patience: int) -> TutorialArtifacts: | |
| tutorial_name = "bcic_iv_4_ecog_trial" | |
| dataset = BCICompetitionIVDataset4(subject_ids=[1]) | |
| _ecog_preprocessing(dataset, n_jobs=1) | |
| windows_dataset = create_windows_from_target_channels( | |
| dataset, window_size_samples=1000, preload=False, last_target_only=True | |
| ) | |
| windows_dataset.target_transform = lambda x: x[0:1] | |
| subsets = windows_dataset.split("session") | |
| train_set = subsets["train"] | |
| test_set = subsets["test"] | |
| idx_train, idx_valid = train_test_split( | |
| np.arange(len(train_set)), | |
| random_state=100, | |
| test_size=0.2, | |
| shuffle=False, | |
| ) | |
| valid_set = torch.utils.data.Subset(train_set, idx_valid) | |
| train_set = torch.utils.data.Subset(train_set, idx_train) | |
| _, device = _device_and_seed() | |
| model = ShallowFBCSPNet( | |
| train_set[0][0].shape[0], | |
| train_set[0][1].shape[0], | |
| n_times=1000, | |
| final_conv_length="auto", | |
| ) | |
| regressor = EEGRegressor( | |
| model, | |
| criterion=torch.nn.MSELoss, | |
| optimizer=torch.optim.AdamW, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=0.0625 * 0.01, | |
| optimizer__weight_decay=0, | |
| batch_size=64, | |
| callbacks=[ | |
| "r2", | |
| ( | |
| "valid_pearson_r", | |
| EpochScoring( | |
| _pearson_r_score, | |
| lower_is_better=False, | |
| on_train=False, | |
| name="valid_pearson_r", | |
| ), | |
| ), | |
| ( | |
| "train_pearson_r", | |
| EpochScoring( | |
| _pearson_r_score, | |
| lower_is_better=False, | |
| on_train=True, | |
| name="train_pearson_r", | |
| ), | |
| ), | |
| ( | |
| "lr_scheduler", | |
| LRScheduler("CosineAnnealingLR", T_max=max(1, epochs - 1)), | |
| ), | |
| ("early_stopping", EarlyStopping(patience=patience, load_best=True)), | |
| ], | |
| device=device, | |
| ) | |
| regressor.fit(train_set, y=None, epochs=epochs) | |
| preds_test = regressor.predict(test_set) | |
| y_test = np.stack([data[1] for data in test_set]) | |
| corr_coeffs = [ | |
| float(np.corrcoef(preds_test[:, dim], y_test[:, dim])[0, 1]) | |
| for dim in range(y_test.shape[1]) | |
| ] | |
| metadata = { | |
| "best_valid_pearson_r": float(max(regressor.history[:, "valid_pearson_r"])), | |
| "best_valid_r2": float(max(regressor.history[:, "valid_r2"])), | |
| "display_metric_as_percentage": False, | |
| "display_metric_key": "test_mean_pearson_r", | |
| "display_metric_name": "mean Pearson r", | |
| "display_split_name": "held-out test session", | |
| "epochs_ran": len(regressor.history), | |
| "epochs_requested": epochs, | |
| "final_valid_pearson_r": float(regressor.history[-1, "valid_pearson_r"]), | |
| "final_valid_r2": float(regressor.history[-1, "valid_r2"]), | |
| "patience": patience, | |
| "reference_recording_scope": "whole recording", | |
| "reference_uses_full_recordings": True, | |
| "short_run_epochs": 2, | |
| "subject_id": 1, | |
| "test_mean_pearson_r": float(np.nanmean(corr_coeffs)), | |
| "test_pearson_r_per_dim": corr_coeffs, | |
| "tutorial": tutorial_name, | |
| } | |
| return TutorialArtifacts( | |
| clf=regressor, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _pad_and_select_predictions(preds, y): | |
| preds = np.pad( | |
| preds, | |
| ((0, 0), (0, 0), (y.shape[2] - preds.shape[2], 0)), | |
| "constant", | |
| constant_values=0, | |
| ) | |
| mask = ~np.isnan(y[0, 0, :]) | |
| preds = np.squeeze(preds[..., mask], 0) | |
| y = np.squeeze(y[..., mask], 0) | |
| return y.T, preds.T | |
| def _ecog_cropped(*, epochs: int, patience: int) -> TutorialArtifacts: | |
| tutorial_name = "bcic_iv_4_ecog_cropped" | |
| dataset = BCICompetitionIVDataset4(subject_ids=[1]) | |
| dataset_split = dataset.split("session") | |
| train_set = dataset_split["train"] | |
| test_set = dataset_split["test"] | |
| train_duration_s = float(train_set.datasets[0].raw.times[-1]) | |
| valid_tmin_s = 0.8 * train_duration_s | |
| valid_set = preprocess( | |
| copy.deepcopy(train_set), | |
| [Preprocessor("crop", tmin=valid_tmin_s, tmax=None)], | |
| n_jobs=1, | |
| ) | |
| preprocess( | |
| train_set, | |
| [Preprocessor("crop", tmin=0, tmax=valid_tmin_s)], | |
| n_jobs=1, | |
| ) | |
| _ecog_preprocessing(train_set, n_jobs=1) | |
| _ecog_preprocessing(valid_set, n_jobs=1) | |
| _ecog_preprocessing(test_set, n_jobs=1) | |
| _, device = _device_and_seed() | |
| n_times = 1000 | |
| n_chans = train_set[0][0].shape[0] - 5 | |
| model = ShallowFBCSPNet( | |
| n_chans, | |
| 1, | |
| n_times=n_times, | |
| final_conv_length=2, | |
| ) | |
| n_preds_per_input = model.get_output_shape()[2] | |
| train_windows = create_fixed_length_windows( | |
| train_set, | |
| start_offset_samples=0, | |
| stop_offset_samples=None, | |
| window_size_samples=n_times, | |
| window_stride_samples=n_preds_per_input, | |
| drop_last_window=False, | |
| targets_from="channels", | |
| last_target_only=False, | |
| preload=False, | |
| ) | |
| valid_windows = create_fixed_length_windows( | |
| valid_set, | |
| start_offset_samples=0, | |
| stop_offset_samples=None, | |
| window_size_samples=n_times, | |
| window_stride_samples=n_preds_per_input, | |
| drop_last_window=False, | |
| targets_from="channels", | |
| last_target_only=False, | |
| preload=False, | |
| ) | |
| test_windows = create_fixed_length_windows( | |
| test_set, | |
| start_offset_samples=0, | |
| stop_offset_samples=None, | |
| window_size_samples=n_times, | |
| window_stride_samples=n_preds_per_input, | |
| drop_last_window=False, | |
| targets_from="channels", | |
| last_target_only=False, | |
| preload=False, | |
| ) | |
| train_windows.target_transform = lambda x: x[0:1] | |
| valid_windows.target_transform = lambda x: x[0:1] | |
| test_windows.target_transform = lambda x: x[0:1] | |
| regressor = EEGRegressor( | |
| model, | |
| cropped=True, | |
| aggregate_predictions=False, | |
| criterion=TimeSeriesLoss, | |
| criterion__loss_function=torch.nn.functional.mse_loss, | |
| optimizer=torch.optim.AdamW, | |
| train_split=predefined_split(valid_windows), | |
| optimizer__lr=0.0625 * 0.01, | |
| optimizer__weight_decay=0, | |
| iterator_train__shuffle=True, | |
| batch_size=27, | |
| callbacks=[ | |
| ( | |
| "lr_scheduler", | |
| LRScheduler("CosineAnnealingLR", T_max=max(1, epochs - 1)), | |
| ), | |
| ( | |
| "r2_train", | |
| CroppedTimeSeriesEpochScoring( | |
| r2_score, | |
| lower_is_better=False, | |
| on_train=True, | |
| name="r2_train", | |
| ), | |
| ), | |
| ( | |
| "r2_valid", | |
| CroppedTimeSeriesEpochScoring( | |
| r2_score, | |
| lower_is_better=False, | |
| on_train=False, | |
| name="r2_valid", | |
| ), | |
| ), | |
| ("early_stopping", EarlyStopping(patience=patience, load_best=True)), | |
| ], | |
| device=device, | |
| ) | |
| regressor.fit(train_windows, y=None, epochs=epochs) | |
| preds_test, y_test = regressor.predict_trials(test_windows, return_targets=True) | |
| preds_test, y_test = _pad_and_select_predictions(preds_test, y_test) | |
| corr_coeffs = [ | |
| float(np.corrcoef(preds_test[:, dim], y_test[:, dim])[0, 1]) | |
| for dim in range(y_test.shape[1]) | |
| ] | |
| metadata = { | |
| "best_valid_r2": float(max(regressor.history[:, "r2_valid"])), | |
| "display_metric_as_percentage": False, | |
| "display_metric_key": "test_mean_pearson_r", | |
| "display_metric_name": "mean Pearson r", | |
| "display_split_name": "held-out test session", | |
| "epochs_ran": len(regressor.history), | |
| "epochs_requested": epochs, | |
| "final_valid_r2": float(regressor.history[-1, "r2_valid"]), | |
| "patience": patience, | |
| "reference_recording_scope": "whole recording", | |
| "reference_uses_full_recordings": True, | |
| "short_run_epochs": 8, | |
| "subject_id": 1, | |
| "test_mean_pearson_r": float(np.nanmean(corr_coeffs)), | |
| "test_pearson_r_per_dim": corr_coeffs, | |
| "tutorial": tutorial_name, | |
| "valid_fraction": 0.2, | |
| } | |
| return TutorialArtifacts( | |
| clf=regressor, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _get_center_label(x): | |
| if isinstance(x, Integral): | |
| return x | |
| return x[np.ceil(len(x) / 2).astype(int)] if len(x) > 1 else x | |
| def _sleep_physionet_windows( | |
| *, | |
| crop=None, | |
| picks=None, | |
| recording_ids=None, | |
| subject_ids=None, | |
| train_subject_ids=None, | |
| valid_subject_ids=None, | |
| raw_preprocessors: list[Preprocessor], | |
| window_preprocessors: list[Preprocessor] | None = None, | |
| ): | |
| if recording_ids is None: | |
| recording_ids = [2] | |
| if subject_ids is None: | |
| subject_ids = [0, 1] | |
| if train_subject_ids is None: | |
| train_subject_ids = subject_ids[::2] | |
| if valid_subject_ids is None: | |
| valid_subject_ids = subject_ids[1::2] | |
| dataset = SleepPhysionet( | |
| subject_ids=subject_ids, | |
| recording_ids=recording_ids, | |
| crop_wake_mins=30, | |
| crop=crop, | |
| ) | |
| preprocess(dataset, raw_preprocessors, n_jobs=-1) | |
| windows_dataset = create_windows_from_events( | |
| dataset, | |
| trial_start_offset_samples=0, | |
| trial_stop_offset_samples=0, | |
| window_size_samples=30 * 100, | |
| window_stride_samples=30 * 100, | |
| picks=picks, | |
| preload=True, | |
| mapping=SLEEP_MAPPING, | |
| ) | |
| if window_preprocessors: | |
| preprocess(windows_dataset, window_preprocessors, n_jobs=-1) | |
| splits = windows_dataset.split( | |
| dict(train=train_subject_ids, valid=valid_subject_ids) | |
| ) | |
| train_set, valid_set = splits["train"], splits["valid"] | |
| train_sampler = SequenceSampler(train_set.get_metadata(), 3, 3, randomize=True) | |
| valid_sampler = SequenceSampler(valid_set.get_metadata(), 3, 3) | |
| return train_set, valid_set, train_sampler, valid_sampler | |
| def _balanced_accuracy_multi(model, X, y): | |
| y_pred = model.predict(X) | |
| return balanced_accuracy_score(y.flatten(), y_pred.flatten()) | |
| def _sleep_usleep(*, epochs: int, patience: int) -> TutorialArtifacts: | |
| tutorial_name = "plot_sleep_staging_usleep" | |
| train_set, valid_set, train_sampler, valid_sampler = _sleep_physionet_windows( | |
| crop=(0, 30 * 400), | |
| raw_preprocessors=[Preprocessor(robust_scale, channel_wise=True)], | |
| ) | |
| y_train = [train_set[idx][1][1] for idx in train_sampler] | |
| class_weights = compute_class_weight( | |
| "balanced", classes=np.unique(y_train), y=y_train | |
| ) | |
| _, device = _device_and_seed(seed=31) | |
| n_classes = 5 | |
| classes = list(range(n_classes)) | |
| in_chans, input_size_samples = train_set[0][0].shape | |
| model = USleep( | |
| n_chans=in_chans, | |
| sfreq=100, | |
| depth=12, | |
| with_skip_connection=True, | |
| n_outputs=n_classes, | |
| n_times=input_size_samples, | |
| ) | |
| if device != "cpu": | |
| model.to(device) | |
| callbacks = [ | |
| ( | |
| "train_bal_acc", | |
| EpochScoring( | |
| scoring=_balanced_accuracy_multi, | |
| on_train=True, | |
| name="train_bal_acc", | |
| lower_is_better=False, | |
| ), | |
| ), | |
| ( | |
| "valid_bal_acc", | |
| EpochScoring( | |
| scoring=_balanced_accuracy_multi, | |
| on_train=False, | |
| name="valid_bal_acc", | |
| lower_is_better=False, | |
| ), | |
| ), | |
| ("early_stopping", EarlyStopping(patience=patience, load_best=True)), | |
| ] | |
| clf = EEGClassifier( | |
| model, | |
| criterion=torch.nn.CrossEntropyLoss, | |
| criterion__weight=torch.Tensor(class_weights).to(device), | |
| optimizer=torch.optim.Adam, | |
| iterator_train__shuffle=False, | |
| iterator_train__sampler=train_sampler, | |
| iterator_valid__sampler=valid_sampler, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=1e-3, | |
| batch_size=32, | |
| callbacks=callbacks, | |
| device=device, | |
| classes=classes, | |
| ) | |
| clf.set_params(callbacks__valid_acc=None) | |
| clf.fit(train_set, y=None, epochs=epochs) | |
| metadata = { | |
| "best_valid_bal_acc": max(clf.history[:, "valid_bal_acc"]), | |
| "chance_level": 0.20, | |
| "display_metric_key": "best_valid_bal_acc", | |
| "display_metric_name": "balanced accuracy", | |
| "display_split_name": "held-out recording", | |
| "epochs_ran": len(clf.history), | |
| "epochs_requested": epochs, | |
| "final_valid_bal_acc": clf.history[-1, "valid_bal_acc"], | |
| "patience": patience, | |
| "short_run_epochs": 3, | |
| "use_safetensors": True, | |
| "tutorial": tutorial_name, | |
| } | |
| return TutorialArtifacts( | |
| clf=clf, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _sleep_attnsleep(*, epochs: int, patience: int) -> TutorialArtifacts: | |
| tutorial_name = "plot_sleep_staging_eldele2021" | |
| train_set, valid_set, train_sampler, valid_sampler = _sleep_physionet_windows( | |
| picks="Fpz-Cz", | |
| recording_ids=[1, 2], | |
| subject_ids=[0, 1, 2, 3, 4, 5], | |
| train_subject_ids=[0, 1, 2, 3], | |
| valid_subject_ids=[4, 5], | |
| raw_preprocessors=[ | |
| Preprocessor(lambda data: multiply(data, 1e6), apply_on_array=True), | |
| Preprocessor("filter", l_freq=None, h_freq=30), | |
| ], | |
| window_preprocessors=[Preprocessor(standard_scale, channel_wise=True)], | |
| ) | |
| train_set.target_transform = _get_center_label | |
| valid_set.target_transform = _get_center_label | |
| y_train = [train_set[idx][1] for idx in train_sampler] | |
| class_weights = compute_class_weight( | |
| "balanced", classes=np.unique(y_train), y=y_train | |
| ) | |
| _, device = _device_and_seed(seed=31) | |
| n_classes = 5 | |
| feat_extractor = AttnSleep( | |
| sfreq=100, | |
| n_outputs=n_classes, | |
| n_times=train_set[0][0].shape[1], | |
| drop_prob=0.3, | |
| return_feats=True, | |
| ) | |
| model = nn.Sequential( | |
| TimeDistributed(feat_extractor), | |
| nn.Sequential( | |
| nn.Flatten(start_dim=1), | |
| nn.Dropout(0.5), | |
| nn.Linear(feat_extractor.len_last_layer * 3, n_classes), | |
| ), | |
| ) | |
| if device != "cpu": | |
| model.to(device) | |
| callbacks = [ | |
| ( | |
| "train_bal_acc", | |
| EpochScoring( | |
| scoring="balanced_accuracy", | |
| on_train=True, | |
| name="train_bal_acc", | |
| lower_is_better=False, | |
| ), | |
| ), | |
| ( | |
| "valid_bal_acc", | |
| EpochScoring( | |
| scoring="balanced_accuracy", | |
| on_train=False, | |
| name="valid_bal_acc", | |
| lower_is_better=False, | |
| ), | |
| ), | |
| ( | |
| "lr_scheduler", | |
| LRScheduler("CosineAnnealingLR", T_max=max(1, epochs - 1)), | |
| ), | |
| ( | |
| "grad_clip", | |
| GradientNormClipping(gradient_clip_value=1.0), | |
| ), | |
| ( | |
| "early_stopping", | |
| EarlyStopping( | |
| monitor="valid_bal_acc", | |
| lower_is_better=False, | |
| patience=patience, | |
| load_best=True, | |
| ), | |
| ), | |
| ] | |
| wandb_result = _make_wandb_callback( | |
| tutorial_name, | |
| config={ | |
| "model": "AttnSleep", | |
| "sfreq": 100, | |
| "drop_prob": 0.3, | |
| "n_windows": 3, | |
| "lr": 1e-3, | |
| "weight_decay": 1e-3, | |
| "label_smoothing": 0.1, | |
| "batch_size": 32, | |
| "epochs": epochs, | |
| "patience": patience, | |
| "train_subjects": [0, 1, 2, 3], | |
| "valid_subjects": [4, 5], | |
| "recording_ids": [1, 2], | |
| }, | |
| ) | |
| if wandb_result is not None: | |
| callbacks.append(wandb_result[0]) | |
| clf = EEGClassifier( | |
| model, | |
| criterion=torch.nn.CrossEntropyLoss, | |
| criterion__weight=torch.Tensor(class_weights).to(device), | |
| criterion__label_smoothing=0.1, | |
| optimizer=torch.optim.Adam, | |
| iterator_train__shuffle=False, | |
| iterator_train__sampler=train_sampler, | |
| iterator_valid__sampler=valid_sampler, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=1e-3, | |
| optimizer__weight_decay=1e-3, | |
| batch_size=32, | |
| callbacks=callbacks, | |
| device=device, | |
| classes=np.unique(y_train), | |
| ) | |
| clf.fit(train_set, y=None, epochs=epochs) | |
| metadata = { | |
| "best_valid_bal_acc": max(clf.history[:, "valid_bal_acc"]), | |
| "chance_level": 0.20, | |
| "display_metric_key": "best_valid_bal_acc", | |
| "display_metric_name": "balanced accuracy", | |
| "display_split_name": "held-out recording", | |
| "epochs_ran": len(clf.history), | |
| "epochs_requested": epochs, | |
| "final_valid_bal_acc": clf.history[-1, "valid_bal_acc"], | |
| "patience": patience, | |
| "short_run_epochs": 3, | |
| "use_safetensors": False, | |
| "tutorial": tutorial_name, | |
| } | |
| if wandb_result is not None: | |
| import wandb | |
| metadata["wandb_run_url"] = wandb_result[1].get_url() | |
| wandb.finish() | |
| return TutorialArtifacts( | |
| clf=clf, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _sleep_chambon(*, epochs: int, patience: int) -> TutorialArtifacts: | |
| tutorial_name = "plot_sleep_staging_chambon2018" | |
| train_set, valid_set, train_sampler, valid_sampler = _sleep_physionet_windows( | |
| raw_preprocessors=[ | |
| Preprocessor(lambda data: multiply(data, 1e6), apply_on_array=True), | |
| Preprocessor("filter", l_freq=None, h_freq=30), | |
| ], | |
| window_preprocessors=[Preprocessor(standard_scale, channel_wise=True)], | |
| ) | |
| train_set.target_transform = _get_center_label | |
| valid_set.target_transform = _get_center_label | |
| y_train = [train_set[idx][1] for idx in train_sampler] | |
| class_weights = compute_class_weight( | |
| "balanced", classes=np.unique(y_train), y=y_train | |
| ) | |
| _, device = _device_and_seed(seed=31) | |
| n_classes = 5 | |
| n_channels, input_size_samples = train_set[0][0].shape | |
| feat_extractor = SleepStagerChambon2018( | |
| n_channels, | |
| 100, | |
| n_outputs=n_classes, | |
| n_times=input_size_samples, | |
| return_feats=True, | |
| ) | |
| model = nn.Sequential( | |
| TimeDistributed(feat_extractor), | |
| nn.Sequential( | |
| nn.Flatten(start_dim=1), | |
| nn.Dropout(0.5), | |
| nn.Linear(feat_extractor.len_last_layer * 3, n_classes), | |
| ), | |
| ) | |
| if device != "cpu": | |
| model.to(device) | |
| callbacks = [ | |
| ( | |
| "train_bal_acc", | |
| EpochScoring( | |
| scoring="balanced_accuracy", | |
| on_train=True, | |
| name="train_bal_acc", | |
| lower_is_better=False, | |
| ), | |
| ), | |
| ( | |
| "valid_bal_acc", | |
| EpochScoring( | |
| scoring="balanced_accuracy", | |
| on_train=False, | |
| name="valid_bal_acc", | |
| lower_is_better=False, | |
| ), | |
| ), | |
| ("early_stopping", EarlyStopping(patience=patience, load_best=True)), | |
| ] | |
| clf = EEGClassifier( | |
| model, | |
| criterion=torch.nn.CrossEntropyLoss, | |
| criterion__weight=torch.Tensor(class_weights).to(device), | |
| optimizer=torch.optim.Adam, | |
| iterator_train__shuffle=False, | |
| iterator_train__sampler=train_sampler, | |
| iterator_valid__sampler=valid_sampler, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=1e-3, | |
| batch_size=32, | |
| callbacks=callbacks, | |
| device=device, | |
| classes=np.unique(y_train), | |
| ) | |
| clf.fit(train_set, y=None, epochs=epochs) | |
| metadata = { | |
| "best_valid_bal_acc": max(clf.history[:, "valid_bal_acc"]), | |
| "chance_level": 0.20, | |
| "display_metric_key": "best_valid_bal_acc", | |
| "display_metric_name": "balanced accuracy", | |
| "display_split_name": "held-out recording", | |
| "epochs_ran": len(clf.history), | |
| "epochs_requested": epochs, | |
| "final_valid_bal_acc": clf.history[-1, "valid_bal_acc"], | |
| "patience": patience, | |
| "short_run_epochs": 10, | |
| "tutorial": tutorial_name, | |
| } | |
| return TutorialArtifacts( | |
| clf=clf, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _cropped_shallow( | |
| *, subject_id: int, epochs: int, patience: int | |
| ) -> TutorialArtifacts: | |
| tutorial_name = "plot_bcic_iv_2a_moabb_cropped" | |
| dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id]) | |
| _common_preprocessing(dataset, n_jobs=-1) | |
| _, device = _device_and_seed() | |
| n_times = 1000 | |
| model = ShallowFBCSPNet(22, 4, n_times=n_times, final_conv_length=30) | |
| model.to_dense_prediction_model() | |
| n_preds_per_input = model.get_output_shape()[2] | |
| sfreq = dataset.datasets[0].raw.info["sfreq"] | |
| windows_dataset = create_windows_from_events( | |
| dataset, | |
| trial_start_offset_samples=int(-0.5 * sfreq), | |
| trial_stop_offset_samples=0, | |
| window_size_samples=n_times, | |
| window_stride_samples=n_preds_per_input, | |
| drop_last_window=False, | |
| preload=True, | |
| ) | |
| splitted = windows_dataset.split("session") | |
| train_set = splitted["0train"] | |
| valid_set = splitted["1test"] | |
| clf = EEGClassifier( | |
| model, | |
| cropped=True, | |
| criterion=CroppedLoss, | |
| criterion__loss_function=torch.nn.functional.cross_entropy, | |
| optimizer=torch.optim.AdamW, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=0.0625 * 0.01, | |
| optimizer__weight_decay=0, | |
| iterator_train__shuffle=True, | |
| batch_size=64, | |
| callbacks=_training_callbacks(epochs, patience), | |
| device=device, | |
| classes=list(range(4)), | |
| ) | |
| clf.fit(train_set, y=None, epochs=epochs) | |
| metadata = { | |
| "best_valid_accuracy": max(clf.history[:, "valid_accuracy"]), | |
| "chance_level": 0.25, | |
| "display_metric_key": "best_valid_accuracy", | |
| "display_metric_name": "accuracy", | |
| "display_split_name": "held-out session", | |
| "epochs_ran": len(clf.history), | |
| "epochs_requested": epochs, | |
| "final_valid_accuracy": clf.history[-1, "valid_accuracy"], | |
| "patience": patience, | |
| "short_run_epochs": 2, | |
| "subject_id": subject_id, | |
| "tutorial": tutorial_name, | |
| } | |
| return TutorialArtifacts( | |
| clf=clf, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _eegprep_eegnex( | |
| *, subject_id: int, epochs: int, patience: int | |
| ) -> TutorialArtifacts: | |
| tutorial_name = "plot_bcic_iv_2a_eegprep_cleaning" | |
| dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id]) | |
| _common_preprocessing(dataset, use_eegprep=True, n_jobs=-1) | |
| windows_dataset = _trialwise_windows(dataset) | |
| splitted = windows_dataset.split("session") | |
| train_set = splitted["0train"] | |
| valid_set = splitted["1test"] | |
| _, device = _device_and_seed() | |
| model = EEGNeX( | |
| n_chans=train_set[0][0].shape[0], | |
| n_outputs=4, | |
| n_times=train_set[0][0].shape[1], | |
| ) | |
| clf = EEGClassifier( | |
| model, | |
| criterion=torch.nn.CrossEntropyLoss, | |
| optimizer=torch.optim.AdamW, | |
| train_split=predefined_split(valid_set), | |
| optimizer__lr=1e-3, | |
| optimizer__weight_decay=0, | |
| batch_size=64, | |
| callbacks=_training_callbacks(epochs, patience), | |
| device=device, | |
| classes=list(range(4)), | |
| ) | |
| clf.fit(train_set, y=None, epochs=epochs) | |
| metadata = { | |
| "best_valid_accuracy": max(clf.history[:, "valid_accuracy"]), | |
| "chance_level": 0.25, | |
| "display_metric_key": "best_valid_accuracy", | |
| "display_metric_name": "accuracy", | |
| "display_split_name": "held-out session", | |
| "epochs_ran": len(clf.history), | |
| "epochs_requested": epochs, | |
| "final_valid_accuracy": clf.history[-1, "valid_accuracy"], | |
| "patience": patience, | |
| "short_run_epochs": 4, | |
| "subject_id": subject_id, | |
| "tutorial": tutorial_name, | |
| } | |
| return TutorialArtifacts( | |
| clf=clf, | |
| repo_id=tutorial_repo_id(tutorial_name), | |
| metadata=metadata, | |
| ) | |
| def _make_search_candidate( | |
| transform, | |
| *, | |
| augmentation: str, | |
| magnitude: float, | |
| display_magnitude: float, | |
| axis_label: str, | |
| candidate_label: str, | |
| sort_order: int, | |
| ): | |
| transform._tutorial_candidate_label = candidate_label | |
| transform._tutorial_augmentation = augmentation | |
| transform._tutorial_magnitude = magnitude | |
| transform._tutorial_display_magnitude = display_magnitude | |
| transform._tutorial_axis_label = axis_label | |
| transform._tutorial_sort_order = sort_order | |
| return transform | |
| def _data_augmentation_candidates(sfreq: float, seed: int): | |
| candidates = [ | |
| _make_search_candidate( | |
| IdentityTransform(), | |
| augmentation="IdentityTransform", | |
| magnitude=0.0, | |
| display_magnitude=0.0, | |
| axis_label="Identity baseline", | |
| candidate_label="IdentityTransform()", | |
| sort_order=0, | |
| ) | |
| ] | |
| for phase_noise in (0.1, 0.3, 0.5, 0.7, 0.9): | |
| candidates.append( | |
| _make_search_candidate( | |
| FTSurrogate( | |
| probability=0.5, | |
| phase_noise_magnitude=phase_noise, | |
| random_state=seed, | |
| ), | |
| augmentation="FTSurrogate", | |
| magnitude=phase_noise, | |
| display_magnitude=phase_noise, | |
| axis_label="Phase noise magnitude", | |
| candidate_label=f"FTSurrogate(phase_noise_magnitude={phase_noise:.1f})", | |
| sort_order=1, | |
| ) | |
| ) | |
| for mask_len_samples in (100, 200, 300, 400, 500): | |
| candidates.append( | |
| _make_search_candidate( | |
| SmoothTimeMask( | |
| probability=0.5, | |
| mask_len_samples=mask_len_samples, | |
| random_state=seed, | |
| ), | |
| augmentation="SmoothTimeMask", | |
| magnitude=mask_len_samples, | |
| display_magnitude=mask_len_samples / sfreq, | |
| axis_label="Mask length (s)", | |
| candidate_label=f"SmoothTimeMask(mask_len_samples={mask_len_samples})", | |
| sort_order=2, | |
| ) | |
| ) | |
| for p_drop in (0.2, 0.4, 0.6, 0.8, 1.0): | |
| candidates.append( | |
| _make_search_candidate( | |
| ChannelsDropout(probability=0.5, p_drop=p_drop, random_state=seed), | |
| augmentation="ChannelsDropout", | |
| magnitude=p_drop, | |
| display_magnitude=p_drop, | |
| axis_label="Drop probability", | |
| candidate_label=f"ChannelsDropout(p_drop={p_drop:.1f})", | |
| sort_order=3, | |
| ) | |
| ) | |
| return candidates | |
| def _data_augmentation_search_table(cv_results: dict) -> pd.DataFrame: | |
| rows = [] | |
| for index, params in enumerate(cv_results["params"]): | |
| transform = params["iterator_train__transforms"] | |
| rows.append( | |
| { | |
| "candidate_label": transform._tutorial_candidate_label, | |
| "augmentation": transform._tutorial_augmentation, | |
| "magnitude": transform._tutorial_magnitude, | |
| "display_magnitude": transform._tutorial_display_magnitude, | |
| "axis_label": transform._tutorial_axis_label, | |
| "sort_order": transform._tutorial_sort_order, | |
| "mean_training_accuracy": float(cv_results["mean_train_score"][index]), | |
| "std_training_accuracy": float(cv_results["std_train_score"][index]), | |
| "mean_validation_accuracy": float(cv_results["mean_test_score"][index]), | |
| "std_validation_accuracy": float(cv_results["std_test_score"][index]), | |
| "rank_validation_accuracy": int(cv_results["rank_test_score"][index]), | |
| } | |
| ) | |
| search_results = pd.DataFrame(rows).sort_values(["sort_order", "display_magnitude"]) | |
| identity_validation_score = float( | |
| search_results.loc[ | |
| search_results["augmentation"] == "IdentityTransform", | |
| "mean_validation_accuracy", | |
| ].iloc[0] | |
| ) | |
| identity_training_score = float( | |
| search_results.loc[ | |
| search_results["augmentation"] == "IdentityTransform", | |
| "mean_training_accuracy", | |
| ].iloc[0] | |
| ) | |
| search_results["relative_validation_improvement"] = ( | |
| search_results["mean_validation_accuracy"] / identity_validation_score - 1 | |
| ) | |
| search_results["relative_training_improvement"] = ( | |
| search_results["mean_training_accuracy"] / identity_training_score - 1 | |
| ) | |
| search_results["relative_validation_improvement_pct"] = ( | |
| search_results["relative_validation_improvement"] * 100 | |
| ) | |
| search_results["relative_training_improvement_pct"] = ( | |
| search_results["relative_training_improvement"] * 100 | |
| ) | |
| return search_results.reset_index(drop=True) | |
| def _data_augmentation_search(*, subject_id: int, epochs: int, patience: int): | |
| tutorial_name = "plot_data_augmentation_search" | |
| dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id]) | |
| _common_preprocessing(dataset, n_jobs=-1) | |
| windows_dataset = _trialwise_windows(dataset) | |
| splitted = windows_dataset.split("session") | |
| train_set = splitted["0train"] | |
| eval_set = splitted["1test"] | |
| seed = 20200220 | |
| sfreq = dataset.datasets[0].raw.info["sfreq"] | |
| search_candidates = _data_augmentation_candidates(sfreq, seed) | |
| _, device = _device_and_seed(seed=seed) | |
| model = ShallowFBCSPNet( | |
| n_chans=train_set[0][0].shape[0], | |
| n_outputs=4, | |
| n_times=train_set[0][0].shape[1], | |
| final_conv_length="auto", | |
| ) | |
| if device != "cpu": | |
| model.to(device) | |
| clf = EEGClassifier( | |
| model, | |
| iterator_train=AugmentedDataLoader, | |
| iterator_train__transforms=[IdentityTransform()], | |
| criterion=torch.nn.CrossEntropyLoss, | |
| optimizer=torch.optim.AdamW, | |
| train_split=ValidSplit(0.2, stratified=True, random_state=seed), | |
| optimizer__lr=0.0625 * 0.01, | |
| optimizer__weight_decay=0, | |
| batch_size=64, | |
| callbacks=[ | |
| "accuracy", | |
| ( | |
| "lr_scheduler", | |
| LRScheduler("CosineAnnealingLR", T_max=max(1, epochs - 1)), | |
| ), | |
| ("early_stopping", EarlyStopping(patience=patience, load_best=True)), | |
| ], | |
| device=device, | |
| classes=list(range(4)), | |
| ) | |
| clf.verbose = 0 | |
| train_X = SliceDataset(train_set, idx=0) | |
| train_y = np.array(list(SliceDataset(train_set, idx=1))) | |
| search = GridSearchCV( | |
| estimator=clf, | |
| param_grid={"iterator_train__transforms": search_candidates}, | |
| cv=KFold(n_splits=2, shuffle=True, random_state=seed), | |
| n_jobs=-1, | |
| return_train_score=True, | |
| scoring="accuracy", | |
| refit=True, | |
| verbose=1, | |
| error_score="raise", | |
| ) | |
| with parallel_backend("threading", n_jobs=-1): | |
| search.fit(train_X, train_y, epochs=epochs) | |
| search_results = _data_augmentation_search_table(search.cv_results_) | |
| best_run = search_results.sort_values( | |
| "mean_validation_accuracy", ascending=False | |
| ).iloc[0] | |
| identity_validation_score = float( | |
| search_results.loc[ | |
| search_results["augmentation"] == "IdentityTransform", | |
| "mean_validation_accuracy", | |
| ].iloc[0] | |
| ) | |
| eval_accuracy = float( | |
| search.score(SliceDataset(eval_set, idx=0), SliceDataset(eval_set, idx=1)) | |
| ) | |
| metadata = { | |
| "best_augmentation": best_run["augmentation"], | |
| "best_candidate": best_run["candidate_label"], | |
| "best_magnitude": float(best_run["magnitude"]), | |
| "best_relative_validation_improvement": float( | |
| best_run["relative_validation_improvement"] | |
| ), | |
| "chance_level": 0.25, | |
| "cv_splits": 2, | |
| "display_metric_key": "eval_accuracy", | |
| "display_metric_name": "accuracy", | |
| "display_split_name": "held-out session", | |
| "epochs_requested": epochs, | |
| "eval_accuracy": eval_accuracy, | |
| "identity_validation_score": identity_validation_score, | |
| "patience": patience, | |
| "search_candidates": len(search_candidates), | |
| "search_magnitudes_per_augmentation": 5, | |
| "short_run_epochs": 2, | |
| "training_score": float(best_run["mean_training_accuracy"]), | |
| "tutorial": tutorial_name, | |
| "validation_score": float(best_run["mean_validation_accuracy"]), | |
| } | |
| return tutorial_repo_id(tutorial_name), metadata, search_results | |
| def train_tutorial( | |
| tutorial_name: str, | |
| *, | |
| subject_id: int, | |
| epochs: int, | |
| patience: int, | |
| ) -> TutorialArtifacts: | |
| if tutorial_name == "plot_bcic_iv_2a_moabb_trial": | |
| return _trialwise_shallow( | |
| tutorial_name, | |
| subject_id=subject_id, | |
| epochs=epochs, | |
| patience=patience, | |
| ) | |
| if tutorial_name == "plot_bcic_iv_2a_moabb_cropped": | |
| return _cropped_shallow( | |
| subject_id=subject_id, | |
| epochs=epochs, | |
| patience=patience, | |
| ) | |
| if tutorial_name == "plot_bcic_iv_2a_eegprep_cleaning": | |
| return _eegprep_eegnex( | |
| subject_id=subject_id, | |
| epochs=epochs, | |
| patience=patience, | |
| ) | |
| if tutorial_name == "bcic_iv_4_ecog_trial": | |
| return _ecog_trialwise(epochs=epochs, patience=patience) | |
| if tutorial_name == "bcic_iv_4_ecog_cropped": | |
| return _ecog_cropped(epochs=epochs, patience=patience) | |
| if tutorial_name == "plot_sleep_staging_usleep": | |
| return _sleep_usleep(epochs=epochs, patience=patience) | |
| if tutorial_name == "plot_sleep_staging_eldele2021": | |
| return _sleep_attnsleep(epochs=epochs, patience=patience) | |
| if tutorial_name == "plot_sleep_staging_chambon2018": | |
| return _sleep_chambon(epochs=epochs, patience=patience) | |
| raise ValueError(f"Unsupported tutorial: {tutorial_name}") | |
| def _tutorial_example_path(tutorial_name: str) -> str: | |
| matches = list((REPO_ROOT / "examples").rglob(f"{tutorial_name}.py")) | |
| if not matches: | |
| return f"examples/{tutorial_name}.py" | |
| return str(matches[0].relative_to(REPO_ROOT)) | |
| def _build_readme(artifacts: TutorialArtifacts) -> str: | |
| tutorial = artifacts.metadata["tutorial"] | |
| example_path = _tutorial_example_path(tutorial) | |
| return ( | |
| f"# {artifacts.repo_id.split('/')[-1]}\n\n" | |
| "Pretrained artifacts for the Braindecode tutorial " | |
| f"`{example_path}`.\n\n" | |
| "These files are meant to be loaded by the tutorial so the docs can " | |
| "show stable predictions without retraining the model from scratch.\n\n" | |
| "## Stored files\n\n" | |
| "- `params.safetensors`: classifier parameters\n" | |
| "- `history.json`: Skorch training history used by the tutorial plots\n" | |
| "- `metadata.json`: summary metrics for the stored checkpoint\n" | |
| ) | |
| def _build_search_readme(repo_id: str, tutorial_name: str) -> str: | |
| example_path = _tutorial_example_path(tutorial_name) | |
| return ( | |
| f"# {repo_id.split('/')[-1]}\n\n" | |
| "Saved search results for the Braindecode tutorial " | |
| f"`{example_path}`.\n\n" | |
| "These files are meant to be loaded by the tutorial so the docs can " | |
| "plot the offline augmentation search without rerunning the full " | |
| "GridSearchCV procedure.\n\n" | |
| "## Stored files\n\n" | |
| f"- `{SEARCH_RESULTS_FILENAME}`: tidy cross-validation search summary\n" | |
| "- `metadata.json`: summary metrics for the saved search\n" | |
| ) | |
| def _save_search_artifacts( | |
| output_dir: Path, | |
| *, | |
| search_results: pd.DataFrame, | |
| metadata: dict, | |
| readme_text: str, | |
| ) -> None: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| search_results.to_csv(output_dir / SEARCH_RESULTS_FILENAME, index=False) | |
| (output_dir / "metadata.json").write_text( | |
| json.dumps(metadata, indent=2, sort_keys=True) + "\n" | |
| ) | |
| (output_dir / "README.md").write_text(readme_text) | |
| def _parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "Train Braindecode tutorial checkpoints and optionally push them to " | |
| "Hugging Face Hub." | |
| ) | |
| ) | |
| parser.add_argument( | |
| "--tutorial", | |
| choices=("all",) + AVAILABLE_TUTORIALS, | |
| default="all", | |
| help="Tutorial checkpoint to train.", | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=100, | |
| help="Maximum number of epochs to train.", | |
| ) | |
| parser.add_argument( | |
| "--patience", | |
| type=int, | |
| default=10, | |
| help="Early stopping patience.", | |
| ) | |
| parser.add_argument( | |
| "--subject-id", | |
| type=int, | |
| default=3, | |
| help="BCIC IV 2a subject id used by the tutorials.", | |
| ) | |
| parser.add_argument( | |
| "--output-root", | |
| type=Path, | |
| default=Path("tutorial_artifacts"), | |
| help="Directory where the generated artifacts are written.", | |
| ) | |
| parser.add_argument( | |
| "--push", | |
| action="store_true", | |
| help="Upload the generated artifacts to Hugging Face Hub.", | |
| ) | |
| parser.add_argument( | |
| "--private", | |
| action="store_true", | |
| help="Create private Hugging Face repos when pushing.", | |
| ) | |
| parser.add_argument( | |
| "--wandb", | |
| action="store_true", | |
| help="Log training metrics to Weights & Biases.", | |
| ) | |
| parser.add_argument( | |
| "--wandb-project", | |
| type=str, | |
| default="braindecode-tutorials", | |
| help="Weights & Biases project name.", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| args = _parse_args() | |
| global _WANDB_PROJECT | |
| if args.wandb: | |
| _WANDB_PROJECT = args.wandb_project | |
| tutorial_names = AVAILABLE_TUTORIALS if args.tutorial == "all" else (args.tutorial,) | |
| for tutorial_name in tutorial_names: | |
| print(f"Training {tutorial_name}...") | |
| if tutorial_name == "plot_data_augmentation_search": | |
| repo_id, metadata, search_results = _data_augmentation_search( | |
| subject_id=args.subject_id, | |
| epochs=args.epochs, | |
| patience=args.patience, | |
| ) | |
| output_dir = args.output_root / tutorial_name | |
| _save_search_artifacts( | |
| output_dir, | |
| search_results=search_results, | |
| metadata=metadata, | |
| readme_text=_build_search_readme(repo_id, tutorial_name), | |
| ) | |
| print(f"Saved artifacts to {output_dir}") | |
| if args.push: | |
| url = upload_tutorial_artifacts( | |
| repo_id=repo_id, | |
| artifact_dir=output_dir, | |
| private=args.private, | |
| ) | |
| print(f"Pushed artifacts to {url}") | |
| continue | |
| artifacts = train_tutorial( | |
| tutorial_name, | |
| subject_id=args.subject_id, | |
| epochs=args.epochs, | |
| patience=args.patience, | |
| ) | |
| output_dir = args.output_root / tutorial_name | |
| use_safetensors = artifacts.metadata.get("use_safetensors", True) | |
| if tutorial_name == "plot_sleep_staging_eldele2021": | |
| use_safetensors = False | |
| save_tutorial_checkpoint( | |
| artifacts.clf, | |
| output_dir, | |
| metadata=artifacts.metadata, | |
| readme_text=_build_readme(artifacts), | |
| use_safetensors=use_safetensors, | |
| ) | |
| _save_loss_curve(artifacts.clf, output_dir, tutorial_name) | |
| print(f"Saved artifacts to {output_dir}") | |
| if args.push: | |
| url = upload_tutorial_artifacts( | |
| repo_id=artifacts.repo_id, | |
| artifact_dir=output_dir, | |
| private=args.private, | |
| ) | |
| print(f"Pushed artifacts to {url}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment