-
-
Save jasmainak/d459dd128ded1e93b54050c7fab08e79 to your computer and use it in GitHub Desktop.
ransac first PR
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/bench_ransac.py b/bench_ransac.py | |
new file mode 100644 | |
index 0000000..8ee500b | |
--- /dev/null | |
+++ b/bench_ransac.py | |
@@ -0,0 +1,200 @@ | |
+import matplotlib.pyplot as plt | |
+import os.path as op | |
+import numpy as np | |
+ | |
+import mne | |
+from mne import io | |
+from mne.datasets import testing | |
+ | |
+from mne.channels.interpolation import _make_interpolation_matrix | |
+ | |
+mne.set_log_level('WARNING') | |
+ | |
+ | |
+def _modified_median(data): | |
+ """PREP uses a weird median ... | |
+ """ | |
+ data = np.sort(data, axis=-1) # median along 3rd dimension | |
+ # XXX: won't work if data.shape[-1] is even | |
+ data = data[:, :, round(data.shape[-1] / 2) - 1] | |
+ return data | |
+ | |
+ | |
+def _randsample(X, num): | |
+ """Generate random subsamples (as done in PREP code) | |
+ """ | |
+ Y = [] | |
+ for k in range(num): | |
+ pick = int(round(1 + (len(X) - 1) * np.random.random())) - 1 | |
+ Y.append(X[pick]) | |
+ del X[pick] | |
+ return Y | |
+ | |
+ | |
+class Ransac(object): | |
+ | |
+ def __init__(self, n_resample=50, min_channels=0.25, min_corr=0.75, | |
+ unbroken_time=0.4, window_size=5, ch_type='eeg'): | |
+ """ | |
+ Parameters | |
+ ---------- | |
+ n_resample : int | |
+ Number of times the sensors are resampled. | |
+ min_channels : float | |
+ Fraction of sensors for robust reconstruction. | |
+ min_corr : float | |
+ Cut-off correlation for abnormal wrt neighbours. | |
+ unbroken_time : float | |
+ Cut-off fraction of time sensor can have poor RANSAC | |
+ predictability. | |
+ window_size : float | |
+ Correlation window for RANSAC. | |
+ ch_type : str | |
+ 'meg' | 'eeg' | |
+ """ | |
+ self.n_resample = n_resample | |
+ self.min_channels = min_channels | |
+ self.min_corr = min_corr | |
+ self.unbroken_time = unbroken_time | |
+ self.window_size = window_size | |
+ self.ch_type = ch_type | |
+ | |
+ def _get_random_subsets(self, info): | |
+ """ Get random channels | |
+ """ | |
+ # have to set the seed here | |
+ np.random.seed(435656) | |
+ n_channels = len(info['ch_names']) | |
+ | |
+ # number of channels to interpolate from | |
+ n_samples = int(np.round(self.min_channels * n_channels)) | |
+ | |
+ # get picks for resamples | |
+ picks = [] | |
+ for idx in range(self.n_resample): | |
+ pick = _randsample(range(n_channels), n_samples) | |
+ picks.append(pick) | |
+ | |
+ # get channel subsets as lists | |
+ ch_subsets = [] | |
+ for pick in picks: | |
+ ch_subsets.append([info['ch_names'][p] for p in pick]) | |
+ | |
+ return ch_subsets | |
+ | |
+ def _get_mappings(self, inst): | |
+ from utils import _fast_map_meg_channels | |
+ from progressbar import ProgressBar, SimpleProgress | |
+ | |
+ ch_subsets = self.ch_subsets_ | |
+ pbar = ProgressBar(widgets=[SimpleProgress()]) | |
+ pos = np.array([ch['loc'][:3] for ch in inst.info['chs']]) | |
+ ch_names = inst.info['ch_names'] | |
+ n_channels = len(ch_names) | |
+ pick_to = range(n_channels) | |
+ mappings = [] | |
+ print('Trying channel subset: ') | |
+ for idx in pbar(range(len(ch_subsets))): | |
+ # don't do the following as it will sort the channels! | |
+ # pick_from = pick_channels(ch_names, ch_subsets[idx]) | |
+ pick_from = np.array([ch_names.index(name) | |
+ for name in ch_subsets[idx]]) | |
+ mapping = np.zeros((n_channels, n_channels)) | |
+ if self.ch_type == 'meg': | |
+ mapping[:, pick_from] = _fast_map_meg_channels(inst, pick_from, | |
+ pick_to) | |
+ elif self.ch_type == 'eeg': | |
+ mapping[:, pick_from] = _make_interpolation_matrix(pos[pick_from], | |
+ pos[pick_to], | |
+ alpha=1e-5) | |
+ mappings.append(mapping) | |
+ mappings = np.concatenate(mappings) | |
+ return mappings | |
+ | |
+ def _compute_correlations(self, inst): | |
+ """ Compute correlation between prediction and real data. | |
+ """ | |
+ mappings = self.mappings_ | |
+ n_channels, n_times = inst._data.shape | |
+ | |
+ # start pooling all the predictions (can be done more efficiently) | |
+ print('Pooling predictions') | |
+ y_pred = inst._data.T.dot(mappings.T).reshape((n_times, n_channels, self.n_resample), order='F') | |
+ print('[Done]') | |
+ | |
+ print('Robustifying the predictions') | |
+ y_pred = _modified_median(y_pred) | |
+ print('[Done]') | |
+ | |
+ print('Computing correlations') | |
+ num = np.sum(inst._data.T * y_pred, axis=0) | |
+ denom = (np.sqrt(np.sum(inst._data.T ** 2, axis=0)) * | |
+ np.sqrt(np.sum(y_pred ** 2, axis=0))) | |
+ corr = num / denom | |
+ # following is *not* equivalent to the above | |
+ # corr = np.zeros((n_channels, )) | |
+ # for idx in range(n_channels): | |
+ # corr[idx] = np.corrcoef(inst._data[idx], y_pred[idx])[0, 1] | |
+ print('[Done]') | |
+ return corr | |
+ | |
+ def fit(self, raw): | |
+ """ | |
+ """ | |
+ self.ch_subsets_ = self._get_random_subsets(raw.info) | |
+ self.mappings_ = self._get_mappings(raw) | |
+ return self | |
+ | |
+ def predict(self, raw): | |
+ """ | |
+ """ | |
+ self.corr_ = self._compute_correlations(raw) | |
+ | |
+ def fit_predict(self, raw): | |
+ window_size = self.window_size | |
+ n_windows = int(raw.times[-1] // window_size) | |
+ n_channels = len(raw.ch_names) | |
+ corrs = np.zeros((n_windows, n_channels)) | |
+ for idx in range(0, n_windows): | |
+ raw_crop = raw.crop(idx * window_size, | |
+ (idx + 1) * window_size - 1 / raw.info['sfreq']) | |
+ print('RANSAC on window %d/%d' % (idx + 1, n_windows)) | |
+ self.fit(raw_crop).predict(raw_crop) | |
+ corrs[idx, :] = self.corr_ | |
+ self.corr_ = corrs | |
+ | |
+ # compute how many windows is a sensor RANSAC-bad | |
+ self.bad_log = np.zeros_like(self.corr_) | |
+ self.bad_log[self.corr_ < self.min_corr] = 1 | |
+ bad_log = self.bad_log.sum(axis=0) | |
+ | |
+ bad_idx = np.where(bad_log > (1 - self.unbroken_time) * n_windows)[0] | |
+ if len(bad_idx) > 0: | |
+ self.bad_chs_ = [raw.info['ch_names'][p] for p in bad_idx] | |
+ else: | |
+ self.bad_chs_ = [] | |
+ | |
+# dataset = 'sample' # can be 'sample', 'egi', 'somato's | |
+ | |
+base_dir = op.join(testing.data_path(download=False), 'EEGLAB') | |
+raw_fname = op.join(base_dir, 'test_raw.set') | |
+montage = op.join(base_dir, 'test_chans.locs') | |
+raw = io.read_raw_eeglab(input_fname=raw_fname, montage=montage, preload=True, | |
+ eog=[]) | |
+ | |
+# pick EEG and MEG channels | |
+raw.info['bads'] = [] | |
+raw.pick_types(meg=False, eeg=True, stim=False, eog=False, | |
+ include=[], exclude=[]) | |
+ | |
+# use the same params as Prep. folks | |
+# it doesn't give *exactly* same results but almost there ... | |
+raw.filter(l_freq=None, h_freq=45, h_trans_bandwidth=5, method='fft') | |
+ | |
+ransac = Ransac() | |
+ransac.fit_predict(raw) | |
+ | |
+raw.info['bads'] = ransac.bad_chs_ | |
+raw.plot() | |
+plt.matshow(ransac.bad_log.T) | |
+plt.show() | |
diff --git a/utils.py b/utils.py | |
index 31466db..8380dd1 100644 | |
--- a/utils.py | |
+++ b/utils.py | |
@@ -35,8 +35,7 @@ def load_data(dataset): | |
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' | |
raw = io.Raw(raw_fname, preload=True) | |
- projs, _ = mne.preprocessing.compute_proj_ecg(raw, n_eeg=1, average=True, | |
- verbose=True) | |
+ projs, _ = mne.preprocessing.compute_proj_ecg(raw, n_eeg=1, average=True) | |
raw.add_proj(projs).apply_proj() | |
event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-' | |
@@ -191,8 +190,7 @@ def _fast_map_meg_channels(inst, pick_from, pick_to, mode='fast'): | |
from mne.io.pick import pick_info | |
from mne.forward._field_interpolation import _setup_dots | |
from mne.forward._field_interpolation import _compute_mapping_matrix | |
- from mne.io.constants import FIFF | |
- from mne.forward._make_forward import _create_coils | |
+ from mne.forward._make_forward import _create_meg_coils, _read_coil_defs | |
from mne.forward._lead_dots import _do_self_dots, _do_cross_dots | |
miss = 1e-4 # Smoothing criterion for MEG | |
@@ -200,8 +198,9 @@ def _fast_map_meg_channels(inst, pick_from, pick_to, mode='fast'): | |
def _compute_dots(info, mode='fast'): | |
"""Compute all-to-all dots. | |
""" | |
- coils = _create_coils(info['chs'], FIFF.FWD_COIL_ACCURACY_NORMAL, | |
- info['dev_head_t'], 'meg') | |
+ templates = _read_coil_defs() | |
+ coils = _create_meg_coils(info['chs'], 'normal', info['dev_head_t'], | |
+ templates) | |
my_origin, int_rad, noise, lut_fun, n_fact = _setup_dots(mode, coils, | |
'meg') | |
self_dots = _do_self_dots(int_rad, False, coils, my_origin, 'meg', | |
@@ -215,8 +214,9 @@ def _compute_dots(info, mode='fast'): | |
info['bads'] = [] # if bads is different, hash will be different | |
info_from = pick_info(info, pick_from, copy=True) | |
- coils_from = _create_coils(info_from['chs'], FIFF.FWD_COIL_ACCURACY_NORMAL, | |
- info_from['dev_head_t'], 'meg') | |
+ templates = _read_coil_defs() | |
+ coils_from = _create_meg_coils(info_from['chs'], 'normal', | |
+ info_from['dev_head_t'], templates) | |
my_origin, int_rad, noise, lut_fun, n_fact = _setup_dots(mode, coils_from, | |
'meg') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment