Skip to content

Instantly share code, notes, and snippets.

@henryroe
Last active August 26, 2020 13:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save henryroe/885604eaa68ccf104594 to your computer and use it in GitHub Desktop.
Save henryroe/885604eaa68ccf104594 to your computer and use it in GitHub Desktop.
The demo of traitsui, matplotlib, including a pop-up menu, I wish I'd found.
import wx
import matplotlib
matplotlib.use('WXAgg')
from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib.image import AxesImage
from matplotlib.axes import Axes
from matplotlib.widgets import AxesWidget
import matplotlib.pyplot as plt
from matplotlib import cm
from traits.api import Instance
from traitsui.wx.editor import Editor
from traitsui.wx.basic_editor_factory import BasicEditorFactory
import numpy as np
from traits.api import HasTraits, Array, CInt, Int, CFloat, Float, Str, on_trait_change, Range, Enum, List, Dict
from traitsui.api import View, Item, Handler, HGroup, VGroup, StatusItem, TextEditor
from traitsui.ui_info import UIInfo
import pdb # NOQA
# TODO: someday come back and play more with resizing, including:
# - maintaining proportion of image panel
# - minimum size for each element enforced (i.e. no clipping of UI and no resizing of plot figure) so window can't be made tiny
# TODO: someday come back and play with how gui closes itself out. (if run from python environment the window does not close itself, although control is returned to the python prompt and it can be relaunched just fine again.) I suspect this is something left dangling by wx or matplotlib, so probably answer is to clean up and delete a bunch of those object/event connections.
def _clear_ticks_and_frame_from_axes(ax=None):
if ax is None:
ax = plt.gca()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
for spine in ax.spines.itervalues():
spine.set_visible(False)
# Much credit to Gael Varoquaux for his [tutorial on using matplotlib within traitsui](http://docs.enthought.com/traitsui/tutorials/traits_ui_scientific_app.html) from which the following Editor classes are derived.
class _ResizingMPLFigureEditor(Editor):
scrollable = True
panel = Instance(wx._windows.Panel)
canvas = Instance(FigureCanvas)
def init(self, parent):
self.control = self._create_canvas(parent)
def update_editor(self):
pass
def _create_canvas(self, parent):
""" Create the MPL canvas. """
self.panel = panel = wx.Panel(parent, -1, style=wx.CLIP_CHILDREN)
sizer = wx.BoxSizer(wx.VERTICAL)
panel.SetSizer(sizer)
self.canvas = FigureCanvas(panel, -1, self.value)
sizer.Add(self.canvas, 1.0, wx.LEFT | wx.TOP | wx.GROW)
return panel
class ResizingMPLFigureEditor(BasicEditorFactory):
klass = _ResizingMPLFigureEditor
class _NonResizingMPLFigureEditor(Editor):
scrollable = False
panel = Instance(wx._windows.Panel)
canvas = Instance(FigureCanvas)
def init(self, parent):
self.control = self._create_canvas(parent)
def update_editor(self):
pass
def _create_canvas(self, parent):
""" Create the MPL canvas. """
self.panel = panel = wx.Panel(parent, -1, style=wx.CLIP_CHILDREN)
sizer = wx.BoxSizer(wx.VERTICAL)
panel.SetSizer(sizer)
self.canvas = FigureCanvas(panel, -1, self.value)
sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
return panel
class NonResizingMPLFigureEditor(BasicEditorFactory):
klass = _NonResizingMPLFigureEditor
class MPLInitHandler(Handler):
ui_info = Instance(UIInfo)
def init(self, info):
"""
This method gets called after the controls have all been
created but before they are displayed.
"""
self.ui_info = info
self.ui_info.object.setup_mpl_events()
return True
class Test(HasTraits):
image = Array()
image_figure = Instance(Figure, ())
image_axes = Instance(Axes)
image_axesimage = Instance(AxesImage)
image_xsize = Int(256)
image_ysize = Int(256)
image_ramp_power = Range(value=1.0, low=0.0, high=None)
image_cmap = Str('gray')
image_popup_menu = Instance(wx.Menu)
available_cmaps = List(['gray', 'hot', 'hsv', 'jet'])
eventID_to_cmap = Dict()
cmap_to_eventID = Dict()
cmap_bitmaps = Dict()
plot_figure = Instance(Figure, ())
plot_axes = Instance(Axes)
plot_num_points = Range(value=20, low=10, high=100)
plot_coeff0 = Range(value=1.0, low=0.0, high=20.0)
plot_coeff1 = Range(value=1.0, low=0.0, high=5.0)
plot_gain = Range(value=1.0, low=0.1, high=3.0)
status_string_left = Str('')
status_string_right = Str('')
def __init__(self):
super(Test, self).__init__()
self.eventID_to_cmap = {wx.NewId():x for x in self.available_cmaps}
self.cmap_to_eventID = {self.eventID_to_cmap[x]:x for x in self.eventID_to_cmap}
cmap_bitmap_height = 15
cmap_bitmap_width = 100
for cmap in self.available_cmaps:
temp = cm.ScalarMappable(cmap=cmap)
rgba = temp.to_rgba( np.outer( np.ones(cmap_bitmap_height, dtype=np.uint8),
np.arange(cmap_bitmap_width, dtype=np.uint8) ) )
self.cmap_bitmaps[cmap] = wx.BitmapFromBufferRGBA(cmap_bitmap_width, cmap_bitmap_height,
np.uint8(np.round(rgba*255)))
self.image_axes = self.image_figure.add_axes([0., 0., 1., 1.])
self.image = self._fresh_image()
self.image_axesimage = self.image_axes.imshow(self.image, cmap=self.image_cmap, interpolation='nearest')
self.image_axes.set_ylim(0., self.image_ysize)
_clear_ticks_and_frame_from_axes(self.image_axes)
self.plot_axes = self.plot_figure.add_subplot(111)
x, perfect_y, observed_y = self._fresh_plot_data()
self.plot_axes.plot(x, perfect_y, 'bo', markersize=3.0)
self.plot_axes.plot(x, observed_y, 'm+', markersize=10.0)
def default_traits_view(self):
return View(HGroup(VGroup(Item('image_figure', editor=ResizingMPLFigureEditor(), show_label=False, width=400, height=400),
HGroup(Item('image_ramp_power', label='Power',
editor=TextEditor(auto_set=False, enter_set=True, evaluate=float))),
show_border=False),
VGroup(HGroup(Item('plot_figure', editor=NonResizingMPLFigureEditor(), show_label=False,
width=400, height=200)),
Item('plot_num_points'),
Item('plot_coeff0'),
Item('plot_coeff1'),
Item('plot_gain'),
show_border=False)),
resizable=True,
statusbar = [StatusItem(name = 'status_string_left', width = 0.67),
StatusItem(name = 'status_string_right', width = 0.33)],
title='Demo of traitsui with matplotlib and popup menus',
handler=MPLInitHandler)
def _fresh_image(self):
return (np.outer(np.ones(self.image_ysize), np.arange(self.image_xsize) /
(self.image_xsize - 1.0)))**self.image_ramp_power
def _fresh_plot_data(self):
x = np.arange(1, self.plot_num_points + 1)
perfect_y = self.plot_coeff0 + self.plot_coeff1 * x
observed_y = [np.random.poisson(y * self.plot_gain) / self.plot_gain for y in perfect_y]
return x, perfect_y, observed_y
@on_trait_change('plot_num_points, plot_coeff+, plot_gain')
def update_plot_data(self):
self.plot_axes.lines = []
x, perfect_y, observed_y = self._fresh_plot_data()
self.plot_axes.plot(x, perfect_y, 'bo', markersize=3.0)
self.plot_axes.plot(x, observed_y, 'm+', markersize=10.0)
self.plot_axes.relim()
self.plot_axes.autoscale_view()
self.plot_figure.canvas.draw()
@on_trait_change('image_ramp_power')
def reset_image(self):
self.image = self._fresh_image()
self.image_axesimage.set_data(self.image)
self.image_axesimage.set_clim(0.0, 1.0)
self.image_figure.canvas.draw()
self.status_string_right = 'Image reset with changed power'
def increase_image_noise(self, event):
self.image += np.random.normal(0., 0.1, (self.image_ysize, self.image_xsize))
self.image_axesimage.set_data(self.image)
self.image_axesimage.set_clim(0.0, 1.0)
self.image_figure.canvas.draw()
self.status_string_right = 'Image noise increased'
def decrease_image_noise(self, event):
self.image += self._fresh_image()
self.image /= 2.0
self.image_axesimage.set_data(self.image)
self.image_axesimage.set_clim(0.0, 1.0)
self.image_figure.canvas.draw()
self.status_string_right = 'Image noise decreased'
@on_trait_change('image_cmap')
def update_image_cmap(self):
self.image_axesimage.set_cmap(self.image_cmap)
self.image_figure.canvas.draw()
def setup_mpl_events(self):
self.image_axeswidget = AxesWidget(self.image_axes)
self.image_axeswidget.connect_event('motion_notify_event', self.image_on_motion)
self.image_axeswidget.connect_event('figure_leave_event', self.on_cursor_leave)
self.image_axeswidget.connect_event('figure_enter_event', self.on_cursor_enter)
wx.EVT_RIGHT_DOWN(self.image_figure.canvas, self.on_right_down)
def _append_menu_item(self, menu, wx_id, title, fxn):
if wx_id is None:
wx_id = wx.NewId()
menu.Append(wx_id, title)
wx.EVT_MENU(menu, wx_id, fxn)
def on_right_down(self, event):
if self.image_popup_menu is None:
menu = wx.Menu()
self._append_menu_item(menu, None, "Increase Noise", self.increase_image_noise)
self._append_menu_item(menu, None, "Decrease Noise", self.decrease_image_noise)
menu.AppendSeparator()
image_cmap_submenu = wx.Menu()
for cmap in self.available_cmaps:
menuItem = image_cmap_submenu.AppendCheckItem(self.cmap_to_eventID[cmap], cmap)
wx.EVT_MENU(image_cmap_submenu, self.cmap_to_eventID[cmap], self._on_change_image_cmap_event)
menuItem.SetBitmap(self.cmap_bitmaps[cmap])
menu.AppendMenu(-1, 'Color Maps', image_cmap_submenu)
self.image_popup_menu = menu
for cmap in self.available_cmaps:
self.image_popup_menu.Check(self.cmap_to_eventID[cmap], False)
self.image_popup_menu.Check(self.cmap_to_eventID[self.image_cmap], True)
self.image_figure.canvas.PopupMenuXY(self.image_popup_menu, event.x + 8, event.y + 8)
def _on_change_image_cmap_event(self, event):
self.image_cmap = self.eventID_to_cmap[event.GetId()]
def image_on_motion(self, event):
if event.xdata is None or event.ydata is None:
return
x = int(np.round(event.xdata))
y = int(np.round(event.ydata))
if ((x >= 0) and (x < self.image.shape[1]) and
(y >= 0) and (y < self.image.shape[0])):
imval = self.image[y, x]
self.status_string_left = "x,y={},{} {:.5g}".format(x, y, imval)
else:
self.status_string_left = ""
def on_cursor_leave(self, event):
if hasattr(self, 'saved_cursor') and self.saved_cursor is not None:
self.image_figure.canvas.SetCursor(self.saved_cursor)
self.saved_cursor = None
self.status_string_left = ''
def on_cursor_enter(self, event):
self.saved_cursor = self.image_figure.canvas.GetCursor()
self.image_figure.canvas.SetCursor(wx.StockCursor(wx.CURSOR_CROSS))
if __name__ == "__main__":
Test().configure_traits()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment