Created
September 20, 2017 17:31
-
-
Save kingjr/8e1467f122360d4efdc79d49c788b594 to your computer and use it in GitHub Desktop.
example generalization std single trial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mne | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.pipeline import make_pipeline | |
from sklearn.model_selection import StratifiedKFold | |
from mne.decoding import SlidingEstimator | |
# Generate random, fake MEG data | |
n_trials, n_channels, n_times, sfreq, max_snr = 200, 32, 50, 500., 10 | |
snr = np.random.randint(0, max_snr+1, n_trials) | |
info = mne.create_info(n_channels, sfreq) | |
X = np.random.randn(n_trials, n_channels, n_times) | |
# add signal in some channels/times proportionally to the snr | |
for this_snr in set(snr): | |
X[snr==this_snr, :10, n_times//2:] += this_snr | |
# setup logistic regression classifier | |
clf = make_pipeline( | |
StandardScaler(), # Z-score data, because gradiometers and magnetometers have different scales | |
LogisticRegression(), | |
) | |
sliding = SlidingEstimator(clf, n_jobs=-1) | |
# Find indices of minimum and max snr, with which we'll train the | |
# classifiers | |
train_cond = np.where(np.logical_or(snr==0, snr==max_snr))[0] | |
# Find indices of the other intermediary snr trials | |
gen_cond = np.setdiff1d(range(n_trials), train_cond) | |
# Setup a unique cross validation scheme, applied separately for | |
# the training and generalization conditions | |
cv = StratifiedKFold(n_splits=10, random_state=0) | |
# Apply cross-validation scheme on training and generalization sets | |
cv_train = cv.split(X[train_cond], snr[train_cond]) | |
cv_gen = cv.split(X[gen_cond], np.ones_like(snr[gen_cond])) | |
# Retrieve corresponding indices | |
trains, tests = zip(*[(train_cond[train], train_cond[test]) | |
for train, test in cv_train]) | |
gens = [gen_cond[test] for _, test in cv_gen] | |
# Cross-validation loop for single trial predictions | |
y_pred = np.zeros((n_trials, n_times)) | |
for (train, test, gen) in zip(trains, tests, gens): | |
# Check that train on 0 and max snr | |
assert set(snr[train]) == {0, max_snr} | |
# Fit | |
sliding.fit(X=X[train], y=snr[train]) | |
# Predict | |
y_pred[test] = sliding.decision_function(X[test]) | |
# Generalize | |
y_pred[gen] = sliding.decision_function(X[gen]) | |
# Plot single trial predictions | |
fig, (ax1, ax2) = plt.subplots(1, 2) | |
ax1.matshow(y_pred[np.argsort(snr)], aspect='auto') | |
# Plot std for each category | |
std = np.zeros(max_snr+1) | |
for this_snr in range(max_snr+1): | |
std[this_snr] = np.std(y_pred[snr==this_snr, -1]) | |
ax2.plot(std) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment