Created
February 4, 2020 00:01
-
-
Save skuschel/ca31559327f7e3eb9dee47e789e50b74 to your computer and use it in GitHub Desktop.
watch a function while its beeing evaluated
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
# Stephan Kuschel, May 2019 | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def addcolorbar(ax, im, pos='right', size='5%', pad=0.05, orientation='vertical', | |
stub=False, max_ticks=None, label=None): | |
''' | |
add a colorbar to a matplotlib image. | |
ax -- the axis object the image is drawn in | |
im -- the image (return value of ax.imshow(...)) | |
When changed, please update: | |
https://gist.github.com/skuschel/85f0645bd6e37509164510290435a85a | |
Stephan Kuschel, 2018 | |
''' | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
import matplotlib.pyplot as plt | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes(pos, size=size, pad=pad) | |
if stub: | |
cax.set_visible(False) | |
return cax | |
cb = plt.colorbar(im, cax=cax, orientation=orientation) | |
if max_ticks is not None: | |
from matplotlib import ticker | |
tick_locator = ticker.MaxNLocator(nbins=max_ticks) | |
cb.locator = tick_locator | |
cb.update_ticks() | |
if label is not None: | |
cb.set_label(label) | |
return cax | |
class watchit(): | |
''' | |
watch a function whenever its been evaluated. | |
Stephan Kuschel, 2019 | |
''' | |
def __init__(self, f): | |
self.fig, self.ax = plt.subplots() | |
self.f = f | |
self.xs = [] | |
self.ys = [] | |
self.zs = [] | |
self.im = None | |
def __call__(self, x, y): | |
ret = self.f(x, y) | |
self.xs.append(x) | |
self.ys.append(y) | |
self.zs.append(ret) | |
if len(self.xs) > 3: | |
self.plotdata() | |
return ret | |
def plotdata(self): | |
from scipy.interpolate import griddata | |
xi = np.linspace(np.min(self.xs), np.max(self.xs),50, endpoint=True) | |
yi = np.linspace(np.min(self.ys), np.max(self.ys),50, endpoint=True) | |
zi = griddata((self.xs, self.ys), self.zs, (xi[None,:], yi[:,None]), method='linear') | |
ext = np.min(xi), np.max(xi), np.min(yi), np.max(yi) | |
if self.im is None: | |
self.im = self.ax.imshow(zi, extent=ext, origin='lower', aspect='auto') | |
self.cbar = addcolorbar(self.ax, self.im) | |
self.fig.tight_layout() | |
else: | |
self.im.set_data(zi) | |
self.im.set_extent(ext) | |
self.im.autoscale() | |
self.ax.plot(self.xs, self.ys, 'o', c='r') | |
self.fig.show() | |
self.fig.canvas.draw() | |
self.fig.canvas.flush_events() | |
#self.fig.canvas.draw_idle() | |
#self.fig.show() | |
#plt.draw() | |
#plt.pause(0.01) | |
def test(): | |
import time | |
@watchit | |
def testfunc(x, y): | |
return np.sin(3*x+ 6*y) | |
while True: | |
v = testfunc(np.random.random(), np.random.random()) | |
print(v) | |
time.sleep(0.3) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment