Skip to content

Instantly share code, notes, and snippets.

@ycanerol
Created March 12, 2019 16:56
Show Gist options
  • Save ycanerol/04e6a2dc64565b301c1dbd80372d4054 to your computer and use it in GitHub Desktop.
Save ycanerol/04e6a2dc64565b301c1dbd80372d4054 to your computer and use it in GitHub Desktop.
Multi-STA browser with slider
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.widgets import Slider
import iofuncs as iof
import plotfuncs as plf
def stabrowser(stas, frame_duration=None, cmap=None, centerzero=True):
"""
Returns an interactive plot to browse multiple spatiotemporal
STAs at the same time.
Parameters
--------
stas:
Numpy array containing STAs. First dimension should index individual cells,
last dimension should index time.
frame_duration:
Time between each frame. (optional)
cmap:
Colormap to use.
centerzero:
Whether to center the colormap around zero for diverging colormaps.
Example
------
>>> print(stas.shape) # (nrcells, xpixels, ypixels, time)
(36, 75, 100, 40)
>>> fig, slider = stabrowser(stas, frame_duration=1/60)
Notes
-----
When calling the function, the slider is returned to prevent the reference
to it getting destroyed and to keep it interactive.
The dummy variable `_` can also be used.
"""
if not mpl.get_backend().startswith('Qt'):
raise ValueError('Switch to a Qt backend to see the animation.')
if cmap is None:
cmap = iof.config('colormap')
if centerzero:
vmax = np.nanmax(np.abs(stas))
vmin = -vmax
else:
vmax, vmin = stas.max(), stas.min()
imshowkwargs = dict(cmap=cmap, vmax=vmax, vmin=vmin)
rows, cols = plf.numsubplots(stas.shape[0])
fig, axes = plt.subplots(rows, cols, sharex=True, sharey=True)
initial_frame = 5
axsl = fig.add_axes([0.25, 0.05, 0.65, 0.03])
# For the slider to remain interactive, a reference to it should
# be kept, so it is returned by the function
slider_t = Slider(axsl, 'Frame before spike',
0, stas.shape[-1]-1,
valinit=initial_frame,
valstep=1,
valfmt='%2.0f')
def update(frame):
frame = int(frame)
for i in range(rows):
for j in range(cols):
im = axes[i, j].get_images()[0]
im.set_data(stas[i*rows+j, ..., frame])
if frame_duration is not None:
fig.suptitle(f'{frame*frame_duration*1000:4.0f} ms')
fig.canvas.draw_idle()
slider_t.on_changed(update)
for i in range(rows):
for j in range(cols):
ax = axes[i, j]
ax.imshow(stas[i*rows+j, ..., initial_frame], **imshowkwargs)
ax.set_axis_off()
plt.tight_layout()
plt.subplots_adjust(wspace=.01, hspace=.01)
return fig, slider_t
data = iof.load('20180802', 6) # Frozen noise
stas = np.array(data['stas'])
frame_duration = data['frame_duration']
fig, _ = stabrowser(stas, frame_duration)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment