Instantly share code, notes, and snippets.
Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each))
28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Running this instead:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
tfd = tfp.distributions
num_burnin_steps = 500
num_results = 500
num_chains = 256
init_step_size = tf.fill([num_chains], 0.25)
# Define a batch of normal logpdf's: 1 for each chain
target_log_prob_fn_batched = tfd.Normal(loc=tf.fill([num_chains],0.),
nkernel = tfp.mcmc.HamiltonianMonteCarlo(
adapt_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
samples, stats = tfp.mcmc.sample_chain(
trace_fn=lambda _, pkr: [pkr.new_step_size,
return samples, stats
samples, stats = sampler(startvals_)
Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.
Sorry, something went wrong.