Skip to content

Instantly share code, notes, and snippets.

@sidravi1
Created June 15, 2019 05:19
Show Gist options
  • Save sidravi1/a7965d57c63e71f9b9ff47098cd774df to your computer and use it in GitHub Desktop.
Save sidravi1/a7965d57c63e71f9b9ff47098cd774df to your computer and use it in GitHub Desktop.
HMC animation
import autograd.numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib as mpl
import seaborn as sns
from minimc.minimc.minimc_slow import hamiltonian_monte_carlo as hmc_slow
from minimc.minimc import neg_log_normal, mixture
FIGSIZE = (10, 7)
mixture_params = [(0, 0.1), (0.5, 0.2), (-0.5, 0.2)]
n_mix = len(mixture_params)
mixture_norm = [neg_log_normal(param[0], param[1]) for param in mixture_params]
p_mix = [1/n_mix] * n_mix
mixture_logp = mixture(mixture_norm, p_mix)
samples, positions, momentums, accepted = hmc_slow(50, mixture_logp,
initial_position=0.,
path_len=1.0,
step_size=0.01)
pos_vec, mom_vec = np.hstack(positions), np.hstack(momentums)
np.random.seed(100)
def init():
ax.set_ylim(-3.5, 3.5)
ax.set_xlim(-1.0, 1.0)
line.set_data([], [])
star.set_data([], [])
sample_points.set_data([], [])
selected_pos_mom.set_data([], [])
return line, star
def run(data):
# update the data
p, m, s, pm = data
line.set_data(p, m)
sample_points.set_data(s, [-3.3]*len(s))
if pm.shape[0] > 0:
selected_pos_mom.set_data(pm[:, 0], pm[:, 1])
if len(p) > 0:
star.set_data([p[-1], m[-1]])
else:
star.set_data(p, m)
return line, star
def data_gen():
cnt = 0
i = 0
sample_selected = []
sample_pos_mom = []
while cnt < mom_vec.shape[0]:
low_val = np.max(cnt-100, 0)
if (cnt > 0) and (pos_vec[cnt] == pos_vec[cnt-1]):
sample_selected.append(samples[i])
sample_pos_mom.append([pos_vec[cnt], mom_vec[cnt]])
i += 1
cnt += 1
yield pos_vec[low_val:cnt], mom_vec[low_val:cnt],\
sample_selected, np.array(sample_pos_mom)
with plt.style.context('Solarize_Light2'):
print(mpl.__version__)
fig, ax = plt.subplots()
mus = [x[0] for x in mixture_params]
sds = [x[1] for x in mixture_params]
ax.set_xlabel("Position")
ax.set_ylabel("Momentum")
ax.grid(ls="--")
actual_samples = st.norm(mus, sds).rvs([10000, n_mix])
ax2 = ax.twinx()
ax2.grid(False)
sns.kdeplot(actual_samples.ravel(), shade=True, ax=ax2)
line, = ax.plot([], [], lw=2, color='firebrick')
star, = ax.plot([], [], "*r")
sample_points, = ax.plot([], [], ".k")
selected_pos_mom, = ax.plot([], [], "x", color='orange')
ani = animation.FuncAnimation(fig, run, data_gen, blit=False, interval=1,
repeat=False, init_func=init, save_count=mom_vec.shape[0]*1.5)
#ani.save('animation.gif', writer='imagemagick', fps=60)
#ani.save('anim_hmc.mp4', 'ffmpeg', fps=40)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment