Created
February 15, 2016 12:56
-
-
Save mbillingr/688daaa842ae98860548 to your computer and use it in GitHub Desktop.
Compare ICA methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.decomposition import FastICA, PCA | |
from sklearn.cluster import AffinityPropagation | |
import matplotlib.pyplot as plt | |
from scot.eegtopo.topoplot import Topoplot | |
from scot.external.infomax_ import infomax | |
from scot.csp import csp | |
from scot.datatools import cut_segments, cat_trials | |
import scotdata.motorimagery as midata | |
from time import time | |
def binica(data, binary='./ica.exe'): | |
#def binica(data, binary='binica/ica_linux'): | |
""" Simple wrapper for the BINICA binary. | |
This function calculates the ICA transformation using the Infomax algorithm implemented in BINICA. | |
BINICA is bundled with EEGLAB, or can be downloaded from here: | |
http://sccn.ucsd.edu/eeglab/binica/ | |
This function attempts to automatically download and extract the BINICA binary. | |
By default the binary is expected to be "binica/ica_linux" relative | |
to the directory where this module lies (typically scot/binica/ica_linux) | |
Parameters | |
---------- | |
data : array-like, shape = [n_samples, n_channels] | |
EEG data set | |
binary : str | |
Full path to the binica binary | |
Returns | |
------- | |
w : array, shape = [n_channels, n_channels] | |
ICA weights matrix | |
s : array, shape = [n_channels, n_channels] | |
Sphering matrix | |
Notes | |
----- | |
The unmixing matrix is obtained by multiplying U = dot(s, w) | |
""" | |
import os, sys | |
from uuid import uuid4 | |
import subprocess | |
def check_binary_(binary): | |
"""check if binary is available, and try to download it if not""" | |
if os.path.exists(binary): | |
print(binary, 'found') | |
return | |
url = 'http://sccn.ucsd.edu/eeglab/binica/binica.zip' | |
print(binary+' not found. Trying to download from '+url) | |
path = os.path.dirname(binary) | |
if not os.path.exists(path): | |
os.makedirs(path) | |
try: | |
# Python 3 | |
from urllib.request import urlretrieve as urlretrieve | |
except ImportError: | |
# Python 2.7 | |
from urllib import urlretrieve as urlretrieve | |
import zipfile | |
import stat | |
urlretrieve(url, path + '/binica.zip') | |
if not os.path.exists(path + '/binica.zip'): | |
raise RuntimeError('Error downloading binica.zip.') | |
print('unzipping', path + '/binica.zip') | |
with zipfile.ZipFile(path + '/binica.zip') as tgz: | |
tgz.extractall(path + '/..') | |
if not os.path.exists(binary): | |
raise RuntimeError(binary + ' not found, even after extracting binica.zip.') | |
mode = os.stat(binary).st_mode | |
os.chmod(binary, mode | stat.S_IXUSR) | |
check_binary_(binary) | |
data = np.array(data, dtype=np.float32) | |
nframes, nchans = data.shape | |
uid = uuid4() | |
scriptfile = 'binica-%s.sc' % uid | |
datafile = 'binica-%s.fdt' % uid | |
weightsfile = 'binica-%s.wts' % uid | |
#weightstmpfile = 'binicatmp-%s.wts' % uid | |
spherefile = 'binica-%s.sph' % uid | |
config = {'DataFile': datafile, | |
'WeightsOutFile': weightsfile, | |
'SphereFile': spherefile, | |
'chans': nchans, | |
'frames': nframes, | |
'extended': 1} | |
# config['WeightsTempFile'] = weightstmpfile | |
# create data file | |
f = open(datafile, 'wb') | |
data.tofile(f) | |
f.close() | |
# create script file | |
f = open(scriptfile, 'wt') | |
for h in config: | |
print(h, config[h], file=f) | |
f.close() | |
# flush output streams otherwise things printed before might appear after the ICA output. | |
sys.stdout.flush() | |
sys.stderr.flush() | |
if os.path.exists(binary): | |
with open(scriptfile) as sc: | |
try: | |
proc = subprocess.Popen(binary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, stdin=sc) | |
print('waiting for binica to finish...') | |
proc.wait() | |
print('binica output:') | |
print(proc.stdout.read().decode()) | |
proc.stdout.close() | |
except FileNotFoundError: | |
raise RuntimeError('The BINICA binary ica_linux exists in the file system but could not be executed. ' | |
'This indicates that 32 bit libraries are not installed on the system.') | |
else: | |
raise RuntimeError('the binary is not there!?') | |
os.remove(scriptfile) | |
os.remove(datafile) | |
# read weights | |
f = open(weightsfile, 'rb') | |
weights = np.fromfile(f, dtype=np.float32) | |
f.close() | |
weights = np.reshape(weights, (nchans,nchans)) | |
# os.remove(weightstmpfile) | |
os.remove(weightsfile) | |
# read sphering matrix | |
f = open( spherefile, 'rb' ) | |
sphere = np.fromfile(f, dtype=np.float32) | |
f.close() | |
sphere = np.reshape(sphere, (nchans,nchans)) | |
os.remove(spherefile) | |
return weights, sphere | |
def fastica(x): | |
ica = FastICA(max_iter=500, w_init=np.eye(x.shape[1])).fit(x) | |
#ica = FastICA(n_components=45).fit(x) | |
return ica.mixing_, ica.components_ | |
def mne_infomax(x): | |
u = infomax(x) | |
return np.linalg.pinv(u), u | |
def binica_infomax(x): | |
w, s = binica(x) | |
return np.linalg.inv(s.dot(w)) | |
def do_ica(algorithm, x, preunmix): | |
x = cat_trials(x) | |
y = np.dot(x, preunmix.T) | |
premix = np.linalg.pinv(preunmix) | |
w_mix, w_unmix = algorithm(y) | |
w_mix = np.dot(premix, w_mix) | |
w_unmix = np.dot(w_unmix, preunmix) | |
variance = np.var(np.dot(x, w_unmix.T), axis=0) / np.sum(w_unmix**2, axis=1) | |
order = np.argsort(variance)[::-1] | |
return w_mix, w_unmix | |
def plot_ica(topo, w_mix, order=None, labels=None): | |
if order is not None: | |
w_mix = w_mix[:, order] | |
labels = labels[order] | |
for i in range(w_mix.shape[1]): | |
plt.subplot(5, 9, i+1, aspect='equal') | |
topo.set_values(w_mix[:, i]) | |
#topo.create_map() | |
#topo.plot_map() | |
topo.plot_circles(0.2) | |
topo.plot_head() | |
plt.xlim(-3, 3) | |
plt.ylim(-3, 3) | |
plt.xticks([]) | |
plt.yticks([]) | |
if labels is not None: | |
plt.title(labels[i]) | |
def cluster_icas(icas, x, topo, titles): | |
weights = [] | |
projections = [] | |
for m, u in icas: | |
weights.append(u) | |
projections.append(m) | |
n_components, n_features = u.shape | |
n_icas = len(weights) | |
weights = np.concatenate(weights, axis=0) | |
projections = np.concatenate(projections, axis=1).T | |
features = np.concatenate([projections, weights], axis=1) | |
r = np.corrcoef(np.dot(x, weights.T).T) | |
cluster = AffinityPropagation(affinity='precomputed').fit(r**2) | |
clusters = [] | |
ar2 = [] | |
ns = [] | |
for k in np.unique(cluster.labels_): | |
i = cluster.labels_ == k | |
t = np.tri(np.sum(i), k=-1, dtype=bool) | |
ar = np.mean(r[i, :][:, i][t]**2) | |
clusters.append(k) | |
ar2.append(ar) | |
ns.append(np.sum(i)) | |
for k in np.argsort(ar2): | |
print('cluster ', k, ': ', ns[k], 'sources, average r^2:', ar2[k]) | |
features = np.abs(features) | |
pca = PCA(n_components=2).fit(features) | |
wt = pca.transform(features) | |
plt.figure() | |
plt.scatter(wt[:, 0], wt[:, 1], c=cluster.labels_) | |
plt.colorbar() | |
plt.title('stupid scatterplot') | |
plt.figure() | |
plt.imshow(r, interpolation='none') | |
plt.title('source correlations (all ICAs)') | |
plt.figure() | |
plt.stem(np.sort(ar2)) | |
plt.xticks([]) | |
plt.xlabel('cluster') | |
plt.ylabel('average $r^2$') | |
plt.title('cluster similarity') | |
for i in range(n_icas): | |
a = i * n_components | |
b = (1 + i) * n_components | |
l = cluster.labels_[a:b] | |
o = np.argsort([ar2[k] for k in l]) | |
plt.figure() | |
plot_ica(topo, projections[a:b].T, order=o, labels=l) | |
if titles is not None: | |
plt.suptitle(titles[i]) | |
def pre_raw(x, y): | |
return np.eye(x.shape[1]) | |
def pre_pca(x, y): | |
x = cat_trials(x) | |
return PCA(n_components=None).fit(x).components_ | |
def pre_csp(x, y): | |
w, v = csp(x, y) | |
return w | |
def main(): | |
raweeg = midata.eeg | |
labels = midata.classes | |
locs = midata.locations | |
fs = midata.samplerate | |
topo = Topoplot() | |
topo.set_locations(locs) | |
x = cut_segments(raweeg, midata.triggers, 3 * fs, 4 * fs) | |
#cluster_icas((fastica(cat_trials(x)) for _ in range(5)), cat_trials(x), topo) | |
#plt.show() | |
np.random.seed(42) | |
icas = [] | |
titles = [] | |
for pre_trans in [pre_raw, pre_pca, pre_csp]: | |
for ica in [mne_infomax, fastica]: | |
titles.append(pre_trans.__name__ + ', ' + ica.__name__) | |
print(titles[-1], end=': ') | |
pretransform = pre_trans(x, labels) | |
t0 = time() | |
icas.append(do_ica(ica, x, pretransform)) | |
print(np.round(time() - t0, 1), 'seconds taken') | |
cluster_icas(icas, cat_trials(x), topo, titles) | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment