Skip to content

Instantly share code, notes, and snippets.

@hanjinliu
Last active August 22, 2021 14:21
Show Gist options
  • Save hanjinliu/b2b84ec6774b79d2027fa8a5ed945411 to your computer and use it in GitHub Desktop.
Save hanjinliu/b2b84ec6774b79d2027fa8a5ed945411 to your computer and use it in GitHub Desktop.
Make image analysis protocols in napari from a function
from __future__ import annotations
import inspect
from typing import Callable
import napari
from magicgui.widgets import Label, Table, create_widget, Container
import numpy as np
from skimage.measure import regionprops_table
# GUI class can generate image analysis protocols using decorator @gui.bind_protocol
class GUI:
def __init__(self, viewer:"napari.Viewer"):
self.viewer = viewer
self.proceed = False
self._yielded_func = None
def bind_protocol(self, func=None, key1:str="F1", key2:str="F2",
allowed_dims:int|tuple[int, ...]=(1, 2, 3), exit_with_error:bool=False):
"""
Make protocol from a generator of functions.
Parameters
----------
func : function generator
Protocol function. This function must accept ``func(self)`` and yield functions that accept
``f(self, **kwargs)`` . Docstring of the yielded functions will be displayed on the top of the
parameter container as a tooltip. Therefore it would be very useful if you write procedure of
the protocol as docstrings.
key1 : str, default is "F1"
First key binding. When this key is pushed ``self.proceed`` will be False.
key2 : str, default is "F2"
Second key binding. When this key is pushed ``self.proceed`` will be True.
allowed_dims : int or tuple of int, default is (1, 2, 3)
Function will not be called if the number of displayed dimensions does not match it.
exit_with_error :bool default is False
If True, protocol will quit whenever exception is raised and key binding will be released. If
False, protocol continues from the same step.
"""
allowed_dims = (allowed_dims,) if isinstance(allowed_dims, int) else tuple(allowed_dims)
def wrapper(protocol):
if not callable(protocol):
raise TypeError("func must be callable.")
gen = protocol(self) # prepare generator from protocol function
# initialize
self.proceed = False
self._yielded_func = next(gen)
self._add_parameter_container(self._yielded_func)
def _exit(viewer:"napari.Viewer"):
# delete keymap
viewer.keymap.pop(key1)
# delete widget
viewer.window.remove_dock_widget(viewer.window._dock_widgets["Parameter Container"])
return None
@self.viewer.bind_key(key1, overwrite=True)
def _1(viewer:"napari.Viewer"):
self.proceed = False
return _base(viewer)
@self.viewer.bind_key(key2, overwrite=True)
def _2(viewer:"napari.Viewer"):
self.proceed = True
return _base(viewer)
def _base(viewer:"napari.Viewer"):
if not viewer.dims.ndisplay in allowed_dims:
return None
try:
# call the current function
self._yielded_func(self, **self.params)
except Exception:
exit_with_error and _exit(viewer)
raise
else:
try:
# get next function, update container if needed
next_func = next(gen)
if next_func != self._yielded_func:
# This avoid container renewing
self._yielded_func = next_func
self._add_parameter_container(self._yielded_func)
except StopIteration:
_exit(viewer)
# update all the layers
for layer in viewer.layers:
layer.refresh()
return None
return protocol
return wrapper if func is None else wrapper(func)
def _add_parameter_container(self, f:Callable):
widget_name = "Parameter Container"
params = inspect.signature(f).parameters
if not f.__doc__ and len(params) == 1:
return None
if widget_name in self.viewer.window._dock_widgets:
# clear all the widgets if container already exists
self._container.clear()
while self._container.native.layout().count() > 0:
self._container.native.layout().takeAt(0)
else:
# make new container
self._container = Container(name=widget_name)
wid = self.viewer.window.add_dock_widget(self._container, area="right", name=widget_name)
wid.resize(140, 100)
wid.setFloating(True)
if f.__doc__:
self._container.append(Label(value=f.__doc__))
for i, (name, param) in enumerate(params.items()):
# make a container widget
if i == 0:
continue
value = None if param.default is inspect._empty else param.default
widget = create_widget(value=value, annotation=param.annotation,
name=name, param_kind=param.kind)
self._container.append(widget)
self.viewer.window._dock_widgets[widget_name].show()
return None
@property
def params(self) -> dict:
"""
Get parameter values from the container
"""
if hasattr(self, "_container"):
kwargs = {wid.name: wid.value for wid in self._container if not isinstance(wid, Label)}
else:
kwargs = {}
return kwargs
# Defining a class is not a must, but it will be easier to write protocols.
# Here we define Measure class for running regionprops around manually picked points.
class Measure:
def __init__(self):
self.image_layer = None
self.labels_layer = None
self.points_layer = None
def select_molecules(self, gui:GUI):
"""
Add markers with "F1".
Go to next step with "F2".
"""
if gui.proceed:
return
pos = gui.viewer.cursor.position
if self.points_layer is None:
self.points_layer = gui.viewer.add_points(pos,
face_color=[0,0,0,0],
edge_color=[0,1,0,1],
)
else:
self.points_layer.add(pos)
def select_image(self, gui:GUI):
"""
Select target image and push "F1".
"""
selected = list(gui.viewer.layers.selection)[0]
if not isinstance(selected, napari.layers.Image):
raise TypeError("Selected layer is not an image.")
self.image_layer = selected
labels = np.zeros(self.image_layer.data.shape, dtype=np.uint32)
self.labels_layer = gui.viewer.add_labels(labels, opacity=0.5)
def label(self, gui:GUI, radius=3):
"""
Set proper radius to label around markers.
Push "F1" to preview.
Push "F2" to apply.
"""
coords = self.points_layer.data.astype(np.int32)
lbl = self.labels_layer.data
for i, crds in enumerate(coords):
y0 = max(crds[0]-radius, 0)
y1 = min(crds[0]+radius, lbl.shape[0])
x0 = max(crds[1]-radius, 0)
x1 = min(crds[1]+radius, lbl.shape[1])
lbl[y0:y1, x0:x1] = i+1
def measure(self, gui:GUI):
d = regionprops_table(self.labels_layer.data, self.image_layer.data, properties=("mean_intensity", "area"))
table = Table(d)
gui.viewer.window.add_dock_widget(table.native, name="Measurement", area="right")
if __name__ == "__main__":
viewer = napari.Viewer()
gui = GUI(viewer)
# Function "func" will be converted to image analysis protocol in the viewer.
@gui.bind_protocol
def func(gui):
measure = Measure()
gui.proceed = False
while not gui.proceed:
yield measure.select_molecules
yield measure.select_image
gui.proceed = False
while not gui.proceed:
yield measure.label
measure.measure(gui)
napari.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment