Skip to content

Instantly share code, notes, and snippets.

@kingjr
Created January 17, 2020 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/3c24666903a7ba954084a88643afdf76 to your computer and use it in GitHub Desktop.
Save kingjr/3c24666903a7ba954084a88643afdf76 to your computer and use it in GitHub Desktop.
extract_2d.py
import mne
import numpy as np
import matplotlib.pyplot as plt
import scipy
import torch
from itertools import product
import pandas as pd
import seaborn as sns
def cart_to_pol(x, y):
rho = np.sqrt(x**2 + y**2)
phi = np.arctan2(y, x)
return (rho, phi)
def pol_to_cart(rho, phi):
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return np.c_[x, y]
def get_border_channels(xy_chan):
from scipy.spatial import ConvexHull
# x and y between 0 and 1
assert xy_chan.min() == 0
assert xy_chan.max() == 1
# center on 0
xy_chan = xy_chan - .5
# select border channels
sel = np.unique(ConvexHull(xy_chan).simplices)
xy_border = xy_chan[sel, :]
# extend them
rho, phi = cart_to_pol(*xy_border.T)
xy_border = pol_to_cart(rho * 2, phi) + .5 # decenter
return xy_border, sel
def get_chan2pix_matrices(xy_chan, n_pixels=20, method='cubic', dtype=torch.float32, borders=False):
from scipy.interpolate import griddata
# make sure xy are between 0 and 1
xy_chan = xy_chan - xy_chan.min(0)
xy_chan /= xy_chan.ptp(0)
n_chans = len(xy_chan)
# Add virtual border channels
xy_border, border_idx = get_border_channels(xy_chan)
# Define 2d map mesh
offset = .1 # to avoid border cropping
x_new, y_new = np.meshgrid(np.linspace(-offset, 1. + offset, n_pixels),
np.linspace(-offset, 1. + offset, n_pixels))
xy_pix = (x_new.ravel(), y_new.ravel())
# Define channels to pixels matrix conversion
chan_to_pix = torch.zeros((n_chans, n_pixels, n_pixels), dtype=dtype)
for c in range(n_chans):
d = np.zeros(n_chans+len(xy_border))
d[c] = 1.
if not borders:
# color corresponding border channels
if c in border_idx:
d[n_chans + np.where(border_idx==c)[0][0]] = 1.
map2d = griddata(np.r_[xy_chan, xy_border],
d, xy_pix, method, fill_value=0)
else:
map2d = griddata(xy_chan, d[:n_chans], xy_pix, method, fill_value=0)
chan_to_pix[c] = torch.from_numpy(map2d.reshape(n_pixels, n_pixels))
# Define pixels to channels matrix conversion
pix_to_chan = torch.zeros((n_pixels, n_pixels, n_chans), dtype=dtype)
for pix1 in range(n_pixels):
for pix2 in range(n_pixels):
d = np.zeros((n_pixels, n_pixels))
d[pix1, pix2] = 1
map1d = griddata(xy_pix, d.ravel(), xy_chan, method, fill_value=0)
pix_to_chan[pix1, pix2] = torch.from_numpy(map1d)
return np.array(xy_pix).T, chan_to_pix, pix_to_chan
if __name__ == '__main__':
# read Data
mne.set_log_level(False)
path = mne.datasets.sample.data_path()
fname = path + '/MEG/sample/sample_audvis-ave.fif'
evoked = mne.read_evokeds(fname)[0]
# Get channel 2D position (flatten 3D)
evoked = evoked.pick_types(meg='mag', ref_meg=False)
layout = mne.find_layout(evoked.info)
xy_chan = layout.pos[:, :2]
xy_chan -= xy_chan.min(0)
xy_chan /= xy_chan.max(0)
# if want to explore 3d conv:
xyz_chan = np.array([ch['loc'][:3] for ch in evoked.info['chs']])
# main
xy_pix, chan_to_pix, pix_to_chan = get_chan2pix_matrices(xy_chan, n_pixels=20, method='cubic', borders=False)
# plot
t = 20
data = torch.tensor(evoked.data[:, t] * 1e12, dtype=torch.float32)
n_chans = len(data)
pix = data @ chan_to_pix.reshape(n_chans, -1)
back = pix @ pix_to_chan.reshape(-1, n_chans)
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True)
for d, ax, pos in zip((data, pix, back), axes, (xy_chan, xy_pix, xy_chan)):
ax.scatter(*pos.T, c=d, cmap='coolwarm')
if __name__ == '__test__':
def evaluate(data, xy_chan, n_pixels, method, borders):
"""Evaluate reconstruction error"""
n_chans, n_times = data.shape
xy_pix, chan_to_pix, pix_to_chan = get_chan2pix_matrices(xy_chan,
n_pixels=n_pixels,
method=method,
borders=borders)
err = list()
for orig in data.T:
pix = orig @ chan_to_pix.reshape(n_chans, -1)
back = pix @ pix_to_chan.reshape(-1, n_chans)
err.append(((orig - back)**2).sum())
return np.asarray(err)
# Vary across parameters
df = list()
for n_pixels, method, borders in product([5, 15, 25], ('cubic', 'linear'), (False, True)):
print('.', end='')
err = evaluate(data, xy_chan, n_pixels, method, borders)
df.extend([dict(err=e, n_pixels=n_pixels, method=method, borders=borders)
for e in err])
df = pd.DataFrame(df)
fig, axes = plt.subplots(1, 3)
for metrics, ax in zip(('n_pixels', 'method', 'borders'), axes):
sns.boxplot(x=metrics, y='err', data=df, ax=ax)
fig, axes = plt.subplots(1, 2)
for metrics, ax in zip(('n_pixels', 'method'), axes):
sns.boxplot(x=metrics, y='err', data=df.query('borders==False and n_pixels>5'), ax=ax, hue='n_pixels')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment