Skip to content

Instantly share code, notes, and snippets.

@maxentile
Last active March 27, 2020 14:50
Show Gist options
  • Save maxentile/5497a1a87cf51fb132962ca26462dc0a to your computer and use it in GitHub Desktop.
Save maxentile/5497a1a87cf51fb132962ca26462dc0a to your computer and use it in GitHub Desktop.
implementation of informed discrete proposals following [Zanella, 2017] https://arxiv.org/abs/1711.07424
# Implementation of balanced locally informed proposals, following:
# [Zanella, 2017] "Informed proposals for local MCMC in discrete spaces"
# https://arxiv.org/abs/1711.07424
# Usage: define a neighborhood_fxn(x) that returns an iterable of x's neighbors.
# Where you would have made uninformed proposals using this neighborhood, you
# can now make target-informed proposals using this neighborhood.
# Assumption: neighborhood_fxn is deterministic
# Assumption: if y is in neighborhood_fxn(x), then x is in neighborhood_fxn(y)
from functools import partial
import numpy as np
from scipy.special import logsumexp
from tqdm import tqdm
def uniform_proposal(x, neighborhood_fxn):
"""uniform distribution over x's neighbors"""
neighbors = neighborhood_fxn(x)
log_p_forward = -np.log(len(neighbors))
proposal_log_probs = log_p_forward + np.zeros(len(neighbors))
return neighbors, proposal_log_probs
# balancing function must satisfy g(t) = t g(1/t)
# some balancing functions from paper
def barker(t):
return t / (1 + t)
balancing_functions = [
np.sqrt,
barker,
# ... also min(1, t), max(1, t), ...
]
def informed_proposal(x, proposal_distribution, log_target, balancing_fxn=barker):
"""reweight proposal_distribution over x's neighbors by
balancing_fxn(target(proposal) / target(x))
"""
neighbors, proposal_log_probs = proposal_distribution(x)
log_p_x = log_target(x)
for (i, neighbor) in enumerate(neighbors):
log_p_proposal = log_target(neighbor)
ratio = np.exp(log_p_proposal - log_p_x)
# note: now unnormalized!
proposal_log_probs[i] += np.log(balancing_fxn(ratio))
# normalize
proposal_log_probs -= logsumexp(proposal_log_probs)
return neighbors, proposal_log_probs
def sample_proposal(x, proposal_distribution):
"""sample from proposal, and record forward log proposal probability"""
neighbors, proposal_log_probs = proposal_distribution(x)
proposal_ind = np.random.choice(np.arange(len(neighbors)), p=np.exp(proposal_log_probs))
return neighbors[proposal_ind], proposal_log_probs[proposal_ind]
def mh_step(x, proposal_distribution, log_target):
"""take one metropolis-hastings step"""
# sample forward proposal (proposal | x)
proposal, forward_log_prob = sample_proposal(x, proposal_distribution)
# check reverse proposal probability (x | proposal)
reverse_neighbors, reverse_log_probs = proposal_distribution(proposal)
assert (x in reverse_neighbors)
# compute acceptance probability
log_target_ratio = log_target(proposal) - log_target(x)
log_p_forward_over_reverse = forward_log_prob - reverse_log_probs[reverse_neighbors.index(x)]
A = min(1, np.exp(log_target_ratio - log_p_forward_over_reverse))
# accept / reject
if np.random.rand() < A:
return proposal, True
else:
return x, False
def run_mh(x0, proposal_distribution, log_target, n_steps=10000):
"""take n_steps metropolis-hastings steps, recording trajectory and acceptance fraction"""
mh_traj = [x0]
n_accept = 0
for t in tqdm(range(n_steps)):
next_x, accept = mh_step(mh_traj[-1], proposal_distribution, log_target)
mh_traj.append(next_x)
n_accept += int(accept)
return mh_traj, n_accept / n_steps
if __name__ == '__main__':
# a gaussian log pdf, restricted to integers [-7, ..., +7]
def log_target(x):
if (type(x) != int) or (abs(x) > 7):
return - np.inf
else:
return -(x / 3) ** 2
# neighborhood: successor and predecessor
def int_neighbors(x: int):
return [x - 1, x + 1]
# define naive proposal
naive_int_proposal = partial(
uniform_proposal,
neighborhood_fxn=int_neighbors
)
# define informed proposal
informed_int_proposal = partial(
informed_proposal,
proposal_distribution=naive_int_proposal,
log_target=log_target,
balancing_fxn=barker
)
# run some mcmc
naive_traj, naive_accept_rate = run_mh(0, naive_int_proposal, log_target)
informed_traj, informed_accept_rate = run_mh(0, informed_int_proposal, log_target)
# report acceptance rates
print(f'naive_accept_rate: {naive_accept_rate}')
print(f'informed_accept_rate: {informed_accept_rate}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment