Skip to content

Instantly share code, notes, and snippets.

@ColCarroll
Created July 28, 2019 15:12
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ColCarroll/17c7fb6da0b8e3a32996ffa3c8826d46 to your computer and use it in GitHub Desktop.
Save ColCarroll/17c7fb6da0b8e3a32996ffa3c8826d46 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@roblem
Copy link

roblem commented Feb 26, 2020

Thanks for this gist and the post on your blog. Invaluable for getting me thinking about vectorized code and how the shape of the initial state drives the number of sampled chains in tensorflow. One thing I noticed was that at least on my machine (on cpu), the code as-is only executed on 1 processor and takes some time to run (28.9 s ± 5.74 s per loop (mean ± std. dev. of 7 runs, 1 loop each))

Running this instead:

import tensorflow as tf
import tensorflow_probability as tfp

import matplotlib.pyplot as plt
tfd = tfp.distributions

num_burnin_steps = 500
num_results = 500
num_chains = 256

init_step_size = tf.fill([num_chains],  0.25)
startvals_=tf.zeros(num_chains)

# Define a batch of normal logpdf's: 1 for each chain
target_log_prob_fn_batched = tfd.Normal(loc=tf.fill([num_chains],0.), 
                                        scale=tf.fill([num_chains],0.1)).log_prob

@tf.function
def sampler(initial_state):
    
    nkernel = tfp.mcmc.HamiltonianMonteCarlo(
      target_log_prob_fn=target_log_prob_fn_batched,
      num_leapfrog_steps=8,
      step_size=init_step_size)
    
    adapt_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
      inner_kernel=nkernel, num_adaptation_steps=num_burnin_steps, 
        target_accept_prob=0.6)

    samples, stats = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=initial_state,
      kernel=adapt_kernel,
      trace_fn=lambda _, pkr: [pkr.new_step_size,
                           pkr.inner_results.log_accept_ratio])
    return samples, stats

samples, stats = sampler(startvals_)

Runs in 3.5ish seconds. I know the intent of the post wasn't fast code, but all the same I wanted to point this out for others maybe stumbling on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment