Skip to content

Instantly share code, notes, and snippets.

@jsosulski
Created May 10, 2021 06:53
Show Gist options
  • Save jsosulski/2c5660c185e4b20f453d0fe928a54358 to your computer and use it in GitHub Desktop.
Save jsosulski/2c5660c185e4b20f453d0fe928a54358 to your computer and use it in GitHub Desktop.
Script to create visualizations for each session/subject of each P300 dataset
import warnings
from pathlib import Path
import mne
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from moabb.paradigms import P300
from spot.datasets.utils import get_p300_datasets
from spot.visualization.utils import add_baseline, add_color_spans
sns.set_style("whitegrid")
def create_plot_overview(epo, plot_opts=None, path=None):
# Butterflyplot
epo_t = epo["Target"]
epo_nt = epo["NonTarget"]
evkd = epo_t.average()
evkd_nt = epo_nt.average()
fig1, axes = plt.subplots(2, 1, figsize=(6, 6), sharey="all", sharex="all")
evkd.plot(spatial_colors=True, show=False, axes=axes[0])
axes[0].set_title("Target response")
evkd_nt.plot(spatial_colors=True, show=False, axes=axes[1])
axes[1].set_title("NonTarget response")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fig1.tight_layout()
fig1.savefig(path / f"butterflyplot.{plot_format}", dpi=plot_opts["dpi"])
# topomap
tp = plot_opts["topo"]["timepoints"]
tmin, tmax = plot_opts["topo"]["tmin"], plot_opts["topo"]["tmax"]
times = np.linspace(tmin, tmax, tp)
fig2 = evkd.plot_topomap(times=times, colorbar=True, show=False)
fig2.savefig(
path / f"topomap_{tp}timepoints.{plot_format}",
dpi=plot_opts["dpi"],
)
# jointmap
fig3 = evkd.plot_joint(show=False)
fig3.savefig(path / f"jointmap.{plot_format}", dpi=plot_opts["dpi"])
# sensorplot
fig4 = mne.viz.plot_compare_evokeds(
[evkd.crop(0, 0.6), evkd_nt.crop(0, 0.6)], axes="topo", show=False
)
fig4[0].savefig(path / f"sensorplot.{plot_format}", dpi=plot_opts["dpi"])
fig5, ax = plt.subplots(2, 1, figsize=(8, 6), sharex="all", sharey="all")
t_data = epo_t.get_data() * 1e6
nt_data = epo_nt.get_data() * 1e6
data = epo.get_data() * 1e6
minmax = np.max(data, axis=2) - np.min(data, axis=2)
per_channel = np.mean(minmax, axis=0)
worst_ch = np.argsort(per_channel)
worst_ch = worst_ch[max(-8, -len(epo.ch_names)) :]
minmax_t = np.max(t_data, axis=2) - np.min(t_data, axis=2)
minmax_nt = np.max(nt_data, axis=2) - np.min(nt_data, axis=2)
ch = epo_t.ch_names
for i in range(minmax_nt.shape[1]):
lab = ch[i] if i in worst_ch else None
sns.kdeplot(minmax_t[:, i], ax=ax[0], label=lab, clip=(0, 300))
sns.kdeplot(minmax_nt[:, i], ax=ax[1], label=lab, clip=(0, 300))
ax[0].set_xlim(0, 200)
ax[0].set_title("Target minmax")
ax[1].set_title("NonTarget minmax")
ax[1].set_xlabel("Minmax in $\mu$V")
ax[1].legend(title="Worst channels")
fig5.tight_layout()
fig5.savefig(path / f"minmax.{plot_format}", dpi=plot_opts["dpi"])
plt.close("all")
FIGURES_PATH = Path("/home/jan/bci_data/figures/moabb")
plot_format = "png"
baseline = None
highpass = 0.5
lowpass = 16
paradigm = P300(
resample=100,
fmin=highpass,
fmax=lowpass,
baseline=baseline,
)
ival = [-0.3, 1]
plot_opts = {
"dpi": 120,
"topo": {
"timepoints": 10,
"tmin": 0,
"tmax": 0.6,
},
}
dsets = get_p300_datasets()
plt.ioff()
dsets = dsets
for dset in dsets:
print(f"Processing dataset: {dset}")
dset.interval = ival
dset_name = dset.__class__.__name__
data_path = FIGURES_PATH / dset_name # Path of the dataset folder
Path(data_path).mkdir(parents=True, exist_ok=True)
for subject in dset.subject_list:
print(f" Subject: {subject}")
try:
_, _, meta, epos, _ = paradigm.get_data(dset, [subject], return_epochs=True)
except:
print(f"Failed to get data for {dset_name}-{subject}")
continue
subject_path = data_path / f"subject_{subject}"
subject_path.mkdir(parents=True, exist_ok=True)
create_plot_overview(epos, plot_opts=plot_opts, path=subject_path)
if len(meta["session"].unique()) > 1:
for session in meta["session"].unique():
session_path = subject_path / f"session_{session}"
session_path.mkdir(parents=True, exist_ok=True)
ix = meta.session == session
create_plot_overview(epos[ix], plot_opts=plot_opts, path=session_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment