Skip to content

Instantly share code, notes, and snippets.

@el-hult
Created February 26, 2022 15:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save el-hult/327d93397d70c3fa6a404319e3ae2021 to your computer and use it in GitHub Desktop.
Save el-hult/327d93397d70c3fa6a404319e3ae2021 to your computer and use it in GitHub Desktop.
A code sample on how to create a interactive visualization in matplotlib with the matplotlib widgets. In this case, it shows how a quadratic with unknown location can be upper bounded.
# %% set up the plot
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Slider
LOC_MAX = 2
# The parametrized function to be plotted
def f1(t, amplitude, location):
return amplitude * (t - location) ** 2
def f2(t, amplitude, location):
return amplitude * (np.abs(t) + LOC_MAX) ** 2
t = np.linspace(-LOC_MAX - 1, LOC_MAX + 1, 1000)
# Define initial parameters
init_amplitude = 1
init_location = 0
# Create the figure and the line that we will manipulate
fig, ax = plt.subplots()
(line1,) = plt.plot(t, f1(t, init_amplitude, init_location), lw=2, label="True")
(line2,) = plt.plot(t, f2(t, init_amplitude, init_location), lw=2, label="Bound")
ax.set_xlabel("Time [s]")
slider_bg_color = "lightgoldenrodyellow"
ax.margins(x=0)
# adjust the main plot to make room for the sliders
plt.subplots_adjust(left=0.25, bottom=0.25)
# Make a horizontal slider to control the frequency.
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=slider_bg_color)
location_slider = Slider(
ax=axfreq,
label="Minimum location",
valmin=-LOC_MAX,
valmax=LOC_MAX,
valinit=init_location,
)
# Make a vertically oriented slider to control the amplitude
axamp = plt.axes([0.1, 0.25, 0.0225, 0.63], facecolor=slider_bg_color)
amp_slider = Slider(
ax=axamp,
label="Amplitude",
valmin=0,
valmax=10,
valinit=init_amplitude,
orientation="vertical",
)
# The function to be called anytime a slider's value changes
def update(val):
line1.set_ydata(f1(t, amp_slider.val, location_slider.val))
line2.set_ydata(f2(t, amp_slider.val, location_slider.val))
fig.canvas.draw_idle()
# register the update function with each slider
location_slider.on_changed(update)
amp_slider.on_changed(update)
def reset(event):
location_slider.reset()
amp_slider.reset()
plt.show()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment