Skip to content

Instantly share code, notes, and snippets.

@theodoregoetz
Created June 5, 2017 18:17
Show Gist options
  • Save theodoregoetz/a0e1eaa4bf0e9eb5250f313401323913 to your computer and use it in GitHub Desktop.
Save theodoregoetz/a0e1eaa4bf0e9eb5250f313401323913 to your computer and use it in GitHub Desktop.
matplotlib line plot with 1D spline and interactively adjustable control points using wxPython
import matplotlib as mpl
mpl.use('wxAgg')
import numpy as np
import wx
from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as Canvas
from matplotlib.figure import Figure
from numpy import random as rand
from scipy import interpolate
rand.seed(1)
class LinePlot(wx.Panel):
def __init__(self, parent):
super().__init__(parent)
xdpi, ydpi = wx.ScreenDC().GetPPI()
self.figure = Figure(figsize=(5, 4), dpi=xdpi)
self.axes = self.figure.add_subplot(1,1,1)
self.xmin, self.xmax = 0, 10
self.ymin, self.ymax = -10, 20
self.x = np.linspace(self.xmin, self.xmax, 4)
self.y = self.x + rand.normal(0, 3, len(self.x))
self.axes.set_xlim(self.xmin, self.xmax)
self.axes.set_ylim(self.ymin, self.ymax)
self.axes.autoscale(False)
self.transAxesInv = self.axes.transAxes.inverted()
self.transDataInv = self.axes.transData.inverted()
self.plt, = self.axes.plot(self.x, self.y, marker='o')
self.canvas = Canvas(self, wx.ID_ANY, self.figure)
self.layout()
self.connect()
def layout(self):
vbox = wx.BoxSizer(wx.VERTICAL)
vbox.Add(self.canvas, 1, wx.ALL | wx.EXPAND, 0)
self.SetSizerAndFit(vbox)
def connect(self):
self._dragging = False
self._index = None
self.canvas.mpl_connect('button_press_event',
lambda e: self.on_button_press(e))
self.canvas.mpl_connect('motion_notify_event',
lambda e: self.on_motion_notify(e))
self.canvas.mpl_connect('button_release_event',
lambda e: self.on_button_release(e))
self.canvas.mpl_connect('figure_leave_event',
lambda e: self.on_figure_leave(e))
def on_button_press(self, event):
if event.inaxes:
if event.button == 1:
self._cache = {'x': self.x.copy(), 'y': self.y.copy()}
self._index = self.pick(event.x, event.y, event.xdata, event.ydata)
self._dragging = True
if self._index is None:
self.add(event.xdata, event.ydata)
else:
self.on_motion_notify(event)
elif event.button == 3:
idx = self.pick(event.x, event.y, event.xdata, event.ydata)
if idx is not None:
self.delete(idx)
def add(self, x, y):
if self._dragging:
self._index = 0
self.x = np.concatenate([[x], self.x])
self.y = np.concatenate([[y], self.y])
self.update()
def delete(self, idx):
if idx == 0:
if self.x[idx] == self.x[idx + 1]:
self.x = np.delete(self.x, idx)
self.y = np.delete(self.y, idx)
self.update()
elif idx == len(self.x) - 1:
if self.x[idx] == self.x[idx - 1]:
self.x = np.delete(self.x, idx)
self.y = np.delete(self.y, idx)
self.update()
else:
self.x = np.delete(self.x, idx)
self.y = np.delete(self.y, idx)
self.update()
def on_motion_notify(self, event):
if self._dragging:
if self._index is not None:
if event.inaxes:
x, y = event.xdata, event.ydata
else:
x, y = self.transDataInv.transform([event.x, event.y])
x = min(max(x, self.xmin), self.xmax)
y = min(max(y, self.ymin), self.ymax)
self.x[self._index] = x
self.y[self._index] = y
self.update()
def on_button_release(self, event):
if self._dragging:
self.on_motion_notify(event)
self._dragging = False
self._index = None
def on_figure_leave(self, event):
if self._dragging:
self._dragging = False
self._index = None
self.x = self._cache['x'].copy()
self.y = self._cache['y'].copy()
self.update()
del self._cache
def pick(self, x, y, xdata, ydata):
idx = ((self.x - xdata)**2 + (self.y - ydata)**2).argmin()
xpx, ypx = self.axes.transData.transform([self.x[idx], self.y[idx]])
distsq = (xpx - x)**2 + (ypx - y)**2
if distsq < 100:
return idx
def update_points(self):
if any(self.x[:-1] > self.x[1:]):
order = np.argsort(self.x)
if self._dragging:
self._index = np.argwhere(order == self._index).flat[0]
self.x = self.x[order]
self.y = self.y[order]
if self.x[0] != self.xmin:
self.x = np.concatenate([[self.xmin], self.x])
self.y = np.concatenate([self._cache['y'][:1], self.y])
if self._dragging:
if self._index is not None:
self._index += 1
if self.x[-1] != self.xmax:
self.x = np.concatenate([self.x, [self.xmax]])
self.y = np.concatenate([self.y, self._cache['y'][-1:]])
self.plt.set_data(self.x, self.y)
def update(self):
self.update_points()
self.canvas.draw()
def OnPaint(self, event):
self.canvas.draw()
event.Skip()
class SplinePlot(LinePlot):
def __init__(self, parent):
super().__init__(parent)
self.spline_x = np.linspace(self.xmin, self.xmax, 200)
self.spline_plt, = self.axes.plot(self.spline_x, np.zeros(self.spline_x.shape))
self.update_spline()
def update_spline(self):
self.spline = interpolate.UnivariateSpline(self.x, self.y, k=3, s=0)
self.spline_plt.set_data(self.spline_x, self.spline(self.spline_x))
def update(self):
self.update_spline()
super().update()
class MainFrame(wx.Frame):
def __init__(self):
super().__init__(None, wx.ID_ANY, 'Main Window', size=(500, 400))
self.plot = SplinePlot(self)
self._layout()
def _layout(self):
vbox = wx.BoxSizer(wx.VERTICAL)
vbox.Add(self.plot, 1, wx.ALL | wx.EXPAND)
self.SetSizer(vbox)
self.Layout()
class Application(wx.App):
def OnInit(self):
mainFrame = MainFrame()
mainFrame.Show(True)
return True
if __name__ == '__main__':
app = Application(False)
app.MainLoop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment