Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/7defd2b5c0841398cb68 to your computer and use it in GitHub Desktop.
Save kingjr/7defd2b5c0841398cb68 to your computer and use it in GitHub Desktop.
Why are probabilistic outputs better than non-probabilistic ones?
"""
==========================
Better with probabilities?
==========================
Comparing classification performance of SVC versus SVC+Platt
using an MEG example from MNE-python.
"""
# Authors: Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.datasets import sample
from mne.decoding import time_generalization
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, roc_auc_score
# Load and preprocess data
data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
raw = mne.io.Raw(raw_fname, preload=True)
picks = mne.pick_types(raw.info, meg=True, exclude='bads')
raw.filter(1, 30, method='iir')
events = mne.read_events(events_fname)
event_id = {'AudL': 1, 'VisL': 3, 'AudR': 2, 'VisR': 4}
epochs = mne.Epochs(raw, events, event_id, -0.1, 0.5, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=5e-12), decim=4)
epochs_list = [epochs[k] for k in ['AudL', 'VisL']]
mne.epochs.equalize_epoch_counts(epochs_list)
# Decoding parameters
scaler = StandardScaler()
def decod(svc, scorer):
clf = Pipeline([('scaler', scaler), ('svc', svc)])
results = time_generalization(epochs_list, clf=clf, scoring=scorer,
n_jobs=1)
return results['scores'], 1e3 * results['train_times']
# Scores on decision_function
svc = SVC(C=1, kernel='linear') # normal SVC
scorer = make_scorer(roc_auc_score)
scores_distance, times = decod(svc, scorer)
# Scores on probabilities
svc = SVC(C=1, kernel='linear', probability=True) # SVC + Platt
roc_auc_scorer = lambda y_true, y_pred: roc_auc_score(y_true, y_pred[:, 1])
scorer = make_scorer(roc_auc_scorer, needs_proba=True)
scores_proba, times = decod(svc, scorer)
# Vizualize
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax1, ax2 = ax.T.flatten()
def plot_time_gen(ax, scores, title):
im = ax.imshow(scores, interpolation='nearest', origin='lower',
extent=[times[0], times[-1],
times[0], times[-1]],
vmin=0., vmax=1.)
ax.set_xlabel('Times Test (ms)')
ax.set_ylabel('Times Train (ms)')
ax.set_title(title)
ax.axvline(0, color='k')
ax.axhline(0, color='k')
plt.colorbar(im, ax=ax)
plot_time_gen(ax1, scores_distance, 'Distance')
plot_time_gen(ax2, scores_proba, 'Probabilities')
mne.viz.tight_layout(fig=fig)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment