Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Last active April 10, 2023 18:48
Show Gist options
  • Save KeAWang/04b4f9ce2235d1767d6587005906a555 to your computer and use it in GitHub Desktop.
Save KeAWang/04b4f9ce2235d1767d6587005906a555 to your computer and use it in GitHub Desktop.
Parameterized Arrays in Jax
# %%
import jax.numpy as jnp
import jax
import equinox as eqx
from typing import Union, Any
from abc import ABC, abstractmethod
MaybeParameterizedArray = Union[jax.Array, "ParameterizedArray"]
class ParameterizedArray(ABC, eqx.Module):
@abstractmethod
def eval(self) -> jax.Array:
return
@staticmethod
def tree_eval(pytree: Any) -> Any:
"""Turn every (possibly nested) ParameterizedArray into a single jax.Array"""
def _is_parameterized(x: MaybeParameterizedArray):
return isinstance(x, ParameterizedArray)
def _eval(x):
return x.eval() if _is_parameterized(x) else x
return jax.tree_map(_eval, pytree, is_leaf=_is_parameterized)
class PositiveArray(ParameterizedArray):
unconstrained_array: MaybeParameterizedArray
@staticmethod
def inv_softplus(x: jax.Array) -> jax.Array:
return x + jnp.log(-jnp.expm1(-x))
def __init__(self, val: jax.Array):
self.unconstrained_array = PositiveArray.inv_softplus(val)
def eval(self) -> jax.Array:
unconstrained_array = ParameterizedArray.tree_eval(self.unconstrained_array)
return jax.nn.softplus(unconstrained_array)
shape = property(lambda self: self.unconstrainted_array.shape)
dtype = property(lambda self: self.unconstrainted_array.dtype)
class MaybeLearnableArray(ParameterizedArray):
array: MaybeParameterizedArray
learnable: bool = eqx.static_field()
def eval(self) -> jax.Array:
array = ParameterizedArray.tree_eval(self.array)
return jax.lax.cond(self.learnable, lambda x: x, lambda x: jax.lax.stop_gradient(x), array)
shape = property(lambda self: self.array.shape)
dtype = property(lambda self: self.array.dtype)
class PartiallyLearnableArray(ParameterizedArray):
array: MaybeParameterizedArray
learnable_mask: jax.Array = eqx.static_field() # 1 if learnable, 0 if not
def eval(self) -> jax.Array:
array = ParameterizedArray.tree_eval(self.array)
return jnp.where(self.learnable_mask, array, jax.lax.stop_gradient(self.array))
shape = property(lambda self: self.array.shape)
dtype = property(lambda self: self.array.dtype)
class BoundedArray(ParameterizedArray):
unconstrained_array: MaybeParameterizedArray
lb: float = eqx.static_field()
ub: float = eqx.static_field()
CONSTRAINT_EPS: float = eqx.static_field(default=1e-8)
def _constrain(self, x):
"""Constrain (-inf, inf) to [0+EPS, 1-EPS]"""
constrained_x = jax.nn.sigmoid(x)
constrained_x = jnp.clip(constrained_x, a_min=0.0 + self.CONSTRAINT_EPS, a_max=1.0 - self.CONSTRAINT_EPS)
return constrained_x
def _unconstrain(self, x):
"""Unconstrain [0, 1] to (-inf, inf)"""
assert jnp.all(x >= 0.0) and jnp.all(x <= 1.0)
unconstrained_x = jnp.clip(
x, a_min=0.0 + self.CONSTRAINT_EPS, a_max=1.0 - self.CONSTRAINT_EPS
) # In case we get 0 or 1
unconstrained_x = jax.scipy.special.logit(unconstrained_x)
return unconstrained_x
# TODO: separate ths into a different initializer
def __init__(self, val, lb, ub):
assert lb < ub
assert jnp.all(val >= lb) and jnp.all(val <= ub)
self.unconstrained_array = self._unconstrain((val - lb) / (ub - lb))
self.lb = lb
self.ub = ub
def eval(self) -> jax.Array:
unconstrained_array = ParameterizedArray.tree_eval(self.unconstrained_array)
return self.lb + self._constrain(unconstrained_array) * (self.ub - self.lb)
shape = property(lambda self: self.array.shape)
dtype = property(lambda self: self.array.dtype)
if __name__ == "__main__":
#### Usage example
class InnerArray(ParameterizedArray):
a: MaybeParameterizedArray
multiplier: float = 2
def eval(self) -> jax.Array:
# Any time your eval() requires a MaybeParameterizedArray, you must tree_eval() it first
a = ParameterizedArray.tree_eval(self.a)
return a * self.multiplier
class OuterArray(ParameterizedArray):
b: MaybeParameterizedArray
multiplier: jax.Array
def eval(self) -> jax.Array:
b = ParameterizedArray.tree_eval(self.b)
return b * jax.lax.stop_gradient(self.multiplier)
inner = InnerArray(jnp.array([1, 2, 3]))
outer = OuterArray(inner, jnp.array([1, 0, 1]))
model = (outer, inner, jnp.array([4, 5, 6]), True)
print("Outer", outer.eval())
print("Inner", inner.eval())
print("Evaluate PyTree of Parameterized Arrays", ParameterizedArray.tree_eval(model))
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment