You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each))
Running this instead:
importtensorflowastfimporttensorflow_probabilityastfpimportmatplotlib.pyplotasplttfd=tfp.distributionsnum_burnin_steps=500num_results=500num_chains=256init_step_size=tf.fill([num_chains], 0.25)
startvals_=tf.zeros(num_chains)
# Define a batch of normal logpdf's: 1 for each chaintarget_log_prob_fn_batched=tfd.Normal(loc=tf.fill([num_chains],0.),
scale=tf.fill([num_chains],0.1)).log_prob@tf.functiondefsampler(initial_state):
nkernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn_batched,
num_leapfrog_steps=8,
step_size=init_step_size)
adapt_kernel=tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=nkernel, num_adaptation_steps=num_burnin_steps,
target_accept_prob=0.6)
samples, stats=tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=adapt_kernel,
trace_fn=lambda_, pkr: [pkr.new_step_size,
pkr.inner_results.log_accept_ratio])
returnsamples, statssamples, stats=sampler(startvals_)
Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.
Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (
28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
)Running this instead:
Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.