Skip to content

Instantly share code, notes, and snippets.

@Allgoerithm
Last active December 11, 2019 05:48
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Allgoerithm/8225b69f33aff56e56e47687dfc9e6bb to your computer and use it in GitHub Desktop.
Save Allgoerithm/8225b69f33aff56e56e47687dfc9e6bb to your computer and use it in GitHub Desktop.
# observed data
total_rock = tf.constant(5., tf.float32)
total_paper = tf.constant(0., tf.float32)
total_scissors = tf.constant(0., tf.float32)
# define some constants
number_of_steps = 10000
burnin = 5000
# Set the chain's start state
initial_chain_state = [
1/3 * tf.ones([], dtype=tf.float32, name="init_p_rock"),
1/3 * tf.ones([], dtype=tf.float32, name="init_p_paper")
]
# for trainsforming contrained parameter space (in this case, [0, 1] for each parameter) to unconstrained real numbers
unconstraining_bijectors = [
tfp.bijectors.Sigmoid(), # bijector for p_rock
tfp.bijectors.Sigmoid() # bijector for p_paper
]
# fix data to the actually observed values in the joint log prob to optimize only over the unkown model parameters,
# and convert to tensorflow function for speedup
joint_log_prob_for_opt = tf.function(func=lambda x, y: joint_log_prob(total_rock=total_rock,
total_paper=total_paper,
total_scissors=total_scissors,
p_rock=x, p_paper=y),
input_signature=2 * (tf.TensorSpec(shape=[], dtype=tf.float32),)
)
# define the Hamilton markov chain by successively adding wrappers to an inner kernel
kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=joint_log_prob_for_opt,
num_leapfrog_steps=2,
step_size=tf.constant(0.5, dtype=tf.float32),
state_gradients_are_stopped=True
)
kernel=tfp.mcmc.TransformedTransitionKernel(
inner_kernel=kernel,
bijector=unconstraining_bijectors
)
kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=kernel,
num_adaptation_steps=int(burnin * 0.8)
)
# sample from the chain
[
posterior_p_rock,
posterior_p_paper
], kernel_results = tfp.mcmc.sample_chain(
num_results=number_of_steps,
num_burnin_steps=burnin,
current_state=initial_chain_state,
trace_fn=lambda _, kernel_results: kernel_results.inner_results.inner_results.is_accepted,
kernel=kernel)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment