Skip to content

Instantly share code, notes, and snippets.

@ColCarroll
Created July 28, 2019 15:12
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@roblem
Copy link

roblem commented Feb 26, 2020

Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each))

Running this instead:

import tensorflow as tf
import tensorflow_probability as tfp

import matplotlib.pyplot as plt
tfd = tfp.distributions

num_burnin_steps = 500
num_results = 500
num_chains = 256

init_step_size = tf.fill([num_chains],  0.25)
startvals_=tf.zeros(num_chains)

# Define a batch of normal logpdf's: 1 for each chain
target_log_prob_fn_batched = tfd.Normal(loc=tf.fill([num_chains],0.), 
                                        scale=tf.fill([num_chains],0.1)).log_prob

@tf.function
def sampler(initial_state):
    
    nkernel = tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn_batched,
      num_leapfrog_steps=8,
      step_size=init_step_size)
    
    adapt_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
      inner_kernel=nkernel, num_adaptation_steps=num_burnin_steps, 
        target_accept_prob=0.6)

    samples, stats = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=initial_state,
      kernel=adapt_kernel,
      trace_fn=lambda _, pkr: [pkr.new_step_size,
                           pkr.inner_results.log_accept_ratio])
    return samples, stats

samples, stats = sampler(startvals_)

Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment