Skip to content

Instantly share code, notes, and snippets.

@epignatelli
Last active October 28, 2020 14:15
Show Gist options
  • Save epignatelli/42662667f7be1ea0daa3e34a375f027a to your computer and use it in GitHub Desktop.
Save epignatelli/42662667f7be1ea0daa3e34a375f027a to your computer and use it in GitHub Desktop.
a module decorator for jax.experimental.stax
from typing import Tuple, NamedTuple, Callable, Any
import functools
import jax.numpy as jnp
Params = Any
RNGKey = jnp.ndarray
Shape = Tuple[int]
class Module(NamedTuple):
init: Callable[[RNGKey, Shape], Tuple[Shape, Params]]
apply: Callable[[Params, jnp.ndarray], jnp.ndarray]
def module(module_maker):
@functools.wraps(module_maker)
def fabricate_module(*args, **kwargs):
init, apply = module_maker(*args, **kwargs)
return Module(init, apply)
return fabricate_module
# example usage on a Dense function
if __name__ == "__main__":
import jax
from jax.experimental.stax import Dense, glorot_normal, normal
@module
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = jax.random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
seed = 0
rng = jax.random.PRNGKey(seed)
input_shape = (-1, 4)
dense = Dense(8, W_init=glorot_normal(), b_init=normal())
print(dense)
assert dense.init == dense[0]
assert dense.apply == dense[1]
out_shape, params = dense.init(rng, input_shape)
noise = jax.random.normal(rng, (input_shape[-1],))
output = dense.apply(params, noise)
print(output.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment