Last active
March 7, 2023 17:55
-
-
Save antotocar34/3e6a762df1427a7db6105cfe72f66185 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
from jax import vmap, grad | |
import jax.numpy as jnp | |
from jax.flatten_util import ravel_pytree | |
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_structure | |
import functools | |
from typing import Callable, NamedTuple, Any, Dict, List, Tuple | |
from blackjax.types import Array, PyTree | |
from blackjax.base import Optimizer | |
Position = PyTree | |
State = NamedTuple | |
from jax.flatten_util import ravel_pytree | |
class SVGDState(NamedTuple): | |
particles: PyTree | |
kernel_parameters: Dict[str, Any] | |
opt_state: Any | |
def init(initial_particles: PyTree, kernel_parameters: Dict[str, Any], optimizer: Optimizer) -> SVGDState: | |
""" | |
Initializes Stein Variational Gradient Descent Algorithm. | |
Parameters | |
---------- | |
initial_particles | |
Initial set of particles to start the optimization | |
kernel_paremeters | |
Arguments to the kernel function | |
optimizer | |
Optax compatible optimizer, which conforms to the `Optimizer` protocol | |
""" | |
particle_array = jnp.stack(tree_flatten(initial_particles)[0]).squeeze().T | |
opt_state = optimizer.init(particle_array) | |
return SVGDState(initial_particles, kernel_parameters, opt_state) | |
def step(state: SVGDState, logdensity_fn: Callable, kernel: Callable, optimizer: Optimizer) -> SVGDState: | |
""" | |
Performs one step of Stein Variational Gradient Descent. | |
See Algorithm 1 of [1]. | |
Parameters | |
---------- | |
state | |
SVGDState object containing information about previous iteration | |
logdensity_fn | |
(un-normalized) log densify function of target distribution to take | |
approximate samples from | |
kernel | |
positive semi definite kernel | |
optimizer | |
Optax compatible optimizer, which conforms to the `Optimizer` protocol | |
Returns | |
------- | |
SVGDState containing new particles, optimizer state and kernel parameters. | |
References | |
---------- | |
.. [1]: Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm | |
Qiang Liu et al., arXiv:1608.04471 | |
""" | |
particles, kernel_params, opt_state = state | |
kernel = functools.partial(kernel, **kernel_params) | |
# Destructure particle PyTree into an array of shape (num_particles, param_dimension) | |
_particle_array, pytree_def = tree_flatten(particles) | |
particle_array = jnp.stack(_particle_array).squeeze().T | |
num_particles = particle_array.shape[0] | |
# Redefine logdensity so that it accepts particles in array form | |
single_particle = tree_map(lambda array: array[0], particles) | |
single_particle_structure = tree_structure(single_particle) | |
log_p = lambda x: logdensity_fn(tree_unflatten(single_particle_structure, jnp.atleast_1d(x))) | |
def phi_star_summand(sum_index, arg_index, particle_array, gradients): | |
x_j = particle_array[sum_index] | |
x = particle_array[arg_index] | |
k, grad_k = jax.value_and_grad(kernel, argnums=0)(x_j, x) | |
return ( k * gradients[sum_index] ) + grad_k | |
def phi_star_i(arg_index, particle_array, gradients): | |
partialled_fn = functools.partial( | |
phi_star_summand, arg_index=arg_index, particle_array=particle_array, gradients=gradients | |
) | |
return vmap(partialled_fn)( jnp.arange(num_particles) ).mean(0) | |
gradients = vmap(grad(log_p))(particle_array) # Precompute all the gradients for this step TODO Allow user to use pmap? | |
phi_star = vmap(functools.partial(phi_star_i, particle_array=particle_array, gradients=gradients)) | |
functional_gradient = phi_star(jnp.arange(num_particles)) | |
update, opt_state = optimizer.update(functional_gradient, opt_state) | |
return SVGDState( | |
(particle_array - update).T, # TODO make this work with different kinds of particle PyTrees | |
kernel_params, | |
opt_state | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment