Created
June 5, 2017 18:17
-
-
Save theodoregoetz/a0e1eaa4bf0e9eb5250f313401323913 to your computer and use it in GitHub Desktop.
matplotlib line plot with 1D spline and interactively adjustable control points using wxPython
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib as mpl | |
mpl.use('wxAgg') | |
import numpy as np | |
import wx | |
from matplotlib.backends.backend_wxagg import FigureCanvasWxAgg as Canvas | |
from matplotlib.figure import Figure | |
from numpy import random as rand | |
from scipy import interpolate | |
rand.seed(1) | |
class LinePlot(wx.Panel): | |
def __init__(self, parent): | |
super().__init__(parent) | |
xdpi, ydpi = wx.ScreenDC().GetPPI() | |
self.figure = Figure(figsize=(5, 4), dpi=xdpi) | |
self.axes = self.figure.add_subplot(1,1,1) | |
self.xmin, self.xmax = 0, 10 | |
self.ymin, self.ymax = -10, 20 | |
self.x = np.linspace(self.xmin, self.xmax, 4) | |
self.y = self.x + rand.normal(0, 3, len(self.x)) | |
self.axes.set_xlim(self.xmin, self.xmax) | |
self.axes.set_ylim(self.ymin, self.ymax) | |
self.axes.autoscale(False) | |
self.transAxesInv = self.axes.transAxes.inverted() | |
self.transDataInv = self.axes.transData.inverted() | |
self.plt, = self.axes.plot(self.x, self.y, marker='o') | |
self.canvas = Canvas(self, wx.ID_ANY, self.figure) | |
self.layout() | |
self.connect() | |
def layout(self): | |
vbox = wx.BoxSizer(wx.VERTICAL) | |
vbox.Add(self.canvas, 1, wx.ALL | wx.EXPAND, 0) | |
self.SetSizerAndFit(vbox) | |
def connect(self): | |
self._dragging = False | |
self._index = None | |
self.canvas.mpl_connect('button_press_event', | |
lambda e: self.on_button_press(e)) | |
self.canvas.mpl_connect('motion_notify_event', | |
lambda e: self.on_motion_notify(e)) | |
self.canvas.mpl_connect('button_release_event', | |
lambda e: self.on_button_release(e)) | |
self.canvas.mpl_connect('figure_leave_event', | |
lambda e: self.on_figure_leave(e)) | |
def on_button_press(self, event): | |
if event.inaxes: | |
if event.button == 1: | |
self._cache = {'x': self.x.copy(), 'y': self.y.copy()} | |
self._index = self.pick(event.x, event.y, event.xdata, event.ydata) | |
self._dragging = True | |
if self._index is None: | |
self.add(event.xdata, event.ydata) | |
else: | |
self.on_motion_notify(event) | |
elif event.button == 3: | |
idx = self.pick(event.x, event.y, event.xdata, event.ydata) | |
if idx is not None: | |
self.delete(idx) | |
def add(self, x, y): | |
if self._dragging: | |
self._index = 0 | |
self.x = np.concatenate([[x], self.x]) | |
self.y = np.concatenate([[y], self.y]) | |
self.update() | |
def delete(self, idx): | |
if idx == 0: | |
if self.x[idx] == self.x[idx + 1]: | |
self.x = np.delete(self.x, idx) | |
self.y = np.delete(self.y, idx) | |
self.update() | |
elif idx == len(self.x) - 1: | |
if self.x[idx] == self.x[idx - 1]: | |
self.x = np.delete(self.x, idx) | |
self.y = np.delete(self.y, idx) | |
self.update() | |
else: | |
self.x = np.delete(self.x, idx) | |
self.y = np.delete(self.y, idx) | |
self.update() | |
def on_motion_notify(self, event): | |
if self._dragging: | |
if self._index is not None: | |
if event.inaxes: | |
x, y = event.xdata, event.ydata | |
else: | |
x, y = self.transDataInv.transform([event.x, event.y]) | |
x = min(max(x, self.xmin), self.xmax) | |
y = min(max(y, self.ymin), self.ymax) | |
self.x[self._index] = x | |
self.y[self._index] = y | |
self.update() | |
def on_button_release(self, event): | |
if self._dragging: | |
self.on_motion_notify(event) | |
self._dragging = False | |
self._index = None | |
def on_figure_leave(self, event): | |
if self._dragging: | |
self._dragging = False | |
self._index = None | |
self.x = self._cache['x'].copy() | |
self.y = self._cache['y'].copy() | |
self.update() | |
del self._cache | |
def pick(self, x, y, xdata, ydata): | |
idx = ((self.x - xdata)**2 + (self.y - ydata)**2).argmin() | |
xpx, ypx = self.axes.transData.transform([self.x[idx], self.y[idx]]) | |
distsq = (xpx - x)**2 + (ypx - y)**2 | |
if distsq < 100: | |
return idx | |
def update_points(self): | |
if any(self.x[:-1] > self.x[1:]): | |
order = np.argsort(self.x) | |
if self._dragging: | |
self._index = np.argwhere(order == self._index).flat[0] | |
self.x = self.x[order] | |
self.y = self.y[order] | |
if self.x[0] != self.xmin: | |
self.x = np.concatenate([[self.xmin], self.x]) | |
self.y = np.concatenate([self._cache['y'][:1], self.y]) | |
if self._dragging: | |
if self._index is not None: | |
self._index += 1 | |
if self.x[-1] != self.xmax: | |
self.x = np.concatenate([self.x, [self.xmax]]) | |
self.y = np.concatenate([self.y, self._cache['y'][-1:]]) | |
self.plt.set_data(self.x, self.y) | |
def update(self): | |
self.update_points() | |
self.canvas.draw() | |
def OnPaint(self, event): | |
self.canvas.draw() | |
event.Skip() | |
class SplinePlot(LinePlot): | |
def __init__(self, parent): | |
super().__init__(parent) | |
self.spline_x = np.linspace(self.xmin, self.xmax, 200) | |
self.spline_plt, = self.axes.plot(self.spline_x, np.zeros(self.spline_x.shape)) | |
self.update_spline() | |
def update_spline(self): | |
self.spline = interpolate.UnivariateSpline(self.x, self.y, k=3, s=0) | |
self.spline_plt.set_data(self.spline_x, self.spline(self.spline_x)) | |
def update(self): | |
self.update_spline() | |
super().update() | |
class MainFrame(wx.Frame): | |
def __init__(self): | |
super().__init__(None, wx.ID_ANY, 'Main Window', size=(500, 400)) | |
self.plot = SplinePlot(self) | |
self._layout() | |
def _layout(self): | |
vbox = wx.BoxSizer(wx.VERTICAL) | |
vbox.Add(self.plot, 1, wx.ALL | wx.EXPAND) | |
self.SetSizer(vbox) | |
self.Layout() | |
class Application(wx.App): | |
def OnInit(self): | |
mainFrame = MainFrame() | |
mainFrame.Show(True) | |
return True | |
if __name__ == '__main__': | |
app = Application(False) | |
app.MainLoop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment