Skip to content

Instantly share code, notes, and snippets.

@HGangloff
Created August 3, 2021 09:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HGangloff/85987b1b80acde534c7f930bb1e9f91c to your computer and use it in GitHub Desktop.
Save HGangloff/85987b1b80acde534c7f930bb1e9f91c to your computer and use it in GitHub Desktop.
Efficient chromatic Gibbs sampler for a binary Ising Markov random field with Jax jit, vmap and lax.scan
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
from jax import vmap, jit
from jax.scipy.signal import convolve2d
def color_image_graph(lx, ly, neigh_list):
def first_available(color_list):
n = 0
while True:
if n not in color_list:
return n
n += 1
colored_graph = np.full((lx, ly), -1)
for x in range(lx):
for y in range(ly):
# wrapping
neigh_positions_x = [x+i if x+i < lx else 0+(i-1)
for i, _ in neigh_list]
neigh_positions_y = [y+j if y+j < ly else 0+(j-1)
for _, j in neigh_list]
neigh_colors = [colored_graph[n, m] for n, m in
zip(neigh_positions_x, neigh_positions_y)]
colored_graph[x, y] = first_available(neigh_colors)
return colored_graph
def gibbs_sampler_one_step(xy, sub_key, f_params):
x, y = xy
probas = jnp.array([f_params["logits"][0][x, y],
f_params["logits"][1][x, y]])
sample_at_xy = jax.random.categorical(sub_key, probas)
return sample_at_xy
def gibbs_sampler_on_color(update_sites, rng,
f_params):
def scan_gibbs_sampler_one_step(rng, xy):
rng, key = jax.random.split(rng)
sample_at_xy = gibbs_sampler_one_step(xy, key, f_params)
# rng is what we need for the next site
# and we store the sample at xy
return rng, sample_at_xy
_, samples = jax.lax.scan(scan_gibbs_sampler_one_step, rng, update_sites.T)
return samples
# we vectorize the code on each color and simultaneously on each key
vmap_gibbs_sampler_on_color = jit(vmap(gibbs_sampler_on_color, in_axes=(0, 0, None)))
def chromatic_gibbs_sampler(lx, ly, key, init_field, f_params,
colored_graph=None, nb_it=100, neigh_list=8):
def get_ising_potentials(current_field, f_params):
beta = f_params["beta"]
# Compute potentials for each site in parallel
k_same_nei = np.ones((3, 3))
k_same_nei[1, 1] = 0
# for each site the number of 1 nei
same_nei_1 = (convolve2d(current_field,
k_same_nei) * 2 * beta)
# for each site the number of 0 nei
same_nei_0 = (convolve2d(jnp.logical_not(current_field.astype(bool)
).astype(int),
k_same_nei) * 2 * beta)
p0 = (same_nei_0)
p1 = (same_nei_1)
f_params["logits"] = [p0, p1]
return f_params
if neigh_list == 4:
neigh_list = [[0,-1],[-1,0],[0,1],[1,0]]
elif neigh_list == 8:
neigh_list = [[0,-1],[-1,-1],[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1]]
if colored_graph is None:
colored_graph = color_image_graph(lx, ly, neigh_list)
color_nb = jnp.amax(colored_graph) + 1
current_field = init_field
color_list = np.arange(0, color_nb)
colored_graph = np.asarray(colored_graph)
update_sites = np.array([np.where(colored_graph == c)
for c in color_list])
for k in range(nb_it):
print("Gibbs iterations number", k)
keys = jax.random.split(key, num=color_nb + 1)
key = keys[0]
f_params = get_ising_potentials(current_field, f_params)
samples_each_color = vmap_gibbs_sampler_on_color(update_sites,
keys[1:], f_params)
for c in color_list:
current_field[update_sites[c][0], update_sites[c][1]] = samples_each_color[c]
return current_field, colored_graph
if __name__ == "__main__":
lx, ly = 512, 512
np.random.seed(0)
init_field = np.random.randint(0, 2, size=(lx, ly))
rng = jax.random.PRNGKey(1)
rng, key = jax.random.split(rng)
f_params = {"beta": .5}
ising_field, colored_graph = chromatic_gibbs_sampler(lx, ly, rng,
init_field, f_params, colored_graph=None, nb_it=100, neigh_list=8)
fig, axes = plt.subplots(1, 2)
axes[0].imshow(colored_graph)
axes[1].imshow(ising_field)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment