Last active
December 11, 2019 05:48
-
-
Save Allgoerithm/8225b69f33aff56e56e47687dfc9e6bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# observed data | |
total_rock = tf.constant(5., tf.float32) | |
total_paper = tf.constant(0., tf.float32) | |
total_scissors = tf.constant(0., tf.float32) | |
# define some constants | |
number_of_steps = 10000 | |
burnin = 5000 | |
# Set the chain's start state | |
initial_chain_state = [ | |
1/3 * tf.ones([], dtype=tf.float32, name="init_p_rock"), | |
1/3 * tf.ones([], dtype=tf.float32, name="init_p_paper") | |
] | |
# for trainsforming contrained parameter space (in this case, [0, 1] for each parameter) to unconstrained real numbers | |
unconstraining_bijectors = [ | |
tfp.bijectors.Sigmoid(), # bijector for p_rock | |
tfp.bijectors.Sigmoid() # bijector for p_paper | |
] | |
# fix data to the actually observed values in the joint log prob to optimize only over the unkown model parameters, | |
# and convert to tensorflow function for speedup | |
joint_log_prob_for_opt = tf.function(func=lambda x, y: joint_log_prob(total_rock=total_rock, | |
total_paper=total_paper, | |
total_scissors=total_scissors, | |
p_rock=x, p_paper=y), | |
input_signature=2 * (tf.TensorSpec(shape=[], dtype=tf.float32),) | |
) | |
# define the Hamilton markov chain by successively adding wrappers to an inner kernel | |
kernel = tfp.mcmc.HamiltonianMonteCarlo( | |
target_log_prob_fn=joint_log_prob_for_opt, | |
num_leapfrog_steps=2, | |
step_size=tf.constant(0.5, dtype=tf.float32), | |
state_gradients_are_stopped=True | |
) | |
kernel=tfp.mcmc.TransformedTransitionKernel( | |
inner_kernel=kernel, | |
bijector=unconstraining_bijectors | |
) | |
kernel = tfp.mcmc.SimpleStepSizeAdaptation( | |
inner_kernel=kernel, | |
num_adaptation_steps=int(burnin * 0.8) | |
) | |
# sample from the chain | |
[ | |
posterior_p_rock, | |
posterior_p_paper | |
], kernel_results = tfp.mcmc.sample_chain( | |
num_results=number_of_steps, | |
num_burnin_steps=burnin, | |
current_state=initial_chain_state, | |
trace_fn=lambda _, kernel_results: kernel_results.inner_results.inner_results.is_accepted, | |
kernel=kernel) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment