Skip to content

Instantly share code, notes, and snippets.

@wmvanvliet
Created May 10, 2024 08:06
Show Gist options
  • Save wmvanvliet/1cf49422c385d298867c2363c77525d4 to your computer and use it in GitHub Desktop.
Save wmvanvliet/1cf49422c385d298867c2363c77525d4 to your computer and use it in GitHub Desktop.
Plotting clusters coming out of the cluster-based permutation tests.
"""Plot a cluster.
Plot the spatial extend of a cluster (as those returned from the cluster-based
permutation stats) on a brain.
Author: Marijn van Vliet <w.m.vanvliet@gmail.com>
"""
import mne
import numpy as np
def plot_cluster(cluster, src, brain, time_index=None, color="magenta", width=1):
"""Plot the spatial extent of a cluster on top of a brain.
Parameters
----------
cluster : tuple (time_idx, vertex_idx)
The cluster to plot.
src : SourceSpaces
The source space that was used for the inverse computation.
brain : Brain
The brain figure on which to plot the cluster.
time_index : int | None
The index of the time at which to plot the spatial extent of the cluster.
By default (None), the time of maximal spatial extent is chosen.
color : str
A maplotlib-style color specification indicating the color to use when plotting
the spatial extent of the cluster.
width : int
The width of the lines used to draw the outlines.
Returns
-------
brain : Brain
The brain figure, now with the cluster plotted on top of it.
"""
cluster_time_index, cluster_vertex_index = cluster
# A cluster is defined both in space and time. If we want to plot the boundaries of
# the cluster in space, we must choose a specific time for which to show the
# boundaries (as they change over time).
if time_index is None:
time_index, n_vertices = np.unique(cluster_time_index, return_counts=True)
time_index = time_index[np.argmax(n_vertices)]
# Select only the vertex indices at the chosen time
draw_vertex_index = [
v for v, t in zip(cluster_vertex_index, cluster_time_index) if t == time_index
]
# Let's create an anatomical label containing these vertex indices.
# Problem 1): a label must be defined for either the left or right hemisphere. It
# cannot span both hemispheres. So we must filter the vertices based on their
# hemisphere.
# Problem 2): we have vertex *indices* that need to be transformed into proper
# vertex numbers. Not every vertex in the original high-resolution brain mesh is a
# source point in the source estimate. Do draw nice smooth curves, we need to
# interpolate the vertex indices.
# Both problems can be solved by accessing the vertices defined in the source space
# object. The source space object is actually a list of two source spaces.
src_lh, src_rh = src
# Split the vertices based on the hemisphere in which they are located.
lh_verts, rh_verts = src_lh["vertno"], src_rh["vertno"]
n_lh_verts = len(lh_verts)
draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts]
draw_rh_verts = [
rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts
]
# Vertices in a label must be unique and in increasing order
draw_lh_verts = np.unique(draw_lh_verts)
draw_rh_verts = np.unique(draw_rh_verts)
# We are now ready to create the anatomical label objects
cluster_index = 0
for label in brain.labels["lh"] + brain.labels["rh"]:
if label.name.startswith("cluster-"):
try:
cluster_index = max(cluster_index, int(label.name.split("-", 1)[1]))
except ValueError:
pass
lh_label = mne.Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}")
rh_label = mne.Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}")
# Interpolate the vertices in each label to the full resolution mesh
if len(lh_label) > 0:
lh_label = lh_label.smooth(
smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir
)
brain.add_label(lh_label, borders=width, color=color)
if len(rh_label) > 0:
rh_label = rh_label.smooth(
smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir
)
brain.add_label(rh_label, borders=width, color=color)
def on_time_change(event):
print(event)
time_index = np.searchsorted(brain._times, event.time)
for hemi in brain._hemis:
mesh = brain._layered_meshes[hemi]
for i, label in enumerate(brain.labels[hemi]):
if label.name == f"cluster-{cluster_index}":
del brain.labels[hemi][i]
mesh.remove_overlay(label.name)
# Select only the vertex indices at the chosen time
draw_vertex_index = [
v
for v, t in zip(cluster_vertex_index, cluster_time_index)
if t == time_index
]
draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts]
draw_rh_verts = [
rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts
]
# Vertices in a label must be unique and in increasing order
draw_lh_verts = np.unique(draw_lh_verts)
draw_rh_verts = np.unique(draw_rh_verts)
lh_label = mne.Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}")
rh_label = mne.Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}")
if len(lh_label) > 0:
lh_label = lh_label.smooth(
smooth=3,
subject=brain._subject,
subjects_dir=brain._subjects_dir,
verbose=False,
)
brain.add_label(lh_label, borders=width, color=color)
if len(rh_label) > 0:
rh_label = rh_label.smooth(
smooth=3,
subject=brain._subject,
subjects_dir=brain._subjects_dir,
verbose=False,
)
brain.add_label(rh_label, borders=width, color=color)
mne.viz.ui_events.subscribe(brain, "time_change", on_time_change)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment