Skip to content

Instantly share code, notes, and snippets.

Created January 17, 2020 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/3c24666903a7ba954084a88643afdf76 to your computer and use it in GitHub Desktop.
Save kingjr/3c24666903a7ba954084a88643afdf76 to your computer and use it in GitHub Desktop.
import mne
import numpy as np
import matplotlib.pyplot as plt
import scipy
import torch
from itertools import product
import pandas as pd
import seaborn as sns
def cart_to_pol(x, y):
rho = np.sqrt(x**2 + y**2)
phi = np.arctan2(y, x)
return (rho, phi)
def pol_to_cart(rho, phi):
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return np.c_[x, y]
def get_border_channels(xy_chan):
from scipy.spatial import ConvexHull
# x and y between 0 and 1
assert xy_chan.min() == 0
assert xy_chan.max() == 1
# center on 0
xy_chan = xy_chan - .5
# select border channels
sel = np.unique(ConvexHull(xy_chan).simplices)
xy_border = xy_chan[sel, :]
# extend them
rho, phi = cart_to_pol(*xy_border.T)
xy_border = pol_to_cart(rho * 2, phi) + .5 # decenter
return xy_border, sel
def get_chan2pix_matrices(xy_chan, n_pixels=20, method='cubic', dtype=torch.float32, borders=False):
from scipy.interpolate import griddata
# make sure xy are between 0 and 1
xy_chan = xy_chan - xy_chan.min(0)
xy_chan /= xy_chan.ptp(0)
n_chans = len(xy_chan)
# Add virtual border channels
xy_border, border_idx = get_border_channels(xy_chan)
# Define 2d map mesh
offset = .1 # to avoid border cropping
x_new, y_new = np.meshgrid(np.linspace(-offset, 1. + offset, n_pixels),
np.linspace(-offset, 1. + offset, n_pixels))
xy_pix = (x_new.ravel(), y_new.ravel())
# Define channels to pixels matrix conversion
chan_to_pix = torch.zeros((n_chans, n_pixels, n_pixels), dtype=dtype)
for c in range(n_chans):
d = np.zeros(n_chans+len(xy_border))
d[c] = 1.
if not borders:
# color corresponding border channels
if c in border_idx:
d[n_chans + np.where(border_idx==c)[0][0]] = 1.
map2d = griddata(np.r_[xy_chan, xy_border],
d, xy_pix, method, fill_value=0)
map2d = griddata(xy_chan, d[:n_chans], xy_pix, method, fill_value=0)
chan_to_pix[c] = torch.from_numpy(map2d.reshape(n_pixels, n_pixels))
# Define pixels to channels matrix conversion
pix_to_chan = torch.zeros((n_pixels, n_pixels, n_chans), dtype=dtype)
for pix1 in range(n_pixels):
for pix2 in range(n_pixels):
d = np.zeros((n_pixels, n_pixels))
d[pix1, pix2] = 1
map1d = griddata(xy_pix, d.ravel(), xy_chan, method, fill_value=0)
pix_to_chan[pix1, pix2] = torch.from_numpy(map1d)
return np.array(xy_pix).T, chan_to_pix, pix_to_chan
if __name__ == '__main__':
# read Data
path = mne.datasets.sample.data_path()
fname = path + '/MEG/sample/sample_audvis-ave.fif'
evoked = mne.read_evokeds(fname)[0]
# Get channel 2D position (flatten 3D)
evoked = evoked.pick_types(meg='mag', ref_meg=False)
layout = mne.find_layout(
xy_chan = layout.pos[:, :2]
xy_chan -= xy_chan.min(0)
xy_chan /= xy_chan.max(0)
# if want to explore 3d conv:
xyz_chan = np.array([ch['loc'][:3] for ch in['chs']])
# main
xy_pix, chan_to_pix, pix_to_chan = get_chan2pix_matrices(xy_chan, n_pixels=20, method='cubic', borders=False)
# plot
t = 20
data = torch.tensor([:, t] * 1e12, dtype=torch.float32)
n_chans = len(data)
pix = data @ chan_to_pix.reshape(n_chans, -1)
back = pix @ pix_to_chan.reshape(-1, n_chans)
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True)
for d, ax, pos in zip((data, pix, back), axes, (xy_chan, xy_pix, xy_chan)):
ax.scatter(*pos.T, c=d, cmap='coolwarm')
if __name__ == '__test__':
def evaluate(data, xy_chan, n_pixels, method, borders):
"""Evaluate reconstruction error"""
n_chans, n_times = data.shape
xy_pix, chan_to_pix, pix_to_chan = get_chan2pix_matrices(xy_chan,
err = list()
for orig in data.T:
pix = orig @ chan_to_pix.reshape(n_chans, -1)
back = pix @ pix_to_chan.reshape(-1, n_chans)
err.append(((orig - back)**2).sum())
return np.asarray(err)
# Vary across parameters
df = list()
for n_pixels, method, borders in product([5, 15, 25], ('cubic', 'linear'), (False, True)):
print('.', end='')
err = evaluate(data, xy_chan, n_pixels, method, borders)
df.extend([dict(err=e, n_pixels=n_pixels, method=method, borders=borders)
for e in err])
df = pd.DataFrame(df)
fig, axes = plt.subplots(1, 3)
for metrics, ax in zip(('n_pixels', 'method', 'borders'), axes):
sns.boxplot(x=metrics, y='err', data=df, ax=ax)
fig, axes = plt.subplots(1, 2)
for metrics, ax in zip(('n_pixels', 'method'), axes):
sns.boxplot(x=metrics, y='err', data=df.query('borders==False and n_pixels>5'), ax=ax, hue='n_pixels')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment