Last active
August 29, 2015 14:05
-
-
Save LeoHuckvale/4bc76d16eabdbc77c933 to your computer and use it in GitHub Desktop.
Select corresponding points in multiple subplots
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
class MultiPlotSelector: | |
def __init__(self, data, fig, ax): | |
""" | |
data in paired columns per subplot, i.e. M x (2 x N) | |
len(ax) == M | |
fig is parent figure to redraw on selection | |
""" | |
self.fig = fig | |
self.ax = ax | |
self.data = data | |
# Initialise line-selection, line-data dictionaries | |
self.linesel = {} | |
self.linedata = {} | |
# Plot data | |
self.plot() | |
def plot(self): | |
for axi, datai in zip(self.ax, self.data): | |
line, = axi.plot(datai[0], datai[1], 'ko', ms=5, picker=5) | |
selected, = axi.plot(0, 0, 'yo', ms=12, alpha=0.4, visible=False) | |
self.linesel[line] = selected | |
self.linedata[line] = datai | |
def onpick(self, event): | |
if event.artist not in self.linesel.keys(): return True | |
N = len(event.ind) | |
if not N: return True | |
x = event.mouseevent.xdata | |
y = event.mouseevent.ydata | |
dx = x-self.linedata[event.artist][0][event.ind] | |
dy = y-self.linedata[event.artist][1][event.ind] | |
distances = np.hypot(dx, dy) | |
indmin = distances.argmin() | |
dataind = event.ind[indmin] | |
for dat, sel in zip(self.linedata.values(), self.linesel.values()): | |
sel.set_data(dat[0][dataind], dat[1][dataind]) | |
sel.set_visible(True) | |
self.fig.canvas.draw() | |
return dataind | |
class Foo: | |
def __init__(self, data): | |
self.data = data | |
def onselect(self, dataind): | |
print self.data[:, :, dataind] | |
if __name__ == '__main__': | |
# Set up your figure and subplots | |
fig, ax = plt.subplots(3, 1, figsize=(5, 5)) | |
# Example data in paired columns per subplot | |
data = np.random.rand(3, 2, 20) | |
# Point selector instance | |
selector = MultiPlotSelector(data, fig, ax) | |
# Some other class to do stuff with selected data | |
foo = Foo(data) | |
def onpick(event): | |
dataind = selector.onpick(event) | |
# Example function to do stuff with selection | |
foo.onselect(dataind) | |
fig.canvas.mpl_connect('pick_event', onpick) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment