Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (
Running this instead:
import tensorflow as tf import tensorflow_probability as tfp import matplotlib.pyplot as plt tfd = tfp.distributions num_burnin_steps = 500 num_results = 500 num_chains = 256 init_step_size = tf.fill([num_chains], 0.25) startvals_=tf.zeros(num_chains) # Define a batch of normal logpdf's: 1 for each chain target_log_prob_fn_batched = tfd.Normal(loc=tf.fill([num_chains],0.), scale=tf.fill([num_chains],0.1)).log_prob @tf.function def sampler(initial_state): nkernel = tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn_batched, num_leapfrog_steps=8, step_size=init_step_size) adapt_kernel = tfp.mcmc.SimpleStepSizeAdaptation( inner_kernel=nkernel, num_adaptation_steps=num_burnin_steps, target_accept_prob=0.6) samples, stats = tfp.mcmc.sample_chain( num_results=num_results, num_burnin_steps=num_burnin_steps, current_state=initial_state, kernel=adapt_kernel, trace_fn=lambda _, pkr: [pkr.new_step_size, pkr.inner_results.log_accept_ratio]) return samples, stats samples, stats = sampler(startvals_)
Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.