Created
November 7, 2022 21:25
-
-
Save roualdes/b5990892e8e9c9b067a0f092809401fa to your computer and use it in GitHub Desktop.
BridgeStan: Rosenbrock
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bridgestan as bs | |
import optax | |
import numpy as np | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
sm = bs.StanModel("rosenbrock_model.so") | |
q = np.asarray([0.0, 2.5]) | |
D = sm.param_unc_num() | |
grad = np.zeros(D) | |
steps = 1000 | |
states = np.zeros(shape = (steps, D)) | |
schedule = optax.warmup_cosine_decay_schedule( | |
init_value=0.0, | |
peak_value=0.5, | |
warmup_steps=50, | |
decay_steps=500, | |
end_value=1e-2, | |
) | |
optimizer = optax.chain( | |
optax.clip(1.0), | |
optax.adamw(learning_rate=schedule), | |
) | |
params = {"rstate": np.zeros(D)} | |
opt_state = optimizer.init(params) | |
for step in range(steps): | |
ld, _ = sm.log_density_gradient(q, out = grad) | |
states[step, :] = q | |
updates, opt_state = optimizer.update({"rstate": grad}, opt_state, params) | |
params = optax.apply_updates(params, updates) | |
q = np.asarray(params["rstate"], dtype = np.float64) | |
if step % (steps//10) == 0: | |
print(f"step {step}, log density: {ld}") | |
X = np.arange(-2, 2, 0.05) | |
Y = np.arange(-1, 3, 0.05) | |
x, y = np.meshgrid(X, Y) | |
z = [sm.log_density(np.asarray([x, y])) for x,y in zip(x.flatten(), y.flatten())] | |
l = np.asarray(z).reshape(np.shape(x)) | |
idx = np.arange(0, steps, 1); | |
plt.cla() | |
plt.contour(x, y, l, np.logspace(-4, 3, 19), | |
cmap = cm.coolwarm, alpha = 0.5); | |
plt.plot(states[idx,0], states[idx,1], color = "black", | |
label = "Adam w/ weight decay"); | |
plt.scatter(states[idx,0], states[idx,1], color = "black", label = "", | |
s = 10); | |
plt.title("Rosenbrock function"); | |
plt.xlabel("x"); | |
plt.ylabel("y"); | |
plt.legend(loc = "lower left"); | |
plt.savefig("rosenbrock.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment