Skip to content

Instantly share code, notes, and snippets.

@alexandrebarachant
Created July 1, 2015 13:31
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save alexandrebarachant/c929834f1a68c2728d35 to your computer and use it in GitHub Desktop.
Save alexandrebarachant/c929834f1a68c2728d35 to your computer and use it in GitHub Desktop.
CSP for Grasp and lift challenge
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 29 14:00:37 2015
@author: alexandrebarachant
"""
import numpy as np
import pandas as pd
from mne.io import RawArray
from mne.channels import read_montage
from mne.epochs import concatenate_epochs
from mne import create_info, find_events, Epochs
from mne.viz.topomap import _prepare_topo_plot, plot_topomap
from mne.decoding import CSP
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.cross_validation import cross_val_score, LeaveOneLabelOut
from glob import glob
import matplotlib.pyplot as plt
from scipy.signal import welch
def creat_mne_raw_object(fname):
"""Create a mne raw instance from csv file"""
# Read EEG file
data = pd.read_csv(fname)
# get chanel names
ch_names = list(data.columns[1:])
# read EEG standard montage from mne
montage = read_montage('standard_1005',ch_names)
# events file
ev_fname = fname.replace('_data','_events')
# read event file
events = pd.read_csv(ev_fname)
events_names = events.columns[1:]
events_data = np.array(events[events_names]).T
# concatenate event file and data
data = np.concatenate((1e-6*np.array(data[ch_names]).T,events_data))
# define channel type, the first is EEG, the last 6 are stimulations
ch_type = ['eeg']*len(ch_names) + ['stim']*6
# create and populate MNE info structure
ch_names.extend(events_names)
info = create_info(ch_names,sfreq=500.0, ch_types=ch_type, montage=montage)
info['filename'] = fname
# create raw object
raw = RawArray(data,info)
return raw
subject = 1
epochs_tot = []
#eid = 'HandStart'
fnames = glob('data/train/subj%d_series*_data.csv' % (subject))
session = []
y = []
for i,fname in enumerate(fnames):
# read data
raw = creat_mne_raw_object(fname)
# Filter data for alpha frequency and beta band
# Note that MNE implement a zero phase (filtfilt) filtering not compatible
# with the rule of future data.
raw.filter(5,35,picks=range(32),method='iir',n_jobs=-1)
# get event posision corresponding to Replace
events = find_events(raw,stim_channel='Replace')
# epochs signal for 1.5 second before the movement
epochs = Epochs(raw, events, {'during' : 1}, -2, -0.5, proj=False,
picks=range(32), baseline=None, preload=True,
add_eeg_ref=False, verbose =False)
epochs_tot.append(epochs)
session.extend([i]*len(epochs))
y.extend([1]*len(epochs))
# epochs signal for 1.5 second after the movement, this correspond to the
# rest period.
epochs_rest = Epochs(raw, events, {'after' : 1}, 0.5, 2, proj=False,
picks=range(32), baseline=None, preload=True,
add_eeg_ref=False, verbose =False)
# Workaround to be able to concatenate epochs
epochs_rest.times = epochs.times
epochs_tot.append(epochs_rest)
session.extend([i]*len(epochs_rest))
y.extend([-1]*len(epochs_rest))
#concatenate all epochs
epochs = concatenate_epochs(epochs_tot)
# get data
X = epochs.get_data()
y = np.array(y)
# run CSP
csp = CSP(reg='lws')
csp.fit(X,y)
# prepare topoplot
_,epos,_,_,_ = _prepare_topo_plot(epochs,'eeg',None)
# plot first pattern
plot_topomap(csp.patterns_[0,:],epos)
# compute spatial filtered spectrum
po = []
for x in X:
f,p = welch(np.dot(csp.filters_[0,:].T,x), 500, nperseg=512)
po.append(p)
# plot spectrum
po = np.array(po)
fix = (f>5) & (f<35)
plt.plot(f[fix],np.log(po[y==1][:,fix].mean(axis=0).T),'-r',lw=2)
plt.plot(f[fix],np.log(po[y==-1][:,fix].mean(axis=0).T),'-b',lw=2)
plt.legend(['during','after'])
plt.grid()
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power (dB)')
plt.show()
# run cross validation
clf = make_pipeline(CSP(),LogisticRegression())
auc = cross_val_score(clf,X,y,cv=10,scoring='roc_auc').mean()
print("AUC cross val score : %.3f" % auc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment