Skip to content

Instantly share code, notes, and snippets.

@antotocar34
Last active March 7, 2023 17:55
Show Gist options
  • Save antotocar34/3e6a762df1427a7db6105cfe72f66185 to your computer and use it in GitHub Desktop.
Save antotocar34/3e6a762df1427a7db6105cfe72f66185 to your computer and use it in GitHub Desktop.
import jax
from jax import vmap, grad
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_structure
import functools
from typing import Callable, NamedTuple, Any, Dict, List, Tuple
from blackjax.types import Array, PyTree
from blackjax.base import Optimizer
Position = PyTree
State = NamedTuple
from jax.flatten_util import ravel_pytree
class SVGDState(NamedTuple):
particles: PyTree
kernel_parameters: Dict[str, Any]
opt_state: Any
def init(initial_particles: PyTree, kernel_parameters: Dict[str, Any], optimizer: Optimizer) -> SVGDState:
"""
Initializes Stein Variational Gradient Descent Algorithm.
Parameters
----------
initial_particles
Initial set of particles to start the optimization
kernel_paremeters
Arguments to the kernel function
optimizer
Optax compatible optimizer, which conforms to the `Optimizer` protocol
"""
particle_array = jnp.stack(tree_flatten(initial_particles)[0]).squeeze().T
opt_state = optimizer.init(particle_array)
return SVGDState(initial_particles, kernel_parameters, opt_state)
def step(state: SVGDState, logdensity_fn: Callable, kernel: Callable, optimizer: Optimizer) -> SVGDState:
"""
Performs one step of Stein Variational Gradient Descent.
See Algorithm 1 of [1].
Parameters
----------
state
SVGDState object containing information about previous iteration
logdensity_fn
(un-normalized) log densify function of target distribution to take
approximate samples from
kernel
positive semi definite kernel
optimizer
Optax compatible optimizer, which conforms to the `Optimizer` protocol
Returns
-------
SVGDState containing new particles, optimizer state and kernel parameters.
References
----------
.. [1]: Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
Qiang Liu et al., arXiv:1608.04471
"""
particles, kernel_params, opt_state = state
kernel = functools.partial(kernel, **kernel_params)
# Destructure particle PyTree into an array of shape (num_particles, param_dimension)
_particle_array, pytree_def = tree_flatten(particles)
particle_array = jnp.stack(_particle_array).squeeze().T
num_particles = particle_array.shape[0]
# Redefine logdensity so that it accepts particles in array form
single_particle = tree_map(lambda array: array[0], particles)
single_particle_structure = tree_structure(single_particle)
log_p = lambda x: logdensity_fn(tree_unflatten(single_particle_structure, jnp.atleast_1d(x)))
def phi_star_summand(sum_index, arg_index, particle_array, gradients):
x_j = particle_array[sum_index]
x = particle_array[arg_index]
k, grad_k = jax.value_and_grad(kernel, argnums=0)(x_j, x)
return ( k * gradients[sum_index] ) + grad_k
def phi_star_i(arg_index, particle_array, gradients):
partialled_fn = functools.partial(
phi_star_summand, arg_index=arg_index, particle_array=particle_array, gradients=gradients
)
return vmap(partialled_fn)( jnp.arange(num_particles) ).mean(0)
gradients = vmap(grad(log_p))(particle_array) # Precompute all the gradients for this step TODO Allow user to use pmap?
phi_star = vmap(functools.partial(phi_star_i, particle_array=particle_array, gradients=gradients))
functional_gradient = phi_star(jnp.arange(num_particles))
update, opt_state = optimizer.update(functional_gradient, opt_state)
return SVGDState(
(particle_array - update).T, # TODO make this work with different kinds of particle PyTrees
kernel_params,
opt_state
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment