Skip to content

Instantly share code, notes, and snippets.

View j-towns's full-sized avatar

Jamie Townsend j-towns

View GitHub Profile
import jax.numpy as np
import jax.linear_util as lu
from jax.util import unzip2, safe_zip, safe_map
from jax.experimental import stax
from jax.interpreters import partial_eval as pe
from jax.interpreters.batching import get_aval
from jax.api_util import (
wraps, pytree_to_jaxtupletree, pytree_fun_to_jaxtupletree_fun)
import jax.core as jc
from jax import random