Created
August 3, 2021 09:50
-
-
Save HGangloff/85987b1b80acde534c7f930bb1e9f91c to your computer and use it in GitHub Desktop.
Efficient chromatic Gibbs sampler for a binary Ising Markov random field with Jax jit, vmap and lax.scan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import jax.numpy as jnp | |
import jax | |
from jax import vmap, jit | |
from jax.scipy.signal import convolve2d | |
def color_image_graph(lx, ly, neigh_list): | |
def first_available(color_list): | |
n = 0 | |
while True: | |
if n not in color_list: | |
return n | |
n += 1 | |
colored_graph = np.full((lx, ly), -1) | |
for x in range(lx): | |
for y in range(ly): | |
# wrapping | |
neigh_positions_x = [x+i if x+i < lx else 0+(i-1) | |
for i, _ in neigh_list] | |
neigh_positions_y = [y+j if y+j < ly else 0+(j-1) | |
for _, j in neigh_list] | |
neigh_colors = [colored_graph[n, m] for n, m in | |
zip(neigh_positions_x, neigh_positions_y)] | |
colored_graph[x, y] = first_available(neigh_colors) | |
return colored_graph | |
def gibbs_sampler_one_step(xy, sub_key, f_params): | |
x, y = xy | |
probas = jnp.array([f_params["logits"][0][x, y], | |
f_params["logits"][1][x, y]]) | |
sample_at_xy = jax.random.categorical(sub_key, probas) | |
return sample_at_xy | |
def gibbs_sampler_on_color(update_sites, rng, | |
f_params): | |
def scan_gibbs_sampler_one_step(rng, xy): | |
rng, key = jax.random.split(rng) | |
sample_at_xy = gibbs_sampler_one_step(xy, key, f_params) | |
# rng is what we need for the next site | |
# and we store the sample at xy | |
return rng, sample_at_xy | |
_, samples = jax.lax.scan(scan_gibbs_sampler_one_step, rng, update_sites.T) | |
return samples | |
# we vectorize the code on each color and simultaneously on each key | |
vmap_gibbs_sampler_on_color = jit(vmap(gibbs_sampler_on_color, in_axes=(0, 0, None))) | |
def chromatic_gibbs_sampler(lx, ly, key, init_field, f_params, | |
colored_graph=None, nb_it=100, neigh_list=8): | |
def get_ising_potentials(current_field, f_params): | |
beta = f_params["beta"] | |
# Compute potentials for each site in parallel | |
k_same_nei = np.ones((3, 3)) | |
k_same_nei[1, 1] = 0 | |
# for each site the number of 1 nei | |
same_nei_1 = (convolve2d(current_field, | |
k_same_nei) * 2 * beta) | |
# for each site the number of 0 nei | |
same_nei_0 = (convolve2d(jnp.logical_not(current_field.astype(bool) | |
).astype(int), | |
k_same_nei) * 2 * beta) | |
p0 = (same_nei_0) | |
p1 = (same_nei_1) | |
f_params["logits"] = [p0, p1] | |
return f_params | |
if neigh_list == 4: | |
neigh_list = [[0,-1],[-1,0],[0,1],[1,0]] | |
elif neigh_list == 8: | |
neigh_list = [[0,-1],[-1,-1],[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1]] | |
if colored_graph is None: | |
colored_graph = color_image_graph(lx, ly, neigh_list) | |
color_nb = jnp.amax(colored_graph) + 1 | |
current_field = init_field | |
color_list = np.arange(0, color_nb) | |
colored_graph = np.asarray(colored_graph) | |
update_sites = np.array([np.where(colored_graph == c) | |
for c in color_list]) | |
for k in range(nb_it): | |
print("Gibbs iterations number", k) | |
keys = jax.random.split(key, num=color_nb + 1) | |
key = keys[0] | |
f_params = get_ising_potentials(current_field, f_params) | |
samples_each_color = vmap_gibbs_sampler_on_color(update_sites, | |
keys[1:], f_params) | |
for c in color_list: | |
current_field[update_sites[c][0], update_sites[c][1]] = samples_each_color[c] | |
return current_field, colored_graph | |
if __name__ == "__main__": | |
lx, ly = 512, 512 | |
np.random.seed(0) | |
init_field = np.random.randint(0, 2, size=(lx, ly)) | |
rng = jax.random.PRNGKey(1) | |
rng, key = jax.random.split(rng) | |
f_params = {"beta": .5} | |
ising_field, colored_graph = chromatic_gibbs_sampler(lx, ly, rng, | |
init_field, f_params, colored_graph=None, nb_it=100, neigh_list=8) | |
fig, axes = plt.subplots(1, 2) | |
axes[0].imshow(colored_graph) | |
axes[1].imshow(ising_field) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment