Skip to content

Instantly share code, notes, and snippets.

@jasmainak
Created December 4, 2020 03:35
Show Gist options
  • Save jasmainak/d459dd128ded1e93b54050c7fab08e79 to your computer and use it in GitHub Desktop.
Save jasmainak/d459dd128ded1e93b54050c7fab08e79 to your computer and use it in GitHub Desktop.
ransac first PR
diff --git a/bench_ransac.py b/bench_ransac.py
new file mode 100644
index 0000000..8ee500b
--- /dev/null
+++ b/bench_ransac.py
@@ -0,0 +1,200 @@
+import matplotlib.pyplot as plt
+import os.path as op
+import numpy as np
+
+import mne
+from mne import io
+from mne.datasets import testing
+
+from mne.channels.interpolation import _make_interpolation_matrix
+
+mne.set_log_level('WARNING')
+
+
+def _modified_median(data):
+ """PREP uses a weird median ...
+ """
+ data = np.sort(data, axis=-1) # median along 3rd dimension
+ # XXX: won't work if data.shape[-1] is even
+ data = data[:, :, round(data.shape[-1] / 2) - 1]
+ return data
+
+
+def _randsample(X, num):
+ """Generate random subsamples (as done in PREP code)
+ """
+ Y = []
+ for k in range(num):
+ pick = int(round(1 + (len(X) - 1) * np.random.random())) - 1
+ Y.append(X[pick])
+ del X[pick]
+ return Y
+
+
+class Ransac(object):
+
+ def __init__(self, n_resample=50, min_channels=0.25, min_corr=0.75,
+ unbroken_time=0.4, window_size=5, ch_type='eeg'):
+ """
+ Parameters
+ ----------
+ n_resample : int
+ Number of times the sensors are resampled.
+ min_channels : float
+ Fraction of sensors for robust reconstruction.
+ min_corr : float
+ Cut-off correlation for abnormal wrt neighbours.
+ unbroken_time : float
+ Cut-off fraction of time sensor can have poor RANSAC
+ predictability.
+ window_size : float
+ Correlation window for RANSAC.
+ ch_type : str
+ 'meg' | 'eeg'
+ """
+ self.n_resample = n_resample
+ self.min_channels = min_channels
+ self.min_corr = min_corr
+ self.unbroken_time = unbroken_time
+ self.window_size = window_size
+ self.ch_type = ch_type
+
+ def _get_random_subsets(self, info):
+ """ Get random channels
+ """
+ # have to set the seed here
+ np.random.seed(435656)
+ n_channels = len(info['ch_names'])
+
+ # number of channels to interpolate from
+ n_samples = int(np.round(self.min_channels * n_channels))
+
+ # get picks for resamples
+ picks = []
+ for idx in range(self.n_resample):
+ pick = _randsample(range(n_channels), n_samples)
+ picks.append(pick)
+
+ # get channel subsets as lists
+ ch_subsets = []
+ for pick in picks:
+ ch_subsets.append([info['ch_names'][p] for p in pick])
+
+ return ch_subsets
+
+ def _get_mappings(self, inst):
+ from utils import _fast_map_meg_channels
+ from progressbar import ProgressBar, SimpleProgress
+
+ ch_subsets = self.ch_subsets_
+ pbar = ProgressBar(widgets=[SimpleProgress()])
+ pos = np.array([ch['loc'][:3] for ch in inst.info['chs']])
+ ch_names = inst.info['ch_names']
+ n_channels = len(ch_names)
+ pick_to = range(n_channels)
+ mappings = []
+ print('Trying channel subset: ')
+ for idx in pbar(range(len(ch_subsets))):
+ # don't do the following as it will sort the channels!
+ # pick_from = pick_channels(ch_names, ch_subsets[idx])
+ pick_from = np.array([ch_names.index(name)
+ for name in ch_subsets[idx]])
+ mapping = np.zeros((n_channels, n_channels))
+ if self.ch_type == 'meg':
+ mapping[:, pick_from] = _fast_map_meg_channels(inst, pick_from,
+ pick_to)
+ elif self.ch_type == 'eeg':
+ mapping[:, pick_from] = _make_interpolation_matrix(pos[pick_from],
+ pos[pick_to],
+ alpha=1e-5)
+ mappings.append(mapping)
+ mappings = np.concatenate(mappings)
+ return mappings
+
+ def _compute_correlations(self, inst):
+ """ Compute correlation between prediction and real data.
+ """
+ mappings = self.mappings_
+ n_channels, n_times = inst._data.shape
+
+ # start pooling all the predictions (can be done more efficiently)
+ print('Pooling predictions')
+ y_pred = inst._data.T.dot(mappings.T).reshape((n_times, n_channels, self.n_resample), order='F')
+ print('[Done]')
+
+ print('Robustifying the predictions')
+ y_pred = _modified_median(y_pred)
+ print('[Done]')
+
+ print('Computing correlations')
+ num = np.sum(inst._data.T * y_pred, axis=0)
+ denom = (np.sqrt(np.sum(inst._data.T ** 2, axis=0)) *
+ np.sqrt(np.sum(y_pred ** 2, axis=0)))
+ corr = num / denom
+ # following is *not* equivalent to the above
+ # corr = np.zeros((n_channels, ))
+ # for idx in range(n_channels):
+ # corr[idx] = np.corrcoef(inst._data[idx], y_pred[idx])[0, 1]
+ print('[Done]')
+ return corr
+
+ def fit(self, raw):
+ """
+ """
+ self.ch_subsets_ = self._get_random_subsets(raw.info)
+ self.mappings_ = self._get_mappings(raw)
+ return self
+
+ def predict(self, raw):
+ """
+ """
+ self.corr_ = self._compute_correlations(raw)
+
+ def fit_predict(self, raw):
+ window_size = self.window_size
+ n_windows = int(raw.times[-1] // window_size)
+ n_channels = len(raw.ch_names)
+ corrs = np.zeros((n_windows, n_channels))
+ for idx in range(0, n_windows):
+ raw_crop = raw.crop(idx * window_size,
+ (idx + 1) * window_size - 1 / raw.info['sfreq'])
+ print('RANSAC on window %d/%d' % (idx + 1, n_windows))
+ self.fit(raw_crop).predict(raw_crop)
+ corrs[idx, :] = self.corr_
+ self.corr_ = corrs
+
+ # compute how many windows is a sensor RANSAC-bad
+ self.bad_log = np.zeros_like(self.corr_)
+ self.bad_log[self.corr_ < self.min_corr] = 1
+ bad_log = self.bad_log.sum(axis=0)
+
+ bad_idx = np.where(bad_log > (1 - self.unbroken_time) * n_windows)[0]
+ if len(bad_idx) > 0:
+ self.bad_chs_ = [raw.info['ch_names'][p] for p in bad_idx]
+ else:
+ self.bad_chs_ = []
+
+# dataset = 'sample' # can be 'sample', 'egi', 'somato's
+
+base_dir = op.join(testing.data_path(download=False), 'EEGLAB')
+raw_fname = op.join(base_dir, 'test_raw.set')
+montage = op.join(base_dir, 'test_chans.locs')
+raw = io.read_raw_eeglab(input_fname=raw_fname, montage=montage, preload=True,
+ eog=[])
+
+# pick EEG and MEG channels
+raw.info['bads'] = []
+raw.pick_types(meg=False, eeg=True, stim=False, eog=False,
+ include=[], exclude=[])
+
+# use the same params as Prep. folks
+# it doesn't give *exactly* same results but almost there ...
+raw.filter(l_freq=None, h_freq=45, h_trans_bandwidth=5, method='fft')
+
+ransac = Ransac()
+ransac.fit_predict(raw)
+
+raw.info['bads'] = ransac.bad_chs_
+raw.plot()
+plt.matshow(ransac.bad_log.T)
+plt.show()
diff --git a/utils.py b/utils.py
index 31466db..8380dd1 100644
--- a/utils.py
+++ b/utils.py
@@ -35,8 +35,7 @@ def load_data(dataset):
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = io.Raw(raw_fname, preload=True)
- projs, _ = mne.preprocessing.compute_proj_ecg(raw, n_eeg=1, average=True,
- verbose=True)
+ projs, _ = mne.preprocessing.compute_proj_ecg(raw, n_eeg=1, average=True)
raw.add_proj(projs).apply_proj()
event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-'
@@ -191,8 +190,7 @@ def _fast_map_meg_channels(inst, pick_from, pick_to, mode='fast'):
from mne.io.pick import pick_info
from mne.forward._field_interpolation import _setup_dots
from mne.forward._field_interpolation import _compute_mapping_matrix
- from mne.io.constants import FIFF
- from mne.forward._make_forward import _create_coils
+ from mne.forward._make_forward import _create_meg_coils, _read_coil_defs
from mne.forward._lead_dots import _do_self_dots, _do_cross_dots
miss = 1e-4 # Smoothing criterion for MEG
@@ -200,8 +198,9 @@ def _fast_map_meg_channels(inst, pick_from, pick_to, mode='fast'):
def _compute_dots(info, mode='fast'):
"""Compute all-to-all dots.
"""
- coils = _create_coils(info['chs'], FIFF.FWD_COIL_ACCURACY_NORMAL,
- info['dev_head_t'], 'meg')
+ templates = _read_coil_defs()
+ coils = _create_meg_coils(info['chs'], 'normal', info['dev_head_t'],
+ templates)
my_origin, int_rad, noise, lut_fun, n_fact = _setup_dots(mode, coils,
'meg')
self_dots = _do_self_dots(int_rad, False, coils, my_origin, 'meg',
@@ -215,8 +214,9 @@ def _compute_dots(info, mode='fast'):
info['bads'] = [] # if bads is different, hash will be different
info_from = pick_info(info, pick_from, copy=True)
- coils_from = _create_coils(info_from['chs'], FIFF.FWD_COIL_ACCURACY_NORMAL,
- info_from['dev_head_t'], 'meg')
+ templates = _read_coil_defs()
+ coils_from = _create_meg_coils(info_from['chs'], 'normal',
+ info_from['dev_head_t'], templates)
my_origin, int_rad, noise, lut_fun, n_fact = _setup_dots(mode, coils_from,
'meg')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment