Skip to content

Instantly share code, notes, and snippets.

@roualdes
Created November 7, 2022 21:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save roualdes/b5990892e8e9c9b067a0f092809401fa to your computer and use it in GitHub Desktop.
Save roualdes/b5990892e8e9c9b067a0f092809401fa to your computer and use it in GitHub Desktop.
BridgeStan: Rosenbrock
import bridgestan as bs
import optax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib import cm
sm = bs.StanModel("rosenbrock_model.so")
q = np.asarray([0.0, 2.5])
D = sm.param_unc_num()
grad = np.zeros(D)
steps = 1000
states = np.zeros(shape = (steps, D))
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.5,
warmup_steps=50,
decay_steps=500,
end_value=1e-2,
)
optimizer = optax.chain(
optax.clip(1.0),
optax.adamw(learning_rate=schedule),
)
params = {"rstate": np.zeros(D)}
opt_state = optimizer.init(params)
for step in range(steps):
ld, _ = sm.log_density_gradient(q, out = grad)
states[step, :] = q
updates, opt_state = optimizer.update({"rstate": grad}, opt_state, params)
params = optax.apply_updates(params, updates)
q = np.asarray(params["rstate"], dtype = np.float64)
if step % (steps//10) == 0:
print(f"step {step}, log density: {ld}")
X = np.arange(-2, 2, 0.05)
Y = np.arange(-1, 3, 0.05)
x, y = np.meshgrid(X, Y)
z = [sm.log_density(np.asarray([x, y])) for x,y in zip(x.flatten(), y.flatten())]
l = np.asarray(z).reshape(np.shape(x))
idx = np.arange(0, steps, 1);
plt.cla()
plt.contour(x, y, l, np.logspace(-4, 3, 19),
cmap = cm.coolwarm, alpha = 0.5);
plt.plot(states[idx,0], states[idx,1], color = "black",
label = "Adam w/ weight decay");
plt.scatter(states[idx,0], states[idx,1], color = "black", label = "",
s = 10);
plt.title("Rosenbrock function");
plt.xlabel("x");
plt.ylabel("y");
plt.legend(loc = "lower left");
plt.savefig("rosenbrock.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment