Skip to content

Instantly share code, notes, and snippets.

@kingjr
Created May 31, 2014 16:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kingjr/0807ebd6f458fdfd8bef to your computer and use it in GitHub Desktop.
Save kingjr/0807ebd6f458fdfd8bef to your computer and use it in GitHub Desktop.
def plot_evoked_img(evoked, picks=None, exclude='bads', unit=True, show=True,
ylim=None, proj=False, xlim='tight', hline=None, units=None,
scalings=None, titles=None, axes=None):
"""Plot evoked data as an image (chan x time) where color index amplitude
Parameters
----------
evoked : instance of Evoked
The evoked data
picks : array-like of int | None
The indices of channels to plot. If None show all.
exclude : list of str | 'bads'
Channels names to exclude from being shown. If 'bads', the
bad channels are excluded.
show : bool
Call pyplot.show() as the end or not.
ylim : dict | None
ylim for plots. e.g. ylim = dict(eeg=[-200e-6, 200e6])
Valid keys are eeg, mag, grad, misc. If None, the ylim parameter
for each channel equals the pyplot default.
xlim : 'tight' | tuple | None
xlim for plots.
proj : bool | 'interactive'
If true SSP projections are applied before display. If 'interactive',
a check box for reversible selection of SSP projection vectors will
be shown.
hline : list of floats | None
The values at which to show an horizontal line.
units : dict | None
The units of the channel types used for axes lables. If None,
defaults to `dict(eeg='uV', grad='fT/cm', mag='fT')`.
scalings : dict | None
The scalings of the channel types to be applied for plotting. If None,`
defaults to `dict(eeg=1e6, grad=1e13, mag=1e15)`.
titles : dict | None
The titles associated with the channels. If None, defaults to
`dict(eeg='EEG', grad='Gradiometers', mag='Magnetometers')`.
axes : instance of Axes | list | None
The axes to plot to. If list, the list must be a list of Axes of
the same length as the number of channel types. If instance of
Axes, there must be only one channel type plotted.
"""
import matplotlib.pyplot as plt
if axes is not None and proj == 'interactive':
raise RuntimeError('Currently only single axis figures are supported'
' for interactive SSP selection.')
scalings, titles, units = _mutable_defaults(('scalings', scalings),
('titles', titles),
('units', units))
channel_types = set(key for d in [scalings, titles, units] for key in d)
if picks is None:
picks = list(range(evoked.info['nchan']))
bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads']
if ch in evoked.ch_names]
if len(exclude) > 0:
if isinstance(exclude, string_types) and exclude == 'bads':
exclude = bad_ch_idx
elif (isinstance(exclude, list)
and all([isinstance(ch, string_types) for ch in exclude])):
exclude = [evoked.ch_names.index(ch) for ch in exclude]
else:
raise ValueError('exclude has to be a list of channel names or '
'"bads"')
picks = list(set(picks).difference(exclude))
types = [channel_type(evoked.info, idx) for idx in picks]
n_channel_types = 0
ch_types_used = []
for t in channel_types:
if t in types:
n_channel_types += 1
ch_types_used.append(t)
axes_init = axes # remember if axes where given as input
fig = None
if axes is None:
fig, axes = plt.subplots(n_channel_types, 1)
if isinstance(axes, plt.Axes):
axes = [axes]
elif isinstance(axes, np.ndarray):
axes = list(axes)
if axes_init is not None:
fig = axes[0].get_figure()
if not len(axes) == n_channel_types:
raise ValueError('Number of axes (%g) must match number of channel '
'types (%g)' % (len(axes), n_channel_types))
# instead of projecting during each iteration let's use the mixin here.
if proj is True and evoked.proj is not True:
evoked = evoked.copy()
evoked.apply_proj()
times = 1e3 * evoked.times # time in miliseconds
for ax, t in zip(axes, ch_types_used):
this_scaling = scalings[t]
idx = [picks[i] for i in range(len(picks)) if types[i] == t]
if len(idx) > 0:
D = this_scaling * evoked.data[idx, :]
# plt.axes(ax)
if ylim is not None and t in ylim:
im = ax.imshow(D, interpolation='nearest', origin='lower',
extent=[times[0], times[-1], 0, D.shape[0]], aspect='auto',
vmin=ylim[t][0], vmax=ylim[t][1])
else:
im = ax.imshow(D, interpolation='nearest', origin='lower',
extent=[times[0], times[-1], 0, D.shape[0]], aspect='auto')
if xlim is not None:
if xlim == 'tight':
xlim = (times[0], times[-1])
ax.set_xlim(xlim)
ax.set_title(titles[t] + ' (%d channel%s)' % (
len(D), 's' if len(D) > 1 else ''))
ax.set_xlabel('time (ms)')
ax.set_ylabel('channels')
plt.colorbar(im, ax=ax)
if axes_init is None:
plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63)
if show and plt.get_backend() != 'agg':
plt.show()
fig.canvas.draw() # for axes plots update axes.
tight_layout(fig=fig)
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment