Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Created February 15, 2016 12:56
Show Gist options
  • Save mbillingr/688daaa842ae98860548 to your computer and use it in GitHub Desktop.
Save mbillingr/688daaa842ae98860548 to your computer and use it in GitHub Desktop.
Compare ICA methods
import numpy as np
from sklearn.decomposition import FastICA, PCA
from sklearn.cluster import AffinityPropagation
import matplotlib.pyplot as plt
from scot.eegtopo.topoplot import Topoplot
from scot.external.infomax_ import infomax
from scot.csp import csp
from scot.datatools import cut_segments, cat_trials
import scotdata.motorimagery as midata
from time import time
def binica(data, binary='./ica.exe'):
#def binica(data, binary='binica/ica_linux'):
""" Simple wrapper for the BINICA binary.
This function calculates the ICA transformation using the Infomax algorithm implemented in BINICA.
BINICA is bundled with EEGLAB, or can be downloaded from here:
http://sccn.ucsd.edu/eeglab/binica/
This function attempts to automatically download and extract the BINICA binary.
By default the binary is expected to be "binica/ica_linux" relative
to the directory where this module lies (typically scot/binica/ica_linux)
Parameters
----------
data : array-like, shape = [n_samples, n_channels]
EEG data set
binary : str
Full path to the binica binary
Returns
-------
w : array, shape = [n_channels, n_channels]
ICA weights matrix
s : array, shape = [n_channels, n_channels]
Sphering matrix
Notes
-----
The unmixing matrix is obtained by multiplying U = dot(s, w)
"""
import os, sys
from uuid import uuid4
import subprocess
def check_binary_(binary):
"""check if binary is available, and try to download it if not"""
if os.path.exists(binary):
print(binary, 'found')
return
url = 'http://sccn.ucsd.edu/eeglab/binica/binica.zip'
print(binary+' not found. Trying to download from '+url)
path = os.path.dirname(binary)
if not os.path.exists(path):
os.makedirs(path)
try:
# Python 3
from urllib.request import urlretrieve as urlretrieve
except ImportError:
# Python 2.7
from urllib import urlretrieve as urlretrieve
import zipfile
import stat
urlretrieve(url, path + '/binica.zip')
if not os.path.exists(path + '/binica.zip'):
raise RuntimeError('Error downloading binica.zip.')
print('unzipping', path + '/binica.zip')
with zipfile.ZipFile(path + '/binica.zip') as tgz:
tgz.extractall(path + '/..')
if not os.path.exists(binary):
raise RuntimeError(binary + ' not found, even after extracting binica.zip.')
mode = os.stat(binary).st_mode
os.chmod(binary, mode | stat.S_IXUSR)
check_binary_(binary)
data = np.array(data, dtype=np.float32)
nframes, nchans = data.shape
uid = uuid4()
scriptfile = 'binica-%s.sc' % uid
datafile = 'binica-%s.fdt' % uid
weightsfile = 'binica-%s.wts' % uid
#weightstmpfile = 'binicatmp-%s.wts' % uid
spherefile = 'binica-%s.sph' % uid
config = {'DataFile': datafile,
'WeightsOutFile': weightsfile,
'SphereFile': spherefile,
'chans': nchans,
'frames': nframes,
'extended': 1}
# config['WeightsTempFile'] = weightstmpfile
# create data file
f = open(datafile, 'wb')
data.tofile(f)
f.close()
# create script file
f = open(scriptfile, 'wt')
for h in config:
print(h, config[h], file=f)
f.close()
# flush output streams otherwise things printed before might appear after the ICA output.
sys.stdout.flush()
sys.stderr.flush()
if os.path.exists(binary):
with open(scriptfile) as sc:
try:
proc = subprocess.Popen(binary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, stdin=sc)
print('waiting for binica to finish...')
proc.wait()
print('binica output:')
print(proc.stdout.read().decode())
proc.stdout.close()
except FileNotFoundError:
raise RuntimeError('The BINICA binary ica_linux exists in the file system but could not be executed. '
'This indicates that 32 bit libraries are not installed on the system.')
else:
raise RuntimeError('the binary is not there!?')
os.remove(scriptfile)
os.remove(datafile)
# read weights
f = open(weightsfile, 'rb')
weights = np.fromfile(f, dtype=np.float32)
f.close()
weights = np.reshape(weights, (nchans,nchans))
# os.remove(weightstmpfile)
os.remove(weightsfile)
# read sphering matrix
f = open( spherefile, 'rb' )
sphere = np.fromfile(f, dtype=np.float32)
f.close()
sphere = np.reshape(sphere, (nchans,nchans))
os.remove(spherefile)
return weights, sphere
def fastica(x):
ica = FastICA(max_iter=500, w_init=np.eye(x.shape[1])).fit(x)
#ica = FastICA(n_components=45).fit(x)
return ica.mixing_, ica.components_
def mne_infomax(x):
u = infomax(x)
return np.linalg.pinv(u), u
def binica_infomax(x):
w, s = binica(x)
return np.linalg.inv(s.dot(w))
def do_ica(algorithm, x, preunmix):
x = cat_trials(x)
y = np.dot(x, preunmix.T)
premix = np.linalg.pinv(preunmix)
w_mix, w_unmix = algorithm(y)
w_mix = np.dot(premix, w_mix)
w_unmix = np.dot(w_unmix, preunmix)
variance = np.var(np.dot(x, w_unmix.T), axis=0) / np.sum(w_unmix**2, axis=1)
order = np.argsort(variance)[::-1]
return w_mix, w_unmix
def plot_ica(topo, w_mix, order=None, labels=None):
if order is not None:
w_mix = w_mix[:, order]
labels = labels[order]
for i in range(w_mix.shape[1]):
plt.subplot(5, 9, i+1, aspect='equal')
topo.set_values(w_mix[:, i])
#topo.create_map()
#topo.plot_map()
topo.plot_circles(0.2)
topo.plot_head()
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.xticks([])
plt.yticks([])
if labels is not None:
plt.title(labels[i])
def cluster_icas(icas, x, topo, titles):
weights = []
projections = []
for m, u in icas:
weights.append(u)
projections.append(m)
n_components, n_features = u.shape
n_icas = len(weights)
weights = np.concatenate(weights, axis=0)
projections = np.concatenate(projections, axis=1).T
features = np.concatenate([projections, weights], axis=1)
r = np.corrcoef(np.dot(x, weights.T).T)
cluster = AffinityPropagation(affinity='precomputed').fit(r**2)
clusters = []
ar2 = []
ns = []
for k in np.unique(cluster.labels_):
i = cluster.labels_ == k
t = np.tri(np.sum(i), k=-1, dtype=bool)
ar = np.mean(r[i, :][:, i][t]**2)
clusters.append(k)
ar2.append(ar)
ns.append(np.sum(i))
for k in np.argsort(ar2):
print('cluster ', k, ': ', ns[k], 'sources, average r^2:', ar2[k])
features = np.abs(features)
pca = PCA(n_components=2).fit(features)
wt = pca.transform(features)
plt.figure()
plt.scatter(wt[:, 0], wt[:, 1], c=cluster.labels_)
plt.colorbar()
plt.title('stupid scatterplot')
plt.figure()
plt.imshow(r, interpolation='none')
plt.title('source correlations (all ICAs)')
plt.figure()
plt.stem(np.sort(ar2))
plt.xticks([])
plt.xlabel('cluster')
plt.ylabel('average $r^2$')
plt.title('cluster similarity')
for i in range(n_icas):
a = i * n_components
b = (1 + i) * n_components
l = cluster.labels_[a:b]
o = np.argsort([ar2[k] for k in l])
plt.figure()
plot_ica(topo, projections[a:b].T, order=o, labels=l)
if titles is not None:
plt.suptitle(titles[i])
def pre_raw(x, y):
return np.eye(x.shape[1])
def pre_pca(x, y):
x = cat_trials(x)
return PCA(n_components=None).fit(x).components_
def pre_csp(x, y):
w, v = csp(x, y)
return w
def main():
raweeg = midata.eeg
labels = midata.classes
locs = midata.locations
fs = midata.samplerate
topo = Topoplot()
topo.set_locations(locs)
x = cut_segments(raweeg, midata.triggers, 3 * fs, 4 * fs)
#cluster_icas((fastica(cat_trials(x)) for _ in range(5)), cat_trials(x), topo)
#plt.show()
np.random.seed(42)
icas = []
titles = []
for pre_trans in [pre_raw, pre_pca, pre_csp]:
for ica in [mne_infomax, fastica]:
titles.append(pre_trans.__name__ + ', ' + ica.__name__)
print(titles[-1], end=': ')
pretransform = pre_trans(x, labels)
t0 = time()
icas.append(do_ica(ica, x, pretransform))
print(np.round(time() - t0, 1), 'seconds taken')
cluster_icas(icas, cat_trials(x), topo, titles)
plt.show()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment