Created
January 17, 2020 10:08
-
-
Save kingjr/3c24666903a7ba954084a88643afdf76 to your computer and use it in GitHub Desktop.
extract_2d.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mne | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import scipy | |
import torch | |
from itertools import product | |
import pandas as pd | |
import seaborn as sns | |
def cart_to_pol(x, y): | |
rho = np.sqrt(x**2 + y**2) | |
phi = np.arctan2(y, x) | |
return (rho, phi) | |
def pol_to_cart(rho, phi): | |
x = rho * np.cos(phi) | |
y = rho * np.sin(phi) | |
return np.c_[x, y] | |
def get_border_channels(xy_chan): | |
from scipy.spatial import ConvexHull | |
# x and y between 0 and 1 | |
assert xy_chan.min() == 0 | |
assert xy_chan.max() == 1 | |
# center on 0 | |
xy_chan = xy_chan - .5 | |
# select border channels | |
sel = np.unique(ConvexHull(xy_chan).simplices) | |
xy_border = xy_chan[sel, :] | |
# extend them | |
rho, phi = cart_to_pol(*xy_border.T) | |
xy_border = pol_to_cart(rho * 2, phi) + .5 # decenter | |
return xy_border, sel | |
def get_chan2pix_matrices(xy_chan, n_pixels=20, method='cubic', dtype=torch.float32, borders=False): | |
from scipy.interpolate import griddata | |
# make sure xy are between 0 and 1 | |
xy_chan = xy_chan - xy_chan.min(0) | |
xy_chan /= xy_chan.ptp(0) | |
n_chans = len(xy_chan) | |
# Add virtual border channels | |
xy_border, border_idx = get_border_channels(xy_chan) | |
# Define 2d map mesh | |
offset = .1 # to avoid border cropping | |
x_new, y_new = np.meshgrid(np.linspace(-offset, 1. + offset, n_pixels), | |
np.linspace(-offset, 1. + offset, n_pixels)) | |
xy_pix = (x_new.ravel(), y_new.ravel()) | |
# Define channels to pixels matrix conversion | |
chan_to_pix = torch.zeros((n_chans, n_pixels, n_pixels), dtype=dtype) | |
for c in range(n_chans): | |
d = np.zeros(n_chans+len(xy_border)) | |
d[c] = 1. | |
if not borders: | |
# color corresponding border channels | |
if c in border_idx: | |
d[n_chans + np.where(border_idx==c)[0][0]] = 1. | |
map2d = griddata(np.r_[xy_chan, xy_border], | |
d, xy_pix, method, fill_value=0) | |
else: | |
map2d = griddata(xy_chan, d[:n_chans], xy_pix, method, fill_value=0) | |
chan_to_pix[c] = torch.from_numpy(map2d.reshape(n_pixels, n_pixels)) | |
# Define pixels to channels matrix conversion | |
pix_to_chan = torch.zeros((n_pixels, n_pixels, n_chans), dtype=dtype) | |
for pix1 in range(n_pixels): | |
for pix2 in range(n_pixels): | |
d = np.zeros((n_pixels, n_pixels)) | |
d[pix1, pix2] = 1 | |
map1d = griddata(xy_pix, d.ravel(), xy_chan, method, fill_value=0) | |
pix_to_chan[pix1, pix2] = torch.from_numpy(map1d) | |
return np.array(xy_pix).T, chan_to_pix, pix_to_chan | |
if __name__ == '__main__': | |
# read Data | |
mne.set_log_level(False) | |
path = mne.datasets.sample.data_path() | |
fname = path + '/MEG/sample/sample_audvis-ave.fif' | |
evoked = mne.read_evokeds(fname)[0] | |
# Get channel 2D position (flatten 3D) | |
evoked = evoked.pick_types(meg='mag', ref_meg=False) | |
layout = mne.find_layout(evoked.info) | |
xy_chan = layout.pos[:, :2] | |
xy_chan -= xy_chan.min(0) | |
xy_chan /= xy_chan.max(0) | |
# if want to explore 3d conv: | |
xyz_chan = np.array([ch['loc'][:3] for ch in evoked.info['chs']]) | |
# main | |
xy_pix, chan_to_pix, pix_to_chan = get_chan2pix_matrices(xy_chan, n_pixels=20, method='cubic', borders=False) | |
# plot | |
t = 20 | |
data = torch.tensor(evoked.data[:, t] * 1e12, dtype=torch.float32) | |
n_chans = len(data) | |
pix = data @ chan_to_pix.reshape(n_chans, -1) | |
back = pix @ pix_to_chan.reshape(-1, n_chans) | |
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True) | |
for d, ax, pos in zip((data, pix, back), axes, (xy_chan, xy_pix, xy_chan)): | |
ax.scatter(*pos.T, c=d, cmap='coolwarm') | |
if __name__ == '__test__': | |
def evaluate(data, xy_chan, n_pixels, method, borders): | |
"""Evaluate reconstruction error""" | |
n_chans, n_times = data.shape | |
xy_pix, chan_to_pix, pix_to_chan = get_chan2pix_matrices(xy_chan, | |
n_pixels=n_pixels, | |
method=method, | |
borders=borders) | |
err = list() | |
for orig in data.T: | |
pix = orig @ chan_to_pix.reshape(n_chans, -1) | |
back = pix @ pix_to_chan.reshape(-1, n_chans) | |
err.append(((orig - back)**2).sum()) | |
return np.asarray(err) | |
# Vary across parameters | |
df = list() | |
for n_pixels, method, borders in product([5, 15, 25], ('cubic', 'linear'), (False, True)): | |
print('.', end='') | |
err = evaluate(data, xy_chan, n_pixels, method, borders) | |
df.extend([dict(err=e, n_pixels=n_pixels, method=method, borders=borders) | |
for e in err]) | |
df = pd.DataFrame(df) | |
fig, axes = plt.subplots(1, 3) | |
for metrics, ax in zip(('n_pixels', 'method', 'borders'), axes): | |
sns.boxplot(x=metrics, y='err', data=df, ax=ax) | |
fig, axes = plt.subplots(1, 2) | |
for metrics, ax in zip(('n_pixels', 'method'), axes): | |
sns.boxplot(x=metrics, y='err', data=df.query('borders==False and n_pixels>5'), ax=ax, hue='n_pixels') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment